Ludovic commited on
Commit
4474779
·
1 Parent(s): 58f0122
Files changed (1) hide show
  1. app/processing.py +55 -36
app/processing.py CHANGED
@@ -1,16 +1,17 @@
1
  import torch
2
  from PIL import Image, ImageDraw # ImageDraw pour la section de test
3
  import os
4
- import traceback
5
 
6
  # Imports spécifiques pour LLaVA
7
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
8
 
9
  # --- Configuration du modèle LLaVA-NeXT ---
10
  LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
11
- # Hash de commit de la branche 'main' de LLaVA au moment des tests.
12
- # Vérifiez le plus récent sur https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/commits/main
13
- LLAVA_REVISION = "082142fd2997099498027732cf8e945044bf48c3"
 
14
 
15
  llava_processor = None
16
  llava_model = None
@@ -25,37 +26,39 @@ else:
25
  device = "cpu"
26
  print(f"Utilisation du device : {device} pour les modèles d'IA.")
27
 
 
 
28
  def load_llava_model():
29
  global llava_processor, llava_model, llava_model_loaded, device
30
  if llava_model_loaded:
31
- print(f"Modèle LLaVA ({LLAVA_MODEL_NAME} rev {LLAVA_REVISION}) déjà chargé.")
32
  return
33
 
34
  try:
35
  print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main')...")
36
  llava_processor = LlavaNextProcessor.from_pretrained(
37
  LLAVA_MODEL_NAME
38
- # PAS DE revision=LLAVA_REVISION ICI POUR LE PROCESSEUR
39
  )
40
  print("Processor LLaVA chargé.")
41
 
42
- print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME} rev {LLAVA_REVISION}) sur '{device}'...")
43
  model_args = {
44
- "revision": LLAVA_REVISION, # Épinglage de la révision
45
- "low_cpu_mem_usage": True,
46
  }
 
47
  if device == "cpu":
48
- # Pas de torch_dtype pour CPU, utilise float32 par défaut pour plus de stabilité
49
  print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
50
  elif device == "cuda":
51
  model_args["torch_dtype"] = torch.float16 # ou torch.bfloat16 si GPU récent (Ampere+)
52
- print(f"Configuration de LLaVA pour CUDA ({model_args['torch_dtype']}).")
53
  elif device == "mps":
54
- # Pour MPS, float16 est souvent utilisé, mais float32 est plus sûr pour commencer.
55
- # Laisser float32 par défaut (pas de torch_dtype) est une option.
56
- # Ou essayez float16 :
57
- model_args["torch_dtype"] = torch.float16
58
- print(f"Configuration de LLaVA pour MPS ({model_args['torch_dtype']}).")
59
 
60
  llava_model = LlavaNextForConditionalGeneration.from_pretrained(
61
  LLAVA_MODEL_NAME,
@@ -63,10 +66,10 @@ def load_llava_model():
63
  ).to(device).eval()
64
 
65
  llava_model_loaded = True
66
- print(f"Modèle LLaVA ({LLAVA_MODEL_NAME} rev {LLAVA_REVISION}) chargé avec succès sur '{device}'.")
67
 
68
  except Exception as e:
69
- print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME} rev {LLAVA_REVISION}): {e}")
70
  traceback.print_exc()
71
  llava_model_loaded = False
72
 
@@ -75,10 +78,10 @@ def generate_description_llava(image_path: str) -> str:
75
  global llava_processor, llava_model, llava_model_loaded, device
76
 
77
  if not llava_model_loaded:
78
- print("Modèle LLaVA non chargé. Tentative de chargement...")
79
  load_llava_model()
80
- if not llava_model_loaded:
81
- return "Erreur: Le modèle LLaVA n'a pas pu être chargé."
82
 
83
  if not os.path.exists(image_path):
84
  return f"Erreur: Le fichier image {image_path} n'existe pas."
@@ -86,13 +89,15 @@ def generate_description_llava(image_path: str) -> str:
86
  try:
87
  image = Image.open(image_path).convert("RGB")
88
 
