Ludovic commited on
Commit
bb246b9
·
1 Parent(s): 1a6e3f8
Files changed (3) hide show
  1. app/main.py +58 -29
  2. app/processing.py +77 -67
  3. requirements.txt +5 -4
app/main.py CHANGED
@@ -4,7 +4,7 @@ import uuid
4
  import secrets
5
  import zipfile
6
  import io
7
- import re # Toujours utile pour obtenir une extension propre
8
 
9
  from fastapi import Depends, FastAPI, File, UploadFile, Request, HTTPException, status
10
  from fastapi.responses import HTMLResponse, StreamingResponse
@@ -12,7 +12,7 @@ from fastapi.staticfiles import StaticFiles
12
  from fastapi.templating import Jinja2Templates
13
  from fastapi.security import HTTPBasic, HTTPBasicCredentials
14
  from passlib.context import CryptContext
15
- from typing import List
16
 
17
  from . import processing
18
  from . import utils
@@ -23,7 +23,7 @@ security = HTTPBasic()
23
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
24
 
25
  APP_USERNAME_DEFAULT = "admin"
26
- APP_PASSWORD_DEFAULT = "changezceci"
27
 
28
  APP_USERNAME = os.environ.get("APP_USERNAME", APP_USERNAME_DEFAULT)
29
  APP_PASSWORD_RAW = os.environ.get("APP_PASSWORD", APP_PASSWORD_DEFAULT)
@@ -66,14 +66,30 @@ os.makedirs(OUTPUT_CAPTION_DIR, exist_ok=True)
66
  app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
67
  templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
68
 
69
- def get_safe_extension(filename: str) -> str:
70
- """Extrait et nettoie l'extension d'un nom de fichier."""
71
- name, ext = os.path.splitext(filename)
72
- # Garder uniquement les caractères alphanumériques pour l'extension, et s'assurer qu'elle commence par un point.
73
- safe_ext = re.sub(r'[^a-zA-Z0-9]', '', ext).lower()
74
- if not safe_ext: # Si pas d'extension valide trouvée (ex: fichier sans extension ou avec des caractères bizarres)
75
- return ".img" # Extension par défaut
76
- return f".{safe_ext}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
79
  @app.get("/", response_class=HTMLResponse)
