File size: 4,053 Bytes
32a154c
 
 
 
 
 
 
 
 
 
 
 
 
260c57a
32cf20a
f3c4705
32cf20a
 
f3c4705
32cf20a
 
 
 
 
 
 
 
 
bd025b7
32cf20a
 
 
 
 
 
 
 
32a154c
 
32cf20a
 
 
 
 
32a154c
b682efe
32a154c
32cf20a
 
 
 
 
 
 
 
 
 
 
557cd98
32cf20a
 
 
 
d354569
32cf20a
 
 
 
 
557cd98
32cf20a
 
 
 
 
 
d354569
c95870d
32a154c
 
32cf20a
 
 
f3c4705
 
 
f4fdd8d
260c57a
32a154c
32cf20a
 
b682efe
32cf20a
32a154c
32cf20a
 
 
 
 
 
 
 
 
 
 
e623cdc
32cf20a
 
 
e623cdc
32cf20a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn.functional as F
from PIL import Image
import requests
from io import BytesIO
import json
import os
from transformers import ViTForImageClassification, ViTConfig
from huggingface_hub import hf_hub_download

# Importar el procesador de im谩genes del c贸digo de entrenamiento
from train_categories import PaddingImageProcessor

def load_model_and_config(model_path, class_path, config_path):
	"""Carga el modelo entrenado y su configuraci贸n"""
	with open(class_path, 'r') as f:
		class_info = json.load(f)
	
	with open(config_path, 'r') as f:
		processor_config = json.load(f)
	
	# Crear procesador de im谩genes
	image_processor = PaddingImageProcessor(
		target_size=processor_config['target_size'],
		padding_color=tuple(processor_config['padding_color'])
	)
	
	# Cargar modelo
	model_path = model_path.replace('model.safetensors', '')
	model = ViTForImageClassification.from_pretrained(model_path)
	model.eval()
	
	# Usar GPU si est谩 disponible
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = model.to(device)
	
	return model, image_processor, class_info, device

def download_image(url: str) -> Image.Image:
	"""Descarga una imagen desde una URL"""
	response = requests.get(url, timeout=10)
	response.raise_for_status()
	image = Image.open(BytesIO(response.content)).convert('RGB')
	return image

def classify_image(model, image_processor, class_info, device, image_url, accuracy):

	# Descargar y procesar imagen
	image = download_image(image_url)
	processed_image = image_processor(image).unsqueeze(0).to(device)
	
	# Realizar predicci贸n
	with torch.no_grad():
		outputs = model(pixel_values=processed_image).logits
		probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
	
	# Obtener clases predichas (umbral 0.5)
	predicted_classes = []
	predicted_list=[]
	for i, prob in enumerate(probabilities):
		if prob > accuracy:
			class_name = class_info['class_columns'][i]
			predicted_classes.append(f"{class_name}: {prob:.3f}")
			predicted_list.append({"class": class_name, "confidence": float(prob)})
	
	# Mostrar resultado
	if predicted_classes:
		for prediction in predicted_classes:
			print(prediction)
   
	else:
		# Si ninguna clase supera 0.5, mostrar la m谩s probable
		max_idx = probabilities.argmax()
		max_prob = probabilities[max_idx]
		class_name = class_info['class_columns'][max_idx]
		print(f"{class_name}: {max_prob:.3f}")
		predicted_list.append({"class": class_name, "confidence": float(max_prob)})
	return predicted_list

class EndpointHandler():
	def __init__(self, path=""):
		device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
		model_filename = "vit_multiclass_model_best/model.safetensors"
		model_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename=model_filename)
		class_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='vit_multiclass_model_best/class_info.json')
		config_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='vit_multiclass_model_best/processor_config.json')
		hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='vit_multiclass_model_best/config.json')
		self.model, self.image_processor, self.class_info, self.device = load_model_and_config(model_path, class_path, config_path)

	def predict_objects(self, image_url, accuracy):
		
		result_df =  classify_image(self.model, self.image_processor, self.class_info, self.device, image_url, accuracy)   
		return result_df

	def __call__(self, event):
		if "inputs" not in event:
			return {
				"statusCode": 400,
				"body": json.dumps("Error: Please provide an 'inputs' parameter."),
			}
		event = event["inputs"]
		image_url = event["image_url"]
		accuracy = event["accuracy"]
		try:
			predictions = self.predict_objects(image_url, accuracy)
			#predictions_json = predictions.to_json(orient='records')
			
			return {
				"statusCode": 200,
				"body": json.dumps(predictions),
			}
		except Exception as e:
			return {
				"statusCode": 500,
				"body": json.dumps(f"Error: {str(e)}"),
			}