verifai-handler-v2 / handler.py
TerenceG's picture
Update handler.py
2e640a4 verified
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import base64
import io
import torch.nn.functional as F
import gc
import json
class EndpointHandler:
def __init__(self, path=""):
print("🚀 VerifAI Handler V5 ULTRA ROBUST - Initialisation")
print("⚡ Version ultra-robuste")
self.model = None
self.processor = None
self.model_labels = {}
self.model_name = "haywoodsloan/ai-image-detector-deploy"
try:
print("🔄 Chargement modèle...")
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
self.model = AutoModelForImageClassification.from_pretrained(
self.model_name,
torch_dtype=torch.float32
)
self.model.eval()
if hasattr(self.model.config, 'id2label'):
self.model_labels = self.model.config.id2label
else:
self.model_labels = {0: "Real", 1: "Fake"}
print("✅ Modèle chargé")
print(f"📋 Labels: {self.model_labels}")
print("🎯 Handler V5 prêt!")
except Exception as e:
print(f"❌ Erreur: {e}")
self.model = None
self.processor = None
def _extract_image_data(self, data):
"""Extraction robuste des données image"""
try:
# Cas 1: data est directement la string base64
if isinstance(data, str):
print("📄 Input détecté: string directe")
return data
# Cas 2: data est un dict avec clé "inputs"
if isinstance(data, dict):
print("📄 Input détecté: dictionnaire")
# Essayer "inputs"
if "inputs" in data:
return data["inputs"]
# Essayer d'autres clés communes
for key in ["image", "data", "input", "content"]:
if key in data:
return data[key]
# Si aucune clé connue, prendre la première valeur
if data:
first_value = list(data.values())[0]
print(f"🔍 Utilisation de la première valeur: {type(first_value)}")
return first_value
# Cas 3: data est une liste
if isinstance(data, list) and len(data) > 0:
print("📄 Input détecté: liste")
return data[0]
# Cas 4: autres types
print(f"📄 Input détecté: {type(data)}")
return str(data)
except Exception as e:
print(f"⚠️ Erreur extraction: {e}")
return None
def _normalize_label(self, label):
"""Normalise les labels"""
if not isinstance(label, str):
label = str(label)
label_lower = label.lower()
if any(word in label_lower for word in ['real', 'human', 'authentic']):
return "Human"
if any(word in label_lower for word in ['fake', 'generated', 'ai', 'artificial']):
return "AI Generated"
return "Unknown"
def __call__(self, data):
# Vérification
if self.model is None or self.processor is None:
return {
"status": "error",
"error": "Handler non initialisé",
"prediction": 0,
"predicted_class_name": "Error",
"confidence": 0.0,
"class_probabilities": {"Human": 0.0, "AI Generated": 0.0},
"cam_image": None,
"version": "5.0-ultra-robust",
"handler_name": "VerifAI Handler V5 ULTRA ROBUST"
}
try:
print("🔄 Traitement ultra-robuste...")
print(f"🔍 Type d'input reçu: {type(data)}")
# Extraction robuste des données
image_data = self._extract_image_data(data)
if not image_data:
raise ValueError("Aucune donnée image trouvée")
print(f"🔍 Données extraites: {type(image_data)}, longueur: {len(str(image_data)) if image_data else 0}")
# Nettoyage du base64
if isinstance(image_data, str):
# Supprimer le préfixe data URI si présent
if image_data.startswith('data:'):
image_data = image_data.split(',', 1)[1]
# Supprimer les espaces et retours de ligne
image_data = image_data.strip().replace('\n', '').replace('\r', '').replace(' ', '')
# Décodage
try:
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
print(f"✅ Image décodée: {image.size}, mode: {image.mode}")
except Exception as e:
raise ValueError(f"Erreur décodage base64: {e}")
if image.mode != 'RGB':
image = image.convert('RGB')
# Redimensionnement si nécessaire
if image.size[0] * image.size[1] > 1048576: # Plus de 1MP
image = image.resize((512, 512), Image.Resampling.LANCZOS)
print("⚠️ Image redimensionnée")
print("🧠 Inférence...")
# Inférence
inputs = self.processor(image, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits = outputs.logits
probabilities = F.softmax(logits, dim=-1)[0]
predicted_class_id = logits.argmax().item()
# Résultats
class_probs = {}
for class_id, prob in enumerate(probabilities):
if class_id < len(self.model_labels):
label_str = self.model_labels.get(class_id, f"Class_{class_id}")
normalized_label = self._normalize_label(label_str)
if normalized_label != "Unknown":
class_probs[normalized_label] = float(prob)
class_probs.setdefault("Human", 0.0)
class_probs.setdefault("AI Generated", 0.0)
prediction_label = self._normalize_label(self.model_labels.get(predicted_class_id, "Unknown"))
confidence = class_probs.get(prediction_label, 0.0)
prediction_id = 1 if prediction_label == "AI Generated" else 0
print(f"🎯 Résultat: {prediction_label} ({confidence:.3f})")
# Nettoyage
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return {
"status": "success",
"prediction": prediction_id,
"predicted_class_name": prediction_label,
"confidence": confidence,
"class_probabilities": class_probs,
"cam_image": None,
"model_info": {
"model_name": self.model_name,
"handler_version": "verifai-v5-ultra-robust",
"precision_mode": "fast",
"raw_prediction_id": predicted_class_id,
"raw_labels": self.model_labels
},
"version": "5.0-ultra-robust",
"handler_name": "VerifAI Handler V5 ULTRA ROBUST",
"note": "Version ultra-robuste - gère tous les formats d'entrée",
"input_analysis": {
"original_type": str(type(data)),
"extracted_type": str(type(image_data)),
"image_size": image.size,
"image_mode": image.mode
}
}
except Exception as e:
print(f"❌ Erreur: {e}")
return {
"status": "error",
"error": str(e),
"prediction": 0,
"predicted_class_name": "Error",
"confidence": 0.0,
"class_probabilities": {"Human": 0.0, "AI Generated": 0.0},
"cam_image": None,
"version": "5.0-ultra-robust",
"handler_name": "VerifAI Handler V5 ULTRA ROBUST",
"debug_info": {
"input_type": str(type(data)),
"input_content": str(data)[:100] + "..." if data else "None"
}
}
# Test
if __name__ == "__main__":
print("🧪 TEST HANDLER V5 ULTRA ROBUST")
print("=" * 50)
try:
handler = EndpointHandler()
if handler.model is not None:
print("✅ Initialisation OK")
# Test avec différents formats
test_img = Image.new('RGB', (224, 224), color='red')
buffer = io.BytesIO()
test_img.save(buffer, format='JPEG')
test_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
test_cases = [
{"inputs": test_data}, # Format dict
test_data, # String directe
[test_data], # Liste
]
for i, test_case in enumerate(test_cases, 1):
print(f"\n🔄 Test {i}: {type(test_case)}")
result = handler(test_case)
print(f"📊 Statut: {result['status']}")
if result['status'] == 'success':
print(f"🎯 Prédiction: {result['predicted_class_name']} ({result['confidence']:.3f})")
else:
print(f"❌ Erreur: {result.get('error', 'Inconnue')}")
print("\n✅ Handler V5 ULTRA ROBUST testé!")
else:
print("❌ Échec initialisation")
except Exception as e:
print(f"❌ Erreur test: {e}")