@@ -101,7 +117,10 @@ async def upload_images_for_captioning(
101
 
102
  zip_buffer = io.BytesIO()
103
  files_added_to_zip = 0
104
- image_counter = 1 # Initialiser le compteur pour le nommage séquentiel
 
 
 
105
 
106
  with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
107
  for file in files:
@@ -111,17 +130,28 @@ async def upload_images_for_captioning(
111
  print(f"Fichier ignoré (type non supporté: {file.content_type}): {file.filename}")
112
  continue
113
 
114
- # Obtenir l'extension du fichier original de manière sécurisée
115
- original_extension = get_safe_extension(file.filename)
116
- if not original_extension and file.content_type: # Déduction d'extension via MIME type si get_safe_extension échoue
117
- ext_map = {"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif", "image/webp": ".webp"}
118
- original_extension = ext_map.get(file.content_type, ".img") # .img comme fallback
119
 
120
- # Créer les noms de fichiers séquentiels pour le ZIP
121
- image_filename_in_zip = f"photo{image_counter}{original_extension}"
122
- caption_filename_in_zip = f"photo{image_counter}.txt"
123
 
124
- # Utiliser une extension temporaire basée sur l'original pour le fichier sur disque
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  temp_upload_filename = f"temp_{uuid.uuid4().hex}{original_extension}"
126
  temp_file_path = os.path.join(UPLOAD_DIR, temp_upload_filename)
127
 
@@ -130,26 +160,25 @@ async def upload_images_for_captioning(
130
 
131
  image_description = "Description non générée par défaut."
132
  if processing.is_active_model_loaded():
133
- print(f"Génération de description pour {temp_file_path} (sera {image_filename_in_zip} dans ZIP) avec le modèle {processing.ACTIVE_MODEL}")
134
  image_description = processing.generate_active_description(temp_file_path)
135
  else:
136
  print(f"ERREUR: Tentative de génération alors que le modèle {processing.ACTIVE_MODEL} n'est pas chargé.")
137
  image_description = f"ERREUR CRITIQUE: Le modèle IA ({processing.ACTIVE_MODEL}) n'est pas disponible."
138
 
139
- # Ajouter l'image au ZIP avec son nom séquentiel
140
- zf.write(temp_file_path, arcname=image_filename_in_zip)
141
 
142
- # Ajouter le fichier de description au ZIP avec son nom séquentiel
143
- zf.writestr(caption_filename_in_zip, image_description)
144
 
145
  files_added_to_zip += 1
146
- image_counter += 1 # Incrémenter pour le prochain fichier
147
 
148
- except HTTPException: # Laisser remonter les erreurs HTTP (ex: 503 du chargement modèle)
149
  raise
150
  except Exception as e:
151
  print(f"Erreur inattendue lors du traitement du fichier {file.filename}: {e}")
152
- processing.traceback.print_exc() # Afficher la trace complète pour les erreurs inattendues
153
  finally:
154
  if hasattr(file, 'file') and file.file and not file.file.closed:
155
  file.file.close()
 
4
  import secrets
5
  import zipfile
6
  import io
7
+ import re # Pour nettoyer les noms de fichiers
8
 
9
  from fastapi import Depends, FastAPI, File, UploadFile, Request, HTTPException, status
10
  from fastapi.responses import HTMLResponse, StreamingResponse
 
12
  from fastapi.templating import Jinja2Templates
13
  from fastapi.security import HTTPBasic, HTTPBasicCredentials
14
  from passlib.context import CryptContext
15
+ from typing import List, Tuple # Ajout de Tuple pour le type de retour
16
 
17
  from . import processing
18
  from . import utils
 
23
  pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
24
 
25
  APP_USERNAME_DEFAULT = "admin"
26
+ APP_PASSWORD_DEFAULT = "changezceci" # Changez ce mot de passe par défaut si vous testez localement
27
 
28
  APP_USERNAME = os.environ.get("APP_USERNAME", APP_USERNAME_DEFAULT)
29
  APP_PASSWORD_RAW = os.environ.get("APP_PASSWORD", APP_PASSWORD_DEFAULT)
 
66
  app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
67
  templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
68
 
69
+ def sanitize_basename_and_get_ext(filename: str) -> Tuple[str, str]:
70
+ """
71
+ Nettoie le nom de base d'un fichier et retourne (nom_base_nettoye, .extension_nette).
72
+ Exemple: "Mon Image Bizarre!!.JPEG" -> ("Mon_Image_Bizarre", ".jpeg")
73
+ """
74
+ name_part, ext_part = os.path.splitext(filename)
75
+
76
+ # Nettoyer le nom de base
77
+ # Garder les alphanumériques, espaces, tirets, underscores, points (non finaux pour le nom)
78
+ safe_name = re.sub(r'[^\w\s.-]', '', name_part).strip()
79
+ # Remplacer les espaces et séquences de points/tirets par un seul underscore
80
+ safe_name = re.sub(r'[\s._-]+', '_', safe_name)
81
+ # Enlever les underscores au début ou à la fin après remplacement
82
+ safe_name = safe_name.strip('_')
83
+
84
+ if not safe_name: # Si le nom est vide après nettoyage (ex: "!!.jpg")
85
+ safe_name = f"image_{uuid.uuid4().hex[:8]}" # Nom de fallback
86
+
87
+ # Nettoyer l'extension
88
+ safe_ext = re.sub(r'[^a-zA-Z0-9]', '', ext_part).lower()
89
+ if not safe_ext:
90
+ safe_ext = "img" # Extension par défaut si l'originale n'est pas valide/présente
91
+
92
+ return safe_name, f".{safe_ext}"
93
 
94
 
95
  @app.get("/", response_class=HTMLResponse)
 
117
 
118
  zip_buffer = io.BytesIO()
119
  files_added_to_zip = 0
120
+
121
+ # Pour gérer les noms de fichiers dupliqués (basés sur l'original) dans le ZIP
122
+ # Clé: "basename.ext", Valeur: compteur pour ce nom de fichier
123
+ filenames_in_zip_tracker = {}
124
 
125
  with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
126
  for file in files:
 
130
  print(f"Fichier ignoré (type non supporté: {file.content_type}): {file.filename}")
131
  continue
132
 
133
+ # Nettoyer le nom de base et obtenir l'extension à partir du nom de fichier original
134
+ original_base_name, original_extension = sanitize_basename_and_get_ext(file.filename)
 
 
 
135
 
136
+ # Gérer les noms dupliqués pour l'image
137
+ # Le nom de base pour la déduplication est le nom original nettoyé sans extension
138
+ dedup_key_base = original_base_name
139
 
140
+ image_arcname_final = f"{original_base_name}{original_extension}"
141
+ caption_arcname_final = f"{original_base_name}.txt"
142
+
143
+ # Vérifier si cette combinaison nom_base + extension image existe déjà
144
+ # Ou si nom_base + .txt existe (car ils partagent le même nom de base)
145
+ # On numérote le nom de base si nécessaire.
146
+ count = filenames_in_zip_tracker.get(dedup_key_base, 0)
147
+ if count > 0: # Si dedup_key_base a déjà été vu, on ajoute le compteur
148
+ image_arcname_final = f"{original_base_name}({count}){original_extension}"
149
+ caption_arcname_final = f"{original_base_name}({count}).txt"
150
+
151
+ filenames_in_zip_tracker[dedup_key_base] = count + 1
152
+
153
+
154
+ # Utiliser une extension temporaire basée sur l'original pour le fichier sur disque serveur
155
  temp_upload_filename = f"temp_{uuid.uuid4().hex}{original_extension}"
156
  temp_file_path = os.path.join(UPLOAD_DIR, temp_upload_filename)
157
 
 
160
 
161
  image_description = "Description non générée par défaut."
162
  if processing.is_active_model_loaded():
163
+ print(f"Génération de description pour {temp_file_path} (sera {image_arcname_final} dans ZIP) avec le modèle {processing.ACTIVE_MODEL}")
164
  image_description = processing.generate_active_description(temp_file_path)
165
  else:
166
  print(f"ERREUR: Tentative de génération alors que le modèle {processing.ACTIVE_MODEL} n'est pas chargé.")
167
  image_description = f"ERREUR CRITIQUE: Le modèle IA ({processing.ACTIVE_MODEL}) n'est pas disponible."
168
 
169
+ # Ajouter l'image au ZIP avec son nom (original nettoyé, potentiellement dédupliqué)
170
+ zf.write(temp_file_path, arcname=image_arcname_final)
171
 
172
+ # Ajouter le fichier de description au ZIP (même nom de base, potentiellement dédupliqué, extension .txt)
173
+ zf.writestr(caption_arcname_final, image_description)
174
 
175
  files_added_to_zip += 1
 
176
 
177
+ except HTTPException:
178
  raise
179
  except Exception as e:
180
  print(f"Erreur inattendue lors du traitement du fichier {file.filename}: {e}")
181
+ processing.traceback.print_exc()
182
  finally:
183
  if hasattr(file, 'file') and file.file and not file.file.closed:
184
  file.file.close()
app/processing.py CHANGED
@@ -8,10 +8,7 @@ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
8
 
9
  # --- Configuration du modèle LLaVA-NeXT ---
10
  LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
11
- # La variable LLAVA_REVISION est définie ici au cas nous voudrions épingler une version spécifique plus tard,
12
- # une fois que nous aurons confirmé que tout fonctionne bien avec la version 'main'.
13
- # Pour l'instant, elle n'est PAS utilisée dans les appels from_pretrained().
14
- LLAVA_REVISION = "082142fd2997099498027732cf8e945044bf48c3" # Exemple de hash, non utilisé ci-dessous
15
 
16
  llava_processor = None
17
  llava_model = None
@@ -20,7 +17,7 @@ llava_model_loaded = False
20
  # Détection du device (CPU, CUDA, ou MPS pour Mac Apple Silicon)
21
  if torch.cuda.is_available():
22
  device = "cuda"
23
- elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
24
  device = "mps"
25
  else:
26
  device = "cpu"
@@ -31,42 +28,43 @@ print(f"Utilisation du device : {device} pour les modèles d'IA.")
31
  def load_llava_model():
32
  global llava_processor, llava_model, llava_model_loaded, device
33
  if llava_model_loaded:
34
- print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main') déjà chargé.")
35
  return
36
 
37
  try:
38
- print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main')...")
39
  llava_processor = LlavaNextProcessor.from_pretrained(
40
  LLAVA_MODEL_NAME
41
- # Pas de 'revision' ici, on charge depuis la branche 'main'
42
  )
43
  print("Processor LLaVA chargé.")
44
 
45
- print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}, depuis branche 'main') sur '{device}'...")
46
  model_args = {
47
- # Pas de 'revision' ici non plus pour le moment
48
- "low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU lors du chargement initial
49
  }
50
 
51
- if device == "cpu":
 
 
 
 
 
 
 
 
 
52
  # Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité
53
  print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
54
- elif device == "cuda":
55
- model_args["torch_dtype"] = torch.float16 # ou torch.bfloat16 si GPU récent (Ampere+)
56
- print(f"Configuration de LLaVA pour CUDA ({model_args.get('torch_dtype', 'par défaut')}).")
57
- elif device == "mps":
58
- # Pour MPS, float16 peut offrir des gains de vitesse. float32 est plus sûr pour commencer.
59
- # Laisser float32 par défaut (pas de torch_dtype) est une option, ou essayer float16.
60
- model_args["torch_dtype"] = torch.float16 # Essayons float16 pour MPS
61
- print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut')}).")
62
 
63
  llava_model = LlavaNextForConditionalGeneration.from_pretrained(
64
  LLAVA_MODEL_NAME,
65
  **model_args
66
- ).to(device).eval()
67
 
68
  llava_model_loaded = True
69
- print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}, tous deux depuis branche 'main') chargés avec succès sur '{device}'.")
70
 
