Spaces:
Sleeping
Sleeping
Ludovic commited on
Commit ·
4474779
1
Parent(s): 58f0122
cor 4
Browse files- app/processing.py +55 -36
app/processing.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
| 1 |
import torch
|
| 2 |
from PIL import Image, ImageDraw # ImageDraw pour la section de test
|
| 3 |
import os
|
| 4 |
-
import traceback
|
| 5 |
|
| 6 |
# Imports spécifiques pour LLaVA
|
| 7 |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
|
| 8 |
|
| 9 |
# --- Configuration du modèle LLaVA-NeXT ---
|
| 10 |
LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
| 11 |
-
#
|
| 12 |
-
#
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
llava_processor = None
|
| 16 |
llava_model = None
|
|
@@ -25,37 +26,39 @@ else:
|
|
| 25 |
device = "cpu"
|
| 26 |
print(f"Utilisation du device : {device} pour les modèles d'IA.")
|
| 27 |
|
|
|
|
|
|
|
| 28 |
def load_llava_model():
|
| 29 |
global llava_processor, llava_model, llava_model_loaded, device
|
| 30 |
if llava_model_loaded:
|
| 31 |
-
print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}
|
| 32 |
return
|
| 33 |
|
| 34 |
try:
|
| 35 |
print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main')...")
|
| 36 |
llava_processor = LlavaNextProcessor.from_pretrained(
|
| 37 |
LLAVA_MODEL_NAME
|
| 38 |
-
#
|
| 39 |
)
|
| 40 |
print("Processor LLaVA chargé.")
|
| 41 |
|
| 42 |
-
print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}
|
| 43 |
model_args = {
|
| 44 |
-
|
| 45 |
-
"low_cpu_mem_usage": True,
|
| 46 |
}
|
|
|
|
| 47 |
if device == "cpu":
|
| 48 |
-
# Pas de torch_dtype
|
| 49 |
print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
|
| 50 |
elif device == "cuda":
|
| 51 |
model_args["torch_dtype"] = torch.float16 # ou torch.bfloat16 si GPU récent (Ampere+)
|
| 52 |
-
print(f"Configuration de LLaVA pour CUDA ({model_args
|
| 53 |
elif device == "mps":
|
| 54 |
-
# Pour MPS, float16
|
| 55 |
-
# Laisser float32 par défaut (pas de torch_dtype) est une option.
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
print(f"Configuration de LLaVA pour MPS ({model_args['torch_dtype']}).")
|
| 59 |
|
| 60 |
llava_model = LlavaNextForConditionalGeneration.from_pretrained(
|
| 61 |
LLAVA_MODEL_NAME,
|
|
@@ -63,10 +66,10 @@ def load_llava_model():
|
|
| 63 |
).to(device).eval()
|
| 64 |
|
| 65 |
llava_model_loaded = True
|
| 66 |
-
print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}
|
| 67 |
|
| 68 |
except Exception as e:
|
| 69 |
-
print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}
|
| 70 |
traceback.print_exc()
|
| 71 |
llava_model_loaded = False
|
| 72 |
|
|
@@ -75,10 +78,10 @@ def generate_description_llava(image_path: str) -> str:
|
|
| 75 |
global llava_processor, llava_model, llava_model_loaded, device
|
| 76 |
|
| 77 |
if not llava_model_loaded:
|
| 78 |
-
print("Modèle LLaVA non chargé. Tentative de chargement...")
|
| 79 |
load_llava_model()
|
| 80 |
-
if not llava_model_loaded:
|
| 81 |
-
return "Erreur: Le modèle LLaVA n'a pas pu être chargé."
|
| 82 |
|
| 83 |
if not os.path.exists(image_path):
|
| 84 |
return f"Erreur: Le fichier image {image_path} n'existe pas."
|
|
@@ -86,13 +89,15 @@ def generate_description_llava(image_path: str) -> str:
|
|
| 86 |
try:
|
| 87 |
image = Image.open(image_path).convert("RGB")
|
| 88 |
|
| 89 |
-
#
|
| 90 |
user_prompt = "Describe this image in English with precision and detail."
|
| 91 |
-
#
|
|
|
|
| 92 |
|
|
|
|
| 93 |
prompt_text = f"<s>[INST] <image>\n{user_prompt} [/INST]"
|
| 94 |
|
| 95 |
-
print(f"Préparation des entrées pour LLaVA avec le prompt: {user_prompt}")
|
| 96 |
inputs_on_cpu = llava_processor(text=prompt_text, images=image, return_tensors="pt")
|
| 97 |
|
| 98 |
inputs = {}
|
|
@@ -100,8 +105,9 @@ def generate_description_llava(image_path: str) -> str:
|
|
| 100 |
if torch.is_tensor(value):
|
| 101 |
inputs[key] = value.to(device)
|
| 102 |
else:
|
| 103 |
-
inputs[key] = value
|
| 104 |
|
|
|
|
| 105 |
if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
|
| 106 |
(llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
|
| 107 |
for k_tensor, v_tensor in inputs.items():
|
|
@@ -114,18 +120,18 @@ def generate_description_llava(image_path: str) -> str:
|
|
| 114 |
generation_kwargs = {
|
| 115 |
"max_new_tokens": 768,
|
| 116 |
"num_beams": 3,
|
| 117 |
-
"early_stopping": True
|
| 118 |
}
|
| 119 |
|
| 120 |
generated_ids = llava_model.generate(**inputs, **generation_kwargs)
|
| 121 |
|
| 122 |
input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
|
| 123 |
-
generated_ids_only = generated_ids[0, input_token_len:]
|
| 124 |
|
| 125 |
cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
|
| 126 |
|
| 127 |
-
# Nettoyage supplémentaire si
|
| 128 |
-
inst_marker_space = " [/INST]"
|
| 129 |
inst_marker_no_space = "[/INST]"
|
| 130 |
if cleaned_text.startswith(inst_marker_space):
|
| 131 |
cleaned_text = cleaned_text[len(inst_marker_space):].strip()
|
|
@@ -138,23 +144,30 @@ def generate_description_llava(image_path: str) -> str:
|
|
| 138 |
except Exception as e:
|
| 139 |
print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
|
| 140 |
traceback.print_exc()
|
| 141 |
-
if torch.cuda.is_available() or device == "mps":
|
| 142 |
if device == "cuda": torch.cuda.empty_cache()
|
| 143 |
-
# if device == "mps": torch.mps.empty_cache() #
|
| 144 |
return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
|
| 145 |
|
|
|
|
|
|
|
| 146 |
ACTIVE_MODEL = "llava"
|
| 147 |
|
| 148 |
def load_active_model():
|
| 149 |
print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
|
| 150 |
if ACTIVE_MODEL == "llava":
|
| 151 |
load_llava_model()
|
|
|
|
|
|
|
|
|
|
| 152 |
else:
|
| 153 |
print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
|
| 154 |
|
| 155 |
def generate_active_description(image_path: str) -> str:
|
| 156 |
if ACTIVE_MODEL == "llava":
|
| 157 |
return generate_description_llava(image_path)
|
|
|
|
|
|
|
| 158 |
else:
|
| 159 |
error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
|
| 160 |
print(error_msg)
|
|
@@ -163,25 +176,31 @@ def generate_active_description(image_path: str) -> str:
|
|
| 163 |
def is_active_model_loaded() -> bool:
|
| 164 |
if ACTIVE_MODEL == "llava":
|
| 165 |
return llava_model_loaded
|
|
|
|
|
|
|
| 166 |
return False
|
| 167 |
|
|
|
|
| 168 |
if __name__ == '__main__':
|
| 169 |
print("Début du test de processing.py...")
|
| 170 |
-
|
|
|
|
|
|
|
| 171 |
if not os.path.exists(dummy_image_name):
|
| 172 |
try:
|
|
|
|
| 173 |
img = Image.new('RGB', (200, 150), color = 'skyblue')
|
| 174 |
draw = ImageDraw.Draw(img)
|
| 175 |
draw.text((10, 10), "Test Image", fill='black')
|
| 176 |
img.save(dummy_image_name)
|
| 177 |
print(f"Image de test '{dummy_image_name}' créée.")
|
| 178 |
except Exception as e_img:
|
| 179 |
-
print(f"Impossible de créer l'image de test : {e_img}")
|
| 180 |
|
| 181 |
if os.path.exists(dummy_image_name):
|
| 182 |
print(f"Utilisation du modèle actif : {ACTIVE_MODEL}")
|
| 183 |
-
print("Chargement du modèle actif (peut prendre du temps)...")
|
| 184 |
-
load_active_model()
|
| 185 |
if is_active_model_loaded():
|
| 186 |
print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
|
| 187 |
description = generate_active_description(dummy_image_name)
|
|
@@ -191,5 +210,5 @@ if __name__ == '__main__':
|
|
| 191 |
else:
|
| 192 |
print("Le modèle actif n'a pas pu être chargé. Test de description annulé.")
|
| 193 |
else:
|
| 194 |
-
print(f"Image de test '{dummy_image_name}' non trouvée
|
| 195 |
print("Fin du test de processing.py.")
|
|
|
|
| 1 |
import torch
|
| 2 |
from PIL import Image, ImageDraw # ImageDraw pour la section de test
|
| 3 |
import os
|
| 4 |
+
import traceback # Pour un log d'erreur plus détaillé
|
| 5 |
|
| 6 |
# Imports spécifiques pour LLaVA
|
| 7 |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
|
| 8 |
|
| 9 |
# --- Configuration du modèle LLaVA-NeXT ---
|
| 10 |
LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
| 11 |
+
# La variable LLAVA_REVISION est définie ici au cas où nous voudrions épingler une version spécifique plus tard,
|
| 12 |
+
# une fois que nous aurons confirmé que tout fonctionne bien avec la version 'main'.
|
| 13 |
+
# Pour l'instant, elle n'est PAS utilisée dans les appels from_pretrained().
|
| 14 |
+
LLAVA_REVISION = "082142fd2997099498027732cf8e945044bf48c3" # Exemple de hash, non utilisé ci-dessous
|
| 15 |
|
| 16 |
llava_processor = None
|
| 17 |
llava_model = None
|
|
|
|
| 26 |
device = "cpu"
|
| 27 |
print(f"Utilisation du device : {device} pour les modèles d'IA.")
|
| 28 |
|
| 29 |
+
|
| 30 |
+
# --- Fonctions de chargement et de génération pour LLaVA ---
|
| 31 |
def load_llava_model():
|
| 32 |
global llava_processor, llava_model, llava_model_loaded, device
|
| 33 |
if llava_model_loaded:
|
| 34 |
+
print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main') déjà chargé.")
|
| 35 |
return
|
| 36 |
|
| 37 |
try:
|
| 38 |
print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main')...")
|
| 39 |
llava_processor = LlavaNextProcessor.from_pretrained(
|
| 40 |
LLAVA_MODEL_NAME
|
| 41 |
+
# Pas de 'revision' ici, on charge depuis la branche 'main'
|
| 42 |
)
|
| 43 |
print("Processor LLaVA chargé.")
|
| 44 |
|
| 45 |
+
print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main') sur '{device}'...")
|
| 46 |
model_args = {
|
| 47 |
+
# Pas de 'revision' ici non plus pour le moment
|
| 48 |
+
"low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU lors du chargement initial
|
| 49 |
}
|
| 50 |
+
|
| 51 |
if device == "cpu":
|
| 52 |
+
# Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité
|
| 53 |
print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
|
| 54 |
elif device == "cuda":
|
| 55 |
model_args["torch_dtype"] = torch.float16 # ou torch.bfloat16 si GPU récent (Ampere+)
|
| 56 |
+
print(f"Configuration de LLaVA pour CUDA ({model_args.get('torch_dtype', 'par défaut')}).")
|
| 57 |
elif device == "mps":
|
| 58 |
+
# Pour MPS, float16 peut offrir des gains de vitesse. float32 est plus sûr pour commencer.
|
| 59 |
+
# Laisser float32 par défaut (pas de torch_dtype) est une option, ou essayer float16.
|
| 60 |
+
model_args["torch_dtype"] = torch.float16 # Essayons float16 pour MPS
|
| 61 |
+
print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut')}).")
|
|
|
|
| 62 |
|
| 63 |
llava_model = LlavaNextForConditionalGeneration.from_pretrained(
|
| 64 |
LLAVA_MODEL_NAME,
|
|
|
|
| 66 |
).to(device).eval()
|
| 67 |
|
| 68 |
llava_model_loaded = True
|
| 69 |
+
print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}, tous deux depuis branche 'main') chargés avec succès sur '{device}'.")
|
| 70 |
|
| 71 |
except Exception as e:
|
| 72 |
+
print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}")
|
| 73 |
traceback.print_exc()
|
| 74 |
llava_model_loaded = False
|
| 75 |
|
|
|
|
| 78 |
global llava_processor, llava_model, llava_model_loaded, device
|
| 79 |
|
| 80 |
if not llava_model_loaded:
|
| 81 |
+
print("Modèle LLaVA non chargé dans generate_description_llava. Tentative de chargement...")
|
| 82 |
load_llava_model()
|
| 83 |
+
if not llava_model_loaded:
|
| 84 |
+
return "Erreur: Le modèle LLaVA n'a pas pu être chargé (échec lors de la tentative à la demande)."
|
| 85 |
|
| 86 |
if not os.path.exists(image_path):
|
| 87 |
return f"Erreur: Le fichier image {image_path} n'existe pas."
|
|
|
|
| 89 |
try:
|
| 90 |
image = Image.open(image_path).convert("RGB")
|
| 91 |
|
| 92 |
+
# Prompt en anglais par défaut
|
| 93 |
user_prompt = "Describe this image in English with precision and detail."
|
| 94 |
+
# Pour du français :
|
| 95 |
+
# user_prompt = "Décris cette image en français avec précision et de manière détaillée."
|
| 96 |
|
| 97 |
+
# Format de prompt pour LLaVA v1.6
|
| 98 |
prompt_text = f"<s>[INST] <image>\n{user_prompt} [/INST]"
|
| 99 |
|
| 100 |
+
print(f"Préparation des entrées pour LLaVA avec le prompt: \"{user_prompt}\"")
|
| 101 |
inputs_on_cpu = llava_processor(text=prompt_text, images=image, return_tensors="pt")
|
| 102 |
|
| 103 |
inputs = {}
|
|
|
|
| 105 |
if torch.is_tensor(value):
|
| 106 |
inputs[key] = value.to(device)
|
| 107 |
else:
|
| 108 |
+
inputs[key] = value # Conserver d'autres types si présents
|
| 109 |
|
| 110 |
+
# S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS
|
| 111 |
if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
|
| 112 |
(llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
|
| 113 |
for k_tensor, v_tensor in inputs.items():
|
|
|
|
| 120 |
generation_kwargs = {
|
| 121 |
"max_new_tokens": 768,
|
| 122 |
"num_beams": 3,
|
| 123 |
+
"early_stopping": True
|
| 124 |
}
|
| 125 |
|
| 126 |
generated_ids = llava_model.generate(**inputs, **generation_kwargs)
|
| 127 |
|
| 128 |
input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
|
| 129 |
+
generated_ids_only = generated_ids[0, input_token_len:] # Extraire seulement les tokens générés
|
| 130 |
|
| 131 |
cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
|
| 132 |
|
| 133 |
+
# Nettoyage supplémentaire si le marqueur [/INST] est toujours présent (peu probable avec ce décodage)
|
| 134 |
+
inst_marker_space = " [/INST]"
|
| 135 |
inst_marker_no_space = "[/INST]"
|
| 136 |
if cleaned_text.startswith(inst_marker_space):
|
| 137 |
cleaned_text = cleaned_text[len(inst_marker_space):].strip()
|
|
|
|
| 144 |
except Exception as e:
|
| 145 |
print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
|
| 146 |
traceback.print_exc()
|
| 147 |
+
if torch.cuda.is_available() or device == "mps":
|
| 148 |
if device == "cuda": torch.cuda.empty_cache()
|
| 149 |
+
# if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache() # Pour PyTorch >= 1.13
|
| 150 |
return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
|
| 151 |
|
| 152 |
+
|
| 153 |
+
# --- Fonctions de gestion du modèle actif ---
|
| 154 |
ACTIVE_MODEL = "llava"
|
| 155 |
|
| 156 |
def load_active_model():
|
| 157 |
print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
|
| 158 |
if ACTIVE_MODEL == "llava":
|
| 159 |
load_llava_model()
|
| 160 |
+
# Ajoutez d'autres conditions ici si vous réactivez d'autres modèles
|
| 161 |
+
# elif ACTIVE_MODEL == "florence":
|
| 162 |
+
# load_florence_model()
|
| 163 |
else:
|
| 164 |
print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
|
| 165 |
|
| 166 |
def generate_active_description(image_path: str) -> str:
|
| 167 |
if ACTIVE_MODEL == "llava":
|
| 168 |
return generate_description_llava(image_path)
|
| 169 |
+
# elif ACTIVE_MODEL == "florence":
|
| 170 |
+
# return generate_description_florence(image_path)
|
| 171 |
else:
|
| 172 |
error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
|
| 173 |
print(error_msg)
|
|
|
|
| 176 |
def is_active_model_loaded() -> bool:
|
| 177 |
if ACTIVE_MODEL == "llava":
|
| 178 |
return llava_model_loaded
|
| 179 |
+
# elif ACTIVE_MODEL == "florence":
|
| 180 |
+
# return florence_model_loaded
|
| 181 |
return False
|
| 182 |
|
| 183 |
+
# --- Section de Test (pour exécution directe de ce fichier) ---
|
| 184 |
if __name__ == '__main__':
|
| 185 |
print("Début du test de processing.py...")
|
| 186 |
+
|
| 187 |
+
# Créer une image de test factice si elle n'existe pas
|
| 188 |
+
dummy_image_name = "dummy_test_image.png" # S'assure qu'elle est bien ignorée par .gitignore si elle est créée
|
| 189 |
if not os.path.exists(dummy_image_name):
|
| 190 |
try:
|
| 191 |
+
# ImageDraw a été importé en haut avec PIL
|
| 192 |
img = Image.new('RGB', (200, 150), color = 'skyblue')
|
| 193 |
draw = ImageDraw.Draw(img)
|
| 194 |
draw.text((10, 10), "Test Image", fill='black')
|
| 195 |
img.save(dummy_image_name)
|
| 196 |
print(f"Image de test '{dummy_image_name}' créée.")
|
| 197 |
except Exception as e_img:
|
| 198 |
+
print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}")
|
| 199 |
|
| 200 |
if os.path.exists(dummy_image_name):
|
| 201 |
print(f"Utilisation du modèle actif : {ACTIVE_MODEL}")
|
| 202 |
+
print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...")
|
| 203 |
+
load_active_model() # Tente de charger le modèle
|
| 204 |
if is_active_model_loaded():
|
| 205 |
print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
|
| 206 |
description = generate_active_description(dummy_image_name)
|
|
|
|
| 210 |
else:
|
| 211 |
print("Le modèle actif n'a pas pu être chargé. Test de description annulé.")
|
| 212 |
else:
|
| 213 |
+
print(f"Image de test '{dummy_image_name}' non trouvée pour le test.")
|
| 214 |
print("Fin du test de processing.py.")
|