Drazcat-AI commited on
Commit
32cf20a
verified
1 Parent(s): 32a154c

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +94 -92
handler.py CHANGED
@@ -12,104 +12,106 @@ from huggingface_hub import hf_hub_download
12
  from train_categories import PaddingImageProcessor
13
 
14
  def load_model_and_config(model_path):
15
- """Carga el modelo entrenado y su configuraci贸n"""
16
- hf_path = "vit_multiclass_model_best"
17
- # Cargar informaci贸n de las clases
18
- class_info_path = os.path.join(hf_path, 'class_info.json')
19
- with open(class_info_path, 'r') as f:
20
- class_info = json.load(f)
21
-
22
- # Cargar configuraci贸n del procesador
23
- processor_config_path = os.path.join(hf_path, 'processor_config.json')
24
- with open(processor_config_path, 'r') as f:
25
- processor_config = json.load(f)
26
-
27
- # Crear procesador de im谩genes
28
- image_processor = PaddingImageProcessor(
29
- target_size=processor_config['target_size'],
30
- padding_color=tuple(processor_config['padding_color'])
31
- )
32
-
33
- # Cargar modelo
34
- model = ViTForImageClassification.from_pretrained(model_path)
35
- model.eval()
36
-
37
- # Usar GPU si est谩 disponible
38
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
- model = model.to(device)
40
-
41
- return model, image_processor, class_info, device
42
 
43
  def download_image(url: str) -> Image.Image:
44
- """Descarga una imagen desde una URL"""
45
- response = requests.get(url, timeout=10)
46
- response.raise_for_status()
47
- image = Image.open(BytesIO(response.content)).convert('RGB')
48
- return image
49
 
50
  def classify_image(model, image_processor, class_info, device, accuracy):
51
 
52
- # Descargar y procesar imagen
53
- image = download_image(image_url)
54
- processed_image = image_processor(image).unsqueeze(0).to(device)
55
-
56
- # Realizar predicci贸n
57
- with torch.no_grad():
58
- outputs = model(pixel_values=processed_image).logits
59
- probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
60
-
61
- # Obtener clases predichas (umbral 0.5)
62
- predicted_classes = []
63
- for i, prob in enumerate(probabilities):
64
- if prob > accuracy:
65
- class_name = class_info['class_columns'][i]
66
- predicted_classes.append(f"{class_name}: {prob:.3f}")
67
-
68
- # Mostrar resultado
69
- if predicted_classes:
70
- for prediction in predicted_classes:
71
- print(prediction)
72
- return predicted_classes
73
- else:
74
- # Si ninguna clase supera 0.5, mostrar la m谩s probable
75
- max_idx = probabilities.argmax()
76
- max_prob = probabilities[max_idx]
77
- class_name = class_info['class_columns'][max_idx]
78
- print(f"{class_name}: {max_prob:.3f}")
79
- return [class_name, max_prob]
80
 
81
  class EndpointHandler():
82
- def __init__(self, path=""):
83
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
- model_filename = "vit_multiclass_model_best/model.safetensors"
85
- local_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename=model_filename)
86
- print('MODEL PATH:',local_path)
87
- self.model, self.image_processor, self.class_info, self.device = load_model_and_config(local_path)
 
 
88
 
89
- def predict_objects(self, image_url, accuracy):
90
-
91
- result_df = classify_image(image_url, accuracy)
92
- return result_df
93
 
94
- def __call__(self, event):
95
- if "inputs" not in event:
96
- return {
97
- "statusCode": 400,
98
- "body": json.dumps("Error: Please provide an 'inputs' parameter."),
99
- }
100
- event = event["inputs"]
101
- image_url = event["image_url"]
102
- accuracy = event["accuracy"]
103
- try:
104
- predictions = self.predict_objects(image_url, accuracy)
105
- predictions_json = predictions.to_json(orient='records')
106
-
107
- return {
108
- "statusCode": 200,
109
- "body": json.dumps(predictions_json),
110
- }
111
- except Exception as e:
112
- return {
113
- "statusCode": 500,
114
- "body": json.dumps(f"Error: {str(e)}"),
115
- }
 