71
  except Exception as e:
72
  print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}")
@@ -78,96 +76,109 @@ def generate_description_llava(image_path: str) -> str:
78
  global llava_processor, llava_model, llava_model_loaded, device
79
 
80
  if not llava_model_loaded:
81
- print("Modèle LLaVA non chargé dans generate_description_llava. Tentative de chargement...")
82
  load_llava_model()
83
  if not llava_model_loaded:
84
- return "Erreur: Le modèle LLaVA n'a pas pu être chargé (échec lors de la tentative à la demande)."
85
-
86
  if not os.path.exists(image_path):
87
  return f"Erreur: Le fichier image {image_path} n'existe pas."
88
 
89
  try:
90
  image = Image.open(image_path).convert("RGB")
91
 
92
- # Prompt en anglais par défaut
93
- user_prompt = "Describe this image in English with precision and detail."
94
- # Pour du français :
95
- # user_prompt = "Décris cette image en français avec précision et de manière détaillée."
96
-
97
- # Format de prompt pour LLaVA v1.6
98
- prompt_text = f"<s>[INST] <image>\n{user_prompt} [/INST]"
 
 
 
 
99
 
100
- print(f"Préparation des entrées pour LLaVA avec le prompt: \"{user_prompt}\"")
101
- inputs_on_cpu = llava_processor(text=prompt_text, images=image, return_tensors="pt")
102
 
 
 
 
 
