File size: 11,356 Bytes
5e23602
 
 
4474779
5e23602
 
 
 
 
 
bb246b9
5e23602
 
 
 
 
 
 
 
bb246b9
5e23602
 
 
 
 
4474779
 
5e23602
 
 
bb246b9
5e23602
 
 
bb246b9
5e23602
58f0122
bb246b9
5e23602
 
 
bb246b9
5e23602
bb246b9
 
5e23602
4474779
bb246b9
 
 
 
 
 
 
 
 
 
4474779
5e23602
 
 
 
 
bb246b9
5e23602
 
bb246b9
5e23602
 
4474779
5e23602
 
 
 
 
 
 
 
bb246b9
5e23602
4474779
bb246b9
 
5e23602
 
 
 
 
 
bb246b9
 
 
 
 
 
 
 
 
 
 
5e23602
bb246b9
5e23602
bb246b9
 
 
 
5e23602
 
 
 
 
bb246b9
5e23602
bb246b9
5e23602
 
 
bb246b9
5e23602
 
 
 
 
bb246b9
5e23602
bb246b9
 
 
 
 
 
5e23602
 
 
 
bb246b9
 
 
 
 
 
 
5e23602
bb246b9
 
 
5e23602
bb246b9
 
 
 
 
 
 
5e23602
bb246b9
 
5e23602
 
 
 
4474779
5e23602
bb246b9
5e23602
 
4474779
 
bb246b9
5e23602
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4474779
5e23602
 
4474779
bb246b9
5e23602
 
 
 
bb246b9
 
 
 
5e23602
 
 
4474779
5e23602
 
bb246b9
4474779
bb246b9
5e23602
 
 
 
 
 
 
 
 
4474779
5e23602
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import torch
from PIL import Image, ImageDraw # ImageDraw pour la section de test
import os
import traceback # Pour un log d'erreur plus détaillé

# Imports spécifiques pour LLaVA
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

# --- Configuration du modèle LLaVA-NeXT ---
LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
# LLAVA_REVISION = "main" # On utilise la branche principale par défaut

llava_processor = None
llava_model = None
llava_model_loaded = False

# Détection du device (CPU, CUDA, ou MPS pour Mac Apple Silicon)
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): # Pour Apple Silicon
    device = "mps"
else:
    device = "cpu"
print(f"Utilisation du device : {device} pour les modèles d'IA.")


# --- Fonctions de chargement et de génération pour LLaVA ---
def load_llava_model():
    global llava_processor, llava_model, llava_model_loaded, device
    if llava_model_loaded:
        print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}) déjà chargé.")
        return

    try:
        print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME})...")
        llava_processor = LlavaNextProcessor.from_pretrained(
            LLAVA_MODEL_NAME
            # revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
        )
        print("Processor LLaVA chargé.")

        print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}) sur '{device}'...")
        model_args = {
            "low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU
            # revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
        }

        if device == "cuda":
            model_args["torch_dtype"] = torch.float16 # Précision pour GPU NVIDIA
            # Pour les GPU plus récents (Ampere+), torch.bfloat16 peut aussi être une option
            print(f"Configuration de LLaVA pour CUDA ({model_args['torch_dtype']}).")
        elif device == "mps":
            # Pour MPS, float16 est souvent utilisé, mais float32 est plus sûr pour commencer
            # Si des problèmes de stabilité surviennent avec float16 sur MPS, commenter la ligne suivante
            model_args["torch_dtype"] = torch.float16 
            print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut float32')}).")
        else: # CPU
            # Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité
            print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
        
        llava_model = LlavaNextForConditionalGeneration.from_pretrained(
            LLAVA_MODEL_NAME, 
            **model_args
        ).to(device).eval() # Mettre le modèle en mode évaluation
        
        llava_model_loaded = True
        print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}) chargés avec succès sur '{device}'.")

    except Exception as e:
        print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}")
        traceback.print_exc()
        llava_model_loaded = False


