Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, Form, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from transformers import Qwen2VLForConditionalGeneration, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| from PIL import Image | |
| import torch | |
| import tempfile | |
| import os | |
| import logging | |
| from datetime import datetime | |
| # Configuration logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialiser FastAPI | |
| app = FastAPI( | |
| title="Sparrow Qwen2-VL API", | |
| description="API REST pour extraction de données depuis images via Qwen2-VL", | |
| version="1.0.0" | |
| ) | |
| # Charger le modèle au démarrage | |
| logger.info("🔄 Chargement du modèle Qwen2-VL-7B-Instruct...") | |
| try: | |
| model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| "Qwen/Qwen2-VL-7B-Instruct", | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") | |
| logger.info("✅ Modèle chargé avec succès!") | |
| except Exception as e: | |
| logger.error(f"❌ Erreur chargement modèle: {e}") | |
| raise | |
| # Modèle de réponse | |
| class ExtractionResponse(BaseModel): | |
| result: str | |
| status: str | |
| timestamp: str | |
| async def predict( | |
| image: UploadFile = File(..., description="Image à analyser"), | |
| query: str = Form(..., description="Instruction d'extraction") | |
| ): | |
| """ | |
| Extraire des données d'une image selon la requête | |
| """ | |
| timestamp = datetime.now().isoformat() | |
| temp_path = None | |
| try: | |
| # Validation du fichier | |
| if not image.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="Fichier doit être une image") | |
| # Sauvegarder temporairement | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file: | |
| content = await image.read() | |
| tmp_file.write(content) | |
| temp_path = tmp_file.name | |
| logger.info(f"🖼️ Traitement image: {image.filename}") | |
| logger.info(f"📝 Requête: {query}") | |
| # Préparer l'image | |
| img = Image.open(temp_path) | |
| # Créer les messages pour le modèle | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": temp_path | |
| }, | |
| { | |
| "type": "text", | |
| "text": query | |
| } | |
| ] | |
| } | |
| ] | |
| # Appliquer le template de chat | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Traiter les informations visuelles | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| # Préparer les inputs | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| # Déplacer sur le bon device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| inputs = inputs.to(device) | |
| # Générer la réponse | |
| logger.info("🤖 Génération de la réponse...") | |
| generated_ids = model.generate(**inputs, max_new_tokens=4096) | |
| # Nettoyer les tokens | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| # Décoder le résultat | |
| output = processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True | |
| )[0] | |
| logger.info(f"✅ Extraction réussie: {len(output)} caractères") | |
| return ExtractionResponse( | |
| result=output, | |
| status="success", | |
| timestamp=timestamp | |
| ) | |
| except Exception as e: | |
| logger.error(f"❌ Erreur traitement: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| finally: | |
| # Nettoyer le fichier temporaire | |
| if temp_path and os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| logger.info("🧹 Fichier temporaire nettoyé") | |
| def health_check(): | |
| """ | |
| Vérifier que l'API fonctionne | |
| """ | |
| return { | |
| "status": "healthy", | |
| "model": "Qwen2-VL-7B-Instruct", | |
| "device": "cuda" if torch.cuda.is_available() else "cpu", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| def api_info(): | |
| """ | |
| Informations sur l'API | |
| """ | |
| return { | |
| "name": "Sparrow Qwen2-VL API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "predict": "/predict", | |
| "health": "/health", | |
| "info": "/info" | |
| }, | |
| "model": "Qwen/Qwen2-VL-7B-Instruct" | |
| } | |
| # Pour compatibilité avec Gradio (optionnel) | |
| def root(): | |
| return JSONResponse({ | |
| "message": "Sparrow Qwen2-VL API is running", | |
| "docs": "/docs", | |
| "health": "/health", | |
| "predict": "/predict" | |
| }) | |
| # Lancer le serveur si exécuté directement | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |