Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -32,7 +32,6 @@ processor = None
|
|
| 32 |
def load_marqo_model():
|
| 33 |
global model, processor
|
| 34 |
try:
|
| 35 |
-
# Import différé pour éviter les problèmes de compatibilité
|
| 36 |
from transformers import CLIPProcessor, CLIPModel
|
| 37 |
|
| 38 |
model_name = "Marqo/marqo-fashionCLIP"
|
|
@@ -53,10 +52,10 @@ async def startup_event():
|
|
| 53 |
thread.daemon = True
|
| 54 |
thread.start()
|
| 55 |
|
| 56 |
-
# Catégories fashion simplifiées
|
| 57 |
categories = [
|
| 58 |
-
"
|
| 59 |
-
"
|
| 60 |
]
|
| 61 |
|
| 62 |
@app.get("/")
|
|
@@ -84,43 +83,48 @@ async def analyze_image(file: UploadFile = File(...)):
|
|
| 84 |
# Réduire la taille
|
| 85 |
image.thumbnail((384, 384))
|
| 86 |
|
| 87 |
-
#
|
|
|
|
| 88 |
inputs = processor(
|
| 89 |
-
text=categories,
|
| 90 |
-
images=image,
|
| 91 |
-
return_tensors="pt",
|
| 92 |
-
padding=True,
|
| 93 |
-
truncation=True
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
with torch.no_grad():
|
| 97 |
outputs = model(**inputs)
|
| 98 |
|
|
|
|
| 99 |
logits_per_image = outputs.logits_per_image
|
| 100 |
-
probs =
|
| 101 |
|
| 102 |
predicted_class_idx = probs.argmax(dim=1).item()
|
| 103 |
category_name = categories[predicted_class_idx]
|
| 104 |
confidence_score = probs[0][predicted_class_idx].item()
|
| 105 |
|
| 106 |
-
#
|
| 107 |
try:
|
| 108 |
-
# Sauvegarder l'image dans un fichier temporaire pour ColorThief
|
| 109 |
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
|
| 110 |
image.save(tmp, format='JPEG')
|
| 111 |
tmp_path = tmp.name
|
| 112 |
|
| 113 |
-
# Utiliser ColorThief avec le fichier temporaire
|
| 114 |
color_thief = colorthief.ColorThief(tmp_path)
|
| 115 |
dominant_color = color_thief.get_color(quality=1)
|
| 116 |
hex_color = '#%02x%02x%02x' % dominant_color
|
| 117 |
|
| 118 |
-
# Nettoyer le fichier temporaire
|
| 119 |
os.unlink(tmp_path)
|
| 120 |
|
| 121 |
except Exception as color_error:
|
| 122 |
print(f"Erreur analyse couleur: {color_error}")
|
| 123 |
-
hex_color = "#000000"
|
| 124 |
|
| 125 |
return {
|
| 126 |
"category": category_name,
|
|
@@ -131,7 +135,7 @@ async def analyze_image(file: UploadFile = File(...)):
|
|
| 131 |
except Exception as e:
|
| 132 |
return {"error": f"Erreur lors de l'analyse: {str(e)}"}
|
| 133 |
|
| 134 |
-
# Interface
|
| 135 |
@app.get("/test-ui", response_class=HTMLResponse)
|
| 136 |
async def test_ui():
|
| 137 |
return """
|
|
@@ -155,13 +159,6 @@ async def test_ui():
|
|
| 155 |
<br>
|
| 156 |
<input type="submit" value="Analyser l'image 👗">
|
| 157 |
</form>
|
| 158 |
-
|
| 159 |
-
<div style="margin-top: 30px; padding: 20px; background: #f8f9fa;">
|
| 160 |
-
<h3>📝 Instructions :</h3>
|
| 161 |
-
<p>• Uploader une image claire d'un vêtement</p>
|
| 162 |
-
<p>• Formats supportés : JPG, PNG, WebP</p>
|
| 163 |
-
<p>• Taille recommandée : moins de 2MB</p>
|
| 164 |
-
</div>
|
| 165 |
</div>
|
| 166 |
</body>
|
| 167 |
</html>
|
|
|
|
| 32 |
def load_marqo_model():
|
| 33 |
global model, processor
|
| 34 |
try:
|
|
|
|
| 35 |
from transformers import CLIPProcessor, CLIPModel
|
| 36 |
|
| 37 |
model_name = "Marqo/marqo-fashionCLIP"
|
|
|
|
| 52 |
thread.daemon = True
|
| 53 |
thread.start()
|
| 54 |
|
| 55 |
+
# Catégories fashion simplifiées (moins de texte pour éviter les problèmes de padding)
|
| 56 |
categories = [
|
| 57 |
+
"t-shirt", "dress", "jeans", "shirt", "skirt", "sneakers",
|
| 58 |
+
"handbag", "jacket", "shorts", "sweater", "coat", "high heels"
|
| 59 |
]
|
| 60 |
|
| 61 |
@app.get("/")
|
|
|
|
| 83 |
# Réduire la taille
|
| 84 |
image.thumbnail((384, 384))
|
| 85 |
|
| 86 |
+
# --- CORRECTION DU PADDING ---
|
| 87 |
+
# Préparer les inputs correctement avec padding et truncation
|
| 88 |
inputs = processor(
|
| 89 |
+
text=categories,
|
| 90 |
+
images=image,
|
| 91 |
+
return_tensors="pt",
|
| 92 |
+
padding=True, # ← PADDING ACTIVÉ
|
| 93 |
+
truncation=True, # ← TRUNCATION ACTIVÉE
|
| 94 |
+
max_length=77, # ← LONGUEUR MAXIMALE POUR CLIP
|
| 95 |
+
return_overflowing_tokens=False
|
| 96 |
)
|
| 97 |
|
| 98 |
+
# Déplacer sur le même device que le modèle
|
| 99 |
+
device = next(model.parameters()).device
|
| 100 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 101 |
+
|
| 102 |
with torch.no_grad():
|
| 103 |
outputs = model(**inputs)
|
| 104 |
|
| 105 |
+
# Récupérer les logits et calculer les probabilités
|
| 106 |
logits_per_image = outputs.logits_per_image
|
| 107 |
+
probs = torch.nn.functional.softmax(logits_per_image, dim=1)
|
| 108 |
|
| 109 |
predicted_class_idx = probs.argmax(dim=1).item()
|
| 110 |
category_name = categories[predicted_class_idx]
|
| 111 |
confidence_score = probs[0][predicted_class_idx].item()
|
| 112 |
|
| 113 |
+
# Analyse couleur
|
| 114 |
try:
|
|
|
|
| 115 |
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmp:
|
| 116 |
image.save(tmp, format='JPEG')
|
| 117 |
tmp_path = tmp.name
|
| 118 |
|
|
|
|
| 119 |
color_thief = colorthief.ColorThief(tmp_path)
|
| 120 |
dominant_color = color_thief.get_color(quality=1)
|
| 121 |
hex_color = '#%02x%02x%02x' % dominant_color
|
| 122 |
|
|
|
|
| 123 |
os.unlink(tmp_path)
|
| 124 |
|
| 125 |
except Exception as color_error:
|
| 126 |
print(f"Erreur analyse couleur: {color_error}")
|
| 127 |
+
hex_color = "#000000"
|
| 128 |
|
| 129 |
return {
|
| 130 |
"category": category_name,
|
|
|
|
| 135 |
except Exception as e:
|
| 136 |
return {"error": f"Erreur lors de l'analyse: {str(e)}"}
|
| 137 |
|
| 138 |
+
# Interface de test
|
| 139 |
@app.get("/test-ui", response_class=HTMLResponse)
|
| 140 |
async def test_ui():
|
| 141 |
return """
|
|
|
|
| 159 |
<br>
|
| 160 |
<input type="submit" value="Analyser l'image 👗">
|
| 161 |
</form>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
</div>
|
| 163 |
</body>
|
| 164 |
</html>
|