def generate_description_llava(image_path: str) -> str:
    global llava_processor, llava_model, llava_model_loaded, device

    if not llava_model_loaded:
        print("Modèle LLaVA non chargé. Tentative de chargement à la demande...")
        load_llava_model()
        if not llava_model_loaded: 
            return "Erreur: Le modèle LLaVA n'a pas pu être chargé."

    if not os.path.exists(image_path):
        return f"Erreur: Le fichier image {image_path} n'existe pas."

    try:
        image = Image.open(image_path).convert("RGB")
        
        # --- PROMPT AMÉLIORÉ ---
        user_prompt_fr = (
            "Analyse cette image en tant qu'œuvre d'art. "
            "Fournis une description objective, factuelle et très détaillée en français. "
            "Concentre-toi sur les éléments visuels : la scène globale, les sujets et personnages (y compris leur apparence, posture, expression, et toute forme de nudité si présente), "
            "les objets, l'arrière-plan, les formes, les couleurs, la lumière, la composition et la technique artistique apparente. "
            "Évite toute interprétation subjective ou jugement moral, et décris ce qui est visible de manière exhaustive."
        )
        
        # Format de prompt spécifique à LLaVA v1.5+ (Mistral utilise ce format)
        prompt_template = f"<s>[INST] <image>\n{user_prompt_fr} [/INST]"
        
        print(f"Préparation des entrées pour LLaVA avec le prompt utilisateur (condensé): \"{user_prompt_fr[:100]}...\"")
        
        # Le processeur gère la tokenisation du texte et le prétraitement de l'image
        inputs_on_cpu = llava_processor(text=prompt_template, images=image, return_tensors="pt")
        
        # Déplacer les tenseurs sur le bon device
        inputs = {}
        for key, value in inputs_on_cpu.items():
            if torch.is_tensor(value):
                inputs[key] = value.to(device)
            else:
                inputs[key] = value 
        
        # S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS (si float16/bfloat16)
        if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
           (llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
            for k_tensor, v_tensor in inputs.items():
                if torch.is_tensor(v_tensor) and torch.is_floating_point(v_tensor): # Appliquer seulement aux tenseurs flottants
                    inputs[k_tensor] = v_tensor.to(llava_model.dtype)
        
        input_dtypes_log = {k: v.dtype for k,v in inputs.items() if torch.is_tensor(v)}
        print(f"Génération de la description LLaVA pour {image_path} (device: {device}, input dtypes: {input_dtypes_log})...")
        
        # Paramètres de génération (ajustables si nécessaire)
        generation_kwargs = {
            "max_new_tokens": 768,  # Augmenté légèrement pour des descriptions potentiellement plus longues
            "num_beams": 3,         # Un peu de beam search peut améliorer la cohérence
            "early_stopping": True,
            "do_sample": False      # Pour des descriptions plus factuelles et moins "créatives" aléatoirement. Mettre True avec temperature si on veut plus de variété.
            # "temperature": 0.7,   # À utiliser avec do_sample=True si on veut de la créativité
            # "top_p": 0.9,         # À utiliser avec do_sample=True
        }
        
        generated_ids = llava_model.generate(**inputs, **generation_kwargs)
        
        # Pour LLaVA, il est important de décoder uniquement les tokens générés *après* le prompt.
        # Certains processeurs/modèles gèrent cela différemment.
        # Pour LLaVA-NeXT, le décodage de la séquence complète et le nettoyage du prompt est une approche courante.
        # Ou, si l'on connaît la longueur des tokens d'entrée :
        # input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
        # generated_ids_only = generated_ids[0, input_token_len:]
        # cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
        
        # Approche plus simple : décoder toute la séquence et enlever manuellement le prompt si besoin.
        # Souvent, pour les formats [INST]...[/INST], le modèle génère directement la réponse.
        full_text = llava_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
        
        # Nettoyage du texte généré pour enlever le prompt s'il est répété (spécifique au format de sortie du modèle)
        # Le format "[INST] <image> \n {prompt} [/INST] {réponse}" fait que la réponse est souvent propre.
        inst_marker = "[/INST]"
        if inst_marker in full_text:
            cleaned_text = full_text.split(inst_marker, 1)[-1].strip()
        else:
            cleaned_text = full_text # Si le marqueur n'est pas là, prendre tout (peut arriver)

        print(f"Description (nettoyée) de LLaVA: {cleaned_text[:200]}...") # Log tronqué
        return cleaned_text if cleaned_text else "Aucune description textuelle distincte n'a été générée par LLaVA."

    except Exception as e:
        print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
        traceback.print_exc() 
        if torch.cuda.is_available() or device == "mps":
            if device == "cuda": torch.cuda.empty_cache()
            # if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache()
        return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"


# --- Fonctions de gestion du modèle actif ---
ACTIVE_MODEL = "llava" # Pour l'instant, seul LLaVA est configuré

def load_active_model():
    print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
    if ACTIVE_MODEL == "llava":
        load_llava_model()
    else:
        print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")

def generate_active_description(image_path: str) -> str:
    if ACTIVE_MODEL == "llava":
        return generate_description_llava(image_path)
    else:
        error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
        print(error_msg)
        return error_msg

def is_active_model_loaded() -> bool:
    if ACTIVE_MODEL == "llava":
        return llava_model_loaded
    return False

# --- Section de Test (pour exécution directe de ce fichier) ---
if __name__ == '__main__':
    print("Début du test de processing.py...")
    
    dummy_image_name = "dummy_test_image.png" 
    if not os.path.exists(dummy_image_name):
        try:
            img = Image.new('RGB', (200, 150), color = 'skyblue')
            draw = ImageDraw.Draw(img)
            draw.text((10, 10), "Test Image for LLaVA", fill='black')
            # Ajouter quelques formes pour le test
            draw.ellipse((30, 50, 90, 110), fill='red', outline='black')
            draw.rectangle((100, 40, 170, 120), fill='lightgreen', outline='blue')
            img.save(dummy_image_name)
            print(f"Image de test '{dummy_image_name}' créée.")
        except Exception as e_img:
            print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}")

    if os.path.exists(dummy_image_name):
        print(f"Utilisation du modèle actif : {ACTIVE_MODEL} sur device {device}")
        print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...")
        load_active_model() 
        if is_active_model_loaded():
            print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
            description = generate_active_description(dummy_image_name)
            print(f"\n--- Description Générée ---")
            print(description)
            print(f"--------------------------")
        else:
            print("Le modèle actif n'a pas pu être chargé. Test de description annulé.")
    else:
        print(f"Image de test '{dummy_image_name}' non trouvée pour le test.")
    print("Fin du test de processing.py.")