Spaces:
Sleeping
Sleeping
File size: 11,356 Bytes
5e23602 4474779 5e23602 bb246b9 5e23602 bb246b9 5e23602 4474779 5e23602 bb246b9 5e23602 bb246b9 5e23602 58f0122 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 4474779 bb246b9 4474779 5e23602 bb246b9 5e23602 bb246b9 5e23602 4474779 5e23602 bb246b9 5e23602 4474779 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 bb246b9 5e23602 4474779 5e23602 bb246b9 5e23602 4474779 bb246b9 5e23602 4474779 5e23602 4474779 bb246b9 5e23602 bb246b9 5e23602 4474779 5e23602 bb246b9 4474779 bb246b9 5e23602 4474779 5e23602 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
import torch
from PIL import Image, ImageDraw # ImageDraw pour la section de test
import os
import traceback # Pour un log d'erreur plus détaillé
# Imports spécifiques pour LLaVA
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
# --- Configuration du modèle LLaVA-NeXT ---
LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
# LLAVA_REVISION = "main" # On utilise la branche principale par défaut
llava_processor = None
llava_model = None
llava_model_loaded = False
# Détection du device (CPU, CUDA, ou MPS pour Mac Apple Silicon)
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): # Pour Apple Silicon
device = "mps"
else:
device = "cpu"
print(f"Utilisation du device : {device} pour les modèles d'IA.")
# --- Fonctions de chargement et de génération pour LLaVA ---
def load_llava_model():
global llava_processor, llava_model, llava_model_loaded, device
if llava_model_loaded:
print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}) déjà chargé.")
return
try:
print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME})...")
llava_processor = LlavaNextProcessor.from_pretrained(
LLAVA_MODEL_NAME
# revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
)
print("Processor LLaVA chargé.")
print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}) sur '{device}'...")
model_args = {
"low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU
# revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
}
if device == "cuda":
model_args["torch_dtype"] = torch.float16 # Précision pour GPU NVIDIA
# Pour les GPU plus récents (Ampere+), torch.bfloat16 peut aussi être une option
print(f"Configuration de LLaVA pour CUDA ({model_args['torch_dtype']}).")
elif device == "mps":
# Pour MPS, float16 est souvent utilisé, mais float32 est plus sûr pour commencer
# Si des problèmes de stabilité surviennent avec float16 sur MPS, commenter la ligne suivante
model_args["torch_dtype"] = torch.float16
print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut float32')}).")
else: # CPU
# Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité
print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
llava_model = LlavaNextForConditionalGeneration.from_pretrained(
LLAVA_MODEL_NAME,
**model_args
).to(device).eval() # Mettre le modèle en mode évaluation
llava_model_loaded = True
print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}) chargés avec succès sur '{device}'.")
except Exception as e:
print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}")
traceback.print_exc()
llava_model_loaded = False
def generate_description_llava(image_path: str) -> str:
global llava_processor, llava_model, llava_model_loaded, device
if not llava_model_loaded:
print("Modèle LLaVA non chargé. Tentative de chargement à la demande...")
load_llava_model()
if not llava_model_loaded:
return "Erreur: Le modèle LLaVA n'a pas pu être chargé."
if not os.path.exists(image_path):
return f"Erreur: Le fichier image {image_path} n'existe pas."
try:
image = Image.open(image_path).convert("RGB")
# --- PROMPT AMÉLIORÉ ---
user_prompt_fr = (
"Analyse cette image en tant qu'œuvre d'art. "
"Fournis une description objective, factuelle et très détaillée en français. "
"Concentre-toi sur les éléments visuels : la scène globale, les sujets et personnages (y compris leur apparence, posture, expression, et toute forme de nudité si présente), "
"les objets, l'arrière-plan, les formes, les couleurs, la lumière, la composition et la technique artistique apparente. "
"Évite toute interprétation subjective ou jugement moral, et décris ce qui est visible de manière exhaustive."
)
# Format de prompt spécifique à LLaVA v1.5+ (Mistral utilise ce format)
prompt_template = f"<s>[INST] <image>\n{user_prompt_fr} [/INST]"
print(f"Préparation des entrées pour LLaVA avec le prompt utilisateur (condensé): \"{user_prompt_fr[:100]}...\"")
# Le processeur gère la tokenisation du texte et le prétraitement de l'image
inputs_on_cpu = llava_processor(text=prompt_template, images=image, return_tensors="pt")
# Déplacer les tenseurs sur le bon device
inputs = {}
for key, value in inputs_on_cpu.items():
if torch.is_tensor(value):
inputs[key] = value.to(device)
else:
inputs[key] = value
# S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS (si float16/bfloat16)
if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
(llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
for k_tensor, v_tensor in inputs.items():
if torch.is_tensor(v_tensor) and torch.is_floating_point(v_tensor): # Appliquer seulement aux tenseurs flottants
inputs[k_tensor] = v_tensor.to(llava_model.dtype)
input_dtypes_log = {k: v.dtype for k,v in inputs.items() if torch.is_tensor(v)}
print(f"Génération de la description LLaVA pour {image_path} (device: {device}, input dtypes: {input_dtypes_log})...")
# Paramètres de génération (ajustables si nécessaire)
generation_kwargs = {
"max_new_tokens": 768, # Augmenté légèrement pour des descriptions potentiellement plus longues
"num_beams": 3, # Un peu de beam search peut améliorer la cohérence
"early_stopping": True,
"do_sample": False # Pour des descriptions plus factuelles et moins "créatives" aléatoirement. Mettre True avec temperature si on veut plus de variété.
# "temperature": 0.7, # À utiliser avec do_sample=True si on veut de la créativité
# "top_p": 0.9, # À utiliser avec do_sample=True
}
generated_ids = llava_model.generate(**inputs, **generation_kwargs)
# Pour LLaVA, il est important de décoder uniquement les tokens générés *après* le prompt.
# Certains processeurs/modèles gèrent cela différemment.
# Pour LLaVA-NeXT, le décodage de la séquence complète et le nettoyage du prompt est une approche courante.
# Ou, si l'on connaît la longueur des tokens d'entrée :
# input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
# generated_ids_only = generated_ids[0, input_token_len:]
# cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
# Approche plus simple : décoder toute la séquence et enlever manuellement le prompt si besoin.
# Souvent, pour les formats [INST]...[/INST], le modèle génère directement la réponse.
full_text = llava_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
# Nettoyage du texte généré pour enlever le prompt s'il est répété (spécifique au format de sortie du modèle)
# Le format "[INST] <image> \n {prompt} [/INST] {réponse}" fait que la réponse est souvent propre.
inst_marker = "[/INST]"
if inst_marker in full_text:
cleaned_text = full_text.split(inst_marker, 1)[-1].strip()
else:
cleaned_text = full_text # Si le marqueur n'est pas là, prendre tout (peut arriver)
print(f"Description (nettoyée) de LLaVA: {cleaned_text[:200]}...") # Log tronqué
return cleaned_text if cleaned_text else "Aucune description textuelle distincte n'a été générée par LLaVA."
except Exception as e:
print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
traceback.print_exc()
if torch.cuda.is_available() or device == "mps":
if device == "cuda": torch.cuda.empty_cache()
# if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache()
return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
# --- Fonctions de gestion du modèle actif ---
ACTIVE_MODEL = "llava" # Pour l'instant, seul LLaVA est configuré
def load_active_model():
print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
if ACTIVE_MODEL == "llava":
load_llava_model()
else:
print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
def generate_active_description(image_path: str) -> str:
if ACTIVE_MODEL == "llava":
return generate_description_llava(image_path)
else:
error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
print(error_msg)
return error_msg
def is_active_model_loaded() -> bool:
if ACTIVE_MODEL == "llava":
return llava_model_loaded
return False
# --- Section de Test (pour exécution directe de ce fichier) ---
if __name__ == '__main__':
print("Début du test de processing.py...")
dummy_image_name = "dummy_test_image.png"
if not os.path.exists(dummy_image_name):
try:
img = Image.new('RGB', (200, 150), color = 'skyblue')
draw = ImageDraw.Draw(img)
draw.text((10, 10), "Test Image for LLaVA", fill='black')
# Ajouter quelques formes pour le test
draw.ellipse((30, 50, 90, 110), fill='red', outline='black')
draw.rectangle((100, 40, 170, 120), fill='lightgreen', outline='blue')
img.save(dummy_image_name)
print(f"Image de test '{dummy_image_name}' créée.")
except Exception as e_img:
print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}")
if os.path.exists(dummy_image_name):
print(f"Utilisation du modèle actif : {ACTIVE_MODEL} sur device {device}")
print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...")
load_active_model()
if is_active_model_loaded():
print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
description = generate_active_description(dummy_image_name)
print(f"\n--- Description Générée ---")
print(description)
print(f"--------------------------")
else:
print("Le modèle actif n'a pas pu être chargé. Test de description annulé.")
else:
print(f"Image de test '{dummy_image_name}' non trouvée pour le test.")
print("Fin du test de processing.py.") |