File size: 8,431 Bytes
66791ba
b1cba22
9c94de5
66791ba
 
cff0bfe
9c94de5
b13493b
e3527ed
5cf61c7
 
 
b13493b
 
 
5cf61c7
b13493b
5cf61c7
b13493b
 
 
 
 
 
 
 
 
 
9c94de5
b13493b
b1cba22
 
9c94de5
 
 
b1cba22
b13493b
9c94de5
 
 
b1cba22
b13493b
 
 
 
9c94de5
 
 
b13493b
 
 
9c94de5
b13493b
 
 
 
 
 
 
 
 
9c94de5
 
b13493b
b1cba22
b13493b
9c94de5
 
b13493b
 
 
9c94de5
 
 
 
b13493b
 
 
 
 
b1cba22
1a0da4b
b13493b
9c94de5
 
 
 
 
 
b13493b
 
 
 
9c94de5
 
 
 
 
 
b13493b
b1cba22
5cf61c7
b13493b
9c94de5
 
 
 
 
 
 
b13493b
9c94de5
b1cba22
dbadef3
b13493b
 
 
 
9f55257
9c94de5
b13493b
 
 
 
 
 
 
 
4d5ff3f
b13493b
 
cff0bfe
b13493b
 
 
 
9c94de5
b13493b
 
 
 
 
 
9c94de5
b13493b
 
 
 
 
 
 
 
 
 
9c94de5
b13493b
 
b1cba22
9c94de5
 
 
b13493b
aa56d44
b13493b
b1cba22
9c94de5
b1cba22
9c94de5
b13493b
 
9c94de5
b13493b
 
 
 
9c94de5
 
 
 
 
 
 
b13493b
9c94de5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b13493b
 
 
 
9c94de5
 
 
 
b13493b
 
 
 
9c94de5
 
b13493b
 
9c94de5
 
 
b13493b
 
 
 
b1cba22
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import os
import json
import time
os.environ['HF_HOME'] = '/tmp/cache'
os.environ['TORCH_HOME'] = '/tmp/cache'

from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from PIL import Image
import torch
import io
import colorthief
import tempfile
import numpy as np

app = FastAPI(title="Fashion Classification API")

# Middleware CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["*"]
)

# --- ÉTAT DU MODÈLE ---
print("⚠️ Démarrage du chargement du modèle Marqo-FashionSigLIP...")
model = None
processor = None
model_loading = False
model_loaded = False
model_error = None

def load_fashion_model():
    global model, processor, model_loading, model_loaded, model_error
    
    model_loading = True
    try:
        from transformers import AutoModel, AutoProcessor
        
        model_name = "Marqo/Marqo-FashionSigLIP-Classification"
        
        print("📦 Téléchargement du modèle... (cela peut prendre 5-10 minutes)")
        
        # Charger le modèle SigLIP
        model = AutoModel.from_pretrained(
            model_name,
            cache_dir="/tmp/cache",
            torch_dtype=torch.float16,
            trust_remote_code=True
        )
        
        processor = AutoProcessor.from_pretrained(
            model_name,
            trust_remote_code=True
        )
        
        print("✅ Modèle Marqo-FashionSigLIP chargé avec succès !")
        model_loaded = True
        model_loading = False
        
    except Exception as e:
        print(f"❌ Erreur chargement modèle: {e}")
        model_error = str(e)
        model_loading = False
        import traceback
        traceback.print_exc()

# Démarrer le chargement IMMÉDIATEMENT
load_fashion_model()

# Catégories de mode
categories = [
    "t-shirt", "dress", "jeans", "shirt", "skirt", 
    "sneakers", "handbag", "jacket", "shorts", "sweater",
    "coat", "high heels", "blouse", "boots", "hat"
]

@app.get("/")
def read_root():
    return {
        "message": "Fashion Classification API is running!", 
        "status": "OK",
        "model_status": "loaded" if model_loaded else "loading" if model_loading else "error",
        "model_name": "Marqo-FashionSigLIP-Classification"
    }

@app.get("/health")
def health_check():
    return {
        "model_loaded": model_loaded,
        "model_loading": model_loading,
        "model_error": model_error,
        "status": "ready" if model_loaded else "loading" if model_loading else "error",
        "model_name": "Marqo-FashionSigLIP-Classification",
        "timestamp": time.time()
    }

