Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from PIL import Image | |
| import torch | |
| import base64 | |
| import io | |
| from transformers import ( | |
| LlavaProcessor, | |
| LlavaForConditionalGeneration | |
| ) | |
| app = FastAPI() | |
| MODEL_NAME = "llava-hf/llava-1.5-7b-hf" | |
| # Chargement du processor (inclut le tokenizer et le vision processor) | |
| processor = LlavaProcessor.from_pretrained( | |
| MODEL_NAME, | |
| trust_remote_code=True | |
| ) | |
| # 🔧 Ajout d'un chat_template personnalisé si manquant | |
| template = ( | |
| "{% for message in messages %}" | |
| "{% if message['role'] == 'user' %}User: {{ message['content'] }}\n" | |
| "{% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }}\n" | |
| "{% endif %}{% endfor %}Assistant:" | |
| ) | |
| if not getattr(processor.tokenizer, "chat_template", None): | |
| processor.tokenizer.chat_template = template | |
| # Chargement du modèle LLaVA avec gestion mémoire et allocation sur device | |
| model = LlavaForConditionalGeneration.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True, | |
| device_map="auto" | |
| ) | |
| model.eval() | |
| # Définition du format d'entrée attendu | |
| class InputData(BaseModel): | |
| prompt: str | |
| image_base64: str | |
| async def predict(data: InputData): | |
| try: | |
| # 1️⃣ Décodage de l'image depuis base64 | |
| base64_str = data.image_base64 | |
| if base64_str.startswith("data:image"): | |
| base64_str = base64_str.split(",", 1)[1] | |
| image_data = base64.b64decode(base64_str) | |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| print("📷 Taille image :", image.size) | |
| print("📷 Format image :", image.format) | |
| # 2️⃣ Génération du prompt compatible avec le chat_template | |
| chat_prompt = processor.tokenizer.apply_chat_template( | |
| [{"role": "user", "content": data.prompt}], | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| if not chat_prompt.strip(): | |
| raise ValueError("Le prompt généré est vide après application du chat_template.") | |
| print("🟡 Prompt utilisé :", repr(chat_prompt)) | |
| # 3️⃣ Tokenisation du texte et de l'image séparément | |
| image_inputs = processor(images=image, return_tensors="pt") | |
| text_inputs = processor.tokenizer(chat_prompt, return_tensors="pt") | |
| print("📤 image_inputs (DEBUG) :", image_inputs) | |
| print("📤 text_inputs (DEBUG) :", text_inputs) | |
| if image_inputs is None: | |
| raise ValueError("❌ image_inputs est None") | |
| if text_inputs is None: | |
| raise ValueError("❌ text_inputs est None") | |
| if "pixel_values" not in image_inputs or image_inputs["pixel_values"] is None: | |
| raise ValueError("❌ pixel_values manquant ou None") | |
| if "input_ids" not in text_inputs or text_inputs["input_ids"] is None: | |
| raise ValueError("❌ input_ids manquant ou None") | |
| if "attention_mask" not in text_inputs or text_inputs["attention_mask"] is None: | |
| raise ValueError("❌ attention_mask manquant ou None") | |
| print("🧪 text_inputs:", text_inputs) | |
| print("🧪 image_inputs:", image_inputs) | |
| if "input_ids" not in text_inputs or text_inputs["input_ids"] is None: | |
| raise ValueError("❌ input_ids manquant dans text_inputs.") | |
| if "pixel_values" not in image_inputs or image_inputs["pixel_values"] is None: | |
| raise ValueError("❌ pixel_values manquant dans image_inputs.") | |
| input_ids = text_inputs["input_ids"].to(model.device) | |
| attention_mask = text_inputs["attention_mask"].to(model.device) | |
| pixel_values = image_inputs["pixel_values"].to(model.device) | |
| if input_ids.shape[1] == 0: | |
| raise ValueError("Aucun token texte généré — le prompt est probablement invalide.") | |
| # 4️⃣ Fusion des inputs manuellement | |
| inputs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "pixel_values": pixel_values | |
| } | |
| text_inputs = processor.tokenizer(chat_prompt, return_tensors="pt") | |
| print("🧪 Inputs:", inputs.keys()) | |
| # 5️⃣ Génération | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=800, | |
| do_sample=False, | |
| temperature=0.2 | |
| ) | |
| # 6️⃣ Décodage | |
| decoded = processor.batch_decode(output, skip_special_tokens=True) | |
| print("🧪 Output brut:", decoded) | |
| if decoded and isinstance(decoded, list) and decoded[0]: | |
| response = decoded[0].strip() | |
| else: | |
| response = "⚠️ No caption generated (empty output)" | |
| return {"caption": response} | |
| except Exception as e: | |
| print("❌ Exception attrapée :", str(e)) | |
| return {"caption": f"❌ Erreur API : {str(e)}"} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |