MODLI commited on
Commit
9f55257
·
verified ·
1 Parent(s): e3527ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -32,7 +32,6 @@ processor = None
32
  def load_marqo_model():
33
  global model, processor
34
  try:
35
- # Import différé pour éviter les problèmes de compatibilité
36
  from transformers import CLIPProcessor, CLIPModel
37
 
38
  model_name = "Marqo/marqo-fashionCLIP"
@@ -53,10 +52,10 @@ async def startup_event():
53
  thread.daemon = True
54
  thread.start()
55
 
56
- # Catégories fashion simplifiées
57
  categories = [
58
- "a t-shirt", "a dress", "jeans", "a shirt", "a skirt", "sneakers",
59
- "a handbag", "a jacket", "shorts", "a sweater", "a coat", "high heels"
60
  ]
61
 
62
  @app.get("/")
@@ -84,43 +83,48 @@ async def analyze_image(file: UploadFile = File(...)):
84
  # Réduire la taille
85
  image.thumbnail((384, 384))
86
 
87
- # Analyse avec Marqo fashionCLIP
 
88
  inputs = processor(
89
- text=categories,
90
- images=image,
91
- return_tensors="pt",
92
- padding=True,
93
- truncation=True
 
 
94
  )
95
 
 
 
 
 
96
  with torch.no_grad():
97
  outputs = model(**inputs)
98
 
 
99
  logits_per_image = outputs.logits_per_image
100
- probs = logits_per_image.softmax(dim=1)
101
 
102
  predicted_class_idx = probs.argmax(dim=1).item()
103
  category_name = categories[predicted_class_idx]
104
  confidence_score = probs[0][predicted_class_idx].item()
105
 
106
- # --- CORRECTION DE L'ANALYSE COULEUR ---
107
  try:
108
- # Sauvegarder l'image dans un fichier temporaire pour ColorThief
109
  with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
110
  image.save(tmp, format='JPEG')
111
  tmp_path = tmp.name
112
 
113
- # Utiliser ColorThief avec le fichier temporaire
114
  color_thief = colorthief.ColorThief(tmp_path)
115
  dominant_color = color_thief.get_color(quality=1)
116
  hex_color = '#%02x%02x%02x' % dominant_color
117
 
118
- # Nettoyer le fichier temporaire
119
  os.unlink(tmp_path)
120
 
121
  except Exception as color_error:
122
  print(f"Erreur analyse couleur: {color_error}")
123
- hex_color = "#000000" # Couleur par défaut
124
 
125
  return {
126
  "category": category_name,
@@ -131,7 +135,7 @@ async def analyze_image(file: UploadFile = File(...)):
131
  except Exception as e:
132
  return {"error": f"Erreur lors de l'analyse: {str(e)}"}
133
 
134
- # Interface simple pour tester
135
  @app.get("/test-ui", response_class=HTMLResponse)
136
  async def test_ui():
137
  return """
@@ -155,13 +159,6 @@ async def test_ui():
155
  <br>
156
  <input type="submit" value="Analyser l'image 👗">
157
  </form>
158
-
159
- <div style="margin-top: 30px; padding: 20px; background: #f8f9fa;">
160
- <h3>📝 Instructions :</h3>
161
- <p>• Uploader une image claire d'un vêtement</p>
162
- <p>• Formats supportés : JPG, PNG, WebP</p>
163
- <p>• Taille recommandée : moins de 2MB</p>
164
- </div>
165
  </div>
166
  </body>
167
  </html>
 
32
  def load_marqo_model():
33
  global model, processor
34
  try:
 
35
  from transformers import CLIPProcessor, CLIPModel
36
 
37
  model_name = "Marqo/marqo-fashionCLIP"
 
52
  thread.daemon = True
53
  thread.start()
54
 
55
+ # Catégories fashion simplifiées (moins de texte pour éviter les problèmes de padding)
56
  categories = [
57
+ "t-shirt", "dress", "jeans", "shirt", "skirt", "sneakers",
58
+ "handbag", "jacket", "shorts", "sweater", "coat", "high heels"
59
  ]
60
 
61
  @app.get("/")
 
83
  # Réduire la taille
84
  image.thumbnail((384, 384))
85
 
86
+ # --- CORRECTION DU PADDING ---
87
+ # Préparer les inputs correctement avec padding et truncation
88
  inputs = processor(
89
+ text=categories,
90
+ images=image,
91
+ return_tensors="pt",
92
+ padding=True, # ← PADDING ACTIVÉ
93
+ truncation=True, # ← TRUNCATION ACTIVÉE
94
+ max_length=77, # ← LONGUEUR MAXIMALE POUR CLIP
95
+ return_overflowing_tokens=False
96
  )
97
 
98
+ # Déplacer sur le même device que le modèle
99
+ device = next(model.parameters()).device
100
+ inputs = {k: v.to(device) for k, v in inputs.items()}
101
+
102
  with torch.no_grad():
103
  outputs = model(**inputs)
104
 
105
+ # Récupérer les logits et calculer les probabilités
106
  logits_per_image = outputs.logits_per_image
107
+ probs = torch.nn.functional.softmax(logits_per_image, dim=1)
108
 
109
  predicted_class_idx = probs.argmax(dim=1).item()
110
  category_name = categories[predicted_class_idx]
111
  confidence_score = probs[0][predicted_class_idx].item()
112
 
113
+ # Analyse couleur
114
  try:
 
115
  with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
116
  image.save(tmp, format='JPEG')
117
  tmp_path = tmp.name
118
 
 
119
  color_thief = colorthief.ColorThief(tmp_path)
120
  dominant_color = color_thief.get_color(quality=1)
121
  hex_color = '#%02x%02x%02x' % dominant_color
122
 
 
123
  os.unlink(tmp_path)
124
 
125
  except Exception as color_error:
126
  print(f"Erreur analyse couleur: {color_error}")
127
+ hex_color = "#000000"
128
 
129
  return {
130
  "category": category_name,
 
135
  except Exception as e:
136
  return {"error": f"Erreur lors de l'analyse: {str(e)}"}
137
 
138
+ # Interface de test
139
  @app.get("/test-ui", response_class=HTMLResponse)
140
  async def test_ui():
141
  return """
 
159
  <br>
160
  <input type="submit" value="Analyser l'image 👗">
161
  </form>
 
 
 
 
 
 
 
162
  </div>
163
  </body>
164
  </html>