103
  inputs = {}
104
  for key, value in inputs_on_cpu.items():
105
  if torch.is_tensor(value):
106
  inputs[key] = value.to(device)
107
  else:
108
- inputs[key] = value # Conserver d'autres types si présents
109
 
110
- # S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS
111
  if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
112
  (llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
113
  for k_tensor, v_tensor in inputs.items():
114
- if torch.is_tensor(v_tensor) and torch.is_floating_point(v_tensor):
115
  inputs[k_tensor] = v_tensor.to(llava_model.dtype)
116
 
117
  input_dtypes_log = {k: v.dtype for k,v in inputs.items() if torch.is_tensor(v)}
118
  print(f"Génération de la description LLaVA pour {image_path} (device: {device}, input dtypes: {input_dtypes_log})...")
119
 
 
120
  generation_kwargs = {
121
- "max_new_tokens": 768,
122
- "num_beams": 3,
123
- "early_stopping": True
 
 
 
124
  }
125
 
126
  generated_ids = llava_model.generate(**inputs, **generation_kwargs)
127
 
128
- input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
129
- generated_ids_only = generated_ids[0, input_token_len:] # Extraire seulement les tokens générés
 
 
 
 
 
130
 
131
- cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
 
 
132
 
133
- # Nettoyage supplémentaire si le marqueur [/INST] est toujours présent (peu probable avec ce décodage)
134
- inst_marker_space = " [/INST]"
135
- inst_marker_no_space = "[/INST]"
136
- if cleaned_text.startswith(inst_marker_space):
137
- cleaned_text = cleaned_text[len(inst_marker_space):].strip()
138
- elif cleaned_text.startswith(inst_marker_no_space):
139
- cleaned_text = cleaned_text[len(inst_marker_no_space):].strip()
140
 
141
- print(f"Description (nettoyée) de LLaVA: {cleaned_text}")
142
- return cleaned_text if cleaned_text and cleaned_text.strip() else "Aucune description textuelle distincte n'a été générée par LLaVA."
143
 
144
  except Exception as e:
145
  print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
146
  traceback.print_exc()
147
  if torch.cuda.is_available() or device == "mps":
148
  if device == "cuda": torch.cuda.empty_cache()
149
- # if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache() # Pour PyTorch >= 1.13
150
  return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
151
 
152
 
153
  # --- Fonctions de gestion du modèle actif ---
154
- ACTIVE_MODEL = "llava"
155
 
156
  def load_active_model():
157
  print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
158
  if ACTIVE_MODEL == "llava":
159
  load_llava_model()
160
- # Ajoutez d'autres conditions ici si vous réactivez d'autres modèles
161
- # elif ACTIVE_MODEL == "florence":
162
- # load_florence_model()
163
  else:
164
  print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
165
 
166
  def generate_active_description(image_path: str) -> str:
167
  if ACTIVE_MODEL == "llava":
168
  return generate_description_llava(image_path)
169
- # elif ACTIVE_MODEL == "florence":
170
- # return generate_description_florence(image_path)
171
  else:
172
  error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
173
  print(error_msg)
@@ -176,31 +187,30 @@ def generate_active_description(image_path: str) -> str:
176
  def is_active_model_loaded() -> bool:
177
  if ACTIVE_MODEL == "llava":
178
  return llava_model_loaded
179
- # elif ACTIVE_MODEL == "florence":
180
- # return florence_model_loaded
181
  return False
182
 
183
  # --- Section de Test (pour exécution directe de ce fichier) ---
184
  if __name__ == '__main__':
185
  print("Début du test de processing.py...")
186
 
187
- # Créer une image de test factice si elle n'existe pas
188
- dummy_image_name = "dummy_test_image.png" # S'assure qu'elle est bien ignorée par .gitignore si elle est créée
189
  if not os.path.exists(dummy_image_name):
190
  try:
191
- # ImageDraw a été importé en haut avec PIL
192
  img = Image.new('RGB', (200, 150), color = 'skyblue')
193
  draw = ImageDraw.Draw(img)
194
- draw.text((10, 10), "Test Image", fill='black')
 
 
 
195
  img.save(dummy_image_name)
196
  print(f"Image de test '{dummy_image_name}' créée.")
197
  except Exception as e_img:
198
  print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}")