89
- # Choix du prompt (anglais par défaut, comme demandé)
90
  user_prompt = "Describe this image in English with precision and detail."
91
- # user_prompt = "Décris cette image en français avec précision et de manière détaillée." # Si vous voulez du français
 
92
 
 
93
  prompt_text = f"<s>[INST] <image>\n{user_prompt} [/INST]"
94
 
95
- print(f"Préparation des entrées pour LLaVA avec le prompt: {user_prompt}")
96
  inputs_on_cpu = llava_processor(text=prompt_text, images=image, return_tensors="pt")
97
 
98
  inputs = {}
@@ -100,8 +105,9 @@ def generate_description_llava(image_path: str) -> str:
100
  if torch.is_tensor(value):
101
  inputs[key] = value.to(device)
102
  else:
103
- inputs[key] = value
104
 
 
105
  if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
106
  (llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
107
  for k_tensor, v_tensor in inputs.items():
@@ -114,18 +120,18 @@ def generate_description_llava(image_path: str) -> str:
114
  generation_kwargs = {
115
  "max_new_tokens": 768,
116
  "num_beams": 3,
117
- "early_stopping": True
118
  }
119
 
120
  generated_ids = llava_model.generate(**inputs, **generation_kwargs)
121
 
122
  input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
123
- generated_ids_only = generated_ids[0, input_token_len:]
124
 
125
  cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
126
 
127
- # Nettoyage supplémentaire si nécessaire (ex: enlever des marqueurs résiduels)
128
- inst_marker_space = " [/INST]" # Avec espace avant, comme souvent produit
129
  inst_marker_no_space = "[/INST]"
130
  if cleaned_text.startswith(inst_marker_space):
131
  cleaned_text = cleaned_text[len(inst_marker_space):].strip()
@@ -138,23 +144,30 @@ def generate_description_llava(image_path: str) -> str:
138
  except Exception as e:
139
  print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
140
  traceback.print_exc()
141
- if torch.cuda.is_available() or device == "mps": # Vider le cache si GPU/MPS
142
  if device == "cuda": torch.cuda.empty_cache()
143
- # if device == "mps": torch.mps.empty_cache() # Si disponible et nécessaire
144
  return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
145
 
 
 
146
  ACTIVE_MODEL = "llava"
147
 
148
  def load_active_model():
149
  print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
150
  if ACTIVE_MODEL == "llava":
151
  load_llava_model()
 
 
 
152
  else:
153
  print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
154
 
155
  def generate_active_description(image_path: str) -> str:
156
  if ACTIVE_MODEL == "llava":
157
  return generate_description_llava(image_path)
 
 
158
  else:
159
  error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
160
  print(error_msg)
@@ -163,25 +176,31 @@ def generate_active_description(image_path: str) -> str:
163
  def is_active_model_loaded() -> bool:
164
  if ACTIVE_MODEL == "llava":
165
  return llava_model_loaded
 
 
166
  return False
167
 
 
168
  if __name__ == '__main__':
169
  print("Début du test de processing.py...")
170
- dummy_image_name = "dummy_test_image.png"
 
 
171
  if not os.path.exists(dummy_image_name):
172
  try:
 
173
  img = Image.new('RGB', (200, 150), color = 'skyblue')
174
  draw = ImageDraw.Draw(img)
175
  draw.text((10, 10), "Test Image", fill='black')
176
  img.save(dummy_image_name)
177
  print(f"Image de test '{dummy_image_name}' créée.")
178
  except Exception as e_img:
179
- print(f"Impossible de créer l'image de test : {e_img}")
180
 
181
  if os.path.exists(dummy_image_name):
182
  print(f"Utilisation du modèle actif : {ACTIVE_MODEL}")
183
- print("Chargement du modèle actif (peut prendre du temps)...")
184
- load_active_model()
185
  if is_active_model_loaded():
186
  print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
187
  description = generate_active_description(dummy_image_name)
@@ -191,5 +210,5 @@ if __name__ == '__main__':
191
  else:
192
  print("Le modèle actif n'a pas pu être chargé. Test de description annulé.")
193
  else:
194
- print(f"Image de test '{dummy_image_name}' non trouvée. Test de description annulé.")
195
  print("Fin du test de processing.py.")
 
1
  import torch
2
  from PIL import Image, ImageDraw # ImageDraw pour la section de test
3
  import os
4
+ import traceback # Pour un log d'erreur plus détaillé
5
 
6
  # Imports spécifiques pour LLaVA
7
  from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
8
 
9
  # --- Configuration du modèle LLaVA-NeXT ---
10
  LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
11
+ # La variable LLAVA_REVISION est définie ici au cas nous voudrions épingler une version spécifique plus tard,
12
+ # une fois que nous aurons confirmé que tout fonctionne bien avec la version 'main'.
13
+ # Pour l'instant, elle n'est PAS utilisée dans les appels from_pretrained().
14
+ LLAVA_REVISION = "082142fd2997099498027732cf8e945044bf48c3" # Exemple de hash, non utilisé ci-dessous
15
 
16
  llava_processor = None
17
  llava_model = None
 
26
  device = "cpu"
27
  print(f"Utilisation du device : {device} pour les modèles d'IA.")
28
 
29
+
30
+ # --- Fonctions de chargement et de génération pour LLaVA ---
31
  def load_llava_model():
32
  global llava_processor, llava_model, llava_model_loaded, device
33
  if llava_model_loaded:
34
+ print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main') déjà chargé.")
35
  return
36
 
37
  try:
38
  print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main')...")
39
  llava_processor = LlavaNextProcessor.from_pretrained(
40
  LLAVA_MODEL_NAME
41
+ # Pas de 'revision' ici, on charge depuis la branche 'main'
42
  )
43
  print("Processor LLaVA chargé.")
44
 
45
+ print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main') sur '{device}'...")
46
  model_args = {
47
+ # Pas de 'revision' ici non plus pour le moment
48
+ "low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU lors du chargement initial
49
  }
50
+
51
  if device == "cpu":
52
+ # Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité
53
  print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
54
  elif device == "cuda":
55
  model_args["torch_dtype"] = torch.float16 # ou torch.bfloat16 si GPU récent (Ampere+)
56
+ print(f"Configuration de LLaVA pour CUDA ({model_args.get('torch_dtype', 'par défaut')}).")
57
  elif device == "mps":
58
+ # Pour MPS, float16 peut offrir des gains de vitesse. float32 est plus sûr pour commencer.
59
+ # Laisser float32 par défaut (pas de torch_dtype) est une option, ou essayer float16.
60
+ model_args["torch_dtype"] = torch.float16 # Essayons float16 pour MPS
61
+ print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut')}).")
 
62
 
63
  llava_model = LlavaNextForConditionalGeneration.from_pretrained(
64
  LLAVA_MODEL_NAME,
 
66
  ).to(device).eval()
67
 
68
  llava_model_loaded = True
69
+ print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}, tous deux depuis branche 'main') chargés avec succès sur '{device}'.")
70
 
