import torch from PIL import Image, ImageDraw # ImageDraw pour la section de test import os import traceback # Pour un log d'erreur plus détaillé # Imports spécifiques pour LLaVA from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration # --- Configuration du modèle LLaVA-NeXT --- LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" # LLAVA_REVISION = "main" # On utilise la branche principale par défaut llava_processor = None llava_model = None llava_model_loaded = False # Détection du device (CPU, CUDA, ou MPS pour Mac Apple Silicon) if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): # Pour Apple Silicon device = "mps" else: device = "cpu" print(f"Utilisation du device : {device} pour les modèles d'IA.") # --- Fonctions de chargement et de génération pour LLaVA --- def load_llava_model(): global llava_processor, llava_model, llava_model_loaded, device if llava_model_loaded: print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}) déjà chargé.") return try: print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME})...") llava_processor = LlavaNextProcessor.from_pretrained( LLAVA_MODEL_NAME # revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision ) print("Processor LLaVA chargé.") print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}) sur '{device}'...") model_args = { "low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU # revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision } if device == "cuda": model_args["torch_dtype"] = torch.float16 # Précision pour GPU NVIDIA # Pour les GPU plus récents (Ampere+), torch.bfloat16 peut aussi être une option print(f"Configuration de LLaVA pour CUDA ({model_args['torch_dtype']}).") elif device == "mps": # Pour MPS, float16 est souvent utilisé, mais float32 est plus sûr pour commencer # Si des problèmes de stabilité surviennent avec float16 sur MPS, commenter la ligne suivante model_args["torch_dtype"] = torch.float16 print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut float32')}).") else: # CPU # Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité print(f"Configuration de LLaVA pour CPU (float32 par défaut).") llava_model = LlavaNextForConditionalGeneration.from_pretrained( LLAVA_MODEL_NAME, **model_args ).to(device).eval() # Mettre le modèle en mode évaluation llava_model_loaded = True print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}) chargés avec succès sur '{device}'.") except Exception as e: print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}") traceback.print_exc() llava_model_loaded = False def generate_description_llava(image_path: str) -> str: global llava_processor, llava_model, llava_model_loaded, device if not llava_model_loaded: print("Modèle LLaVA non chargé. Tentative de chargement à la demande...") load_llava_model() if not llava_model_loaded: return "Erreur: Le modèle LLaVA n'a pas pu être chargé." if not os.path.exists(image_path): return f"Erreur: Le fichier image {image_path} n'existe pas." try: image = Image.open(image_path).convert("RGB") # --- PROMPT AMÉLIORÉ --- user_prompt_fr = ( "Analyse cette image en tant qu'œuvre d'art. " "Fournis une description objective, factuelle et très détaillée en français. " "Concentre-toi sur les éléments visuels : la scène globale, les sujets et personnages (y compris leur apparence, posture, expression, et toute forme de nudité si présente), " "les objets, l'arrière-plan, les formes, les couleurs, la lumière, la composition et la technique artistique apparente. " "Évite toute interprétation subjective ou jugement moral, et décris ce qui est visible de manière exhaustive." ) # Format de prompt spécifique à LLaVA v1.5+ (Mistral utilise ce format) prompt_template = f"[INST] \n{user_prompt_fr} [/INST]" print(f"Préparation des entrées pour LLaVA avec le prompt utilisateur (condensé): \"{user_prompt_fr[:100]}...\"") # Le processeur gère la tokenisation du texte et le prétraitement de l'image inputs_on_cpu = llava_processor(text=prompt_template, images=image, return_tensors="pt") # Déplacer les tenseurs sur le bon device inputs = {} for key, value in inputs_on_cpu.items(): if torch.is_tensor(value): inputs[key] = value.to(device) else: inputs[key] = value # S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS (si float16/bfloat16) if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \ (llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16): for k_tensor, v_tensor in inputs.items(): if torch.is_tensor(v_tensor) and torch.is_floating_point(v_tensor): # Appliquer seulement aux tenseurs flottants inputs[k_tensor] = v_tensor.to(llava_model.dtype) input_dtypes_log = {k: v.dtype for k,v in inputs.items() if torch.is_tensor(v)} print(f"Génération de la description LLaVA pour {image_path} (device: {device}, input dtypes: {input_dtypes_log})...") # Paramètres de génération (ajustables si nécessaire) generation_kwargs = { "max_new_tokens": 768, # Augmenté légèrement pour des descriptions potentiellement plus longues "num_beams": 3, # Un peu de beam search peut améliorer la cohérence "early_stopping": True, "do_sample": False # Pour des descriptions plus factuelles et moins "créatives" aléatoirement. Mettre True avec temperature si on veut plus de variété. # "temperature": 0.7, # À utiliser avec do_sample=True si on veut de la créativité # "top_p": 0.9, # À utiliser avec do_sample=True } generated_ids = llava_model.generate(**inputs, **generation_kwargs) # Pour LLaVA, il est important de décoder uniquement les tokens générés *après* le prompt. # Certains processeurs/modèles gèrent cela différemment. # Pour LLaVA-NeXT, le décodage de la séquence complète et le nettoyage du prompt est une approche courante. # Ou, si l'on connaît la longueur des tokens d'entrée : # input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1] # generated_ids_only = generated_ids[0, input_token_len:] # cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip() # Approche plus simple : décoder toute la séquence et enlever manuellement le prompt si besoin. # Souvent, pour les formats [INST]...[/INST], le modèle génère directement la réponse. full_text = llava_processor.decode(generated_ids[0], skip_special_tokens=True).strip() # Nettoyage du texte généré pour enlever le prompt s'il est répété (spécifique au format de sortie du modèle) # Le format "[INST] \n {prompt} [/INST] {réponse}" fait que la réponse est souvent propre. inst_marker = "[/INST]" if inst_marker in full_text: cleaned_text = full_text.split(inst_marker, 1)[-1].strip() else: cleaned_text = full_text # Si le marqueur n'est pas là, prendre tout (peut arriver) print(f"Description (nettoyée) de LLaVA: {cleaned_text[:200]}...") # Log tronqué return cleaned_text if cleaned_text else "Aucune description textuelle distincte n'a été générée par LLaVA." except Exception as e: print(f"Erreur détaillée lors de la génération de la description avec LLaVA:") traceback.print_exc() if torch.cuda.is_available() or device == "mps": if device == "cuda": torch.cuda.empty_cache() # if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache() return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}" # --- Fonctions de gestion du modèle actif --- ACTIVE_MODEL = "llava" # Pour l'instant, seul LLaVA est configuré def load_active_model(): print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}") if ACTIVE_MODEL == "llava": load_llava_model() else: print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.") def generate_active_description(image_path: str) -> str: if ACTIVE_MODEL == "llava": return generate_description_llava(image_path) else: error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description." print(error_msg) return error_msg def is_active_model_loaded() -> bool: if ACTIVE_MODEL == "llava": return llava_model_loaded return False # --- Section de Test (pour exécution directe de ce fichier) --- if __name__ == '__main__': print("Début du test de processing.py...") dummy_image_name = "dummy_test_image.png" if not os.path.exists(dummy_image_name): try: img = Image.new('RGB', (200, 150), color = 'skyblue') draw = ImageDraw.Draw(img) draw.text((10, 10), "Test Image for LLaVA", fill='black') # Ajouter quelques formes pour le test draw.ellipse((30, 50, 90, 110), fill='red', outline='black') draw.rectangle((100, 40, 170, 120), fill='lightgreen', outline='blue') img.save(dummy_image_name) print(f"Image de test '{dummy_image_name}' créée.") except Exception as e_img: print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}") if os.path.exists(dummy_image_name): print(f"Utilisation du modèle actif : {ACTIVE_MODEL} sur device {device}") print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...") load_active_model() if is_active_model_loaded(): print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...") description = generate_active_description(dummy_image_name) print(f"\n--- Description Générée ---") print(description) print(f"--------------------------") else: print("Le modèle actif n'a pas pu être chargé. Test de description annulé.") else: print(f"Image de test '{dummy_image_name}' non trouvée pour le test.") print("Fin du test de processing.py.")