File size: 5,497 Bytes
90be598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from PIL import Image
import torch
import tempfile
import os
import logging
from datetime import datetime

# Configuration logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialiser FastAPI
app = FastAPI(
    title="Sparrow Qwen2-VL API",
    description="API REST pour extraction de données depuis images via Qwen2-VL",
    version="1.0.0"
)

# Charger le modèle au démarrage
logger.info("🔄 Chargement du modèle Qwen2-VL-7B-Instruct...")
try:
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-7B-Instruct",
        torch_dtype="auto",
        device_map="auto"
    )
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
    logger.info("✅ Modèle chargé avec succès!")
except Exception as e:
    logger.error(f"❌ Erreur chargement modèle: {e}")
    raise

# Modèle de réponse
class ExtractionResponse(BaseModel):
    result: str
    status: str
    timestamp: str

@app.post("/predict", response_model=ExtractionResponse)
async def predict(
    image: UploadFile = File(..., description="Image à analyser"),
    query: str = Form(..., description="Instruction d'extraction")
):
    """
    Extraire des données d'une image selon la requête
    """
    timestamp = datetime.now().isoformat()
    temp_path = None
    
    try:
        # Validation du fichier
        if not image.content_type.startswith('image/'):
            raise HTTPException(status_code=400, detail="Fichier doit être une image")
        
        # Sauvegarder temporairement
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp_file:
            content = await image.read()
            tmp_file.write(content)
            temp_path = tmp_file.name
        
        logger.info(f"🖼️  Traitement image: {image.filename}")
        logger.info(f"📝 Requête: {query}")
        
        # Préparer l'image
        img = Image.open(temp_path)
        
        # Créer les messages pour le modèle
        messages = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "image",
                        "image": temp_path
                    },
                    {
                        "type": "text",
                        "text": query
                    }
                ]
            }
        ]
        
        # Appliquer le template de chat
        text = processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        
        # Traiter les informations visuelles
        image_inputs, video_inputs = process_vision_info(messages)
        
        # Préparer les inputs
        inputs = processor(
            text=[text],
            images=image_inputs,
            videos=video_inputs,
            padding=True,
            return_tensors="pt",
        )
        
        # Déplacer sur le bon device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        inputs = inputs.to(device)
        
        # Générer la réponse
        logger.info("🤖 Génération de la réponse...")
        generated_ids = model.generate(**inputs, max_new_tokens=4096)
        
        # Nettoyer les tokens
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        
        # Décoder le résultat
        output = processor.batch_decode(
            generated_ids_trimmed, 
            skip_special_tokens=True, 
            clean_up_tokenization_spaces=True
        )[0]
        
        logger.info(f"✅ Extraction réussie: {len(output)} caractères")
        
        return ExtractionResponse(
            result=output,
            status="success",
            timestamp=timestamp
        )
        
    except Exception as e:
        logger.error(f"❌ Erreur traitement: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))
        
    finally:
        # Nettoyer le fichier temporaire
        if temp_path and os.path.exists(temp_path):
            os.remove(temp_path)
            logger.info("🧹 Fichier temporaire nettoyé")

@app.get("/health")
def health_check():
    """
    Vérifier que l'API fonctionne
    """
    return {
        "status": "healthy",
        "model": "Qwen2-VL-7B-Instruct",
        "device": "cuda" if torch.cuda.is_available() else "cpu",
        "timestamp": datetime.now().isoformat()
    }

@app.get("/info")
def api_info():
    """
    Informations sur l'API
    """
    return {
        "name": "Sparrow Qwen2-VL API",
        "version": "1.0.0",
        "endpoints": {
            "predict": "/predict",
            "health": "/health",
            "info": "/info"
        },
        "model": "Qwen/Qwen2-VL-7B-Instruct"
    }

# Pour compatibilité avec Gradio (optionnel)
@app.get("/")
def root():
    return JSONResponse({
        "message": "Sparrow Qwen2-VL API is running",
        "docs": "/docs",
        "health": "/health",
        "predict": "/predict"
    })

# Lancer le serveur si exécuté directement
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)