@app.post("/analyze")
async def analyze_image(file: UploadFile = File(...)):
    # Vérifier si le modèle est chargé
    if not model_loaded:
        if model_loading:
            raise HTTPException(status_code=423, detail="Model still loading. Please wait 5-10 minutes and check /health")
        else:
            raise HTTPException(status_code=500, detail=f"Model failed to load: {model_error}")
    
    if model is None or processor is None:
        raise HTTPException(status_code=500, detail="Model not available")
    
    try:
        # Lire et préparer l'image
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
        image = image.resize((384, 384))
        
        # Traitement avec SigLIP
        inputs = processor(
            text=categories,
            images=image,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=64,
        )
        
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model(**inputs)
        
        logits_per_image = outputs.logits_per_image
        probs = torch.sigmoid(logits_per_image)
        probs = probs.cpu().numpy()[0]
        
        predicted_idx = np.argmax(probs)
        category_name = categories[predicted_idx]
        confidence_score = float(probs[predicted_idx])
        
        # Analyse couleur
        try:
            with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
                image.save(tmp, format='JPEG')
                tmp_path = tmp.name
            
            color_thief = colorthief.ColorThief(tmp_path)
            dominant_color = color_thief.get_color(quality=1)
            hex_color = '#%02x%02x%02x' % dominant_color
            os.unlink(tmp_path)
            
        except Exception:
            hex_color = "#000000"

        return {
            "category": category_name,
            "confidence": round(confidence_score, 4),
            "color_hex": hex_color,
            "model": "Marqo-FashionSigLIP-Classification"
        }

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Analysis error: {str(e)}")

# Interface de test avec statut de chargement
@app.get("/test-ui", response_class=HTMLResponse)
async def test_ui():
    return f"""
    <html>
        <head>
            <title>FashionSigLIP Detection</title>
            <style>
                body {{ font-family: Arial, sans-serif; margin: 40px; }}
                .container {{ max-width: 600px; margin: 0 auto; }}
                form {{ border: 2px dashed #ccc; padding: 30px; text-align: center; }}
                .status {{ padding: 15px; margin: 10px 0; border-radius: 5px; }}
                .loading {{ background: #fff3cd; color: #856404; }}
                .ready {{ background: #d4edda; color: #155724; }}
                .error {{ background: #f8d7da; color: #721c24; }}
            </style>
            <script>
                function checkStatus() {{
                    fetch('/health')
                        .then(response => response.json())
                        .then(data => {{
                            const statusDiv = document.getElementById('model-status');
                            const submitBtn = document.getElementById('submit-btn');
                            
                            if (data.model_loaded) {{
                                statusDiv.innerHTML = '✅ <b>Modèle chargé et prêt !</b>';
                                statusDiv.className = 'status ready';
                                submitBtn.disabled = false;
                            }} else if (data.model_loading) {{
                                statusDiv.innerHTML = '⏳ <b>Chargement du modèle en cours...</b><br>Cela peut prendre 5-10 minutes';
                                statusDiv.className = 'status loading';
                                submitBtn.disabled = true;
                                setTimeout(checkStatus, 5000); // Re-check dans 5 sec
                            }} else {{
                                statusDiv.innerHTML = '❌ <b>Erreur de chargement:</b><br>' + (data.model_error || 'Unknown error');
                                statusDiv.className = 'status error';
                                submitBtn.disabled = true;
                            }}
                        }});
                }}
                
                // Vérifier le statut au chargement de la page
                window.onload = checkStatus;
            </script>
        </head>
        <body>
            <div class="container">
                <h1>👗 FashionSigLIP Detector</h1>
                
                <div id="model-status" class="status loading">
                    Vérification du statut du modèle...
                </div>
                
                <form action="/analyze" method="post" enctype="multipart/form-data">
                    <h3>Uploader une image de vêtement :</h3>
                    <input type="file" name="file" accept="image/*" required>
                    <br><br>
                    <input type="submit" id="submit-btn" value="Analyser" disabled>
                </form>
                
                <div style="margin-top: 20px;">
                    <button onclick="checkStatus()">Actualiser le statut</button>
                    <button onclick="location.reload()">Rafraîchir la page</button>
                </div>
            </div>
        </body>
    </html>
    """