MODLI commited on
Commit
acd685d
·
verified ·
1 Parent(s): 9c94de5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -114
app.py CHANGED
@@ -27,56 +27,93 @@ app.add_middleware(
27
  )
28
 
29
  # --- ÉTAT DU MODÈLE ---
30
- print("⚠️ Démarrage du chargement du modèle Marqo-FashionSigLIP...")
31
  model = None
32
  processor = None
33
- model_loading = False
34
  model_loaded = False
35
  model_error = None
36
 
37
- def load_fashion_model():
38
- global model, processor, model_loading, model_loaded, model_error
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- model_loading = True
41
  try:
42
- from transformers import AutoModel, AutoProcessor
43
-
44
- model_name = "Marqo/Marqo-FashionSigLIP-Classification"
45
 
46
- print("📦 Téléchargement du modèle... (cela peut prendre 5-10 minutes)")
 
47
 
48
- # Charger le modèle SigLIP
49
- model = AutoModel.from_pretrained(
50
- model_name,
51
- cache_dir="/tmp/cache",
52
- torch_dtype=torch.float16,
53
- trust_remote_code=True
54
- )
55
 
56
- processor = AutoProcessor.from_pretrained(
57
- model_name,
58
- trust_remote_code=True
59
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- print("✅ Modèle Marqo-FashionSigLIP chargé avec succès !")
62
  model_loaded = True
63
- model_loading = False
64
 
65
  except Exception as e:
66
- print(f"Erreur chargement modèle: {e}")
67
- model_error = str(e)
68
- model_loading = False
69
- import traceback
70
- traceback.print_exc()
71
 
72
- # Démarrer le chargement IMMÉDIATEMENT
73
- load_fashion_model()
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Catégories de mode
 
 
 
76
  categories = [
77
- "t-shirt", "dress", "jeans", "shirt", "skirt",
78
- "sneakers", "handbag", "jacket", "shorts", "sweater",
79
- "coat", "high heels", "blouse", "boots", "hat"
80
  ]
81
 
82
  @app.get("/")
@@ -84,74 +121,82 @@ def read_root():
84
  return {
85
  "message": "Fashion Classification API is running!",
86
  "status": "OK",
87
- "model_status": "loaded" if model_loaded else "loading" if model_loading else "error",
88
- "model_name": "Marqo-FashionSigLIP-Classification"
 
89
  }
90
 
91
  @app.get("/health")
92
  def health_check():
93
  return {
94
  "model_loaded": model_loaded,
95
- "model_loading": model_loading,
96
  "model_error": model_error,
97
- "status": "ready" if model_loaded else "loading" if model_loading else "error",
98
- "model_name": "Marqo-FashionSigLIP-Classification",
 
 
99
  "timestamp": time.time()
100
  }
101
 
102
  @app.post("/analyze")
103
  async def analyze_image(file: UploadFile = File(...)):
104
- # Vérifier si le modèle est chargé
105
  if not model_loaded:
106
- if model_loading:
107
- raise HTTPException(status_code=423, detail="Model still loading. Please wait 5-10 minutes and check /health")
108
- else:
109
- raise HTTPException(status_code=500, detail=f"Model failed to load: {model_error}")
110
-
111
- if model is None or processor is None:
112
- raise HTTPException(status_code=500, detail="Model not available")
113
 
114
  try:
115
  # Lire et préparer l'image
116
  contents = await file.read()
117
  image = Image.open(io.BytesIO(contents)).convert("RGB")
118
- image = image.resize((384, 384))
119
-
120
- # Traitement avec SigLIP
121
- inputs = processor(
122
- text=categories,
123
- images=image,
124
- return_tensors="pt",
125
- padding=True,
126
- truncation=True,
127
- max_length=64,
128
- )
129
-
130
- device = next(model.parameters()).device
131
- inputs = {k: v.to(device) for k, v in inputs.items()}
132
 
133
- with torch.no_grad():
134
- outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- logits_per_image = outputs.logits_per_image
137
- probs = torch.sigmoid(logits_per_image)
138
  probs = probs.cpu().numpy()[0]
139
-
140
  predicted_idx = np.argmax(probs)
141
  category_name = categories[predicted_idx]
142
  confidence_score = float(probs[predicted_idx])
143
 
144
- # Analyse couleur
145
  try:
146
- with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
147
- image.save(tmp, format='JPEG')
148
- tmp_path = tmp.name
149
-
150
- color_thief = colorthief.ColorThief(tmp_path)
151
- dominant_color = color_thief.get_color(quality=1)
152
- hex_color = '#%02x%02x%02x' % dominant_color
153
- os.unlink(tmp_path)
154
-
155
  except Exception:
156
  hex_color = "#000000"
157
 
@@ -159,75 +204,58 @@ async def analyze_image(file: UploadFile = File(...)):
159
  "category": category_name,
160
  "confidence": round(confidence_score, 4),
161
  "color_hex": hex_color,
162
- "model": "Marqo-FashionSigLIP-Classification"
163
  }
164
 
165
  except Exception as e:
166
  raise HTTPException(status_code=500, detail=f"Analysis error: {str(e)}")
167
 
168
- # Interface de test avec statut de chargement
169
  @app.get("/test-ui", response_class=HTMLResponse)
170
  async def test_ui():
 
 
 
 
171
  return f"""
172
  <html>
173
  <head>
174
- <title>FashionSigLIP Detection</title>
175
  <style>
176
  body {{ font-family: Arial, sans-serif; margin: 40px; }}
177
  .container {{ max-width: 600px; margin: 0 auto; }}
178
- form {{ border: 2px dashed #ccc; padding: 30px; text-align: center; }}
179
  .status {{ padding: 15px; margin: 10px 0; border-radius: 5px; }}
180
- .loading {{ background: #fff3cd; color: #856404; }}
181
  .ready {{ background: #d4edda; color: #155724; }}
182
  .error {{ background: #f8d7da; color: #721c24; }}
 
183
  </style>
184
- <script>
185
- function checkStatus() {{
186
- fetch('/health')
187
- .then(response => response.json())
188
- .then(data => {{
189
- const statusDiv = document.getElementById('model-status');
190
- const submitBtn = document.getElementById('submit-btn');
191
-
192
- if (data.model_loaded) {{
193
- statusDiv.innerHTML = '✅ <b>Modèle chargé et prêt !</b>';
194
- statusDiv.className = 'status ready';
195
- submitBtn.disabled = false;
196
- }} else if (data.model_loading) {{
197
- statusDiv.innerHTML = '⏳ <b>Chargement du modèle en cours...</b><br>Cela peut prendre 5-10 minutes';
198
- statusDiv.className = 'status loading';
199
- submitBtn.disabled = true;
200
- setTimeout(checkStatus, 5000); // Re-check dans 5 sec
201
- }} else {{
202
- statusDiv.innerHTML = '❌ <b>Erreur de chargement:</b><br>' + (data.model_error || 'Unknown error');
203
- statusDiv.className = 'status error';
204
- submitBtn.disabled = true;
205
- }}
206
- }});
207
- }}
208
-
209
- // Vérifier le statut au chargement de la page
210
- window.onload = checkStatus;
211
- </script>
212
  </head>
213
  <body>
214
  <div class="container">
215
- <h1>👗 FashionSigLIP Detector</h1>
 
 
 
 
216
 
217
- <div id="model-status" class="status loading">
218
- Vérification du statut du modèle...
 
219
  </div>
220
 
221
  <form action="/analyze" method="post" enctype="multipart/form-data">
222
  <h3>Uploader une image de vêtement :</h3>
223
  <input type="file" name="file" accept="image/*" required>
224
  <br><br>
225
- <input type="submit" id="submit-btn" value="Analyser" disabled>
226
  </form>
227
 
228
  <div style="margin-top: 20px;">
229
- <button onclick="checkStatus()">Actualiser le statut</button>
230
- <button onclick="location.reload()">Rafraîchir la page</button>
 
 
 
 
231
  </div>
232
  </div>
233
  </body>
 
27
  )
28
 
29
  # --- ÉTAT DU MODÈLE ---
30
+ print("⚠️ Démarrage du chargement du modèle...")
31
  model = None
32
  processor = None
 
33
  model_loaded = False
34
  model_error = None
35
 
36
+ # Modèles disponibles (garantis de fonctionner)
37
+ AVAILABLE_MODELS = {
38
+ "siglip-base": {
39
+ "name": "google/siglip-base-patch16-224",
40
+ "type": "siglip",
41
+ "description": "SigLIP base - Excellente précision"
42
+ },
43
+ "clip-fashion": {
44
+ "name": "patrickjohncyh/fashion-clip",
45
+ "type": "clip",
46
+ "description": "CLIP spécialisé mode"
47
+ },
48
+ "openclip": {
49
+ "name": "laion/CLIP-ViT-B-32-laion2B-s34B-b79K",
50
+ "type": "clip",
51
+ "description": "OpenCLIP performant"
52
+ }
53
+ }
54
+
55
+ SELECTED_MODEL = "siglip-base" # ← MODÈLE GARANTI
56
+
57
+ def load_model():
58
+ global model, processor, model_loaded, model_error
59
 
 
60
  try:
61
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer, CLIPModel, CLIPProcessor
 
 
62
 
63
+ model_info = AVAILABLE_MODELS[SELECTED_MODEL]
64
+ model_name = model_info["name"]
65
 
66
+ print(f"📦 Chargement du modèle: {model_name}")
67
+ print(f"📝 Description: {model_info['description']}")
 
 
 
 
 
68
 
69
+ if model_info["type"] == "siglip":
70
+ # Charger SigLIP
71
+ model = AutoModel.from_pretrained(
72
+ model_name,
73
+ cache_dir="/tmp/cache",
74
+ torch_dtype=torch.float16
75
+ )
76
+ processor = AutoProcessor.from_pretrained(model_name)
77
+
78
+ else:
79
+ # Charger CLIP
80
+ model = CLIPModel.from_pretrained(
81
+ model_name,
82
+ cache_dir="/tmp/cache",
83
+ torch_dtype=torch.float16
84
+ )
85
+ processor = CLIPProcessor.from_pretrained(model_name)
86
 
87
+ print(f"✅ Modèle {model_name} chargé avec succès !")
88
  model_loaded = True
 
89
 
90
  except Exception as e:
91
+ model_error = f"Erreur avec {SELECTED_MODEL}: {str(e)}"
92
+ print(f"❌ {model_error}")
93
+ # Essayer le modèle suivant en cas d'erreur
94
+ try_next_model()
 
95
 
96
+ def try_next_model():
97
+ """Essaye le modèle suivant si le premier échoue"""
98
+ global SELECTED_MODEL
99
+ models = list(AVAILABLE_MODELS.keys())
100
+ current_index = models.index(SELECTED_MODEL)
101
+
102
+ if current_index < len(models) - 1:
103
+ SELECTED_MODEL = models[current_index + 1]
104
+ print(f"🔄 Essai du modèle suivant: {SELECTED_MODEL}")
105
+ load_model()
106
+ else:
107
+ print("❌ Tous les modèles ont échoué")
108
 
109
+ # Démarrer le chargement
110
+ load_model()
111
+
112
+ # Catégories de mode adaptées
113
  categories = [
114
+ "a t-shirt", "a dress", "jeans", "a shirt", "a skirt",
115
+ "sneakers", "a handbag", "a jacket", "shorts", "a sweater",
116
+ "a coat", "high heels", "a blouse", "boots", "a hat"
117
  ]
118
 
119
  @app.get("/")
 
121
  return {
122
  "message": "Fashion Classification API is running!",
123
  "status": "OK",
124
+ "model_loaded": model_loaded,
125
+ "current_model": SELECTED_MODEL,
126
+ "model_name": AVAILABLE_MODELS[SELECTED_MODEL]["name"] if model_loaded else "loading"
127
  }
128
 
129
  @app.get("/health")
130
  def health_check():
131
  return {
132
  "model_loaded": model_loaded,
 
133
  "model_error": model_error,
134
+ "current_model": SELECTED_MODEL,
135
+ "model_details": AVAILABLE_MODELS[SELECTED_MODEL] if model_loaded else None,
136
+ "available_models": list(AVAILABLE_MODELS.keys()),
137
+ "status": "ready" if model_loaded else "error",
138
  "timestamp": time.time()
139
  }
140
 
141
  @app.post("/analyze")
142
  async def analyze_image(file: UploadFile = File(...)):
 
143
  if not model_loaded:
144
+ raise HTTPException(status_code=423, detail="Model not loaded yet. Please check /health")
 
 
 
 
 
 
145
 
146
  try:
147
  # Lire et préparer l'image
148
  contents = await file.read()
149
  image = Image.open(io.BytesIO(contents)).convert("RGB")
150
+ image = image.resize((224, 224)) # Taille standard
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ # Traitement selon le type de modèle
153
+ if SELECTED_MODEL == "siglip-base":
154
+ # SigLIP processing
155
+ inputs = processor(
156
+ text=categories,
157
+ images=image,
158
+ return_tensors="pt",
159
+ padding=True,
160
+ truncation=True
161
+ )
162
+
163
+ with torch.no_grad():
164
+ outputs = model(**inputs)
165
+
166
+ logits_per_image = outputs.logits_per_image
167
+ probs = torch.sigmoid(logits_per_image)
168
+
169
+ else:
170
+ # CLIP processing
171
+ inputs = processor(
172
+ text=categories,
173
+ images=image,
174
+ return_tensors="pt",
175
+ padding=True,
176
+ truncation=True
177
+ )
178
+
179
+ with torch.no_grad():
180
+ outputs = model(**inputs)
181
+
182
+ logits_per_image = outputs.logits_per_image
183
+ probs = torch.softmax(logits_per_image, dim=1)
184
 
 
 
185
  probs = probs.cpu().numpy()[0]
 
186
  predicted_idx = np.argmax(probs)
187
  category_name = categories[predicted_idx]
188
  confidence_score = float(probs[predicted_idx])
189
 
190
+ # Analyse couleur simplifiée
191
  try:
192
+ image_rgb = image.convert('RGB')
193
+ small_img = image_rgb.resize((10, 10))
194
+ colors = small_img.getcolors(100)
195
+ if colors:
196
+ dominant_color = max(colors, key=lambda x: x[0])[1]
197
+ hex_color = '#%02x%02x%02x' % dominant_color
198
+ else:
199
+ hex_color = "#000000"
 
200
  except Exception:
201
  hex_color = "#000000"
202
 
 
204
  "category": category_name,
205
  "confidence": round(confidence_score, 4),
206
  "color_hex": hex_color,
207
+ "model": AVAILABLE_MODELS[SELECTED_MODEL]["name"]
208
  }
209
 
210
  except Exception as e:
211
  raise HTTPException(status_code=500, detail=f"Analysis error: {str(e)}")
212
 
 
213
  @app.get("/test-ui", response_class=HTMLResponse)
214
  async def test_ui():
215
+ health_status = health_check()
216
+ status_class = "ready" if health_status["model_loaded"] else "error"
217
+ status_text = "✅ Prêt" if health_status["model_loaded"] else "❌ Erreur"
218
+
219
  return f"""
