|
|
from typing import Any, Dict |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import base64 |
|
|
import io |
|
|
import numpy as np |
|
|
from transformers import AutoModelForImageClassification, AutoImageProcessor |
|
|
from pytorch_grad_cam import GradCAM |
|
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
import warnings |
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir: str = "haywoodsloan/ai-image-detector-deploy", **kwargs: Any): |
|
|
""" |
|
|
Initialise le handler avec le modèle haywoodsloan/ai-image-detector-deploy |
|
|
et configure Grad-CAM pour les cartes de saillance. |
|
|
""" |
|
|
|
|
|
model_name = "haywoodsloan/ai-image-detector-deploy" |
|
|
print(f"Initialisation du handler avec le modèle : {model_name}") |
|
|
print(f"Répertoire de déploiement : {model_dir}") |
|
|
|
|
|
|
|
|
self.model = AutoModelForImageClassification.from_pretrained(model_name) |
|
|
self.model.eval() |
|
|
self.processor = AutoImageProcessor.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
|
|
|
self.target_layer = self._find_target_layer() |
|
|
|
|
|
self.cam = GradCAM( |
|
|
model=self.model, |
|
|
target_layers=[self.target_layer] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.class_names = { |
|
|
0: "Image Réelle", |
|
|
1: "Image Générée par IA" |
|
|
} |
|
|
|
|
|
|
|
|
self.confidence_thresholds = { |
|
|
"très_élevée": 0.9, |
|
|
"élevée": 0.75, |
|
|
"moyenne": 0.6, |
|
|
"faible": 0.4 |
|
|
} |
|
|
|
|
|
print("Handler initialisé avec succès!") |
|
|
|
|
|
def _find_target_layer(self): |
|
|
""" |
|
|
Trouve automatiquement la couche cible appropriée pour Grad-CAM |
|
|
selon l'architecture du modèle. |
|
|
""" |
|
|
try: |
|
|
|
|
|
if hasattr(self.model, 'vit'): |
|
|
if hasattr(self.model.vit, 'encoder'): |
|
|
return self.model.vit.encoder.layer[-1].layernorm_before |
|
|
elif hasattr(self.model.vit, 'layers'): |
|
|
return self.model.vit.layers[-1].norm1 |
|
|
|
|
|
|
|
|
elif hasattr(self.model, 'swin'): |
|
|
return self.model.swin.encoder.layers[-1].blocks[-1].layernorm_before |
|
|
|
|
|
|
|
|
elif hasattr(self.model, 'backbone'): |
|
|
if hasattr(self.model.backbone, 'layers'): |
|
|
return self.model.backbone.layers[-1].blocks[-1].norm1 |
|
|
else: |
|
|
|
|
|
return list(self.model.backbone.children())[-2] |
|
|
|
|
|
|
|
|
elif hasattr(self.model, 'convnext'): |
|
|
return self.model.convnext.encoder.stages[-1].layers[-1].layernorm |
|
|
|
|
|
|
|
|
elif hasattr(self.model, 'resnet'): |
|
|
return self.model.resnet.layer4[-1].bn2 |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
modules = list(self.model.named_modules()) |
|
|
for name, module in reversed(modules): |
|
|
if any(layer_type in name.lower() for layer_type in ['layernorm', 'batchnorm', 'norm']): |
|
|
if 'classifier' not in name.lower(): |
|
|
print(f"Couche cible trouvée : {name}") |
|
|
return module |
|
|
|
|
|
|
|
|
children = list(self.model.children()) |
|
|
if len(children) > 1: |
|
|
return children[-2] |
|
|
else: |
|
|
return children[-1] |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Erreur lors de la recherche de la couche cible: {e}") |
|
|
|
|
|
children = list(self.model.children()) |
|
|
return children[-2] if len(children) > 1 else children[0] |
|
|
|
|
|
def _interpret_confidence(self, confidence: float, predicted_class: str) -> str: |
|
|
""" |
|
|
Interprète le niveau de confiance et génère un message explicatif. |
|
|
""" |
|
|
if confidence >= self.confidence_thresholds["très_élevée"]: |
|
|
level = "très élevée" |
|
|
reliability = "Très fiable" |
|
|
elif confidence >= self.confidence_thresholds["élevée"]: |
|
|
level = "élevée" |
|
|
reliability = "Fiable" |
|
|
elif confidence >= self.confidence_thresholds["moyenne"]: |
|
|
level = "moyenne" |
|
|
reliability = "Moyennement fiable" |
|
|
else: |
|
|
level = "faible" |
|
|
reliability = "Peu fiable" |
|
|
|
|
|
interpretation = f"Confiance {level} ({confidence:.1%}) - {reliability}. " |
|
|
|
|
|
if predicted_class == "Image Générée par IA": |
|
|
if confidence >= 0.8: |
|
|
interpretation += "L'image présente des caractéristiques typiques d'une génération par IA." |
|
|
elif confidence >= 0.6: |
|
|
interpretation += "L'image pourrait être générée par IA, mais nécessite une vérification supplémentaire." |
|
|
else: |
|
|
interpretation += "Classification incertaine - analyse manuelle recommandée." |
|
|
else: |
|
|
if confidence >= 0.8: |
|
|
interpretation += "L'image semble authentique avec des caractéristiques naturelles." |
|
|
elif confidence >= 0.6: |
|
|
interpretation += "L'image semble réelle, mais avec quelques éléments à vérifier." |
|
|
else: |
|
|
interpretation += "Classification incertaine - analyse manuelle recommandée." |
|
|
|
|
|
return interpretation |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Traite une image et retourne la prédiction avec la carte de saillance. |
|
|
|
|
|
Args: |
|
|
data: Dictionnaire contenant l'image encodée en base64 |
|
|
|
|
|
Returns: |
|
|
Dictionnaire avec la prédiction, confiance et carte de saillance |
|
|
""" |
|
|
try: |
|
|
print("Début du traitement de l'image...") |
|
|
|
|
|
|
|
|
if isinstance(data["inputs"], str): |
|
|
image_data = base64.b64decode(data["inputs"]) |
|
|
else: |
|
|
|
|
|
image_data = data["inputs"] |
|
|
|
|
|
image = Image.open(io.BytesIO(image_data)).convert("RGB") |
|
|
print(f"Image chargée avec succès : {image.size}") |
|
|
|
|
|
|
|
|
inputs = self.processor(images=image, return_tensors="pt") |
|
|
input_tensor = inputs["pixel_values"] |
|
|
|
|
|
print("Génération de la carte de saillance Grad-CAM...") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
class ModelWrapper(nn.Module): |
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, x): |
|
|
outputs = self.model(x) |
|
|
|
|
|
return outputs.logits |
|
|
|
|
|
|
|
|
wrapped_model = ModelWrapper(self.model) |
|
|
|
|
|
|
|
|
from pytorch_grad_cam import GradCAM |
|
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
|
|
|
|
wrapped_cam = GradCAM( |
|
|
model=wrapped_model, |
|
|
target_layers=[self.target_layer] |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(input_tensor) |
|
|
predicted_class_idx = torch.argmax(outputs.logits, dim=1).item() |
|
|
|
|
|
|
|
|
targets = [ClassifierOutputTarget(predicted_class_idx)] |
|
|
|
|
|
|
|
|
grayscale_cam = wrapped_cam(input_tensor=input_tensor, targets=targets)[0] |
|
|
|
|
|
|
|
|
cam_height, cam_width = grayscale_cam.shape |
|
|
image_resized = image.resize((cam_width, cam_height)) |
|
|
image_np = np.array(image_resized).astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
visualization = show_cam_on_image(image_np, grayscale_cam, use_rgb=True) |
|
|
|
|
|
|
|
|
buffered = io.BytesIO() |
|
|
Image.fromarray(visualization).save(buffered, format="PNG") |
|
|
cam_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Erreur lors de la génération de Grad-CAM: {e}") |
|
|
cam_image_base64 = None |
|
|
|
|
|
print("Exécution de la prédiction...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.model(**inputs) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.nn.functional.softmax(logits, dim=1) |
|
|
predicted_class = torch.argmax(probabilities, dim=1).item() |
|
|
confidence = probabilities[0][predicted_class].item() |
|
|
|
|
|
|
|
|
class_probabilities = {} |
|
|
for i, prob in enumerate(probabilities[0].tolist()): |
|
|
class_name = self.class_names.get(i, f"Classe {i}") |
|
|
class_probabilities[class_name] = round(prob, 4) |
|
|
|
|
|
|
|
|
predicted_class_name = self.class_names.get(predicted_class, f"Classe {predicted_class}") |
|
|
interpretation = self._interpret_confidence(confidence, predicted_class_name) |
|
|
|
|
|
|
|
|
ai_detection_score = probabilities[0][1].item() if len(probabilities[0]) > 1 else 0.0 |
|
|
|
|
|
result = { |
|
|
"prediction": predicted_class, |
|
|
"predicted_class_name": predicted_class_name, |
|
|
"confidence": round(confidence, 4), |
|
|
"ai_detection_score": round(ai_detection_score, 4), |
|
|
"class_probabilities": class_probabilities, |
|
|
"interpretation": interpretation, |
|
|
"status": "success", |
|
|
"model_used": "haywoodsloan/ai-image-detector-deploy" |
|
|
} |
|
|
|
|
|
|
|
|
if cam_image_base64: |
|
|
result["cam_image"] = cam_image_base64 |
|
|
result["grad_cam_available"] = True |
|
|
else: |
|
|
result["grad_cam_available"] = False |
|
|
result["grad_cam_error"] = "Impossible de générer la carte de saillance" |
|
|
|
|
|
print(f"Traitement terminé avec succès! Prédiction: {predicted_class_name}, Confiance: {confidence:.2%}") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Erreur lors du traitement: {e}") |
|
|
return { |
|
|
"error": str(e), |
|
|
"status": "error", |
|
|
"model_used": "haywoodsloan/ai-image-detector-deploy" |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import os |
|
|
|
|
|
try: |
|
|
print("Test d'initialisation du handler...") |
|
|
handler = EndpointHandler() |
|
|
print("Handler initialisé avec succès!") |
|
|
|
|
|
|
|
|
test_image_path = "test_image.jpg" |
|
|
if os.path.exists(test_image_path): |
|
|
print(f"Test avec l'image : {test_image_path}") |
|
|
with open(test_image_path, "rb") as f: |
|
|
image_bytes = f.read() |
|
|
|
|
|
input_data = {"inputs": base64.b64encode(image_bytes).decode("utf-8")} |
|
|
output = handler(input_data) |
|
|
|
|
|
print("\n=== RÉSULTATS DU TEST ===") |
|
|
print(f"Statut: {output.get('status', 'N/A')}") |
|
|
print(f"Prédiction: {output.get('predicted_class_name', 'N/A')}") |
|
|
print(f"Confiance: {output.get('confidence', 0):.2%}") |
|
|
print(f"Score de détection IA: {output.get('ai_detection_score', 0):.2%}") |
|
|
print(f"Grad-CAM disponible: {output.get('grad_cam_available', False)}") |
|
|
print(f"Interprétation: {output.get('interpretation', 'N/A')}") |
|
|
|
|
|
if 'class_probabilities' in output: |
|
|
print("\nProbabilités par classe:") |
|
|
for class_name, prob in output['class_probabilities'].items(): |
|
|
print(f" {class_name}: {prob:.2%}") |
|
|
else: |
|
|
print(f"Aucune image de test trouvée : {test_image_path}") |
|
|
print("Placez une image de test dans le répertoire pour tester le handler.") |
|
|
print("Vous pouvez utiliser n'importe quel format d'image (JPG, PNG, etc.)") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Erreur lors de l'initialisation ou du test: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |