from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info from PIL import Image import torch import tempfile import os import logging from datetime import datetime # Configuration logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialiser FastAPI app = FastAPI( title="Sparrow Qwen2-VL API", description="API REST pour extraction de données depuis images via Qwen2-VL", version="1.0.0" ) # Charger le modèle au démarrage logger.info("🔄 Chargement du modèle Qwen2-VL-7B-Instruct...") try: model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" ) processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") logger.info("✅ Modèle chargé avec succès!") except Exception as e: logger.error(f"❌ Erreur chargement modèle: {e}") raise # Modèle de réponse class ExtractionResponse(BaseModel): result: str status: str timestamp: str @app.post("/predict", response_model=ExtractionResponse) async def predict( image: UploadFile = File(..., description="Image à analyser"), query: str = Form(..., description="Instruction d'extraction") ): """ Extraire des données d'une image selon la requête """ timestamp = datetime.now().isoformat() temp_path = None try: # Validation du fichier if not image.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="Fichier doit être une image") # Sauvegarder temporairement with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file: content = await image.read() tmp_file.write(content) temp_path = tmp_file.name logger.info(f"🖼️ Traitement image: {image.filename}") logger.info(f"📝 Requête: {query}") # Préparer l'image img = Image.open(temp_path) # Créer les messages pour le modèle messages = [ { "role": "user", "content": [ { "type": "image", "image": temp_path }, { "type": "text", "text": query } ] } ] # Appliquer le template de chat text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Traiter les informations visuelles image_inputs, video_inputs = process_vision_info(messages) # Préparer les inputs inputs = processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) # Déplacer sur le bon device device = "cuda" if torch.cuda.is_available() else "cpu" inputs = inputs.to(device) # Générer la réponse logger.info("🤖 Génération de la réponse...") generated_ids = model.generate(**inputs, max_new_tokens=4096) # Nettoyer les tokens generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] # Décoder le résultat output = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True )[0] logger.info(f"✅ Extraction réussie: {len(output)} caractères") return ExtractionResponse( result=output, status="success", timestamp=timestamp ) except Exception as e: logger.error(f"❌ Erreur traitement: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) finally: # Nettoyer le fichier temporaire if temp_path and os.path.exists(temp_path): os.remove(temp_path) logger.info("🧹 Fichier temporaire nettoyé") @app.get("/health") def health_check(): """ Vérifier que l'API fonctionne """ return { "status": "healthy", "model": "Qwen2-VL-7B-Instruct", "device": "cuda" if torch.cuda.is_available() else "cpu", "timestamp": datetime.now().isoformat() } @app.get("/info") def api_info(): """ Informations sur l'API """ return { "name": "Sparrow Qwen2-VL API", "version": "1.0.0", "endpoints": { "predict": "/predict", "health": "/health", "info": "/info" }, "model": "Qwen/Qwen2-VL-7B-Instruct" } # Pour compatibilité avec Gradio (optionnel) @app.get("/") def root(): return JSONResponse({ "message": "Sparrow Qwen2-VL API is running", "docs": "/docs", "health": "/health", "predict": "/predict" }) # Lancer le serveur si exécuté directement if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)