autocaption-app / app /processing.py
Ludovic
V5
bb246b9
import torch
from PIL import Image, ImageDraw # ImageDraw pour la section de test
import os
import traceback # Pour un log d'erreur plus détaillé
# Imports spécifiques pour LLaVA
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
# --- Configuration du modèle LLaVA-NeXT ---
LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
# LLAVA_REVISION = "main" # On utilise la branche principale par défaut
llava_processor = None
llava_model = None
llava_model_loaded = False
# Détection du device (CPU, CUDA, ou MPS pour Mac Apple Silicon)
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): # Pour Apple Silicon
device = "mps"
else:
device = "cpu"
print(f"Utilisation du device : {device} pour les modèles d'IA.")
# --- Fonctions de chargement et de génération pour LLaVA ---
def load_llava_model():
global llava_processor, llava_model, llava_model_loaded, device
if llava_model_loaded:
print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}) déjà chargé.")
return
try:
print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME})...")
llava_processor = LlavaNextProcessor.from_pretrained(
LLAVA_MODEL_NAME
# revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
)
print("Processor LLaVA chargé.")
print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}) sur '{device}'...")
model_args = {
"low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU
# revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
}
if device == "cuda":
model_args["torch_dtype"] = torch.float16 # Précision pour GPU NVIDIA
# Pour les GPU plus récents (Ampere+), torch.bfloat16 peut aussi être une option
print(f"Configuration de LLaVA pour CUDA ({model_args['torch_dtype']}).")
elif device == "mps":
# Pour MPS, float16 est souvent utilisé, mais float32 est plus sûr pour commencer
# Si des problèmes de stabilité surviennent avec float16 sur MPS, commenter la ligne suivante
model_args["torch_dtype"] = torch.float16
print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut float32')}).")
else: # CPU
# Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité
print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
llava_model = LlavaNextForConditionalGeneration.from_pretrained(
LLAVA_MODEL_NAME,
**model_args
).to(device).eval() # Mettre le modèle en mode évaluation
llava_model_loaded = True
print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}) chargés avec succès sur '{device}'.")
except Exception as e:
print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}")
traceback.print_exc()
llava_model_loaded = False
def generate_description_llava(image_path: str) -> str:
global llava_processor, llava_model, llava_model_loaded, device
if not llava_model_loaded:
print("Modèle LLaVA non chargé. Tentative de chargement à la demande...")
load_llava_model()
if not llava_model_loaded:
return "Erreur: Le modèle LLaVA n'a pas pu être chargé."
if not os.path.exists(image_path):
return f"Erreur: Le fichier image {image_path} n'existe pas."
try:
image = Image.open(image_path).convert("RGB")
# --- PROMPT AMÉLIORÉ ---
user_prompt_fr = (
"Analyse cette image en tant qu'œuvre d'art. "
"Fournis une description objective, factuelle et très détaillée en français. "
"Concentre-toi sur les éléments visuels : la scène globale, les sujets et personnages (y compris leur apparence, posture, expression, et toute forme de nudité si présente), "
"les objets, l'arrière-plan, les formes, les couleurs, la lumière, la composition et la technique artistique apparente. "
"Évite toute interprétation subjective ou jugement moral, et décris ce qui est visible de manière exhaustive."
)
# Format de prompt spécifique à LLaVA v1.5+ (Mistral utilise ce format)
prompt_template = f"<s>[INST] <image>\n{user_prompt_fr} [/INST]"
print(f"Préparation des entrées pour LLaVA avec le prompt utilisateur (condensé): \"{user_prompt_fr[:100]}...\"")
# Le processeur gère la tokenisation du texte et le prétraitement de l'image
inputs_on_cpu = llava_processor(text=prompt_template, images=image, return_tensors="pt")
# Déplacer les tenseurs sur le bon device
inputs = {}
for key, value in inputs_on_cpu.items():
if torch.is_tensor(value):
inputs[key] = value.to(device)
else:
inputs[key] = value
# S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS (si float16/bfloat16)
if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
(llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
for k_tensor, v_tensor in inputs.items():
if torch.is_tensor(v_tensor) and torch.is_floating_point(v_tensor): # Appliquer seulement aux tenseurs flottants
inputs[k_tensor] = v_tensor.to(llava_model.dtype)
input_dtypes_log = {k: v.dtype for k,v in inputs.items() if torch.is_tensor(v)}
print(f"Génération de la description LLaVA pour {image_path} (device: {device}, input dtypes: {input_dtypes_log})...")
# Paramètres de génération (ajustables si nécessaire)
generation_kwargs = {
"max_new_tokens": 768, # Augmenté légèrement pour des descriptions potentiellement plus longues
"num_beams": 3, # Un peu de beam search peut améliorer la cohérence
"early_stopping": True,
"do_sample": False # Pour des descriptions plus factuelles et moins "créatives" aléatoirement. Mettre True avec temperature si on veut plus de variété.
# "temperature": 0.7, # À utiliser avec do_sample=True si on veut de la créativité
# "top_p": 0.9, # À utiliser avec do_sample=True
}
generated_ids = llava_model.generate(**inputs, **generation_kwargs)
# Pour LLaVA, il est important de décoder uniquement les tokens générés *après* le prompt.
# Certains processeurs/modèles gèrent cela différemment.
# Pour LLaVA-NeXT, le décodage de la séquence complète et le nettoyage du prompt est une approche courante.
# Ou, si l'on connaît la longueur des tokens d'entrée :
# input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
# generated_ids_only = generated_ids[0, input_token_len:]
# cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
# Approche plus simple : décoder toute la séquence et enlever manuellement le prompt si besoin.
# Souvent, pour les formats [INST]...[/INST], le modèle génère directement la réponse.
full_text = llava_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
# Nettoyage du texte généré pour enlever le prompt s'il est répété (spécifique au format de sortie du modèle)
# Le format "[INST] <image> \n {prompt} [/INST] {réponse}" fait que la réponse est souvent propre.
inst_marker = "[/INST]"
if inst_marker in full_text:
cleaned_text = full_text.split(inst_marker, 1)[-1].strip()
else:
cleaned_text = full_text # Si le marqueur n'est pas là, prendre tout (peut arriver)
print(f"Description (nettoyée) de LLaVA: {cleaned_text[:200]}...") # Log tronqué
return cleaned_text if cleaned_text else "Aucune description textuelle distincte n'a été générée par LLaVA."
except Exception as e:
print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
traceback.print_exc()
if torch.cuda.is_available() or device == "mps":
if device == "cuda": torch.cuda.empty_cache()
# if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache()
return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
# --- Fonctions de gestion du modèle actif ---
ACTIVE_MODEL = "llava" # Pour l'instant, seul LLaVA est configuré
def load_active_model():
print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
if ACTIVE_MODEL == "llava":
load_llava_model()
else:
print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
def generate_active_description(image_path: str) -> str:
if ACTIVE_MODEL == "llava":
return generate_description_llava(image_path)
else:
error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
print(error_msg)
return error_msg
def is_active_model_loaded() -> bool:
if ACTIVE_MODEL == "llava":
return llava_model_loaded
return False
# --- Section de Test (pour exécution directe de ce fichier) ---
if __name__ == '__main__':
print("Début du test de processing.py...")
dummy_image_name = "dummy_test_image.png"
if not os.path.exists(dummy_image_name):
try:
img = Image.new('RGB', (200, 150), color = 'skyblue')
draw = ImageDraw.Draw(img)
draw.text((10, 10), "Test Image for LLaVA", fill='black')
# Ajouter quelques formes pour le test
draw.ellipse((30, 50, 90, 110), fill='red', outline='black')
draw.rectangle((100, 40, 170, 120), fill='lightgreen', outline='blue')
img.save(dummy_image_name)
print(f"Image de test '{dummy_image_name}' créée.")
except Exception as e_img:
print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}")
if os.path.exists(dummy_image_name):
print(f"Utilisation du modèle actif : {ACTIVE_MODEL} sur device {device}")
print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...")
load_active_model()
if is_active_model_loaded():
print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
description = generate_active_description(dummy_image_name)
print(f"\n--- Description Générée ---")
print(description)
print(f"--------------------------")
else:
print("Le modèle actif n'a pas pu être chargé. Test de description annulé.")
else:
print(f"Image de test '{dummy_image_name}' non trouvée pour le test.")
print("Fin du test de processing.py.")