MODLI commited on
Commit
9c94de5
·
verified ·
1 Parent(s): b13493b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -116
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import os
2
  import json
 
3
  os.environ['HF_HOME'] = '/tmp/cache'
4
  os.environ['TORCH_HOME'] = '/tmp/cache'
5
 
6
- from fastapi import FastAPI, File, UploadFile
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from fastapi.responses import HTMLResponse
9
  from PIL import Image
@@ -25,23 +26,30 @@ app.add_middleware(
25
  expose_headers=["*"]
26
  )
27
 
28
- # --- CHARGE LE MODÈLE MARQO FASHIONSIGLIP ---
29
  print("⚠️ Démarrage du chargement du modèle Marqo-FashionSigLIP...")
30
  model = None
31
  processor = None
 
 
 
32
 
33
  def load_fashion_model():
34
- global model, processor
 
 
35
  try:
36
  from transformers import AutoModel, AutoProcessor
37
 
38
  model_name = "Marqo/Marqo-FashionSigLIP-Classification"
39
 
40
- # Charger le modèle SigLIP spécialisé fashion
 
 
41
  model = AutoModel.from_pretrained(
42
  model_name,
43
  cache_dir="/tmp/cache",
44
- torch_dtype=torch.float16, # Moins de mémoire
45
  trust_remote_code=True
46
  )
47
 
@@ -51,55 +59,65 @@ def load_fashion_model():
51
  )
52
 
53
  print("✅ Modèle Marqo-FashionSigLIP chargé avec succès !")
54
- print(f"📍 Modèle device: {next(model.parameters()).device}")
 
55
 
56
  except Exception as e:
57
  print(f"❌ Erreur chargement modèle: {e}")
 
 
58
  import traceback
59
  traceback.print_exc()
60
 
61
- # Catégories de mode pour SigLIP (adaptées au modèle)
 
 
 
62
  categories = [
63
  "t-shirt", "dress", "jeans", "shirt", "skirt",
64
  "sneakers", "handbag", "jacket", "shorts", "sweater",
65
  "coat", "high heels", "blouse", "boots", "hat"
66
  ]
67
 
68
- @app.on_event("startup")
69
- async def startup_event():
70
- import threading
71
- thread = threading.Thread(target=load_fashion_model)
72
- thread.daemon = True
73
- thread.start()
74
-
75
  @app.get("/")
76
  def read_root():
77
- return {"message": "Fashion Classification API is running!", "status": "OK"}
 
 
 
 
 
78
 
79
  @app.get("/health")
80
  def health_check():
81
  return {
82
- "model_loaded": model is not None,
83
- "processor_loaded": processor is not None,
84
- "status": "ready" if model and processor else "loading",
85
- "model_name": "Marqo-FashionSigLIP-Classification"
 
 
86
  }
87
 
88
  @app.post("/analyze")
89
  async def analyze_image(file: UploadFile = File(...)):
 
 
 
 
 
 
 
90
  if model is None or processor is None:
91
- return {"error": "Model not loaded yet. Please check /health endpoint."}
92
 
93
  try:
94
  # Lire et préparer l'image
95
  contents = await file.read()
96
  image = Image.open(io.BytesIO(contents)).convert("RGB")
97
-
98
- # Redimensionner pour SigLIP
99
  image = image.resize((384, 384))
100
 
101
- # --- TRAITEMENT AVEC SIGLIP ---
102
- # Préparer les inputs
103
  inputs = processor(
104
  text=categories,
105
  images=image,
@@ -107,30 +125,23 @@ async def analyze_image(file: UploadFile = File(...)):
107
  padding=True,
108
  truncation=True,
109
  max_length=64,
110
- return_overflowing_tokens=False
111
  )
112
 
113
- # Déplacer sur le device du modèle
114
  device = next(model.parameters()).device
115
  inputs = {k: v.to(device) for k, v in inputs.items()}
116
 
117
- # Inférence
118
  with torch.no_grad():
119
  outputs = model(**inputs)
120
 
121
- # SigLIP utilise des logits différents
122
  logits_per_image = outputs.logits_per_image
123
-
124
- # Convertir en probabilités
125
- probs = torch.sigmoid(logits_per_image) # SigLIP utilise sigmoid, pas softmax!
126
  probs = probs.cpu().numpy()[0]
127
 
128
- # Trouver la meilleure catégorie
129
  predicted_idx = np.argmax(probs)
130
  category_name = categories[predicted_idx]
131
  confidence_score = float(probs[predicted_idx])
132
 
133
- # --- ANALYSE COULEUR ---
134
  try:
135
  with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
136
  image.save(tmp, format='JPEG')
@@ -139,116 +150,84 @@ async def analyze_image(file: UploadFile = File(...)):
139
  color_thief = colorthief.ColorThief(tmp_path)
140
  dominant_color = color_thief.get_color(quality=1)
141
  hex_color = '#%02x%02x%02x' % dominant_color
142
-
143
  os.unlink(tmp_path)
144
 
145
- except Exception as color_error:
146
- print(f"⚠️ Erreur analyse couleur: {color_error}")
147
  hex_color = "#000000"
148
 
149
- # --- RÉSULTATS DÉTAILLÉS ---
150
- top_categories = []
151
- for i, (cat, prob) in enumerate(zip(categories, probs)):
152
- if prob > 0.1: # Seuil minimal
153
- top_categories.append({
154
- "category": cat,
155
- "score": round(float(prob), 4)
156
- })
157
-
158
- # Trier par score décroissant
159
- top_categories.sort(key=lambda x: x["score"], reverse=True)
160
- top_5 = top_categories[:5]
161
-
162
  return {
163
- "top_prediction": {
164
- "category": category_name,
165
- "confidence": round(confidence_score, 4),
166
- "color_hex": hex_color
167
- },
168
- "top_categories": top_5,
169
  "model": "Marqo-FashionSigLIP-Classification"
170
  }
171
 
172
  except Exception as e:
173
- return {"error": f"Erreur lors de l'analyse: {str(e)}"}
174
 
175
- # Interface de test
176
  @app.get("/test-ui", response_class=HTMLResponse)
177
  async def test_ui():
178
- return """
179
  <html>
180
  <head>
181
  <title>FashionSigLIP Detection</title>
182
  <style>
183
- body {
184
- font-family: Arial, sans-serif;
185
- margin: 40px;
186
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
187
- color: white;
188
- }
189
- .container {
190
- max-width: 600px;
191
- margin: 0 auto;
192
- background: rgba(255, 255, 255, 0.1);
193
- padding: 30px;
194
- border-radius: 15px;
195
- backdrop-filter: blur(10px);
196
- }
197
- form {
198
- border: 2px dashed rgba(255, 255, 255, 0.3);
199
- padding: 30px;
200
- text-align: center;
201
- margin-bottom: 20px;
202
- }
203
- input[type="file"] {
204
- margin: 15px 0;
205
- padding: 10px;
206
- background: rgba(255, 255, 255, 0.2);
207
- border: none;
208
- border-radius: 5px;
209
- color: white;
210
- }
211
- input[type="submit"] {
212
- background: #ff6b6b;
213
- color: white;
214
- padding: 12px 25px;
215
- border: none;
216
- cursor: pointer;
217
- border-radius: 25px;
218
- font-weight: bold;
219
- transition: background 0.3s;
220
- }
221
- input[type="submit"]:hover {
222
- background: #ee5a52;
223
- }
224
- .result {
225
- margin-top: 20px;
226
- padding: 20px;
227
- background: rgba(255, 255, 255, 0.1);
228
- border-radius: 10px;
229
- }
230
  </style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  </head>
232
  <body>
233
  <div class="container">
234
  <h1>👗 FashionSigLIP Detector</h1>
235
- <p>Powered by Marqo/Marqo-FashionSigLIP-Classification</p>
 
 
 
236
 
237
  <form action="/analyze" method="post" enctype="multipart/form-data">
238
  <h3>Uploader une image de vêtement :</h3>
239
  <input type="file" name="file" accept="image/*" required>
240
- <br>
241
- <input type="submit" value="Analyser la mode 🎯">
242
  </form>
243
 
244
- <div class="result">
245
- <h3>📊 Résultats :</h3>
246
- <p>Les résultats apparaîtront ici après analyse...</p>
247
- </div>
248
-
249
- <div style="margin-top: 20px; font-size: 12px; opacity: 0.7;">
250
- <p>Modèle : Marqo-FashionSigLIP-Classification</p>
251
- <p>Spécialisé dans la classification de vêtements</p>
252
  </div>
253
  </div>
254
  </body>
 
1
  import os
2
  import json
3
+ import time
4
  os.environ['HF_HOME'] = '/tmp/cache'
5
  os.environ['TORCH_HOME'] = '/tmp/cache'
6
 
7
+ from fastapi import FastAPI, File, UploadFile, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from fastapi.responses import HTMLResponse
10
  from PIL import Image
 
26
  expose_headers=["*"]
27
  )
28
 
29
+ # --- ÉTAT DU MODÈLE ---
30
  print("⚠️ Démarrage du chargement du modèle Marqo-FashionSigLIP...")
31
  model = None
32
  processor = None
33
+ model_loading = False
34
+ model_loaded = False
35
+ model_error = None
36
 
37
  def load_fashion_model():
38
+ global model, processor, model_loading, model_loaded, model_error
39
+
40
+ model_loading = True
41
  try:
42
  from transformers import AutoModel, AutoProcessor
43
 
44
  model_name = "Marqo/Marqo-FashionSigLIP-Classification"
45
 
46
+ print("📦 Téléchargement du modèle... (cela peut prendre 5-10 minutes)")
47
+
48
+ # Charger le modèle SigLIP
49
  model = AutoModel.from_pretrained(
50
  model_name,
51
  cache_dir="/tmp/cache",
52
+ torch_dtype=torch.float16,
53
  trust_remote_code=True
54
  )
55
 
 
59
  )
60
 
61
  print("✅ Modèle Marqo-FashionSigLIP chargé avec succès !")
62
+ model_loaded = True
63
+ model_loading = False
64
 
65
  except Exception as e:
66
  print(f"❌ Erreur chargement modèle: {e}")
67
+ model_error = str(e)
68
+ model_loading = False
69
  import traceback
70
  traceback.print_exc()
71
 
72
+ # Démarrer le chargement IMMÉDIATEMENT
73
+ load_fashion_model()
74
+
75
+ # Catégories de mode
76
  categories = [
77
  "t-shirt", "dress", "jeans", "shirt", "skirt",
78
  "sneakers", "handbag", "jacket", "shorts", "sweater",
79
  "coat", "high heels", "blouse", "boots", "hat"
80
  ]
81
 
 
 
 
 
 
 
 
82
  @app.get("/")
83
  def read_root():
84
+ return {
85
+ "message": "Fashion Classification API is running!",
86
+ "status": "OK",
87
+ "model_status": "loaded" if model_loaded else "loading" if model_loading else "error",
88
+ "model_name": "Marqo-FashionSigLIP-Classification"
89
+ }
90
 
91
  @app.get("/health")
92
  def health_check():
93
  return {
94
+ "model_loaded": model_loaded,
95
+ "model_loading": model_loading,
96
+ "model_error": model_error,
97
+ "status": "ready" if model_loaded else "loading" if model_loading else "error",
98
+ "model_name": "Marqo-FashionSigLIP-Classification",
99
+ "timestamp": time.time()
100
  }
101
 
102
  @app.post("/analyze")
103
  async def analyze_image(file: UploadFile = File(...)):
104
+ # Vérifier si le modèle est chargé
105
+ if not model_loaded:
106
+ if model_loading:
107
+ raise HTTPException(status_code=423, detail="Model still loading. Please wait 5-10 minutes and check /health")
108
+ else:
109
+ raise HTTPException(status_code=500, detail=f"Model failed to load: {model_error}")
110
+
111
  if model is None or processor is None:
112
+ raise HTTPException(status_code=500, detail="Model not available")
113
 
114
  try:
115
  # Lire et préparer l'image
116
  contents = await file.read()
117
  image = Image.open(io.BytesIO(contents)).convert("RGB")
 
 
118
  image = image.resize((384, 384))
119
 
120
+ # Traitement avec SigLIP
 
121
  inputs = processor(
122
  text=categories,
123
  images=image,
 
125
  padding=True,
126
  truncation=True,
127
  max_length=64,
 
128
  )
129
 
 
130
  device = next(model.parameters()).device
131
  inputs = {k: v.to(device) for k, v in inputs.items()}
132
 
 
133
  with torch.no_grad():
134
  outputs = model(**inputs)
135
 
 
136
  logits_per_image = outputs.logits_per_image
137
+ probs = torch.sigmoid(logits_per_image)
 
 
138
  probs = probs.cpu().numpy()[0]
139
 
 
140
  predicted_idx = np.argmax(probs)
141
  category_name = categories[predicted_idx]
142
  confidence_score = float(probs[predicted_idx])
143
 
144
+ # Analyse couleur
145
  try:
146
  with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
147
  image.save(tmp, format='JPEG')
 
150
  color_thief = colorthief.ColorThief(tmp_path)
151
  dominant_color = color_thief.get_color(quality=1)
152
  hex_color = '#%02x%02x%02x' % dominant_color
 
153
  os.unlink(tmp_path)
154
 
155
+ except Exception:
 
156
  hex_color = "#000000"
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  return {
159
+ "category": category_name,
160
+ "confidence": round(confidence_score, 4),
161
+ "color_hex": hex_color,
 
 
 
162
  "model": "Marqo-FashionSigLIP-Classification"
163
  }
164
 
165
  except Exception as e:
166
+ raise HTTPException(status_code=500, detail=f"Analysis error: {str(e)}")
167
 
168
+ # Interface de test avec statut de chargement
169
  @app.get("/test-ui", response_class=HTMLResponse)
170
  async def test_ui():
171
+ return f"""
172
  <html>
