TerenceG commited on
Commit
77fb0e6
·
verified ·
1 Parent(s): 83bfc80

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +0 -477
handler.py CHANGED
@@ -444,483 +444,6 @@ class EndpointHandler:
444
  except:
445
  pass
446
 
447
- # Test de fonctionnement si exécuté directement
448
- if __name__ == "__main__":
449
- print("🧪 TEST DU HANDLER VERIFAI V2 FIXED")
450
- print("=" * 50)
451
-
452
- try:
453
- # Initialisation
454
- handler = EndpointHandler()
455
-
456
- if handler.model is not None:
457
- print("✅ Initialisation réussie")
458
-
459
- # Test avec une image simple
460
- print("🔄 Test avec image de base...")
461
- test_img = Image.new('RGB', (224, 224), color='red')
462
- buffer = io.BytesIO()
463
- test_img.save(buffer, format='JPEG')
464
- test_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
465
-
466
- result = handler({"inputs": test_data})
467
- print(f"📊 Résultat: {result['status']}")
468
- if result['status'] == 'success':
469
- print(f"🎯 Prédiction: {result['predicted_class_name']} ({result['confidence']:.3f})")
470
- print("✅ Handler fonctionnel!")
471
- else:
472
- print(f"❌ Erreur: {result.get('error', 'Inconnue')}")
473
- else:
474
- print("❌ Échec de l'initialisation")
475
-
476
- except Exception as e:
477
- print(f"❌ Erreur de test: {e}")
478
- print(f"🔍 Traceback: {traceback.format_exc()}") from typing import Any, Dict
479
- import torch
480
- from torchvision import transforms
481
- from PIL import Image, ImageDraw, ImageOps
482
- import base64
483
- import io
484
- import numpy as np
485
- from transformers import AutoModelForImageClassification, AutoImageProcessor, AutoConfig
486
- import torch.nn.functional as F
487
- import json
488
- import re
489
- import gc
490
- import sys
491
- import traceback
492
-
493
- # Gestion des dépendances optionnelles
494
- HAS_MATPLOTLIB = False
495
- try:
496
- import matplotlib.pyplot as plt
497
- import matplotlib.cm as cm
498
- HAS_MATPLOTLIB = True
499
- print("✅ Matplotlib disponible - Grad-CAM avancé activé")
500
- except ImportError:
501
- print("⚠️ Matplotlib non disponible - Utilisation de PIL pour Grad-CAM")
502
-
503
- class OptimizedGradCAM:
504
- """Version optimisée de Grad-CAM avec nettoyage automatique"""
505
-
506
- def __init__(self, model, target_layer):
507
- self.model = model
508
- self.target_layer = target_layer
509
- self.gradients = None
510
- self.activations = None
511
- self.hooks = []
512
-
513
- # Enregistrer les hooks avec nettoyage automatique
514
- if target_layer is not None:
515
- hook1 = self.target_layer.register_backward_hook(self.save_gradients)
516
- hook2 = self.target_layer.register_forward_hook(self.save_activations)
517
- self.hooks = [hook1, hook2]
518
- else:
519
- print("⚠️ Aucune couche cible trouvée - Grad-CAM désactivé")
520
-
521
- def save_gradients(self, module, grad_input, grad_output):
522
- if grad_output[0] is not None:
523
- self.gradients = grad_output[0].detach()
524
-
525
- def save_activations(self, module, input, output):
526
- self.activations = output.detach()
527
-
528
- def generate_cam(self, input_tensor, class_idx=None):
529
- """Génère la carte de saillance Grad-CAM"""
530
- if self.target_layer is None:
531
- return None
532
-
533
- try:
534
- # Forward pass
535
- output = self.model(input_tensor)
536
-
537
- if class_idx is None:
538
- class_idx = output.logits.argmax(dim=1).item()
539
-
540
- # Backward pass
541
- self.model.zero_grad()
542
- output.logits[0, class_idx].backward(retain_graph=False)
543
-
544
- if self.gradients is None or self.activations is None:
545
- print("⚠️ Gradients ou activations manquants")
546
- return None
547
-
548
- # Generate CAM
549
- gradients = self.gradients[0] # (C, H, W)
550
- activations = self.activations[0] # (C, H, W)
551
-
552
- # Moyenne globale des gradients
553
- weights = torch.mean(gradients, dim=(1, 2)) # (C,)
554
-
555
- # CAM = somme pondérée des activations
556
- cam = torch.zeros(activations.shape[1:], device=activations.device) # (H, W)
557
- for i, w in enumerate(weights):
558
- cam += w * activations[i, :, :]
559
-
560
- # ReLU et normalisation
561
- cam = F.relu(cam)
562
- if cam.max() > 0:
563
- cam = cam / cam.max()
564
-
565
- return cam.detach().cpu().numpy()
566
-
567
- except Exception as e:
568
- print(f"⚠️ Erreur lors de la génération CAM: {e}")
569
- return None
570
- finally:
571
- # Nettoyage explicite
572
- if self.gradients is not None:
573
- self.gradients = None
574
- if self.activations is not None:
575
- self.activations = None
576
-
577
- def cleanup(self):
578
- """Nettoie les hooks et libère la mémoire"""
579
- for hook in self.hooks:
580
- try:
581
- hook.remove()
582
- except:
583
- pass
584
- self.hooks = []
585
- self.gradients = None
586
- self.activations = None
587
-
588
- def __del__(self):
589
- """Nettoyage automatique lors de la destruction"""
590
- self.cleanup()
591
-
592
- def get_last_conv_layer_safe(model):
593
- """Trouve la dernière couche de convolution de manière sécurisée"""
594
- try:
595
- last_conv = None
596
- conv_layers = []
597
-
598
- for name, module in model.named_modules():
599
- if isinstance(module, (torch.nn.Conv2d, torch.nn.AdaptiveAvgPool2d)):
600
- conv_layers.append((name, module))
601
-
602
- if conv_layers:
603
- last_conv = conv_layers[-1][1]
604
- print(f"✅ Couche cible trouvée: {conv_layers[-1][0]}")
605
- else:
606
- print("⚠️ Aucune couche de convolution trouvée")
607
-
608
- return last_conv
609
- except Exception as e:
610
- print(f"⚠️ Erreur lors de la recherche de couche: {e}")
611
- return None
612
-
613
- def create_gradcam_overlay_pil(original_image, cam_array):
614
- """Crée une superposition Grad-CAM en utilisant PIL (sans matplotlib)"""
615
- try:
616
- if cam_array is None:
617
- return None
618
-
619
- # Convertir CAM en image
620
- cam_normalized = (cam_array * 255).astype(np.uint8)
621
- cam_img = Image.fromarray(cam_normalized, mode='L')
622
-
623
- # Redimensionner au format de l'image originale
624
- cam_resized = cam_img.resize(original_image.size, Image.Resampling.LANCZOS)
625
-
626
- # Créer une heatmap colorée (rouge pour les zones importantes)
627
- # Convertir en RGB et appliquer une colormap simple
628
- cam_array_resized = np.array(cam_resized)
629
-
630
- # Créer une colormap simple (bleu -> rouge)
631
- heatmap = np.zeros((cam_array_resized.shape[0], cam_array_resized.shape[1], 3), dtype=np.uint8)
632
- heatmap[:, :, 0] = cam_array_resized # Rouge
633
- heatmap[:, :, 2] = 255 - cam_array_resized # Bleu inversé
634
-
635
- heatmap_img = Image.fromarray(heatmap, 'RGB')
636
-
637
- # Mélanger avec l'image originale
638
- blended = Image.blend(original_image.convert('RGB'), heatmap_img, alpha=0.4)
639
-
640
- # Convertir en base64
641
- buffer = io.BytesIO()
642
- blended.save(buffer, format='PNG', optimize=True)
643
- buffer.seek(0)
644
-
645
- return base64.b64encode(buffer.getvalue()).decode('utf-8')
646
-
647
- except Exception as e:
648
- print(f"⚠️ Erreur lors de la création de l'overlay PIL: {e}")
649
- return None
650
-
651
- def create_gradcam_overlay_matplotlib(original_image, cam_array):
652
- """Crée une superposition Grad-CAM en utilisant matplotlib (si disponible)"""
653
- try:
654
- if not HAS_MATPLOTLIB or cam_array is None:
655
- return None
656
-
657
- # Redimensionner CAM
658
- cam_resized = np.array(Image.fromarray((cam_array * 255).astype(np.uint8)).resize(
659
- original_image.size, Image.Resampling.LANCZOS
660
- )) / 255.0
661
-
662
- # Créer la figure
663
- fig, ax = plt.subplots(figsize=(8, 8), dpi=100)
664
- ax.imshow(original_image)
665
- ax.imshow(cam_resized, cmap='jet', alpha=0.5)
666
- ax.axis('off')
667
-
668
- # Sauvegarder en base64
669
- buffer = io.BytesIO()
670
- plt.savefig(buffer, format='png', bbox_inches='tight', pad_inches=0, dpi=100)
671
- plt.close(fig) # Important: fermer la figure
672
- buffer.seek(0)
673
-
674
- return base64.b64encode(buffer.getvalue()).decode('utf-8')
675
-
676
- except Exception as e:
677
- print(f"⚠️ Erreur lors de la création de l'overlay matplotlib: {e}")
678
- if 'fig' in locals():
679
- plt.close(fig)
680
- return None
681
-
682
- class EndpointHandler:
683
- def __init__(self, path=""):
684
- print("🚀 VerifAI Handler V2 FIXED - Initialisation")
685
- print("📋 Modèle: haywoodsloan/ai-image-detector-deploy (Version Corrigée)")
686
-
687
- self.model = None
688
- self.processor = None
689
- self.grad_cam = None
690
- self.model_labels = {}
691
-
692
- try:
693
- # Vérification de la disponibilité du modèle
694
- self.model_name = "haywoodsloan/ai-image-detector-deploy"
695
-
696
- if not self._verify_model_exists():
697
- raise Exception(f"Modèle {self.model_name} non accessible")
698
-
699
- # Chargement du modèle avec gestion d'erreurs
700
- print("🔄 Chargement du modèle...")
701
- self.processor = AutoImageProcessor.from_pretrained(self.model_name)
702
- self.model = AutoModelForImageClassification.from_pretrained(
703
- self.model_name,
704
- torch_dtype=torch.float32 # Force float32 pour la compatibilité
705
- )
706
- self.model.eval()
707
-
708
- # Configuration Grad-CAM sécurisée
709
- target_layer = get_last_conv_layer_safe(self.model)
710
- if target_layer is not None:
711
- self.grad_cam = OptimizedGradCAM(self.model, target_layer)
712
- print("✅ Grad-CAM activé")
713
- else:
714
- print("⚠️ Grad-CAM désactivé (aucune couche compatible)")
715
-
716
- # Récupérer les labels
717
- if hasattr(self.model.config, 'id2label'):
718
- self.model_labels = self.model.config.id2label
719
- else:
720
- self.model_labels = {0: "Real", 1: "Fake"} # Fallback
721
-
722
- print("✅ Modèle chargé avec succès")
723
- print(f"📋 Étiquettes du modèle: {self.model_labels}")
724
- print("🎯 VerifAI Handler V2 FIXED prêt!")
725
-
726
- except Exception as e:
727
- print(f"❌ Erreur lors de l'initialisation: {e}")
728
- print(f"🔍 Traceback: {traceback.format_exc()}")
729
- # Ne pas faire échouer l'initialisation, mais signaler l'erreur
730
- self.model = None
731
- self.processor = None
732
-
733
- def _verify_model_exists(self):
734
- """Vérifie que le modèle existe avant de le charger"""
735
- try:
736
- config = AutoConfig.from_pretrained(self.model_name)
737
- print(f"✅ Modèle {self.model_name} vérifié")
738
- return True
739
- except Exception as e:
740
- print(f"❌ Modèle {self.model_name} non accessible: {e}")
741
- return False
742
-
743
- def _normalize_label(self, label: str) -> str:
744
- """Normalise les étiquettes pour qu'elles soient cohérentes."""
745
- if not isinstance(label, str):
746
- label = str(label)
747
-
748
- label_lower = label.lower()
749
- if re.search(r'real|human|authentic', label_lower):
750
- return "Human"
751
- if re.search(r'fake|generated|ai|artificial', label_lower):
752
- return "AI Generated"
753
- return "Unknown"
754
-
755
- def _cleanup_memory(self):
756
- """Nettoie la mémoire explicitement"""
757
- try:
758
- if torch.cuda.is_available():
759
- torch.cuda.empty_cache()
760
- gc.collect()
761
- except:
762
- pass
763
-
764
- def __call__(self, data):
765
- # Vérification de l'état du handler
766
- if self.model is None or self.processor is None:
767
- return {
768
- "status": "error",
769
- "error": "Handler non initialisé correctement",
770
- "prediction": 0,
771
- "predicted_class_name": "Error",
772
- "confidence": 0.0,
773
- "class_probabilities": {"Human": 0.0, "AI Generated": 0.0},
774
- "cam_image": None,
775
- "version": "2.0-fixed",
776
- "handler_name": "VerifAI Handler V2 FIXED"
777
- }
778
-
779
- try:
780
- # Traitement de l'image avec validation
781
- image_data = data.get("inputs") or data
782
- if not image_data:
783
- raise ValueError("Aucune donnée d'image fournie")
784
-
785
- # Décodage sécurisé de l'image
786
- try:
787
- image_bytes = base64.b64decode(image_data)
788
- image = Image.open(io.BytesIO(image_bytes))
789
-
790
- # Validation et conversion
791
- if image.mode != 'RGB':
792
- image = image.convert('RGB')
793
-
794
- # Validation de la taille
795
- if image.size[0] * image.size[1] > 4096 * 4096:
796
- image = image.resize((1024, 1024), Image.Resampling.LANCZOS)
797
- print("⚠️ Image redimensionnée pour éviter les problèmes de mémoire")
798
-
799
- except Exception as e:
800
- raise ValueError(f"Erreur lors du décodage de l'image: {e}")
801
-
802
- # Prédiction avec gestion d'erreurs
803
- print("🔄 VerifAI V2 FIXED - Analyse en cours...")
804
-
805
- try:
806
- inputs = self.processor(image, return_tensors="pt")
807
-
808
- with torch.no_grad():
809
- outputs = self.model(**inputs)
810
- logits = outputs.logits
811
- probabilities = F.softmax(logits, dim=-1)[0]
812
- predicted_class_id = logits.argmax().item()
813
-
814
- except Exception as e:
815
- raise RuntimeError(f"Erreur lors de l'inférence: {e}")
816
-
817
- # Traitement des résultats
818
- class_probs = {}
819
- for class_id, prob in enumerate(probabilities):
820
- label_str = self.model_labels.get(class_id, f"Class {class_id}")
821
- normalized_label = self._normalize_label(label_str)
822
- if normalized_label != "Unknown":
823
- class_probs[normalized_label] = float(prob)
824
-
825
- # S'assurer que les deux classes existent
826
- class_probs.setdefault("Human", 0.0)
827
- class_probs.setdefault("AI Generated", 0.0)
828
-
829
- prediction_label = self._normalize_label(self.model_labels.get(predicted_class_id, "Unknown"))
830
- confidence = class_probs.get(prediction_label, 0.0)
831
-
832
- # Déterminer l'ID de prédiction pour la compatibilité
833
- prediction_id = 1 if prediction_label == "AI Generated" else 0
834
-
835
- print(f"🔍 VerifAI V2 FIXED Résultat: {prediction_label} (confiance: {confidence:.3f})")
836
-
837
- # Génération du Grad-CAM avec fallback
838
- cam_image_b64 = None
839
- if self.grad_cam is not None:
840
- try:
841
- print("🎨 Génération du Grad-CAM...")
842
- cam = self.grad_cam.generate_cam(inputs['pixel_values'], predicted_class_id)
843
-
844
- if cam is not None:
845
- # Essayer matplotlib d'abord, puis PIL
846
- if HAS_MATPLOTLIB:
847
- cam_image_b64 = create_gradcam_overlay_matplotlib(image, cam)
848
-
849
- if cam_image_b64 is None:
850
- cam_image_b64 = create_gradcam_overlay_pil(image, cam)
851
-
852
- if cam_image_b64:
853
- print("✅ Grad-CAM généré avec succès")
854
- else:
855
- print("⚠️ Échec de la génération Grad-CAM")
856
-
857
- except Exception as e:
858
- print(f"⚠️ Erreur Grad-CAM: {e}")
859
- cam_image_b64 = None
860
-
861
- # Nettoyage mémoire
862
- self._cleanup_memory()
863
-
864
- # Construction de la réponse compatible
865
- return {
866
- "status": "success",
867
- "prediction": prediction_id,
868
- "predicted_class_name": prediction_label,
869
- "confidence": confidence,
870
- "class_probabilities": class_probs,
871
- "cam_image": cam_image_b64,
872
- "model_info": {
873
- "model_name": self.model_name,
874
- "handler_version": "verifai-v2-fixed",
875
- "precision_mode": "high",
876
- "raw_prediction_id": predicted_class_id,
877
- "raw_labels": self.model_labels,
878
- "grad_cam_method": "matplotlib" if HAS_MATPLOTLIB else "pil"
879
- },
880
- "reliability": "TRÈS ÉLEVÉE",
881
- "version": "2.0-fixed",
882
- "handler_name": "VerifAI Handler V2 FIXED",
883
- "deployment_note": "VERIFAI HANDLER V2 FIXED - PRODUCTION READY",
884
- "fixes_applied": [
885
- "Gestion d'erreurs robuste",
886
- "Fallback PIL pour Grad-CAM",
887
- "Nettoyage mémoire automatique",
888
- "Validation d'entrée renforcée"
889
- ]
890
- }
891
-
892
- except Exception as e:
893
- print(f"❌ Erreur dans VerifAI Handler V2 FIXED: {e}")
894
- print(f"🔍 Traceback: {traceback.format_exc()}")
895
-
896
- # Nettoyage en cas d'erreur
897
- self._cleanup_memory()
898
-
899
- return {
900
- "status": "error",
901
- "error": str(e),
902
- "prediction": 0,
903
- "predicted_class_name": "Error",
904
- "confidence": 0.0,
905
- "class_probabilities": {"Human": 0.0, "AI Generated": 0.0},
906
- "cam_image": None,
907
- "version": "2.0-fixed",
908
- "handler_name": "VerifAI Handler V2 FIXED",
909
- "error_details": {
910
- "error_type": type(e).__name__,
911
- "traceback": traceback.format_exc()[-500:] # Dernières 500 chars
912
- }
913
- }
914
-
915
- def __del__(self):
916
- """Nettoyage lors de la destruction de l'instance"""
917
- try:
918
- if hasattr(self, 'grad_cam') and self.grad_cam is not None:
919
- self.grad_cam.cleanup()
920
- self._cleanup_memory()
921
- except:
922
- pass
923
-
924
  # Test de fonctionnement si exécuté directement
925
  if __name__ == "__main__":
926
  print("🧪 TEST DU HANDLER VERIFAI V2 FIXED")
 
444
  except:
445
  pass
446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  # Test de fonctionnement si exécuté directement
448
  if __name__ == "__main__":
449
  print("🧪 TEST DU HANDLER VERIFAI V2 FIXED")