199
 
200
  if os.path.exists(dummy_image_name):
201
- print(f"Utilisation du modèle actif : {ACTIVE_MODEL}")
202
  print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...")
203
- load_active_model() # Tente de charger le modèle
204
  if is_active_model_loaded():
205
  print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
206
  description = generate_active_description(dummy_image_name)
 
8
 
9
  # --- Configuration du modèle LLaVA-NeXT ---
10
  LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
11
+ # LLAVA_REVISION = "main" # On utilise la branche principale par défaut
 
 
 
12
 
13
  llava_processor = None
14
  llava_model = None
 
17
  # Détection du device (CPU, CUDA, ou MPS pour Mac Apple Silicon)
18
  if torch.cuda.is_available():
19
  device = "cuda"
20
+ elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): # Pour Apple Silicon
21
  device = "mps"
22
  else:
23
  device = "cpu"
 
28
  def load_llava_model():
29
  global llava_processor, llava_model, llava_model_loaded, device
30
  if llava_model_loaded:
31
+ print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}) déjà chargé.")
32
  return
33
 
34
  try:
35
+ print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME})...")
36
  llava_processor = LlavaNextProcessor.from_pretrained(
37
  LLAVA_MODEL_NAME
38
+ # revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
39
  )
40
  print("Processor LLaVA chargé.")