173
  <head>
174
  <title>FashionSigLIP Detection</title>
175
  <style>
176
+ body {{ font-family: Arial, sans-serif; margin: 40px; }}
177
+ .container {{ max-width: 600px; margin: 0 auto; }}
178
+ form {{ border: 2px dashed #ccc; padding: 30px; text-align: center; }}
179
+ .status {{ padding: 15px; margin: 10px 0; border-radius: 5px; }}
180
+ .loading {{ background: #fff3cd; color: #856404; }}
181
+ .ready {{ background: #d4edda; color: #155724; }}
182
+ .error {{ background: #f8d7da; color: #721c24; }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  </style>
184
+ <script>
185
+ function checkStatus() {{
186
+ fetch('/health')
187
+ .then(response => response.json())
188
+ .then(data => {{
189
+ const statusDiv = document.getElementById('model-status');
190
+ const submitBtn = document.getElementById('submit-btn');
191
+
192
+ if (data.model_loaded) {{
193
+ statusDiv.innerHTML = '✅ <b>Modèle chargé et prêt !</b>';
194
+ statusDiv.className = 'status ready';
195
+ submitBtn.disabled = false;
196
+ }} else if (data.model_loading) {{
197
+ statusDiv.innerHTML = '⏳ <b>Chargement du modèle en cours...</b><br>Cela peut prendre 5-10 minutes';
198
+ statusDiv.className = 'status loading';
199
+ submitBtn.disabled = true;
200
+ setTimeout(checkStatus, 5000); // Re-check dans 5 sec
201
+ }} else {{
202
+ statusDiv.innerHTML = '❌ <b>Erreur de chargement:</b><br>' + (data.model_error || 'Unknown error');
203
+ statusDiv.className = 'status error';
204
+ submitBtn.disabled = true;
205
+ }}
206
+ }});
207
+ }}
208
+
209
+ // Vérifier le statut au chargement de la page
210
+ window.onload = checkStatus;
211
+ </script>
212
  </head>
213
  <body>
214
  <div class="container">
215
  <h1>👗 FashionSigLIP Detector</h1>
216
+
217
+ <div id="model-status" class="status loading">
218
+ Vérification du statut du modèle...
219
+ </div>
220
 
221
  <form action="/analyze" method="post" enctype="multipart/form-data">
222
  <h3>Uploader une image de vêtement :</h3>
223
  <input type="file" name="file" accept="image/*" required>
224
+ <br><br>
225
+ <input type="submit" id="submit-btn" value="Analyser" disabled>
226
  </form>
227
 
228
+ <div style="margin-top: 20px;">
229
+ <button onclick="checkStatus()">Actualiser le statut</button>
230
+ <button onclick="location.reload()">Rafraîchir la page</button>
 
 
 
 
 
231
  </div>
232
  </div>
233
  </body>