Frederic-CellNum commited on
Commit
90be598
·
verified ·
1 Parent(s): 2d11b86

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -0
app.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, Form, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from pydantic import BaseModel
4
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
5
+ from qwen_vl_utils import process_vision_info
6
+ from PIL import Image
7
+ import torch
8
+ import tempfile
9
+ import os
10
+ import logging
11
+ from datetime import datetime
12
+
13
+ # Configuration logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Initialiser FastAPI
18
+ app = FastAPI(
19
+ title="Sparrow Qwen2-VL API",
20
+ description="API REST pour extraction de données depuis images via Qwen2-VL",
21
+ version="1.0.0"
22
+ )
23
+
24
+ # Charger le modèle au démarrage
25
+ logger.info("🔄 Chargement du modèle Qwen2-VL-7B-Instruct...")
26
+ try:
27
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
28
+ "Qwen/Qwen2-VL-7B-Instruct",
29
+ torch_dtype="auto",
30
+ device_map="auto"
31
+ )
32
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
33
+ logger.info("✅ Modèle chargé avec succès!")
34
+ except Exception as e:
35
+ logger.error(f"❌ Erreur chargement modèle: {e}")
36
+ raise
37
+
38
+ # Modèle de réponse
39
+ class ExtractionResponse(BaseModel):
40
+ result: str
41
+ status: str
42
+ timestamp: str
43
+
44
+ @app.post("/predict", response_model=ExtractionResponse)
45
+ async def predict(
46
+ image: UploadFile = File(..., description="Image à analyser"),
47
+ query: str = Form(..., description="Instruction d'extraction")
48
+ ):
49
+ """
50
+ Extraire des données d'une image selon la requête
51
+ """
52
+ timestamp = datetime.now().isoformat()
53
+ temp_path = None
54
+
55
+ try:
56
+ # Validation du fichier
57
+ if not image.content_type.startswith('image/'):
58
+ raise HTTPException(status_code=400, detail="Fichier doit être une image")
59
+
60
+ # Sauvegarder temporairement
61
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
62
+ content = await image.read()
63
+ tmp_file.write(content)
64
+ temp_path = tmp_file.name
65
+
66
+ logger.info(f"🖼️ Traitement image: {image.filename}")
67
+ logger.info(f"📝 Requête: {query}")
68
+
69
+ # Préparer l'image
70
+ img = Image.open(temp_path)
71
+
72
+ # Créer les messages pour le modèle
73
+ messages = [
74
+ {
75
+ "role": "user",
76
+ "content": [
77
+ {
78
+ "type": "image",
79
+ "image": temp_path
80
+ },
81
+ {
82
+ "type": "text",
83
+ "text": query
84
+ }
85
+ ]
86
+ }
87
+ ]
88
+
89
+ # Appliquer le template de chat
90
+ text = processor.apply_chat_template(
91
+ messages, tokenize=False, add_generation_prompt=True
92
+ )
93
+
94
+ # Traiter les informations visuelles
95
+ image_inputs, video_inputs = process_vision_info(messages)
96
+
97
+ # Préparer les inputs
98
+ inputs = processor(
99
+ text=[text],
100
+ images=image_inputs,
101
+ videos=video_inputs,
102
+ padding=True,
103
+ return_tensors="pt",
104
+ )
105
+
106
+ # Déplacer sur le bon device
107
+ device = "cuda" if torch.cuda.is_available() else "cpu"
108
+ inputs = inputs.to(device)
109
+
110
+ # Générer la réponse
111
+ logger.info("🤖 Génération de la réponse...")
112
+ generated_ids = model.generate(**inputs, max_new_tokens=4096)
113
+
114
+ # Nettoyer les tokens
115
+ generated_ids_trimmed = [
116
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
117
+ ]
118
+
119
+ # Décoder le résultat
120
+ output = processor.batch_decode(
121
+ generated_ids_trimmed,
122
+ skip_special_tokens=True,
123
+ clean_up_tokenization_spaces=True
124
+ )[0]
125
+
126
+ logger.info(f"✅ Extraction réussie: {len(output)} caractères")
127
+
128
+ return ExtractionResponse(
129
+ result=output,
130
+ status="success",
131
+ timestamp=timestamp
132
+ )
133
+
134
+ except Exception as e:
135
+ logger.error(f"❌ Erreur traitement: {str(e)}")
136
+ raise HTTPException(status_code=500, detail=str(e))
137
+
138
+ finally:
139
+ # Nettoyer le fichier temporaire
140
+ if temp_path and os.path.exists(temp_path):
141
+ os.remove(temp_path)
142
+ logger.info("🧹 Fichier temporaire nettoyé")
143
+
144
+ @app.get("/health")
145
+ def health_check():
146
+ """
147
+ Vérifier que l'API fonctionne
148
+ """
149
+ return {
150
+ "status": "healthy",
151
+ "model": "Qwen2-VL-7B-Instruct",
152
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
153
+ "timestamp": datetime.now().isoformat()
154
+ }
155
+
156
+ @app.get("/info")
157
+ def api_info():
158
+ """
159
+ Informations sur l'API
160
+ """
161
+ return {
162
+ "name": "Sparrow Qwen2-VL API",
163
+ "version": "1.0.0",
164
+ "endpoints": {
165
+ "predict": "/predict",
166
+ "health": "/health",
167
+ "info": "/info"
168
+ },
169
+ "model": "Qwen/Qwen2-VL-7B-Instruct"
170
+ }
171
+
172
+ # Pour compatibilité avec Gradio (optionnel)
173
+ @app.get("/")
174
+ def root():
175
+ return JSONResponse({
176
+ "message": "Sparrow Qwen2-VL API is running",
177
+ "docs": "/docs",
178
+ "health": "/health",
179
+ "predict": "/predict"
180
+ })
181
+
182
+ # Lancer le serveur si exécuté directement
183
+ if __name__ == "__main__":
184
+ import uvicorn
185
+ uvicorn.run(app, host="0.0.0.0", port=7860)