TerenceG commited on
Commit
424c6c7
·
verified ·
1 Parent(s): 9ad2cff

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +320 -0
handler.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import base64
6
+ import io
7
+ import numpy as np
8
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
9
+ from pytorch_grad_cam import GradCAM
10
+ from pytorch_grad_cam.utils.image import show_cam_on_image
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ class EndpointHandler:
15
+ def __init__(self, model_dir: str = "haywoodsloan/ai-image-detector-deploy", **kwargs: Any):
16
+ """
17
+ Initialise le handler avec le modèle haywoodsloan/ai-image-detector-deploy
18
+ et configure Grad-CAM pour les cartes de saillance.
19
+ """
20
+ print(f"Initialisation du handler avec le modèle : {model_dir}")
21
+
22
+ # Charger le modèle et le processeur
23
+ self.model = AutoModelForImageClassification.from_pretrained(model_dir)
24
+ self.model.eval()
25
+ self.processor = AutoImageProcessor.from_pretrained(model_dir)
26
+
27
+ # Configuration de Grad-CAM - Adaptation pour différentes architectures
28
+ # Le modèle haywoodsloan/ai-image-detector-deploy est généralement basé sur ViT ou CNN
29
+ self.target_layer = self._find_target_layer()
30
+
31
+ self.cam = GradCAM(
32
+ model=self.model,
33
+ target_layers=[self.target_layer]
34
+ )
35
+
36
+ # Mapping des classes pour le détecteur d'IA
37
+ # Configuration basée sur la structure du modèle haywoodsloan
38
+ self.class_names = {
39
+ 0: "Image Réelle",
40
+ 1: "Image Générée par IA"
41
+ }
42
+
43
+ # Seuils de confiance pour l'interprétation
44
+ self.confidence_thresholds = {
45
+ "très_élevée": 0.9,
46
+ "élevée": 0.75,
47
+ "moyenne": 0.6,
48
+ "faible": 0.4
49
+ }
50
+
51
+ print("Handler initialisé avec succès!")
52
+
53
+ def _find_target_layer(self):
54
+ """
55
+ Trouve automatiquement la couche cible appropriée pour Grad-CAM
56
+ selon l'architecture du modèle.
57
+ """
58
+ try:
59
+ # Pour les modèles Vision Transformer (ViT)
60
+ if hasattr(self.model, 'vit'):
61
+ if hasattr(self.model.vit, 'encoder'):
62
+ return self.model.vit.encoder.layer[-1].layernorm_before
63
+ elif hasattr(self.model.vit, 'layers'):
64
+ return self.model.vit.layers[-1].norm1
65
+
66
+ # Pour les modèles Swin Transformer
67
+ elif hasattr(self.model, 'swin'):
68
+ return self.model.swin.encoder.layers[-1].blocks[-1].layernorm_before
69
+
70
+ # Pour les modèles avec backbone
71
+ elif hasattr(self.model, 'backbone'):
72
+ if hasattr(self.model.backbone, 'layers'):
73
+ return self.model.backbone.layers[-1].blocks[-1].norm1
74
+ else:
75
+ # Fallback pour backbone CNN
76
+ return list(self.model.backbone.children())[-2]
77
+
78
+ # Pour les modèles ConvNeXt
79
+ elif hasattr(self.model, 'convnext'):
80
+ return self.model.convnext.encoder.stages[-1].layers[-1].layernorm
81
+
82
+ # Pour les modèles ResNet ou autres architectures CNN
83
+ elif hasattr(self.model, 'resnet'):
84
+ return self.model.resnet.layer4[-1].bn2
85
+
86
+ # Fallback générique - chercher la dernière couche de normalisation
87
+ else:
88
+ # Parcourir tous les modules pour trouver une couche appropriée
89
+ modules = list(self.model.named_modules())
90
+ for name, module in reversed(modules):
91
+ if any(layer_type in name.lower() for layer_type in ['layernorm', 'batchnorm', 'norm']):
92
+ if 'classifier' not in name.lower():
93
+ print(f"Couche cible trouvée : {name}")
94
+ return module
95
+
96
+ # Si aucune couche de normalisation trouvée, utiliser l'avant-dernière couche
97
+ children = list(self.model.children())
98
+ if len(children) > 1:
99
+ return children[-2]
100
+ else:
101
+ return children[-1]
102
+
103
+ except Exception as e:
104
+ print(f"Erreur lors de la recherche de la couche cible: {e}")
105
+ # Fallback final - utiliser la première couche trouvée
106
+ children = list(self.model.children())
107
+ return children[-2] if len(children) > 1 else children[0]
108
+
109
+ def _interpret_confidence(self, confidence: float, predicted_class: str) -> str:
110
+ """
111
+ Interprète le niveau de confiance et génère un message explicatif.
112
+ """
113
+ if confidence >= self.confidence_thresholds["très_élevée"]:
114
+ level = "très élevée"
115
+ reliability = "Très fiable"
116
+ elif confidence >= self.confidence_thresholds["élevée"]:
117
+ level = "élevée"
118
+ reliability = "Fiable"
119
+ elif confidence >= self.confidence_thresholds["moyenne"]:
120
+ level = "moyenne"
121
+ reliability = "Moyennement fiable"
122
+ else:
123
+ level = "faible"
124
+ reliability = "Peu fiable"
125
+
126
+ interpretation = f"Confiance {level} ({confidence:.1%}) - {reliability}. "
127
+
128
+ if predicted_class == "Image Générée par IA":
129
+ if confidence >= 0.8:
130
+ interpretation += "L'image présente des caractéristiques typiques d'une génération par IA."
131
+ elif confidence >= 0.6:
132
+ interpretation += "L'image pourrait être générée par IA, mais nécessite une vérification supplémentaire."
133
+ else:
134
+ interpretation += "Classification incertaine - analyse manuelle recommandée."
135
+ else:
136
+ if confidence >= 0.8:
137
+ interpretation += "L'image semble authentique avec des caractéristiques naturelles."
138
+ elif confidence >= 0.6:
139
+ interpretation += "L'image semble réelle, mais avec quelques éléments à vérifier."
140
+ else:
141
+ interpretation += "Classification incertaine - analyse manuelle recommandée."
142
+
143
+ return interpretation
144
+
145
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
146
+ """
147
+ Traite une image et retourne la prédiction avec la carte de saillance.
148
+
149
+ Args:
150
+ data: Dictionnaire contenant l'image encodée en base64
151
+
152
+ Returns:
153
+ Dictionnaire avec la prédiction, confiance et carte de saillance
154
+ """
155
+ try:
156
+ print("Début du traitement de l'image...")
157
+
158
+ # Décoder l'image depuis une chaîne base64
159
+ if isinstance(data["inputs"], str):
160
+ image_data = base64.b64decode(data["inputs"])
161
+ else:
162
+ # Si c'est déjà des bytes
163
+ image_data = data["inputs"]
164
+
165
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
166
+ print(f"Image chargée avec succès : {image.size}")
167
+
168
+ # Préprocesser l'image pour le modèle
169
+ inputs = self.processor(images=image, return_tensors="pt")
170
+ input_tensor = inputs["pixel_values"]
171
+
172
+ print("Génération de la carte de saillance Grad-CAM...")
173
+ # Générer la carte de saillance avec Grad-CAM
174
+ try:
175
+ # Correction spécifique pour Swin Transformer v2
176
+ # Créer une classe wrapper pour le modèle
177
+ import torch.nn as nn
178
+
179
+ class ModelWrapper(nn.Module):
180
+ def __init__(self, model):
181
+ super().__init__()
182
+ self.model = model
183
+
184
+ def forward(self, x):
185
+ outputs = self.model(x)
186
+ # Extraire les logits des outputs de Swin v2
187
+ return outputs.logits
188
+
189
+ # Créer le wrapper
190
+ wrapped_model = ModelWrapper(self.model)
191
+
192
+ # Créer un nouveau GradCAM avec le modèle wrapper
193
+ from pytorch_grad_cam import GradCAM
194
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
195
+
196
+ wrapped_cam = GradCAM(
197
+ model=wrapped_model,
198
+ target_layers=[self.target_layer]
199
+ )
200
+
201
+ # Obtenir la classe prédite
202
+ with torch.no_grad():
203
+ outputs = self.model(input_tensor)
204
+ predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
205
+
206
+ # Créer le target pour GradCAM
207
+ targets = [ClassifierOutputTarget(predicted_class_idx)]
208
+
209
+ # Générer la carte de saillance
210
+ grayscale_cam = wrapped_cam(input_tensor=input_tensor, targets=targets)[0]
211
+
212
+ # Redimensionner l'image originale pour correspondre à la carte de saillance
213
+ cam_height, cam_width = grayscale_cam.shape
214
+ image_resized = image.resize((cam_width, cam_height))
215
+ image_np = np.array(image_resized).astype(np.float32) / 255.0
216
+
217
+ # Superposer la carte de chaleur sur l'image
218
+ visualization = show_cam_on_image(image_np, grayscale_cam, use_rgb=True)
219
+
220
+ # Convertir la carte de chaleur en base64
221
+ buffered = io.BytesIO()
222
+ Image.fromarray(visualization).save(buffered, format="PNG")
223
+ cam_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
224
+
225
+ except Exception as e:
226
+ print(f"Erreur lors de la génération de Grad-CAM: {e}")
227
+ cam_image_base64 = None
228
+
229
+ print("Exécution de la prédiction...")
230
+ # Obtenir la prédiction du modèle
231
+ with torch.no_grad():
232
+ outputs = self.model(**inputs)
233
+ logits = outputs.logits
234
+ probabilities = torch.nn.functional.softmax(logits, dim=1)
235
+ predicted_class = torch.argmax(probabilities, dim=1).item()
236
+ confidence = probabilities[0][predicted_class].item()
237
+
238
+ # Calculer le score de confiance pour chaque classe
239
+ class_probabilities = {}
240
+ for i, prob in enumerate(probabilities[0].tolist()):
241
+ class_name = self.class_names.get(i, f"Classe {i}")
242
+ class_probabilities[class_name] = round(prob, 4)
243
+
244
+ # Générer l'interprétation
245
+ predicted_class_name = self.class_names.get(predicted_class, f"Classe {predicted_class}")
246
+ interpretation = self._interpret_confidence(confidence, predicted_class_name)
247
+
248
+ # Score de détection d'IA (probabilité que ce soit une IA)
249
+ ai_detection_score = probabilities[0][1].item() if len(probabilities[0]) > 1 else 0.0
250
+
251
+ result = {
252
+ "prediction": predicted_class,
253
+ "predicted_class_name": predicted_class_name,
254
+ "confidence": round(confidence, 4),
255
+ "ai_detection_score": round(ai_detection_score, 4),
256
+ "class_probabilities": class_probabilities,
257
+ "interpretation": interpretation,
258
+ "status": "success",
259
+ "model_used": "haywoodsloan/ai-image-detector-deploy"
260
+ }
261
+
262
+ # Ajouter la carte Grad-CAM si disponible
263
+ if cam_image_base64:
264
+ result["cam_image"] = cam_image_base64
265
+ result["grad_cam_available"] = True
266
+ else:
267
+ result["grad_cam_available"] = False
268
+ result["grad_cam_error"] = "Impossible de générer la carte de saillance"
269
+
270
+ print(f"Traitement terminé avec succès! Prédiction: {predicted_class_name}, Confiance: {confidence:.2%}")
271
+ return result
272
+
273
+ except Exception as e:
274
+ print(f"Erreur lors du traitement: {e}")
275
+ return {
276
+ "error": str(e),
277
+ "status": "error",
278
+ "model_used": "haywoodsloan/ai-image-detector-deploy"
279
+ }
280
+
281
+ # Test local du handler (optionnel)
282
+ if __name__ == "__main__":
283
+ import os
284
+
285
+ try:
286
+ print("Test d'initialisation du handler...")
287
+ handler = EndpointHandler()
288
+ print("Handler initialisé avec succès!")
289
+
290
+ # Test avec une image d'exemple
291
+ test_image_path = "test_image.jpg"
292
+ if os.path.exists(test_image_path):
293
+ print(f"Test avec l'image : {test_image_path}")
294
+ with open(test_image_path, "rb") as f:
295
+ image_bytes = f.read()
296
+
297
+ input_data = {"inputs": base64.b64encode(image_bytes).decode("utf-8")}
298
+ output = handler(input_data)
299
+
300
+ print("\n=== RÉSULTATS DU TEST ===")
301
+ print(f"Statut: {output.get('status', 'N/A')}")
302
+ print(f"Prédiction: {output.get('predicted_class_name', 'N/A')}")
303
+ print(f"Confiance: {output.get('confidence', 0):.2%}")
304
+ print(f"Score de détection IA: {output.get('ai_detection_score', 0):.2%}")
305
+ print(f"Grad-CAM disponible: {output.get('grad_cam_available', False)}")
306
+ print(f"Interprétation: {output.get('interpretation', 'N/A')}")
307
+
308
+ if 'class_probabilities' in output:
309
+ print("\nProbabilités par classe:")
310
+ for class_name, prob in output['class_probabilities'].items():
311
+ print(f" {class_name}: {prob:.2%}")
312
+ else:
313
+ print(f"Aucune image de test trouvée : {test_image_path}")
314
+ print("Placez une image de test dans le répertoire pour tester le handler.")
315
+ print("Vous pouvez utiliser n'importe quel format d'image (JPG, PNG, etc.)")
316
+
317
+ except Exception as e:
318
+ print(f"Erreur lors de l'initialisation ou du test: {e}")
319
+ import traceback
320
+ traceback.print_exc()