File size: 4,053 Bytes
32a154c 260c57a 32cf20a f3c4705 32cf20a f3c4705 32cf20a bd025b7 32cf20a 32a154c 32cf20a 32a154c b682efe 32a154c 32cf20a 557cd98 32cf20a d354569 32cf20a 557cd98 32cf20a d354569 c95870d 32a154c 32cf20a f3c4705 f4fdd8d 260c57a 32a154c 32cf20a b682efe 32cf20a 32a154c 32cf20a e623cdc 32cf20a e623cdc 32cf20a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import torch
import torch.nn.functional as F
from PIL import Image
import requests
from io import BytesIO
import json
import os
from transformers import ViTForImageClassification, ViTConfig
from huggingface_hub import hf_hub_download
# Importar el procesador de im谩genes del c贸digo de entrenamiento
from train_categories import PaddingImageProcessor
def load_model_and_config(model_path, class_path, config_path):
"""Carga el modelo entrenado y su configuraci贸n"""
with open(class_path, 'r') as f:
class_info = json.load(f)
with open(config_path, 'r') as f:
processor_config = json.load(f)
# Crear procesador de im谩genes
image_processor = PaddingImageProcessor(
target_size=processor_config['target_size'],
padding_color=tuple(processor_config['padding_color'])
)
# Cargar modelo
model_path = model_path.replace('model.safetensors', '')
model = ViTForImageClassification.from_pretrained(model_path)
model.eval()
# Usar GPU si est谩 disponible
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
return model, image_processor, class_info, device
def download_image(url: str) -> Image.Image:
"""Descarga una imagen desde una URL"""
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert('RGB')
return image
def classify_image(model, image_processor, class_info, device, image_url, accuracy):
# Descargar y procesar imagen
image = download_image(image_url)
processed_image = image_processor(image).unsqueeze(0).to(device)
# Realizar predicci贸n
with torch.no_grad():
outputs = model(pixel_values=processed_image).logits
probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
# Obtener clases predichas (umbral 0.5)
predicted_classes = []
predicted_list=[]
for i, prob in enumerate(probabilities):
if prob > accuracy:
class_name = class_info['class_columns'][i]
predicted_classes.append(f"{class_name}: {prob:.3f}")
predicted_list.append({"class": class_name, "confidence": float(prob)})
# Mostrar resultado
if predicted_classes:
for prediction in predicted_classes:
print(prediction)
else:
# Si ninguna clase supera 0.5, mostrar la m谩s probable
max_idx = probabilities.argmax()
max_prob = probabilities[max_idx]
class_name = class_info['class_columns'][max_idx]
print(f"{class_name}: {max_prob:.3f}")
predicted_list.append({"class": class_name, "confidence": float(max_prob)})
return predicted_list
class EndpointHandler():
def __init__(self, path=""):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_filename = "vit_multiclass_model_best/model.safetensors"
model_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename=model_filename)
class_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='vit_multiclass_model_best/class_info.json')
config_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='vit_multiclass_model_best/processor_config.json')
hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='vit_multiclass_model_best/config.json')
self.model, self.image_processor, self.class_info, self.device = load_model_and_config(model_path, class_path, config_path)
def predict_objects(self, image_url, accuracy):
result_df = classify_image(self.model, self.image_processor, self.class_info, self.device, image_url, accuracy)
return result_df
def __call__(self, event):
if "inputs" not in event:
return {
"statusCode": 400,
"body": json.dumps("Error: Please provide an 'inputs' parameter."),
}
event = event["inputs"]
image_url = event["image_url"]
accuracy = event["accuracy"]
try:
predictions = self.predict_objects(image_url, accuracy)
#predictions_json = predictions.to_json(orient='records')
return {
"statusCode": 200,
"body": json.dumps(predictions),
}
except Exception as e:
return {
"statusCode": 500,
"body": json.dumps(f"Error: {str(e)}"),
} |