71
  except Exception as e:
72
+ print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}")
73
  traceback.print_exc()
74
  llava_model_loaded = False
75
 
 
78
  global llava_processor, llava_model, llava_model_loaded, device
79
 
80
  if not llava_model_loaded:
81
+ print("Modèle LLaVA non chargé dans generate_description_llava. Tentative de chargement...")
82
  load_llava_model()
83
+ if not llava_model_loaded:
84
+ return "Erreur: Le modèle LLaVA n'a pas pu être chargé (échec lors de la tentative à la demande)."
85
 
86
  if not os.path.exists(image_path):
87
  return f"Erreur: Le fichier image {image_path} n'existe pas."
 
89
  try:
90
  image = Image.open(image_path).convert("RGB")
91
 
92
+ # Prompt en anglais par défaut
93
  user_prompt = "Describe this image in English with precision and detail."
94
+ # Pour du français :
95
+ # user_prompt = "Décris cette image en français avec précision et de manière détaillée."
96
 
97
+ # Format de prompt pour LLaVA v1.6
98
  prompt_text = f"<s>[INST] <image>\n{user_prompt} [/INST]"
99
 
100
+ print(f"Préparation des entrées pour LLaVA avec le prompt: \"{user_prompt}\"")
101
  inputs_on_cpu = llava_processor(text=prompt_text, images=image, return_tensors="pt")
