|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from PIL import Image |
|
|
import requests |
|
|
from io import BytesIO |
|
|
import json |
|
|
import os |
|
|
from transformers import ViTForImageClassification, ViTConfig |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
from train_categories import PaddingImageProcessor |
|
|
|
|
|
def load_model_and_config(model_path, class_path, config_path): |
|
|
"""Carga el modelo entrenado y su configuraci贸n""" |
|
|
with open(class_path, 'r') as f: |
|
|
class_info = json.load(f) |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
processor_config = json.load(f) |
|
|
|
|
|
|
|
|
image_processor = PaddingImageProcessor( |
|
|
target_size=processor_config['target_size'], |
|
|
padding_color=tuple(processor_config['padding_color']) |
|
|
) |
|
|
|
|
|
|
|
|
model_path = model_path.replace('model.safetensors', '') |
|
|
model = ViTForImageClassification.from_pretrained(model_path) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model = model.to(device) |
|
|
|
|
|
return model, image_processor, class_info, device |
|
|
|
|
|
def download_image(url: str) -> Image.Image: |
|
|
"""Descarga una imagen desde una URL""" |
|
|
response = requests.get(url, timeout=10) |
|
|
response.raise_for_status() |
|
|
image = Image.open(BytesIO(response.content)).convert('RGB') |
|
|
return image |
|
|
|
|
|
def classify_image(model, image_processor, class_info, device, image_url, accuracy): |
|
|
|
|
|
|
|
|
image = download_image(image_url) |
|
|
processed_image = image_processor(image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(pixel_values=processed_image).logits |
|
|
probabilities = torch.sigmoid(outputs).cpu().numpy()[0] |
|
|
|
|
|
|
|
|
predicted_classes = [] |
|
|
predicted_list=[] |
|
|
for i, prob in enumerate(probabilities): |
|
|
if prob > accuracy: |
|
|
class_name = class_info['class_columns'][i] |
|
|
predicted_classes.append(f"{class_name}: {prob:.3f}") |
|
|
predicted_list.append({"class": class_name, "confidence": float(prob)}) |
|
|
|
|
|
|
|
|
if predicted_classes: |
|
|
for prediction in predicted_classes: |
|
|
print(prediction) |
|
|
|
|
|
else: |
|
|
|
|
|
max_idx = probabilities.argmax() |
|
|
max_prob = probabilities[max_idx] |
|
|
class_name = class_info['class_columns'][max_idx] |
|
|
print(f"{class_name}: {max_prob:.3f}") |
|
|
predicted_list.append({"class": class_name, "confidence": float(max_prob)}) |
|
|
return predicted_list |
|
|
|
|
|
class EndpointHandler(): |
|
|
def __init__(self, path=""): |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model_filename = "vit_multiclass_model_best/model.safetensors" |
|
|
model_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename=model_filename) |
|
|
class_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='vit_multiclass_model_best/class_info.json') |
|
|
config_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='vit_multiclass_model_best/processor_config.json') |
|
|
hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='vit_multiclass_model_best/config.json') |
|
|
self.model, self.image_processor, self.class_info, self.device = load_model_and_config(model_path, class_path, config_path) |
|
|
|
|
|
def predict_objects(self, image_url, accuracy): |
|
|
|
|
|
result_df = classify_image(self.model, self.image_processor, self.class_info, self.device, image_url, accuracy) |
|
|
return result_df |
|
|
|
|
|
def __call__(self, event): |
|
|
if "inputs" not in event: |
|
|
return { |
|
|
"statusCode": 400, |
|
|
"body": json.dumps("Error: Please provide an 'inputs' parameter."), |
|
|
} |
|
|
event = event["inputs"] |
|
|
image_url = event["image_url"] |
|
|
accuracy = event["accuracy"] |
|
|
try: |
|
|
predictions = self.predict_objects(image_url, accuracy) |
|
|
|
|
|
|
|
|
return { |
|
|
"statusCode": 200, |
|
|
"body": json.dumps(predictions), |
|
|
} |
|
|
except Exception as e: |
|
|
return { |
|
|
"statusCode": 500, |
|
|
"body": json.dumps(f"Error: {str(e)}"), |
|
|
} |