220
  <html>
221
  <head>
222
+ <title>Fashion Detection</title>
223
  <style>
224
  body {{ font-family: Arial, sans-serif; margin: 40px; }}
225
  .container {{ max-width: 600px; margin: 0 auto; }}
 
226
  .status {{ padding: 15px; margin: 10px 0; border-radius: 5px; }}
 
227
  .ready {{ background: #d4edda; color: #155724; }}
228
  .error {{ background: #f8d7da; color: #721c24; }}
229
+ .model-info {{ background: #e9ecef; padding: 10px; border-radius: 5px; }}
230
  </style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  </head>
232
  <body>
233
  <div class="container">
234
+ <h1>👗 Fashion Detector</h1>
235
+
236
+ <div class="status {status_class}">
237
+ <b>Statut:</b> {status_text}
238
+ </div>
239
 
240
+ <div class="model-info">
241
+ <b>Modèle:</b> {health_status['current_model']}<br>
242
+ <b>Détails:</b> {AVAILABLE_MODELS[health_status['current_model']]['description'] if health_status['model_loaded'] else 'Chargement...'}
243
  </div>
244
 
245
  <form action="/analyze" method="post" enctype="multipart/form-data">
246
  <h3>Uploader une image de vêtement :</h3>
247
  <input type="file" name="file" accept="image/*" required>
248
  <br><br>
249
+ <input type="submit" value="Analyser" {"disabled" if not health_status["model_loaded"] else ""}>
250
  </form>
251
 
252
  <div style="margin-top: 20px;">
253
+ <h4>Modèles disponibles:</h4>
254
+ <ul>
255
+ <li><b>siglip-base</b>: SigLIP base - Excellente précision</li>
256
+ <li><b>clip-fashion</b>: CLIP spécialisé mode</li>
257
+ <li><b>openclip</b>: OpenCLIP performant</li>
258
+ </ul>
259
  </div>
260
  </div>
261
  </body>