KJ24's picture
Update app.py
b390834 verified
from fastapi import FastAPI
from pydantic import BaseModel
from PIL import Image
import torch
import base64
import io
from transformers import (
LlavaProcessor,
LlavaForConditionalGeneration
)
app = FastAPI()
MODEL_NAME = "llava-hf/llava-1.5-7b-hf"
# Chargement du processor (inclut le tokenizer et le vision processor)
processor = LlavaProcessor.from_pretrained(
MODEL_NAME,
trust_remote_code=True
)
# 🔧 Ajout d'un chat_template personnalisé si manquant
template = (
"{% for message in messages %}"
"{% if message['role'] == 'user' %}User: {{ message['content'] }}\n"
"{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n"
"{% endif %}{% endfor %}Assistant:"
)
if not getattr(processor.tokenizer, "chat_template", None):
processor.tokenizer.chat_template = template
# Chargement du modèle LLaVA avec gestion mémoire et allocation sur device
model = LlavaForConditionalGeneration.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto"
)
model.eval()
# Définition du format d'entrée attendu
class InputData(BaseModel):
prompt: str
image_base64: str
@app.post("/predict")
async def predict(data: InputData):
try:
# 1️⃣ Décodage de l'image depuis base64
base64_str = data.image_base64
if base64_str.startswith("data:image"):
base64_str = base64_str.split(",", 1)[1]
image_data = base64.b64decode(base64_str)
image = Image.open(io.BytesIO(image_data)).convert("RGB")
print("📷 Taille image :", image.size)
print("📷 Format image :", image.format)
# 2️⃣ Génération du prompt compatible avec le chat_template
chat_prompt = processor.tokenizer.apply_chat_template(
[{"role": "user", "content": data.prompt}],
tokenize=False,
add_generation_prompt=True
)
if not chat_prompt.strip():
raise ValueError("Le prompt généré est vide après application du chat_template.")
print("🟡 Prompt utilisé :", repr(chat_prompt))
# 3️⃣ Tokenisation du texte et de l'image séparément
image_inputs = processor(images=image, return_tensors="pt")
text_inputs = processor.tokenizer(chat_prompt, return_tensors="pt")
print("📤 image_inputs (DEBUG) :", image_inputs)
print("📤 text_inputs (DEBUG) :", text_inputs)
if image_inputs is None:
raise ValueError("❌ image_inputs est None")
if text_inputs is None:
raise ValueError("❌ text_inputs est None")
if "pixel_values" not in image_inputs or image_inputs["pixel_values"] is None:
raise ValueError("❌ pixel_values manquant ou None")
if "input_ids" not in text_inputs or text_inputs["input_ids"] is None:
raise ValueError("❌ input_ids manquant ou None")
if "attention_mask" not in text_inputs or text_inputs["attention_mask"] is None:
raise ValueError("❌ attention_mask manquant ou None")
print("🧪 text_inputs:", text_inputs)
print("🧪 image_inputs:", image_inputs)
if "input_ids" not in text_inputs or text_inputs["input_ids"] is None:
raise ValueError("❌ input_ids manquant dans text_inputs.")
if "pixel_values" not in image_inputs or image_inputs["pixel_values"] is None:
raise ValueError("❌ pixel_values manquant dans image_inputs.")
input_ids = text_inputs["input_ids"].to(model.device)
attention_mask = text_inputs["attention_mask"].to(model.device)
pixel_values = image_inputs["pixel_values"].to(model.device)
if input_ids.shape[1] == 0:
raise ValueError("Aucun token texte généré — le prompt est probablement invalide.")
# 4️⃣ Fusion des inputs manuellement
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values
}
text_inputs = processor.tokenizer(chat_prompt, return_tensors="pt")
print("🧪 Inputs:", inputs.keys())
# 5️⃣ Génération
output = model.generate(
**inputs,
max_new_tokens=800,
do_sample=False,
temperature=0.2
)
# 6️⃣ Décodage
decoded = processor.batch_decode(output, skip_special_tokens=True)
print("🧪 Output brut:", decoded)
if decoded and isinstance(decoded, list) and decoded[0]:
response = decoded[0].strip()
else:
response = "⚠️ No caption generated (empty output)"
return {"caption": response}
except Exception as e:
print("❌ Exception attrapée :", str(e))
return {"caption": f"❌ Erreur API : {str(e)}"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)