41
 
42
+ print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}) sur '{device}'...")
43
  model_args = {
44
+ "low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU
45
+ # revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
46
  }
47
 
48
+ if device == "cuda":
49
+ model_args["torch_dtype"] = torch.float16 # Précision pour GPU NVIDIA
50
+ # Pour les GPU plus récents (Ampere+), torch.bfloat16 peut aussi être une option
51
+ print(f"Configuration de LLaVA pour CUDA ({model_args['torch_dtype']}).")
52
+ elif device == "mps":
53
+ # Pour MPS, float16 est souvent utilisé, mais float32 est plus sûr pour commencer
54
+ # Si des problèmes de stabilité surviennent avec float16 sur MPS, commenter la ligne suivante
55
+ model_args["torch_dtype"] = torch.float16
56
+ print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut float32')}).")
57
+ else: # CPU
58
  # Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité
59
  print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
 
 
 
 
 
 
 
 
60
 
61
  llava_model = LlavaNextForConditionalGeneration.from_pretrained(
62
  LLAVA_MODEL_NAME,
63
  **model_args
64
+ ).to(device).eval() # Mettre le modèle en mode évaluation
65
 
66
  llava_model_loaded = True
67
+ print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}) chargés avec succès sur '{device}'.")
68
 
69
  except Exception as e:
70
  print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}")
 
76
  global llava_processor, llava_model, llava_model_loaded, device
77
 
78
  if not llava_model_loaded:
79
+ print("Modèle LLaVA non chargé. Tentative de chargement à la demande...")
80
  load_llava_model()
81
  if not llava_model_loaded:
82
+ return "Erreur: Le modèle LLaVA n'a pas pu être chargé."
83
+
84
  if not os.path.exists(image_path):
85
  return f"Erreur: Le fichier image {image_path} n'existe pas."
86
 
87
  try:
88
  image = Image.open(image_path).convert("RGB")
89
 
