Spaces:
Sleeping
Sleeping
Ludovic commited on
Commit ·
bb246b9
1
Parent(s): 1a6e3f8
V5
Browse files- app/main.py +58 -29
- app/processing.py +77 -67
- requirements.txt +5 -4
app/main.py
CHANGED
|
@@ -4,7 +4,7 @@ import uuid
|
|
| 4 |
import secrets
|
| 5 |
import zipfile
|
| 6 |
import io
|
| 7 |
-
import re #
|
| 8 |
|
| 9 |
from fastapi import Depends, FastAPI, File, UploadFile, Request, HTTPException, status
|
| 10 |
from fastapi.responses import HTMLResponse, StreamingResponse
|
|
@@ -12,7 +12,7 @@ from fastapi.staticfiles import StaticFiles
|
|
| 12 |
from fastapi.templating import Jinja2Templates
|
| 13 |
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
| 14 |
from passlib.context import CryptContext
|
| 15 |
-
from typing import List
|
| 16 |
|
| 17 |
from . import processing
|
| 18 |
from . import utils
|
|
@@ -23,7 +23,7 @@ security = HTTPBasic()
|
|
| 23 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 24 |
|
| 25 |
APP_USERNAME_DEFAULT = "admin"
|
| 26 |
-
APP_PASSWORD_DEFAULT = "changezceci"
|
| 27 |
|
| 28 |
APP_USERNAME = os.environ.get("APP_USERNAME", APP_USERNAME_DEFAULT)
|
| 29 |
APP_PASSWORD_RAW = os.environ.get("APP_PASSWORD", APP_PASSWORD_DEFAULT)
|
|
@@ -66,14 +66,30 @@ os.makedirs(OUTPUT_CAPTION_DIR, exist_ok=True)
|
|
| 66 |
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
|
| 67 |
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
|
| 68 |
|
| 69 |
-
def
|
| 70 |
-
"""
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
@app.get("/", response_class=HTMLResponse)
|
|
@@ -101,7 +117,10 @@ async def upload_images_for_captioning(
|
|
| 101 |
|
| 102 |
zip_buffer = io.BytesIO()
|
| 103 |
files_added_to_zip = 0
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 107 |
for file in files:
|
|
@@ -111,17 +130,28 @@ async def upload_images_for_captioning(
|
|
| 111 |
print(f"Fichier ignoré (type non supporté: {file.content_type}): {file.filename}")
|
| 112 |
continue
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
original_extension =
|
| 116 |
-
if not original_extension and file.content_type: # Déduction d'extension via MIME type si get_safe_extension échoue
|
| 117 |
-
ext_map = {"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif", "image/webp": ".webp"}
|
| 118 |
-
original_extension = ext_map.get(file.content_type, ".img") # .img comme fallback
|
| 119 |
|
| 120 |
-
#
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
temp_upload_filename = f"temp_{uuid.uuid4().hex}{original_extension}"
|
| 126 |
temp_file_path = os.path.join(UPLOAD_DIR, temp_upload_filename)
|
| 127 |
|
|
@@ -130,26 +160,25 @@ async def upload_images_for_captioning(
|
|
| 130 |
|
| 131 |
image_description = "Description non générée par défaut."
|
| 132 |
if processing.is_active_model_loaded():
|
| 133 |
-
print(f"Génération de description pour {temp_file_path} (sera {
|
| 134 |
image_description = processing.generate_active_description(temp_file_path)
|
| 135 |
else:
|
| 136 |
print(f"ERREUR: Tentative de génération alors que le modèle {processing.ACTIVE_MODEL} n'est pas chargé.")
|
| 137 |
image_description = f"ERREUR CRITIQUE: Le modèle IA ({processing.ACTIVE_MODEL}) n'est pas disponible."
|
| 138 |
|
| 139 |
-
# Ajouter l'image au ZIP avec son nom
|
| 140 |
-
zf.write(temp_file_path, arcname=
|
| 141 |
|
| 142 |
-
# Ajouter le fichier de description au ZIP
|
| 143 |
-
zf.writestr(
|
| 144 |
|
| 145 |
files_added_to_zip += 1
|
| 146 |
-
image_counter += 1 # Incrémenter pour le prochain fichier
|
| 147 |
|
| 148 |
-
except HTTPException:
|
| 149 |
raise
|
| 150 |
except Exception as e:
|
| 151 |
print(f"Erreur inattendue lors du traitement du fichier {file.filename}: {e}")
|
| 152 |
-
processing.traceback.print_exc()
|
| 153 |
finally:
|
| 154 |
if hasattr(file, 'file') and file.file and not file.file.closed:
|
| 155 |
file.file.close()
|
|
|
|
| 4 |
import secrets
|
| 5 |
import zipfile
|
| 6 |
import io
|
| 7 |
+
import re # Pour nettoyer les noms de fichiers
|
| 8 |
|
| 9 |
from fastapi import Depends, FastAPI, File, UploadFile, Request, HTTPException, status
|
| 10 |
from fastapi.responses import HTMLResponse, StreamingResponse
|
|
|
|
| 12 |
from fastapi.templating import Jinja2Templates
|
| 13 |
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
| 14 |
from passlib.context import CryptContext
|
| 15 |
+
from typing import List, Tuple # Ajout de Tuple pour le type de retour
|
| 16 |
|
| 17 |
from . import processing
|
| 18 |
from . import utils
|
|
|
|
| 23 |
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
| 24 |
|
| 25 |
APP_USERNAME_DEFAULT = "admin"
|
| 26 |
+
APP_PASSWORD_DEFAULT = "changezceci" # Changez ce mot de passe par défaut si vous testez localement
|
| 27 |
|
| 28 |
APP_USERNAME = os.environ.get("APP_USERNAME", APP_USERNAME_DEFAULT)
|
| 29 |
APP_PASSWORD_RAW = os.environ.get("APP_PASSWORD", APP_PASSWORD_DEFAULT)
|
|
|
|
| 66 |
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
|
| 67 |
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
|
| 68 |
|
| 69 |
+
def sanitize_basename_and_get_ext(filename: str) -> Tuple[str, str]:
|
| 70 |
+
"""
|
| 71 |
+
Nettoie le nom de base d'un fichier et retourne (nom_base_nettoye, .extension_nette).
|
| 72 |
+
Exemple: "Mon Image Bizarre!!.JPEG" -> ("Mon_Image_Bizarre", ".jpeg")
|
| 73 |
+
"""
|
| 74 |
+
name_part, ext_part = os.path.splitext(filename)
|
| 75 |
+
|
| 76 |
+
# Nettoyer le nom de base
|
| 77 |
+
# Garder les alphanumériques, espaces, tirets, underscores, points (non finaux pour le nom)
|
| 78 |
+
safe_name = re.sub(r'[^\w\s.-]', '', name_part).strip()
|
| 79 |
+
# Remplacer les espaces et séquences de points/tirets par un seul underscore
|
| 80 |
+
safe_name = re.sub(r'[\s._-]+', '_', safe_name)
|
| 81 |
+
# Enlever les underscores au début ou à la fin après remplacement
|
| 82 |
+
safe_name = safe_name.strip('_')
|
| 83 |
+
|
| 84 |
+
if not safe_name: # Si le nom est vide après nettoyage (ex: "!!.jpg")
|
| 85 |
+
safe_name = f"image_{uuid.uuid4().hex[:8]}" # Nom de fallback
|
| 86 |
+
|
| 87 |
+
# Nettoyer l'extension
|
| 88 |
+
safe_ext = re.sub(r'[^a-zA-Z0-9]', '', ext_part).lower()
|
| 89 |
+
if not safe_ext:
|
| 90 |
+
safe_ext = "img" # Extension par défaut si l'originale n'est pas valide/présente
|
| 91 |
+
|
| 92 |
+
return safe_name, f".{safe_ext}"
|
| 93 |
|
| 94 |
|
| 95 |
@app.get("/", response_class=HTMLResponse)
|
|
|
|
| 117 |
|
| 118 |
zip_buffer = io.BytesIO()
|
| 119 |
files_added_to_zip = 0
|
| 120 |
+
|
| 121 |
+
# Pour gérer les noms de fichiers dupliqués (basés sur l'original) dans le ZIP
|
| 122 |
+
# Clé: "basename.ext", Valeur: compteur pour ce nom de fichier
|
| 123 |
+
filenames_in_zip_tracker = {}
|
| 124 |
|
| 125 |
with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 126 |
for file in files:
|
|
|
|
| 130 |
print(f"Fichier ignoré (type non supporté: {file.content_type}): {file.filename}")
|
| 131 |
continue
|
| 132 |
|
| 133 |
+
# Nettoyer le nom de base et obtenir l'extension à partir du nom de fichier original
|
| 134 |
+
original_base_name, original_extension = sanitize_basename_and_get_ext(file.filename)
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
# Gérer les noms dupliqués pour l'image
|
| 137 |
+
# Le nom de base pour la déduplication est le nom original nettoyé sans extension
|
| 138 |
+
dedup_key_base = original_base_name
|
| 139 |
|
| 140 |
+
image_arcname_final = f"{original_base_name}{original_extension}"
|
| 141 |
+
caption_arcname_final = f"{original_base_name}.txt"
|
| 142 |
+
|
| 143 |
+
# Vérifier si cette combinaison nom_base + extension image existe déjà
|
| 144 |
+
# Ou si nom_base + .txt existe (car ils partagent le même nom de base)
|
| 145 |
+
# On numérote le nom de base si nécessaire.
|
| 146 |
+
count = filenames_in_zip_tracker.get(dedup_key_base, 0)
|
| 147 |
+
if count > 0: # Si dedup_key_base a déjà été vu, on ajoute le compteur
|
| 148 |
+
image_arcname_final = f"{original_base_name}({count}){original_extension}"
|
| 149 |
+
caption_arcname_final = f"{original_base_name}({count}).txt"
|
| 150 |
+
|
| 151 |
+
filenames_in_zip_tracker[dedup_key_base] = count + 1
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# Utiliser une extension temporaire basée sur l'original pour le fichier sur disque serveur
|
| 155 |
temp_upload_filename = f"temp_{uuid.uuid4().hex}{original_extension}"
|
| 156 |
temp_file_path = os.path.join(UPLOAD_DIR, temp_upload_filename)
|
| 157 |
|
|
|
|
| 160 |
|
| 161 |
image_description = "Description non générée par défaut."
|
| 162 |
if processing.is_active_model_loaded():
|
| 163 |
+
print(f"Génération de description pour {temp_file_path} (sera {image_arcname_final} dans ZIP) avec le modèle {processing.ACTIVE_MODEL}")
|
| 164 |
image_description = processing.generate_active_description(temp_file_path)
|
| 165 |
else:
|
| 166 |
print(f"ERREUR: Tentative de génération alors que le modèle {processing.ACTIVE_MODEL} n'est pas chargé.")
|
| 167 |
image_description = f"ERREUR CRITIQUE: Le modèle IA ({processing.ACTIVE_MODEL}) n'est pas disponible."
|
| 168 |
|
| 169 |
+
# Ajouter l'image au ZIP avec son nom (original nettoyé, potentiellement dédupliqué)
|
| 170 |
+
zf.write(temp_file_path, arcname=image_arcname_final)
|
| 171 |
|
| 172 |
+
# Ajouter le fichier de description au ZIP (même nom de base, potentiellement dédupliqué, extension .txt)
|
| 173 |
+
zf.writestr(caption_arcname_final, image_description)
|
| 174 |
|
| 175 |
files_added_to_zip += 1
|
|
|
|
| 176 |
|
| 177 |
+
except HTTPException:
|
| 178 |
raise
|
| 179 |
except Exception as e:
|
| 180 |
print(f"Erreur inattendue lors du traitement du fichier {file.filename}: {e}")
|
| 181 |
+
processing.traceback.print_exc()
|
| 182 |
finally:
|
| 183 |
if hasattr(file, 'file') and file.file and not file.file.closed:
|
| 184 |
file.file.close()
|
app/processing.py
CHANGED
|
@@ -8,10 +8,7 @@ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
|
|
| 8 |
|
| 9 |
# --- Configuration du modèle LLaVA-NeXT ---
|
| 10 |
LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
| 11 |
-
#
|
| 12 |
-
# une fois que nous aurons confirmé que tout fonctionne bien avec la version 'main'.
|
| 13 |
-
# Pour l'instant, elle n'est PAS utilisée dans les appels from_pretrained().
|
| 14 |
-
LLAVA_REVISION = "082142fd2997099498027732cf8e945044bf48c3" # Exemple de hash, non utilisé ci-dessous
|
| 15 |
|
| 16 |
llava_processor = None
|
| 17 |
llava_model = None
|
|
@@ -20,7 +17,7 @@ llava_model_loaded = False
|
|
| 20 |
# Détection du device (CPU, CUDA, ou MPS pour Mac Apple Silicon)
|
| 21 |
if torch.cuda.is_available():
|
| 22 |
device = "cuda"
|
| 23 |
-
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 24 |
device = "mps"
|
| 25 |
else:
|
| 26 |
device = "cpu"
|
|
@@ -31,42 +28,43 @@ print(f"Utilisation du device : {device} pour les modèles d'IA.")
|
|
| 31 |
def load_llava_model():
|
| 32 |
global llava_processor, llava_model, llava_model_loaded, device
|
| 33 |
if llava_model_loaded:
|
| 34 |
-
print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}
|
| 35 |
return
|
| 36 |
|
| 37 |
try:
|
| 38 |
-
print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME}
|
| 39 |
llava_processor = LlavaNextProcessor.from_pretrained(
|
| 40 |
LLAVA_MODEL_NAME
|
| 41 |
-
#
|
| 42 |
)
|
| 43 |
print("Processor LLaVA chargé.")
|
| 44 |
|
| 45 |
-
print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}
|
| 46 |
model_args = {
|
| 47 |
-
|
| 48 |
-
|
| 49 |
}
|
| 50 |
|
| 51 |
-
if device == "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité
|
| 53 |
print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
|
| 54 |
-
elif device == "cuda":
|
| 55 |
-
model_args["torch_dtype"] = torch.float16 # ou torch.bfloat16 si GPU récent (Ampere+)
|
| 56 |
-
print(f"Configuration de LLaVA pour CUDA ({model_args.get('torch_dtype', 'par défaut')}).")
|
| 57 |
-
elif device == "mps":
|
| 58 |
-
# Pour MPS, float16 peut offrir des gains de vitesse. float32 est plus sûr pour commencer.
|
| 59 |
-
# Laisser float32 par défaut (pas de torch_dtype) est une option, ou essayer float16.
|
| 60 |
-
model_args["torch_dtype"] = torch.float16 # Essayons float16 pour MPS
|
| 61 |
-
print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut')}).")
|
| 62 |
|
| 63 |
llava_model = LlavaNextForConditionalGeneration.from_pretrained(
|
| 64 |
LLAVA_MODEL_NAME,
|
| 65 |
**model_args
|
| 66 |
-
).to(device).eval()
|
| 67 |
|
| 68 |
llava_model_loaded = True
|
| 69 |
-
print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}
|
| 70 |
|
| 71 |
except Exception as e:
|
| 72 |
print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}")
|
|
@@ -78,96 +76,109 @@ def generate_description_llava(image_path: str) -> str:
|
|
| 78 |
global llava_processor, llava_model, llava_model_loaded, device
|
| 79 |
|
| 80 |
if not llava_model_loaded:
|
| 81 |
-
print("Modèle LLaVA non chargé
|
| 82 |
load_llava_model()
|
| 83 |
if not llava_model_loaded:
|
| 84 |
-
return "Erreur: Le modèle LLaVA n'a pas pu être chargé
|
| 85 |
-
|
| 86 |
if not os.path.exists(image_path):
|
| 87 |
return f"Erreur: Le fichier image {image_path} n'existe pas."
|
| 88 |
|
| 89 |
try:
|
| 90 |
image = Image.open(image_path).convert("RGB")
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
print(f"Préparation des entrées pour LLaVA avec le prompt: \"{
|
| 101 |
-
inputs_on_cpu = llava_processor(text=prompt_text, images=image, return_tensors="pt")
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
inputs = {}
|
| 104 |
for key, value in inputs_on_cpu.items():
|
| 105 |
if torch.is_tensor(value):
|
| 106 |
inputs[key] = value.to(device)
|
| 107 |
else:
|
| 108 |
-
inputs[key] = value
|
| 109 |
|
| 110 |
-
# S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS
|
| 111 |
if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
|
| 112 |
(llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
|
| 113 |
for k_tensor, v_tensor in inputs.items():
|
| 114 |
-
if torch.is_tensor(v_tensor) and torch.is_floating_point(v_tensor):
|
| 115 |
inputs[k_tensor] = v_tensor.to(llava_model.dtype)
|
| 116 |
|
| 117 |
input_dtypes_log = {k: v.dtype for k,v in inputs.items() if torch.is_tensor(v)}
|
| 118 |
print(f"Génération de la description LLaVA pour {image_path} (device: {device}, input dtypes: {input_dtypes_log})...")
|
| 119 |
|
|
|
|
| 120 |
generation_kwargs = {
|
| 121 |
-
"max_new_tokens": 768,
|
| 122 |
-
"num_beams": 3,
|
| 123 |
-
"early_stopping": True
|
|
|
|
|
|
|
|
|
|
| 124 |
}
|
| 125 |
|
| 126 |
generated_ids = llava_model.generate(**inputs, **generation_kwargs)
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
# Nettoyage
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
if
|
| 137 |
-
cleaned_text =
|
| 138 |
-
|
| 139 |
-
|
| 140 |
|
| 141 |
-
print(f"Description (nettoyée) de LLaVA: {cleaned_text}")
|
| 142 |
-
return cleaned_text if cleaned_text
|
| 143 |
|
| 144 |
except Exception as e:
|
| 145 |
print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
|
| 146 |
traceback.print_exc()
|
| 147 |
if torch.cuda.is_available() or device == "mps":
|
| 148 |
if device == "cuda": torch.cuda.empty_cache()
|
| 149 |
-
# if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache()
|
| 150 |
return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
|
| 151 |
|
| 152 |
|
| 153 |
# --- Fonctions de gestion du modèle actif ---
|
| 154 |
-
ACTIVE_MODEL = "llava"
|
| 155 |
|
| 156 |
def load_active_model():
|
| 157 |
print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
|
| 158 |
if ACTIVE_MODEL == "llava":
|
| 159 |
load_llava_model()
|
| 160 |
-
# Ajoutez d'autres conditions ici si vous réactivez d'autres modèles
|
| 161 |
-
# elif ACTIVE_MODEL == "florence":
|
| 162 |
-
# load_florence_model()
|
| 163 |
else:
|
| 164 |
print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
|
| 165 |
|
| 166 |
def generate_active_description(image_path: str) -> str:
|
| 167 |
if ACTIVE_MODEL == "llava":
|
| 168 |
return generate_description_llava(image_path)
|
| 169 |
-
# elif ACTIVE_MODEL == "florence":
|
| 170 |
-
# return generate_description_florence(image_path)
|
| 171 |
else:
|
| 172 |
error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
|
| 173 |
print(error_msg)
|
|
@@ -176,31 +187,30 @@ def generate_active_description(image_path: str) -> str:
|
|
| 176 |
def is_active_model_loaded() -> bool:
|
| 177 |
if ACTIVE_MODEL == "llava":
|
| 178 |
return llava_model_loaded
|
| 179 |
-
# elif ACTIVE_MODEL == "florence":
|
| 180 |
-
# return florence_model_loaded
|
| 181 |
return False
|
| 182 |
|
| 183 |
# --- Section de Test (pour exécution directe de ce fichier) ---
|
| 184 |
if __name__ == '__main__':
|
| 185 |
print("Début du test de processing.py...")
|
| 186 |
|
| 187 |
-
|
| 188 |
-
dummy_image_name = "dummy_test_image.png" # S'assure qu'elle est bien ignorée par .gitignore si elle est créée
|
| 189 |
if not os.path.exists(dummy_image_name):
|
| 190 |
try:
|
| 191 |
-
# ImageDraw a été importé en haut avec PIL
|
| 192 |
img = Image.new('RGB', (200, 150), color = 'skyblue')
|
| 193 |
draw = ImageDraw.Draw(img)
|
| 194 |
-
draw.text((10, 10), "Test Image", fill='black')
|
|
|
|
|
|
|
|
|
|
| 195 |
img.save(dummy_image_name)
|
| 196 |
print(f"Image de test '{dummy_image_name}' créée.")
|
| 197 |
except Exception as e_img:
|
| 198 |
print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}")
|
| 199 |
|
| 200 |
if os.path.exists(dummy_image_name):
|
| 201 |
-
print(f"Utilisation du modèle actif : {ACTIVE_MODEL}")
|
| 202 |
print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...")
|
| 203 |
-
load_active_model()
|
| 204 |
if is_active_model_loaded():
|
| 205 |
print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
|
| 206 |
description = generate_active_description(dummy_image_name)
|
|
|
|
| 8 |
|
| 9 |
# --- Configuration du modèle LLaVA-NeXT ---
|
| 10 |
LLAVA_MODEL_NAME = "llava-hf/llava-v1.6-mistral-7b-hf"
|
| 11 |
+
# LLAVA_REVISION = "main" # On utilise la branche principale par défaut
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
llava_processor = None
|
| 14 |
llava_model = None
|
|
|
|
| 17 |
# Détection du device (CPU, CUDA, ou MPS pour Mac Apple Silicon)
|
| 18 |
if torch.cuda.is_available():
|
| 19 |
device = "cuda"
|
| 20 |
+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): # Pour Apple Silicon
|
| 21 |
device = "mps"
|
| 22 |
else:
|
| 23 |
device = "cpu"
|
|
|
|
| 28 |
def load_llava_model():
|
| 29 |
global llava_processor, llava_model, llava_model_loaded, device
|
| 30 |
if llava_model_loaded:
|
| 31 |
+
print(f"Modèle LLaVA ({LLAVA_MODEL_NAME}) déjà chargé.")
|
| 32 |
return
|
| 33 |
|
| 34 |
try:
|
| 35 |
+
print(f"Chargement du processor pour LLaVA ({LLAVA_MODEL_NAME})...")
|
| 36 |
llava_processor = LlavaNextProcessor.from_pretrained(
|
| 37 |
LLAVA_MODEL_NAME
|
| 38 |
+
# revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
|
| 39 |
)
|
| 40 |
print("Processor LLaVA chargé.")
|
| 41 |
|
| 42 |
+
print(f"Chargement du modèle LLaVA ({LLAVA_MODEL_NAME}) sur '{device}'...")
|
| 43 |
model_args = {
|
| 44 |
+
"low_cpu_mem_usage": True, # Utile pour réduire l'utilisation de la RAM CPU
|
| 45 |
+
# revision=LLAVA_REVISION # Décommenter si vous voulez épingler une révision
|
| 46 |
}
|
| 47 |
|
| 48 |
+
if device == "cuda":
|
| 49 |
+
model_args["torch_dtype"] = torch.float16 # Précision pour GPU NVIDIA
|
| 50 |
+
# Pour les GPU plus récents (Ampere+), torch.bfloat16 peut aussi être une option
|
| 51 |
+
print(f"Configuration de LLaVA pour CUDA ({model_args['torch_dtype']}).")
|
| 52 |
+
elif device == "mps":
|
| 53 |
+
# Pour MPS, float16 est souvent utilisé, mais float32 est plus sûr pour commencer
|
| 54 |
+
# Si des problèmes de stabilité surviennent avec float16 sur MPS, commenter la ligne suivante
|
| 55 |
+
model_args["torch_dtype"] = torch.float16
|
| 56 |
+
print(f"Configuration de LLaVA pour MPS ({model_args.get('torch_dtype', 'par défaut float32')}).")
|
| 57 |
+
else: # CPU
|
| 58 |
# Pas de torch_dtype spécifique, PyTorch utilisera float32 par défaut pour plus de stabilité
|
| 59 |
print(f"Configuration de LLaVA pour CPU (float32 par défaut).")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
llava_model = LlavaNextForConditionalGeneration.from_pretrained(
|
| 62 |
LLAVA_MODEL_NAME,
|
| 63 |
**model_args
|
| 64 |
+
).to(device).eval() # Mettre le modèle en mode évaluation
|
| 65 |
|
| 66 |
llava_model_loaded = True
|
| 67 |
+
print(f"Modèle LLaVA ET Processeur ({LLAVA_MODEL_NAME}) chargés avec succès sur '{device}'.")
|
| 68 |
|
| 69 |
except Exception as e:
|
| 70 |
print(f"Erreur critique lors du chargement du modèle LLaVA ({LLAVA_MODEL_NAME}): {e}")
|
|
|
|
| 76 |
global llava_processor, llava_model, llava_model_loaded, device
|
| 77 |
|
| 78 |
if not llava_model_loaded:
|
| 79 |
+
print("Modèle LLaVA non chargé. Tentative de chargement à la demande...")
|
| 80 |
load_llava_model()
|
| 81 |
if not llava_model_loaded:
|
| 82 |
+
return "Erreur: Le modèle LLaVA n'a pas pu être chargé."
|
| 83 |
+
|
| 84 |
if not os.path.exists(image_path):
|
| 85 |
return f"Erreur: Le fichier image {image_path} n'existe pas."
|
| 86 |
|
| 87 |
try:
|
| 88 |
image = Image.open(image_path).convert("RGB")
|
| 89 |
|
| 90 |
+
# --- PROMPT AMÉLIORÉ ---
|
| 91 |
+
user_prompt_fr = (
|
| 92 |
+
"Analyse cette image en tant qu'œuvre d'art. "
|
| 93 |
+
"Fournis une description objective, factuelle et très détaillée en français. "
|
| 94 |
+
"Concentre-toi sur les éléments visuels : la scène globale, les sujets et personnages (y compris leur apparence, posture, expression, et toute forme de nudité si présente), "
|
| 95 |
+
"les objets, l'arrière-plan, les formes, les couleurs, la lumière, la composition et la technique artistique apparente. "
|
| 96 |
+
"Évite toute interprétation subjective ou jugement moral, et décris ce qui est visible de manière exhaustive."
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Format de prompt spécifique à LLaVA v1.5+ (Mistral utilise ce format)
|
| 100 |
+
prompt_template = f"<s>[INST] <image>\n{user_prompt_fr} [/INST]"
|
| 101 |
|
| 102 |
+
print(f"Préparation des entrées pour LLaVA avec le prompt utilisateur (condensé): \"{user_prompt_fr[:100]}...\"")
|
|
|
|
| 103 |
|
| 104 |
+
# Le processeur gère la tokenisation du texte et le prétraitement de l'image
|
| 105 |
+
inputs_on_cpu = llava_processor(text=prompt_template, images=image, return_tensors="pt")
|
| 106 |
+
|
| 107 |
+
# Déplacer les tenseurs sur le bon device
|
| 108 |
inputs = {}
|
| 109 |
for key, value in inputs_on_cpu.items():
|
| 110 |
if torch.is_tensor(value):
|
| 111 |
inputs[key] = value.to(device)
|
| 112 |
else:
|
| 113 |
+
inputs[key] = value
|
| 114 |
|
| 115 |
+
# S'assurer que les types de tenseurs correspondent au modèle sur GPU/MPS (si float16/bfloat16)
|
| 116 |
if (device == "cuda" or device == "mps") and hasattr(llava_model, 'dtype') and \
|
| 117 |
(llava_model.dtype == torch.float16 or llava_model.dtype == torch.bfloat16):
|
| 118 |
for k_tensor, v_tensor in inputs.items():
|
| 119 |
+
if torch.is_tensor(v_tensor) and torch.is_floating_point(v_tensor): # Appliquer seulement aux tenseurs flottants
|
| 120 |
inputs[k_tensor] = v_tensor.to(llava_model.dtype)
|
| 121 |
|
| 122 |
input_dtypes_log = {k: v.dtype for k,v in inputs.items() if torch.is_tensor(v)}
|
| 123 |
print(f"Génération de la description LLaVA pour {image_path} (device: {device}, input dtypes: {input_dtypes_log})...")
|
| 124 |
|
| 125 |
+
# Paramètres de génération (ajustables si nécessaire)
|
| 126 |
generation_kwargs = {
|
| 127 |
+
"max_new_tokens": 768, # Augmenté légèrement pour des descriptions potentiellement plus longues
|
| 128 |
+
"num_beams": 3, # Un peu de beam search peut améliorer la cohérence
|
| 129 |
+
"early_stopping": True,
|
| 130 |
+
"do_sample": False # Pour des descriptions plus factuelles et moins "créatives" aléatoirement. Mettre True avec temperature si on veut plus de variété.
|
| 131 |
+
# "temperature": 0.7, # À utiliser avec do_sample=True si on veut de la créativité
|
| 132 |
+
# "top_p": 0.9, # À utiliser avec do_sample=True
|
| 133 |
}
|
| 134 |
|
| 135 |
generated_ids = llava_model.generate(**inputs, **generation_kwargs)
|
| 136 |
|
| 137 |
+
# Pour LLaVA, il est important de décoder uniquement les tokens générés *après* le prompt.
|
| 138 |
+
# Certains processeurs/modèles gèrent cela différemment.
|
| 139 |
+
# Pour LLaVA-NeXT, le décodage de la séquence complète et le nettoyage du prompt est une approche courante.
|
| 140 |
+
# Ou, si l'on connaît la longueur des tokens d'entrée :
|
| 141 |
+
# input_token_len = inputs.get('input_ids', torch.tensor([])).shape[-1]
|
| 142 |
+
# generated_ids_only = generated_ids[0, input_token_len:]
|
| 143 |
+
# cleaned_text = llava_processor.decode(generated_ids_only, skip_special_tokens=True).strip()
|
| 144 |
|
| 145 |
+
# Approche plus simple : décoder toute la séquence et enlever manuellement le prompt si besoin.
|
| 146 |
+
# Souvent, pour les formats [INST]...[/INST], le modèle génère directement la réponse.
|
| 147 |
+
full_text = llava_processor.decode(generated_ids[0], skip_special_tokens=True).strip()
|
| 148 |
|
| 149 |
+
# Nettoyage du texte généré pour enlever le prompt s'il est répété (spécifique au format de sortie du modèle)
|
| 150 |
+
# Le format "[INST] <image> \n {prompt} [/INST] {réponse}" fait que la réponse est souvent propre.
|
| 151 |
+
inst_marker = "[/INST]"
|
| 152 |
+
if inst_marker in full_text:
|
| 153 |
+
cleaned_text = full_text.split(inst_marker, 1)[-1].strip()
|
| 154 |
+
else:
|
| 155 |
+
cleaned_text = full_text # Si le marqueur n'est pas là, prendre tout (peut arriver)
|
| 156 |
|
| 157 |
+
print(f"Description (nettoyée) de LLaVA: {cleaned_text[:200]}...") # Log tronqué
|
| 158 |
+
return cleaned_text if cleaned_text else "Aucune description textuelle distincte n'a été générée par LLaVA."
|
| 159 |
|
| 160 |
except Exception as e:
|
| 161 |
print(f"Erreur détaillée lors de la génération de la description avec LLaVA:")
|
| 162 |
traceback.print_exc()
|
| 163 |
if torch.cuda.is_available() or device == "mps":
|
| 164 |
if device == "cuda": torch.cuda.empty_cache()
|
| 165 |
+
# if device == "mps" and hasattr(torch, 'mps') and hasattr(torch.mps, 'empty_cache'): torch.mps.empty_cache()
|
| 166 |
return f"Erreur lors de la génération de la description avec LLaVA: {type(e).__name__} - {str(e)}"
|
| 167 |
|
| 168 |
|
| 169 |
# --- Fonctions de gestion du modèle actif ---
|
| 170 |
+
ACTIVE_MODEL = "llava" # Pour l'instant, seul LLaVA est configuré
|
| 171 |
|
| 172 |
def load_active_model():
|
| 173 |
print(f"Tentative de chargement du modèle actif: {ACTIVE_MODEL}")
|
| 174 |
if ACTIVE_MODEL == "llava":
|
| 175 |
load_llava_model()
|
|
|
|
|
|
|
|
|
|
| 176 |
else:
|
| 177 |
print(f"Modèle actif inconnu: {ACTIVE_MODEL}. Aucun modèle ne sera chargé.")
|
| 178 |
|
| 179 |
def generate_active_description(image_path: str) -> str:
|
| 180 |
if ACTIVE_MODEL == "llava":
|
| 181 |
return generate_description_llava(image_path)
|
|
|
|
|
|
|
| 182 |
else:
|
| 183 |
error_msg = f"Erreur: Modèle actif inconnu ({ACTIVE_MODEL}). Impossible de générer une description."
|
| 184 |
print(error_msg)
|
|
|
|
| 187 |
def is_active_model_loaded() -> bool:
|
| 188 |
if ACTIVE_MODEL == "llava":
|
| 189 |
return llava_model_loaded
|
|
|
|
|
|
|
| 190 |
return False
|
| 191 |
|
| 192 |
# --- Section de Test (pour exécution directe de ce fichier) ---
|
| 193 |
if __name__ == '__main__':
|
| 194 |
print("Début du test de processing.py...")
|
| 195 |
|
| 196 |
+
dummy_image_name = "dummy_test_image.png"
|
|
|
|
| 197 |
if not os.path.exists(dummy_image_name):
|
| 198 |
try:
|
|
|
|
| 199 |
img = Image.new('RGB', (200, 150), color = 'skyblue')
|
| 200 |
draw = ImageDraw.Draw(img)
|
| 201 |
+
draw.text((10, 10), "Test Image for LLaVA", fill='black')
|
| 202 |
+
# Ajouter quelques formes pour le test
|
| 203 |
+
draw.ellipse((30, 50, 90, 110), fill='red', outline='black')
|
| 204 |
+
draw.rectangle((100, 40, 170, 120), fill='lightgreen', outline='blue')
|
| 205 |
img.save(dummy_image_name)
|
| 206 |
print(f"Image de test '{dummy_image_name}' créée.")
|
| 207 |
except Exception as e_img:
|
| 208 |
print(f"Impossible de créer l'image de test (vérifiez Pillow) : {e_img}")
|
| 209 |
|
| 210 |
if os.path.exists(dummy_image_name):
|
| 211 |
+
print(f"Utilisation du modèle actif : {ACTIVE_MODEL} sur device {device}")
|
| 212 |
print("Chargement du modèle actif (peut prendre du temps, surtout la première fois)...")
|
| 213 |
+
load_active_model()
|
| 214 |
if is_active_model_loaded():
|
| 215 |
print(f"\nGénération de la description pour l'image de test '{dummy_image_name}'...")
|
| 216 |
description = generate_active_description(dummy_image_name)
|
requirements.txt
CHANGED
|
@@ -4,8 +4,8 @@ python-multipart
|
|
| 4 |
jinja2
|
| 5 |
torch
|
| 6 |
torchvision
|
| 7 |
-
# torchaudio #
|
| 8 |
-
transformers
|
| 9 |
Pillow
|
| 10 |
accelerate
|
| 11 |
einops
|
|
@@ -15,6 +15,7 @@ tiktoken
|
|
| 15 |
# Pour l'authentification Basic Auth
|
| 16 |
python-jose[cryptography]>=3.3.0
|
| 17 |
passlib[bcrypt]>=1.7.4
|
| 18 |
-
bcrypt>=3.2.0,<4.1.0
|
| 19 |
|
| 20 |
-
# bitsandbytes # Optionnel
|
|
|
|
|
|
| 4 |
jinja2
|
| 5 |
torch
|
| 6 |
torchvision
|
| 7 |
+
# torchaudio # Optionnel
|
| 8 |
+
transformers>=4.38.0 # Assurer une version récente pour LLaVA-NeXT et autres modèles récents
|
| 9 |
Pillow
|
| 10 |
accelerate
|
| 11 |
einops
|
|
|
|
| 15 |
# Pour l'authentification Basic Auth
|
| 16 |
python-jose[cryptography]>=3.3.0
|
| 17 |
passlib[bcrypt]>=1.7.4
|
| 18 |
+
bcrypt>=3.2.0,<4.1.0
|
| 19 |
|
| 20 |
+
# bitsandbytes # Optionnel pour la quantification
|
| 21 |
+
sentencepiece # Souvent requis par les tokenizers de transformers
|