12
  from train_categories import PaddingImageProcessor
13
 
14
  def load_model_and_config(model_path):
15
+ """Carga el modelo entrenado y su configuraci贸n"""
16
+ hf_path = "vit_multiclass_model_best"
17
+ # Cargar informaci贸n de las clases
18
+ class_info_path = os.path.join(model_path, 'class_info.json')
19
+ with open(class_info_path, 'r') as f:
20
+ class_info = json.load(f)
21
+
22
+ # Cargar configuraci贸n del procesador
23
+ processor_config_path = os.path.join(model_path, 'processor_config.json')
24
+ with open(processor_config_path, 'r') as f:
25
+ processor_config = json.load(f)
26
+
27
+ # Crear procesador de im谩genes
28
+ image_processor = PaddingImageProcessor(
29
+ target_size=processor_config['target_size'],
30
+ padding_color=tuple(processor_config['padding_color'])
31
+ )
32
+
33
+ # Cargar modelo
34
+ model = ViTForImageClassification.from_pretrained(model_path)
35
+ model.eval()
36
+
37
+ # Usar GPU si est谩 disponible
38
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ model = model.to(device)
40
+
41
+ return model, image_processor, class_info, device
42
 
43
  def download_image(url: str) -> Image.Image:
44
+ """Descarga una imagen desde una URL"""
45
+ response = requests.get(url, timeout=10)
46
+ response.raise_for_status()
47
+ image = Image.open(BytesIO(response.content)).convert('RGB')
48
+ return image
49
 
50
  def classify_image(model, image_processor, class_info, device, accuracy):
51
 
52
+ # Descargar y procesar imagen
53
+ image = download_image(image_url)
54
+ processed_image = image_processor(image).unsqueeze(0).to(device)
55
+
56
+ # Realizar predicci贸n
57
+ with torch.no_grad():
58
+ outputs = model(pixel_values=processed_image).logits
59
+ probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
60
+
61
+ # Obtener clases predichas (umbral 0.5)
62
+ predicted_classes = []
63
+ for i, prob in enumerate(probabilities):
64
+ if prob > accuracy:
65
+ class_name = class_info['class_columns'][i]
66
+ predicted_classes.append(f"{class_name}: {prob:.3f}")
67
+
68
+ # Mostrar resultado
69
+ if predicted_classes:
70
+ for prediction in predicted_classes:
71
+ print(prediction)
72
+ return predicted_classes
73
+ else:
74
+ # Si ninguna clase supera 0.5, mostrar la m谩s probable
75
+ max_idx = probabilities.argmax()
76
+ max_prob = probabilities[max_idx]
77
+ class_name = class_info['class_columns'][max_idx]
78
+ print(f"{class_name}: {max_prob:.3f}")
79
+ return [class_name, max_prob]
80
 
81
  class EndpointHandler():
82
+ def __init__(self, path=""):
83
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
+ model_filename = "vit_multiclass_model_best/model.safetensors"
85
+ local_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename=model_filename)
86
+ hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='class_info.json')
87
+ hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename='processor_config.json')
88
+ print('MODEL PATH:',local_path)
89
+ self.model, self.image_processor, self.class_info, self.device = load_model_and_config(local_path)
90
 
91
+ def predict_objects(self, image_url, accuracy):
92
+
93
+ result_df = classify_image(image_url, accuracy)
94
+ return result_df
95
 
96
+ def __call__(self, event):
97
+ if "inputs" not in event:
98
+ return {
99
+ "statusCode": 400,
100
+ "body": json.dumps("Error: Please provide an 'inputs' parameter."),
101
+ }
102
+ event = event["inputs"]
103
+ image_url = event["image_url"]
104
+ accuracy = event["accuracy"]
105
+ try:
106
+ predictions = self.predict_objects(image_url, accuracy)
107
+ predictions_json = predictions.to_json(orient='records')
108
+
109
+ return {
110
+ "statusCode": 200,
111
+ "body": json.dumps(predictions_json),
112
+ }
113
+ except Exception as e:
114
+ return {
115
+ "statusCode": 500,
116
+ "body": json.dumps(f"Error: {str(e)}"),
117
+ }