90
+ # --- PROMPT AMÉLIORÉ ---
91
+ user_prompt_fr = (
92
+ "Analyse cette image en tant qu'œuvre d'art. "
93
+ "Fournis une description objective, factuelle et très détaillée en français. "
94
+ "Concentre-toi sur les éléments visuels : la scène globale, les sujets et personnages (y compris leur apparence, posture, expression, et toute forme de nudité si présente), "
95
+ "les objets, l'arrière-plan, les formes, les couleurs, la lumière, la composition et la technique artistique apparente. "
96
+ "Évite toute interprétation subjective ou jugement moral, et décris ce qui est visible de manière exhaustive."
97
+ )
98
+
99
+ # Format de prompt spécifique à LLaVA v1.5+ (Mistral utilise ce format)
100
+ prompt_template = f"<s>[INST] <image>\n{user_prompt_fr} [/INST]"
101
 
102
+ print(f"Préparation des entrées pour LLaVA avec le prompt utilisateur (condensé): \"{user_prompt_fr[:100]}...\"")
 
103
 
104
+ # Le processeur gère la tokenisation du texte et le prétraitement de l'image
105
+ inputs_on_cpu = llava_processor(text=prompt_template, images=image, return_tensors="pt")
106
+
107
+ # Déplacer les tenseurs sur le bon device
108
  inputs = {}
109
  for key, value in inputs_on_cpu.items():
110
  if torch.is_tensor(value):
111
  inputs[key] = value.to(device)
112
  else:
113
+ inputs[key] = value
114
 
115
+ # S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS (si float16/bfloat16)
116
  if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
117
  (llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
118
  for k_tensor, v_tensor in inputs.items():
119
+ if torch.is_tensor(v_tensor) and torch.is_floating_point(v_tensor): # Appliquer seulement aux tenseurs flottants
120
  inputs[k_tensor] = v_tensor.to(llava_model.dtype)
121
 
122
  input_dtypes_log = {k: v.dtype for k,v in inputs.items() if torch.is_tensor(v)}
123
  print(f"Génération de la description LLaVA pour {image_path} (device: {device}, input dtypes: {input_dtypes_log})...")
124
 
125
+ # Paramètres de génération (ajustables si nécessaire)
126
  generation_kwargs = {
127
+ "max_new_tokens": 768, # Augmenté légèrement pour des descriptions potentiellement plus longues
128
+ "num_beams": 3, # Un peu de beam search peut améliorer la cohérence
129
+ "early_stopping": True,
130
+ "do_sample": False # Pour des descriptions plus factuelles et moins "créatives" aléatoirement. Mettre True avec temperature si on veut plus de variété.
131
+ # "temperature": 0.7, # À utiliser avec do_sample=True si on veut de la créativité
132
+ # "top_p": 0.9, # À utiliser avec do_sample=True
133
  }
134
 
135
  generated_ids = llava_model.generate(**inputs, **generation_kwargs)
136
 
137
+ # Pour LLaVA, il est important de décoder uniquement les tokens générés *après* le prompt.
138
+ # Certains processeurs/modèles gèrent cela différemment.
139
+ # Pour LLaVA-NeXT, le décodage de la séquence complète et le nettoyage du prompt est une approche courante.
140
+ # Ou, si l'on connaît la longueur des tokens d'entrée :
141
+ # input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
142
+ # generated_ids_only = generated_ids[0, input_token_len:]
143
+ # cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
144
 
145
+ # Approche plus simple : décoder toute la séquence et enlever manuellement le prompt si besoin.
146
+ # Souvent, pour les formats [INST]...[/INST], le modèle génère directement la réponse.
147
+ full_text = llava_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
148
 
149
+ # Nettoyage du texte généré pour enlever le prompt s'il est répété (spécifique au format de sortie du modèle)
150
+ # Le format "[INST] <image> \n {prompt} [/INST] {réponse}" fait que la réponse est souvent propre.
151
+ inst_marker = "[/INST]"
152
+ if inst_marker in full_text:
153
+ cleaned_text = full_text.split(inst_marker, 1)[-1].strip()
154
+ else:
155
+ cleaned_text = full_text # Si le marqueur n'est pas là, prendre tout (peut arriver)
156
 
157
+ print(f"Description (nettoyée) de LLaVA: {cleaned_text[:200]}...") # Log tronqué
158
+ return cleaned_text if cleaned_text else "Aucune description textuelle distincte n'a été générée par LLaVA."
159
 
160
  except Exception as e:
161
  print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
162
  traceback.print_exc()
163
  if torch.cuda.is_available() or device == "mps":
164
  if device == "cuda": torch.cuda.empty_cache()
165
+ # if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache()
166
  return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
167
 
168
 
169
  # --- Fonctions de gestion du modèle actif ---
170
+ ACTIVE_MODEL = "llava" # Pour l'instant, seul LLaVA est configuré
171
 
172
  def load_active_model():
173
  print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
174
  if ACTIVE_MODEL == "llava":
175
  load_llava_model()
 
 
 
176
  else:
177
  print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
178
 
179
  def generate_active_description(image_path: str) -> str:
180
  if ACTIVE_MODEL == "llava":
181
  return generate_description_llava(image_path)
 
 
182
  else:
183
  error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
184
  print(error_msg)
 
187
  def is_active_model_loaded() -> bool:
188
  if ACTIVE_MODEL == "llava":
189
  return llava_model_loaded
 
 
190
  return False
191
 
192
  # --- Section de Test (pour exécution directe de ce fichier) ---
193
  if __name__ == '__main__':
194
  print("Début du test de processing.py...")
195
 
196
+ dummy_image_name = "dummy_test_image.png"
 
197
  if not os.path.exists(dummy_image_name):
198
  try:
 
199
  img = Image.new('RGB', (200, 150), color = 'skyblue')
200
  draw = ImageDraw.Draw(img)
201
+ draw.text((10, 10), "Test Image for LLaVA", fill='black')
202
+ # Ajouter quelques formes pour le test
203
+ draw.ellipse((30, 50, 90, 110), fill='red', outline='black')
204
+ draw.rectangle((100, 40, 170, 120), fill='lightgreen', outline='blue')
205
  img.save(dummy_image_name)
206
  print(f"Image de test '{dummy_image_name}' créée.")
207
  except Exception as e_img:
208
  print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}")
