MODLI commited on
Commit
b6eb828
·
verified ·
1 Parent(s): 9f55257

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -27
app.py CHANGED
@@ -52,10 +52,10 @@ async def startup_event():
52
  thread.daemon = True
53
  thread.start()
54
 
55
- # Catégories fashion simplifiées (moins de texte pour éviter les problèmes de padding)
56
  categories = [
57
  "t-shirt", "dress", "jeans", "shirt", "skirt", "sneakers",
58
- "handbag", "jacket", "shorts", "sweater", "coat", "high heels"
59
  ]
60
 
61
  @app.get("/")
@@ -83,32 +83,39 @@ async def analyze_image(file: UploadFile = File(...)):
83
  # Réduire la taille
84
  image.thumbnail((384, 384))
85
 
86
- # --- CORRECTION DU PADDING ---
87
- # Préparer les inputs correctement avec padding et truncation
88
- inputs = processor(
89
- text=categories,
90
- images=image,
91
- return_tensors="pt",
92
- padding=True, # ← PADDING ACTIVÉ
93
- truncation=True, # ← TRUNCATION ACTIVÉE
94
- max_length=77, # ← LONGUEUR MAXIMALE POUR CLIP
95
- return_overflowing_tokens=False
96
- )
97
-
98
- # Déplacer sur le même device que le modèle
99
- device = next(model.parameters()).device
100
- inputs = {k: v.to(device) for k, v in inputs.items()}
101
 
102
- with torch.no_grad():
103
- outputs = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # Récupérer les logits et calculer les probabilités
106
- logits_per_image = outputs.logits_per_image
107
- probs = torch.nn.functional.softmax(logits_per_image, dim=1)
108
 
109
- predicted_class_idx = probs.argmax(dim=1).item()
 
110
  category_name = categories[predicted_class_idx]
111
- confidence_score = probs[0][predicted_class_idx].item()
112
 
113
  # Analyse couleur
114
  try:
@@ -147,18 +154,27 @@ async def test_ui():
147
  .container { max-width: 600px; margin: 0 auto; }
148
  form { border: 2px dashed #ccc; padding: 30px; text-align: center; }
149
  input[type="file"] { margin: 10px 0; }
150
- input[type="submit"] { background: #007bff; color: white; padding: 10px 20px; border: none; cursor: pointer; }
 
 
 
 
151
  </style>
152
  </head>
153
  <body>
154
  <div class="container">
155
- <h1>🎨 Test Fashion Detection</h1>
156
  <form action="/analyze" method="post" enctype="multipart/form-data">
157
  <h3>Uploader une image de vêtement :</h3>
158
  <input type="file" name="file" accept="image/*" required>
159
- <br>
160
  <input type="submit" value="Analyser l'image 👗">
161
  </form>
 
 
 
 
 
162
  </div>
163
  </body>
164
  </html>
 
52
  thread.daemon = True
53
  thread.start()
54
 
55
+ # Catégories fashion (textes plus courts et uniformes)
56
  categories = [
57
  "t-shirt", "dress", "jeans", "shirt", "skirt", "sneakers",
58
+ "handbag", "jacket", "shorts", "sweater", "coat", "heels"
59
  ]
60
 
61
  @app.get("/")
 
83
  # Réduire la taille
84
  image.thumbnail((384, 384))
85
 
86
+ # --- SOLUTION DÉFINITIVE ---
87
+ # Traiter chaque catégorie SÉPARÉMENT pour éviter les problèmes de padding
88
+ similarities = []
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ for category in categories:
91
+ # Préparer les inputs pour UNE catégorie à la fois
92
+ inputs = processor(
93
+ text=[category], # Une seule catégorie
94
+ images=image,
95
+ return_tensors="pt",
96
+ padding=True, # Padding pour une seule phrase
97
+ truncation=True
98
+ )
99
+
100
+ # Déplacer sur le device du modèle
101
+ device = next(model.parameters()).device
102
+ inputs = {k: v.to(device) for k, v in inputs.items()}
103
+
104
+ with torch.no_grad():
105
+ outputs = model(**inputs)
106
+
107
+ # Récupérer le score de similarité
108
+ similarity_score = outputs.logits_per_image.item()
109
+ similarities.append(similarity_score)
110
 
111
+ # Convertir en tensor et calculer les probabilités
112
+ similarities_tensor = torch.tensor(similarities)
113
+ probs = torch.nn.functional.softmax(similarities_tensor, dim=0)
114
 
115
+ # Trouver la catégorie prédite
116
+ predicted_class_idx = probs.argmax().item()
117
  category_name = categories[predicted_class_idx]
118
+ confidence_score = probs[predicted_class_idx].item()
119
 
120
  # Analyse couleur
121
  try:
 
154
  .container { max-width: 600px; margin: 0 auto; }
155
  form { border: 2px dashed #ccc; padding: 30px; text-align: center; }
156
  input[type="file"] { margin: 10px 0; }
157
+ input[type="submit"] {
158
+ background: #007bff; color: white; padding: 10px 20px;
159
+ border: none; cursor: pointer; border-radius: 5px;
160
+ }
161
+ .result { margin-top: 20px; padding: 20px; background: #f0f8ff; }
162
  </style>
163
  </head>
164
  <body>
165
  <div class="container">
166
+ <h1>🎨 Fashion Detection AI</h1>
167
  <form action="/analyze" method="post" enctype="multipart/form-data">
168
  <h3>Uploader une image de vêtement :</h3>
169
  <input type="file" name="file" accept="image/*" required>
170
+ <br><br>
171
  <input type="submit" value="Analyser l'image 👗">
172
  </form>
173
+
174
+ <div class="result">
175
+ <h3>📋 Résultat de l'analyse :</h3>
176
+ <p>Attendez l'upload et le traitement de l'image...</p>
177
+ </div>
178
  </div>
179
  </body>
180
  </html>