102
 
103
  inputs = {}
 
105
  if torch.is_tensor(value):
106
  inputs[key] = value.to(device)
107
  else:
108
+ inputs[key] = value # Conserver d'autres types si présents
109
 
110
+ # S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS
111
  if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
112
  (llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
113
  for k_tensor, v_tensor in inputs.items():
 
120
  generation_kwargs = {
121
  "max_new_tokens": 768,
122
  "num_beams": 3,
123
+ "early_stopping": True
124
  }
125
 
126
  generated_ids = llava_model.generate(**inputs, **generation_kwargs)
127
 
128
  input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
129
+ generated_ids_only = generated_ids[0, input_token_len:] # Extraire seulement les tokens générés
130
 
131
  cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
132
 
133
+ # Nettoyage supplémentaire si le marqueur [/INST] est toujours présent (peu probable avec ce décodage)
134
+ inst_marker_space = " [/INST]"
135
  inst_marker_no_space = "[/INST]"
136
  if cleaned_text.startswith(inst_marker_space):
137
  cleaned_text = cleaned_text[len(inst_marker_space):].strip()
 
144
  except Exception as e:
145
  print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
146
  traceback.print_exc()
147
+ if torch.cuda.is_available() or device == "mps":
148
  if device == "cuda": torch.cuda.empty_cache()
149
+ # if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache() # Pour PyTorch >= 1.13
150
  return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
151
 
152
+
153
+ # --- Fonctions de gestion du modèle actif ---
154
  ACTIVE_MODEL = "llava"
155
 
156
  def load_active_model():
157
  print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
158
  if ACTIVE_MODEL == "llava":
159
  load_llava_model()
160
+ # Ajoutez d'autres conditions ici si vous réactivez d'autres modèles
161
+ # elif ACTIVE_MODEL == "florence":
162
+ # load_florence_model()
163
  else:
164
  print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
165
 
166
  def generate_active_description(image_path: str) -> str:
167
  if ACTIVE_MODEL == "llava":
168
  return generate_description_llava(image_path)
169
+ # elif ACTIVE_MODEL == "florence":
170
+ # return generate_description_florence(image_path)
171
  else:
172
  error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
173
  print(error_msg)
 
176
  def is_active_model_loaded() -> bool:
177
  if ACTIVE_MODEL == "llava":
178
  return llava_model_loaded
179
+ # elif ACTIVE_MODEL == "florence":
180
+ # return florence_model_loaded
181
  return False
182
 
183
+ # --- Section de Test (pour exécution directe de ce fichier) ---
184
  if __name__ == '__main__':
185
  print("Début du test de processing.py...")
186
+
187
+ # Créer une image de test factice si elle n'existe pas
188
+ dummy_image_name = "dummy_test_image.png" # S'assure qu'elle est bien ignorée par .gitignore si elle est créée
189
  if not os.path.exists(dummy_image_name):
190
  try:
191
+ # ImageDraw a été importé en haut avec PIL
192
  img = Image.new('RGB', (200, 150), color = 'skyblue')
193
  draw = ImageDraw.Draw(img)
194
  draw.text((10, 10), "Test Image", fill='black')
195
  img.save(dummy_image_name)
196
  print(f"Image de test '{dummy_image_name}' créée.")
197
  except Exception as e_img:
198
+ print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}")
199
 
200
  if os.path.exists(dummy_image_name):
201
  print(f"Utilisation du modèle actif : {ACTIVE_MODEL}")
202
+ print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...")
203
+ load_active_model() # Tente de charger le modèle
204
  if is_active_model_loaded():
205
  print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
206
  description = generate_active_description(dummy_image_name)
 
210
  else:
211
  print("Le modèle actif n'a pas pu être chargé. Test de description annulé.")
212
  else:
213
+ print(f"Image de test '{dummy_image_name}' non trouvée pour le test.")
214
  print("Fin du test de processing.py.")