Drazcat-AI commited on
Commit
32a154c
verified
1 Parent(s): 183ec56

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +114 -113
handler.py CHANGED
@@ -1,114 +1,115 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from PIL import Image
4
- import requests
5
- from io import BytesIO
6
- import json
7
- import os
8
- from transformers import ViTForImageClassification, ViTConfig
9
- from huggingface_hub import hf_hub_download
10
-
11
- # Importar el procesador de im谩genes del c贸digo de entrenamiento
12
- from train_categories import PaddingImageProcessor
13
-
14
- def load_model_and_config(model_path):
15
- """Carga el modelo entrenado y su configuraci贸n"""
16
- hf_path = "vit_multiclass_model_best"
17
- # Cargar informaci贸n de las clases
18
- class_info_path = os.path.join(hf_path, 'class_info.json')
19
- with open(class_info_path, 'r') as f:
20
- class_info = json.load(f)
21
-
22
- # Cargar configuraci贸n del procesador
23
- processor_config_path = os.path.join(hf_path, 'processor_config.json')
24
- with open(processor_config_path, 'r') as f:
25
- processor_config = json.load(f)
26
-
27
- # Crear procesador de im谩genes
28
- image_processor = PaddingImageProcessor(
29
- target_size=processor_config['target_size'],
30
- padding_color=tuple(processor_config['padding_color'])
31
- )
32
-
33
- # Cargar modelo
34
- model = ViTForImageClassification.from_pretrained(model_path)
35
- model.eval()
36
-
37
- # Usar GPU si est谩 disponible
38
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
- model = model.to(device)
40
-
41
- return model, image_processor, class_info, device
42
-
43
- def download_image(url: str) -> Image.Image:
44
- """Descarga una imagen desde una URL"""
45
- response = requests.get(url, timeout=10)
46
- response.raise_for_status()
47
- image = Image.open(BytesIO(response.content)).convert('RGB')
48
- return image
49
-
50
- def classify_image(model, image_processor, class_info, device, accuracy):
51
-
52
- # Descargar y procesar imagen
53
- image = download_image(image_url)
54
- processed_image = image_processor(image).unsqueeze(0).to(device)
55
-
56
- # Realizar predicci贸n
57
- with torch.no_grad():
58
- outputs = model(pixel_values=processed_image).logits
59
- probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
60
-
61
- # Obtener clases predichas (umbral 0.5)
62
- predicted_classes = []
63
- for i, prob in enumerate(probabilities):
64
- if prob > accuracy:
65
- class_name = class_info['class_columns'][i]
66
- predicted_classes.append(f"{class_name}: {prob:.3f}")
67
-
68
- # Mostrar resultado
69
- if predicted_classes:
70
- for prediction in predicted_classes:
71
- print(prediction)
72
- return predicted_classes
73
- else:
74
- # Si ninguna clase supera 0.5, mostrar la m谩s probable
75
- max_idx = probabilities.argmax()
76
- max_prob = probabilities[max_idx]
77
- class_name = class_info['class_columns'][max_idx]
78
- print(f"{class_name}: {max_prob:.3f}")
79
- return [class_name, max_prob]
80
-
81
- class EndpointHandler():
82
- def __init__(self, path=""):
83
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
- model_filename = "vit_multiclass_model_best/model.safetensors"
85
- local_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename=model_filename)
86
- self.model, self.image_processor, self.class_info, self.device = load_model_and_config(local_path)
87
-
88
- def predict_objects(self, image_url, accuracy):
89
-
90
- result_df = classify_image(image_url, accuracy)
91
- return result_df
92
-
93
- def __call__(self, event):
94
- if "inputs" not in event:
95
- return {
96
- "statusCode": 400,
97
- "body": json.dumps("Error: Please provide an 'inputs' parameter."),
98
- }
99
- event = event["inputs"]
100
- image_url = event["image_url"]
101
- accuracy = event["accuracy"]
102
- try:
103
- predictions = self.predict_objects(image_url, accuracy)
104
- predictions_json = predictions.to_json(orient='records')
105
-
106
- return {
107
- "statusCode": 200,
108
- "body": json.dumps(predictions_json),
109
- }
110
- except Exception as e:
111
- return {
112
- "statusCode": 500,
113
- "body": json.dumps(f"Error: {str(e)}"),
 
114
  }
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ import json
7
+ import os
8
+ from transformers import ViTForImageClassification, ViTConfig
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ # Importar el procesador de im谩genes del c贸digo de entrenamiento
12
+ from train_categories import PaddingImageProcessor
13
+
14
+ def load_model_and_config(model_path):
15
+ """Carga el modelo entrenado y su configuraci贸n"""
16
+ hf_path = "vit_multiclass_model_best"
17
+ # Cargar informaci贸n de las clases
18
+ class_info_path = os.path.join(hf_path, 'class_info.json')
19
+ with open(class_info_path, 'r') as f:
20
+ class_info = json.load(f)
21
+
22
+ # Cargar configuraci贸n del procesador
23
+ processor_config_path = os.path.join(hf_path, 'processor_config.json')
24
+ with open(processor_config_path, 'r') as f:
25
+ processor_config = json.load(f)
26
+
27
+ # Crear procesador de im谩genes
28
+ image_processor = PaddingImageProcessor(
29
+ target_size=processor_config['target_size'],
30
+ padding_color=tuple(processor_config['padding_color'])
31
+ )
32
+
33
+ # Cargar modelo
34
+ model = ViTForImageClassification.from_pretrained(model_path)
35
+ model.eval()
36
+
37
+ # Usar GPU si est谩 disponible
38
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ model = model.to(device)
40
+
41
+ return model, image_processor, class_info, device
42
+
43
+ def download_image(url: str) -> Image.Image:
44
+ """Descarga una imagen desde una URL"""
45
+ response = requests.get(url, timeout=10)
46
+ response.raise_for_status()
47
+ image = Image.open(BytesIO(response.content)).convert('RGB')
48
+ return image
49
+
50
+ def classify_image(model, image_processor, class_info, device, accuracy):
51
+
52
+ # Descargar y procesar imagen
53
+ image = download_image(image_url)
54
+ processed_image = image_processor(image).unsqueeze(0).to(device)
55
+
56
+ # Realizar predicci贸n
57
+ with torch.no_grad():
58
+ outputs = model(pixel_values=processed_image).logits
59
+ probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
60
+
61
+ # Obtener clases predichas (umbral 0.5)
62
+ predicted_classes = []
63
+ for i, prob in enumerate(probabilities):
64
+ if prob > accuracy:
65
+ class_name = class_info['class_columns'][i]
66
+ predicted_classes.append(f"{class_name}: {prob:.3f}")
67
+
68
+ # Mostrar resultado
69
+ if predicted_classes:
70
+ for prediction in predicted_classes:
71
+ print(prediction)
72
+ return predicted_classes
73
+ else:
74
+ # Si ninguna clase supera 0.5, mostrar la m谩s probable
75
+ max_idx = probabilities.argmax()
76
+ max_prob = probabilities[max_idx]
77
+ class_name = class_info['class_columns'][max_idx]
78
+ print(f"{class_name}: {max_prob:.3f}")
79
+ return [class_name, max_prob]
80
+
81
+ class EndpointHandler():
82
+ def __init__(self, path=""):
83
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84
+ model_filename = "vit_multiclass_model_best/model.safetensors"
85
+ local_path = hf_hub_download(repo_id="Drazcat-AI/categories_peru", filename=model_filename)
86
+ print('MODEL PATH:',local_path)
87
+ self.model, self.image_processor, self.class_info, self.device = load_model_and_config(local_path)
88
+
89
+ def predict_objects(self, image_url, accuracy):
90
+
91
+ result_df = classify_image(image_url, accuracy)
92
+ return result_df
93
+
94
+ def __call__(self, event):
95
+ if "inputs" not in event:
96
+ return {
97
+ "statusCode": 400,
98
+ "body": json.dumps("Error: Please provide an 'inputs' parameter."),
99
+ }
100
+ event = event["inputs"]
101
+ image_url = event["image_url"]
102
+ accuracy = event["accuracy"]
103
+ try:
104
+ predictions = self.predict_objects(image_url, accuracy)
105
+ predictions_json = predictions.to_json(orient='records')
106
+
107
+ return {
108
+ "statusCode": 200,
109
+ "body": json.dumps(predictions_json),
110
+ }
111
+ except Exception as e:
112
+ return {
113
+ "statusCode": 500,
114
+ "body": json.dumps(f"Error: {str(e)}"),
115
  }