Spaces:
Sleeping
Sleeping
| import torch | |
| from PIL import Image, ImageDraw # ImageDraw pour la section de test | |
| import os | |
| import traceback # Pour un log d'erreur plus détaillé | |
| # Imports spécifiques pour LLaVA | |
| from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
| # --- Configuration du modèle LLaVA-NeXT --- | |
| LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf" | |
| # LLAVA_REVISION = "main" # On utilise la branche principale par défaut | |
| llava_processor = None | |
| llava_model = None | |
| llava_model_loaded = False | |
| # Détection du device (CPU, CUDA, ou MPS pour Mac Apple Silicon) | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): # Pour Apple Silicon | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| print(f"Utilisation du device : {device} pour les modèles d'IA.") | |
| # --- Fonctions de chargement et de génération pour LLaVA --- | |
| def load_llava_model(): | |
| global llava_processor, llava_model, llava_model_loaded, device | |
| if llava_model_loaded: | |
| print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}) déjà chargé.") | |
| return | |
| try: | |
| print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME})...") | |
| llava_processor = LlavaNextProcessor.from_pretrained( | |
| LLAVA_MODEL_NAME | |
| # revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision | |
| ) | |
| print("Processor LLaVA chargé.") | |
| print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}) sur '{device}'...") | |
| model_args = { | |
| "low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU | |
| # revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision | |
| } | |
| if device == "cuda": | |
| model_args["torch_dtype"] = torch.float16 # Précision pour GPU NVIDIA | |
| # Pour les GPU plus récents (Ampere+), torch.bfloat16 peut aussi être une option | |
| print(f"Configuration de LLaVA pour CUDA ({model_args['torch_dtype']}).") | |
| elif device == "mps": | |
| # Pour MPS, float16 est souvent utilisé, mais float32 est plus sûr pour commencer | |
| # Si des problèmes de stabilité surviennent avec float16 sur MPS, commenter la ligne suivante | |
| model_args["torch_dtype"] = torch.float16 | |
| print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut float32')}).") | |
| else: # CPU | |
| # Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité | |
| print(f"Configuration de LLaVA pour CPU (float32 par défaut).") | |
| llava_model = LlavaNextForConditionalGeneration.from_pretrained( | |
| LLAVA_MODEL_NAME, | |
| **model_args | |
| ).to(device).eval() # Mettre le modèle en mode évaluation | |
| llava_model_loaded = True | |
| print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}) chargés avec succès sur '{device}'.") | |
| except Exception as e: | |
| print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}") | |
| traceback.print_exc() | |
| llava_model_loaded = False | |
| def generate_description_llava(image_path: str) -> str: | |
| global llava_processor, llava_model, llava_model_loaded, device | |
| if not llava_model_loaded: | |
| print("Modèle LLaVA non chargé. Tentative de chargement à la demande...") | |
| load_llava_model() | |
| if not llava_model_loaded: | |
| return "Erreur: Le modèle LLaVA n'a pas pu être chargé." | |
| if not os.path.exists(image_path): | |
| return f"Erreur: Le fichier image {image_path} n'existe pas." | |
| try: | |
| image = Image.open(image_path).convert("RGB") | |
| # --- PROMPT AMÉLIORÉ --- | |
| user_prompt_fr = ( | |
| "Analyse cette image en tant qu'œuvre d'art. " | |
| "Fournis une description objective, factuelle et très détaillée en français. " | |
| "Concentre-toi sur les éléments visuels : la scène globale, les sujets et personnages (y compris leur apparence, posture, expression, et toute forme de nudité si présente), " | |
| "les objets, l'arrière-plan, les formes, les couleurs, la lumière, la composition et la technique artistique apparente. " | |
| "Évite toute interprétation subjective ou jugement moral, et décris ce qui est visible de manière exhaustive." | |
| ) | |
| # Format de prompt spécifique à LLaVA v1.5+ (Mistral utilise ce format) | |
| prompt_template = f"<s>[INST] <image>\n{user_prompt_fr} [/INST]" | |
| print(f"Préparation des entrées pour LLaVA avec le prompt utilisateur (condensé): \"{user_prompt_fr[:100]}...\"") | |
| # Le processeur gère la tokenisation du texte et le prétraitement de l'image | |
| inputs_on_cpu = llava_processor(text=prompt_template, images=image, return_tensors="pt") | |
| # Déplacer les tenseurs sur le bon device | |
| inputs = {} | |
| for key, value in inputs_on_cpu.items(): | |
| if torch.is_tensor(value): | |
| inputs[key] = value.to(device) | |
| else: | |
| inputs[key] = value | |
| # S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS (si float16/bfloat16) | |
| if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \ | |
| (llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16): | |
| for k_tensor, v_tensor in inputs.items(): | |
| if torch.is_tensor(v_tensor) and torch.is_floating_point(v_tensor): # Appliquer seulement aux tenseurs flottants | |
| inputs[k_tensor] = v_tensor.to(llava_model.dtype) | |
| input_dtypes_log = {k: v.dtype for k,v in inputs.items() if torch.is_tensor(v)} | |
| print(f"Génération de la description LLaVA pour {image_path} (device: {device}, input dtypes: {input_dtypes_log})...") | |
| # Paramètres de génération (ajustables si nécessaire) | |
| generation_kwargs = { | |
| "max_new_tokens": 768, # Augmenté légèrement pour des descriptions potentiellement plus longues | |
| "num_beams": 3, # Un peu de beam search peut améliorer la cohérence | |
| "early_stopping": True, | |
| "do_sample": False # Pour des descriptions plus factuelles et moins "créatives" aléatoirement. Mettre True avec temperature si on veut plus de variété. | |
| # "temperature": 0.7, # À utiliser avec do_sample=True si on veut de la créativité | |
| # "top_p": 0.9, # À utiliser avec do_sample=True | |
| } | |
| generated_ids = llava_model.generate(**inputs, **generation_kwargs) | |
| # Pour LLaVA, il est important de décoder uniquement les tokens générés *après* le prompt. | |
| # Certains processeurs/modèles gèrent cela différemment. | |
| # Pour LLaVA-NeXT, le décodage de la séquence complète et le nettoyage du prompt est une approche courante. | |
| # Ou, si l'on connaît la longueur des tokens d'entrée : | |
| # input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1] | |
| # generated_ids_only = generated_ids[0, input_token_len:] | |
| # cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip() | |
| # Approche plus simple : décoder toute la séquence et enlever manuellement le prompt si besoin. | |
| # Souvent, pour les formats [INST]...[/INST], le modèle génère directement la réponse. | |
| full_text = llava_processor.decode(generated_ids[0], skip_special_tokens=True).strip() | |
| # Nettoyage du texte généré pour enlever le prompt s'il est répété (spécifique au format de sortie du modèle) | |
| # Le format "[INST] <image> \n {prompt} [/INST] {réponse}" fait que la réponse est souvent propre. | |
| inst_marker = "[/INST]" | |
| if inst_marker in full_text: | |
| cleaned_text = full_text.split(inst_marker, 1)[-1].strip() | |
| else: | |
| cleaned_text = full_text # Si le marqueur n'est pas là, prendre tout (peut arriver) | |
| print(f"Description (nettoyée) de LLaVA: {cleaned_text[:200]}...") # Log tronqué | |
| return cleaned_text if cleaned_text else "Aucune description textuelle distincte n'a été générée par LLaVA." | |
| except Exception as e: | |
| print(f"Erreur détaillée lors de la génération de la description avec LLaVA:") | |
| traceback.print_exc() | |
| if torch.cuda.is_available() or device == "mps": | |
| if device == "cuda": torch.cuda.empty_cache() | |
| # if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache() | |
| return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}" | |
| # --- Fonctions de gestion du modèle actif --- | |
| ACTIVE_MODEL = "llava" # Pour l'instant, seul LLaVA est configuré | |
| def load_active_model(): | |
| print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}") | |
| if ACTIVE_MODEL == "llava": | |
| load_llava_model() | |
| else: | |
| print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.") | |
| def generate_active_description(image_path: str) -> str: | |
| if ACTIVE_MODEL == "llava": | |
| return generate_description_llava(image_path) | |
| else: | |
| error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description." | |
| print(error_msg) | |
| return error_msg | |
| def is_active_model_loaded() -> bool: | |
| if ACTIVE_MODEL == "llava": | |
| return llava_model_loaded | |
| return False | |
| # --- Section de Test (pour exécution directe de ce fichier) --- | |
| if __name__ == '__main__': | |
| print("Début du test de processing.py...") | |
| dummy_image_name = "dummy_test_image.png" | |
| if not os.path.exists(dummy_image_name): | |
| try: | |
| img = Image.new('RGB', (200, 150), color = 'skyblue') | |
| draw = ImageDraw.Draw(img) | |
| draw.text((10, 10), "Test Image for LLaVA", fill='black') | |
| # Ajouter quelques formes pour le test | |
| draw.ellipse((30, 50, 90, 110), fill='red', outline='black') | |
| draw.rectangle((100, 40, 170, 120), fill='lightgreen', outline='blue') | |
| img.save(dummy_image_name) | |
| print(f"Image de test '{dummy_image_name}' créée.") | |
| except Exception as e_img: | |
| print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}") | |
| if os.path.exists(dummy_image_name): | |
| print(f"Utilisation du modèle actif : {ACTIVE_MODEL} sur device {device}") | |
| print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...") | |
| load_active_model() | |
| if is_active_model_loaded(): | |
| print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...") | |
| description = generate_active_description(dummy_image_name) | |
| print(f"\n--- Description Générée ---") | |
| print(description) | |
| print(f"--------------------------") | |
| else: | |
| print("Le modèle actif n'a pas pu être chargé. Test de description annulé.") | |
| else: | |
| print(f"Image de test '{dummy_image_name}' non trouvée pour le test.") | |
| print("Fin du test de processing.py.") |