209
 
210
  if os.path.exists(dummy_image_name):
211
+ print(f"Utilisation du modèle actif : {ACTIVE_MODEL} sur device {device}")
212
  print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...")
213
+ load_active_model()
214
  if is_active_model_loaded():
215
  print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
216
  description = generate_active_description(dummy_image_name)
requirements.txt CHANGED
@@ -4,8 +4,8 @@ python-multipart
4
  jinja2
5
  torch
6
  torchvision
7
- # torchaudio # Si vous ne l'utilisez pas activement, il peut être omis
8
- transformers
9
  Pillow
10
  accelerate
11
  einops
@@ -15,6 +15,7 @@ tiktoken
15
  # Pour l'authentification Basic Auth
16
  python-jose[cryptography]>=3.3.0
17
  passlib[bcrypt]>=1.7.4
18
- bcrypt>=3.2.0,<4.1.0 # Force une plage compatible pour bcrypt
19
 
20
- # bitsandbytes # Optionnel, décommentez si vous en avez besoin pour la quantification
 
 
4
  jinja2
5
  torch
6
  torchvision
7
+ # torchaudio # Optionnel
8
+ transformers>=4.38.0 # Assurer une version récente pour LLaVA-NeXT et autres modèles récents
9
  Pillow
10
  accelerate
11
  einops
 
15
  # Pour l'authentification Basic Auth
16
  python-jose[cryptography]>=3.3.0
17
  passlib[bcrypt]>=1.7.4
18
+ bcrypt>=3.2.0,<4.1.0
19
 
20
+ # bitsandbytes # Optionnel pour la quantification
21
+ sentencepiece # Souvent requis par les tokenizers de transformers