MODLI commited on
Commit
c3b9201
·
verified ·
1 Parent(s): 7404709

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -38
app.py CHANGED
@@ -10,25 +10,53 @@ from fastapi import FastAPI, HTTPException
10
  from pydantic import BaseModel
11
  from typing import Optional
12
  import uvicorn
 
 
 
 
13
 
14
- # Catégories optimisées
15
  FASHION_CATEGORIES = [
16
- "t-shirt", "dress", "pants", "jacket", "skirt",
17
- "shoes", "bag", "swimwear", "lingerie", "sweater",
18
- "jeans", "coat", "shorts", "blouse", "hat", "top",
19
- "jogging pants", "dress pants", "leggings", "boots",
20
- "sandals", "sneakers", "backpack", "glasses"
 
 
 
 
 
21
  ]
22
 
23
  print("🔧 Loading fashion model...")
24
 
25
- # Modèle principal
26
- fashion_pipe = pipeline(
27
- "zero-shot-image-classification",
28
- model="openai/clip-vit-base-patch32"
29
- )
30
-
31
- print("✅ Model loaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  # Configuration API
34
  API_KEYS = os.environ.get("API_KEYS", "").split(",")
@@ -38,13 +66,30 @@ class ClassificationRequest(BaseModel):
38
  image_data: str
39
  api_key: Optional[str] = None
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def load_image_from_url(url):
42
  """Charge une image depuis une URL de manière robuste"""
43
  try:
44
  headers = {
45
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
46
  }
47
- response = requests.get(url, headers=headers, timeout=10)
48
  response.raise_for_status()
49
 
50
  # Vérifie que c'est bien une image
@@ -52,7 +97,7 @@ def load_image_from_url(url):
52
  raise ValueError("URL does not point to an image")
53
 
54
  image = Image.open(BytesIO(response.content))
55
- return image.convert('RGB')
56
 
57
  except Exception as e:
58
  raise ValueError(f"❌ Cannot load image from URL: {str(e)}")
@@ -67,46 +112,59 @@ def analyze_fashion_item(image_input, url_input):
67
  image = Image.fromarray(image_input)
68
  else:
69
  image = image_input
 
70
  elif url_input and url_input.strip():
71
  # Utilise l'URL
72
  image = load_image_from_url(url_input.strip())
73
  else:
74
  return "❌ Please upload an image or enter a URL first", None
75
 
76
- # Redimensionnement intelligent
77
- width, height = image.size
78
- if max(width, height) > 1024:
79
- ratio = 1024 / max(width, height)
80
- new_size = (int(width * ratio), int(height * ratio))
81
- image = image.resize(new_size, Image.Resampling.LANCZOS)
82
-
83
- # 🔥 ANALYSE PRINCIPALE
84
- predictions = fashion_pipe(
85
- image,
86
- candidate_labels=FASHION_CATEGORIES,
87
- hypothesis_template="a clear photo of {}",
88
- multi_label=False
89
- )
90
-
91
- # Filtrage des résultats
92
- confident_predictions = [p for p in predictions if p['score'] > 0.1]
 
 
 
 
 
 
93
 
94
  if not confident_predictions:
95
- return "❌ No confident prediction. Try a clearer image.", image
 
 
 
96
 
97
  best_pred = confident_predictions[0]
98
 
99
  # Formatage des résultats
100
- result_text = f"🎯 **Main item**: {best_pred['label']}\n"
101
  result_text += f"**Confidence**: {best_pred['score']*100:.1f}%\n\n"
102
 
103
  if len(confident_predictions) > 1:
104
  result_text += "**Other possibilities**:\n"
105
- for i, pred in enumerate(confident_predictions[1:4], 1):
106
- result_text += f"{i}. {pred['label']} ({pred['score']*100:.1f}%)\n"
107
 
108
- result_text += f"\n💡 **Tip**: This appears to be {best_pred['label']}. "
109
- result_text += "Make sure the item is well-lit and centered."
 
 
 
110
 
111
  return result_text, image
112
 
@@ -128,6 +186,9 @@ with gr.Blocks(
128
  .header { text-align: center; margin-bottom: 30px; }
129
  .input-section { background: #f8f9fa; padding: 20px; border-radius: 10px; }
130
  .output-section { background: white; padding: 20px; border-radius: 10px; }
 
 
 
131
  """
132
  ) as demo:
133
 
@@ -153,6 +214,14 @@ with gr.Blocks(
153
  lines=2
154
  )
155
 
 
 
 
 
 
 
 
 
156
  analyze_btn = gr.Button(
157
  "🔍 Analyze Item",
158
  variant="primary",
@@ -178,6 +247,7 @@ with gr.Blocks(
178
  - Make sure the clothing item is clearly visible
179
  - Well-lit images work best
180
  - Avoid busy backgrounds
 
181
  """)
182
 
183
  # Événement de click
@@ -210,6 +280,7 @@ async def api_classify(request: ClassificationRequest):
210
 
211
  image_bytes = base64.b64decode(request.image_data)
212
  image = Image.open(BytesIO(image_bytes))
 
213
 
214
  # Analyse avec des inputs vides pour URL
215
  result_text, processed_image = analyze_fashion_item(image, "")
 
10
  from pydantic import BaseModel
11
  from typing import Optional
12
  import uvicorn
13
+ import torch
14
+ import torchvision.transforms as transforms
15
+ from torchvision.models import resnet50
16
+ import torch.nn as nn
17
 
18
+ # Catégories fashion plus détaillées et précises
19
  FASHION_CATEGORIES = [
20
+ "t-shirt", "dress", "jeans", "jacket", "skirt",
21
+ "sneakers", "handbag", "swimsuit", "lingerie", "sweater",
22
+ "coat", "shorts", "blouse", "hat", "top",
23
+ "sweatpants", "dress pants", "leggings", "boots",
24
+ "sandals", "heels", "backpack", "sunglasses", "blazer",
25
+ "cardigan", "polo shirt", "hoodie", "vest", "jumpsuit",
26
+ "romper", "crop top", "tank top", "long sleeve shirt",
27
+ "windbreaker", "parka", "trench coat", "leather jacket",
28
+ "denim jacket", "waistcoat", "suit", "tie", "scarf",
29
+ "gloves", "belt", "wallet", "watch", "jewelry"
30
  ]
31
 
32
  print("🔧 Loading fashion model...")
33
 
34
+ # Charger un modèle plus spécialisé pour la mode
35
+ try:
36
+ # Essayer d'abord un modèle spécialisé fashion
37
+ fashion_pipe = pipeline(
38
+ "image-classification",
39
+ model="nateraw/fashion-clip",
40
+ device=0 if torch.cuda.is_available() else -1
41
+ )
42
+ print("✅ Fashion-CLIP model loaded successfully!")
43
+ except:
44
+ try:
45
+ # Fallback sur un modèle plus général mais avec fine-tuning
46
+ fashion_pipe = pipeline(
47
+ "zero-shot-image-classification",
48
+ model="openai/clip-vit-large-patch14",
49
+ device=0 if torch.cuda.is_available() else -1
50
+ )
51
+ print("✅ CLIP Large model loaded successfully!")
52
+ except:
53
+ # Dernier recours
54
+ fashion_pipe = pipeline(
55
+ "zero-shot-image-classification",
56
+ model="openai/clip-vit-base-patch32",
57
+ device=0 if torch.cuda.is_available() else -1
58
+ )
59
+ print("✅ CLIP Base model loaded as fallback!")
60
 
61
  # Configuration API
62
  API_KEYS = os.environ.get("API_KEYS", "").split(",")
 
66
  image_data: str
67
  api_key: Optional[str] = None
68
 
69
+ def preprocess_image(image):
70
+ """Prétraite l'image pour améliorer la détection"""
71
+ # Conversion en RGB si nécessaire
72
+ if image.mode != 'RGB':
73
+ image = image.convert('RGB')
74
+
75
+ # Redimensionnement intelligent avec maintien des proportions
76
+ width, height = image.size
77
+ max_size = 512
78
+
79
+ if max(width, height) > max_size:
80
+ ratio = max_size / max(width, height)
81
+ new_size = (int(width * ratio), int(height * ratio))
82
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
83
+
84
+ return image
85
+
86
  def load_image_from_url(url):
87
  """Charge une image depuis une URL de manière robuste"""
88
  try:
89
  headers = {
90
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
91
  }
92
+ response = requests.get(url, headers=headers, timeout=15)
93
  response.raise_for_status()
94
 
95
  # Vérifie que c'est bien une image
 
97
  raise ValueError("URL does not point to an image")
98
 
99
  image = Image.open(BytesIO(response.content))
100
+ return preprocess_image(image)
101
 
102
  except Exception as e:
103
  raise ValueError(f"❌ Cannot load image from URL: {str(e)}")
 
112
  image = Image.fromarray(image_input)
113
  else:
114
  image = image_input
115
+ image = preprocess_image(image)
116
  elif url_input and url_input.strip():
117
  # Utilise l'URL
118
  image = load_image_from_url(url_input.strip())
119
  else:
120
  return "❌ Please upload an image or enter a URL first", None
121
 
122
+ # 🔥 ANALYSE PRINCIPALE AVEC PARAMÈTRES OPTIMISÉS
123
+ try:
124
+ # Essayer d'abord avec le modèle fashion-clip
125
+ predictions = fashion_pipe(image)
126
+
127
+ # Si c'est le modèle fashion-clip, adapter le format de réponse
128
+ if hasattr(fashion_pipe, 'model') and 'fashion-clip' in str(fashion_pipe.model):
129
+ # Trier par score et formater
130
+ predictions = sorted(predictions, key=lambda x: x['score'], reverse=True)
131
+ confident_predictions = [p for p in predictions if p['score'] > 0.05]
132
+ else:
133
+ # Pour les modèles zero-shot
134
+ predictions = fashion_pipe(
135
+ image,
136
+ candidate_labels=FASHION_CATEGORIES,
137
+ hypothesis_template="a clear photo of {}",
138
+ multi_label=True
139
+ )
140
+ confident_predictions = [p for p in predictions if p['score'] > 0.1]
141
+
142
+ except Exception as model_error:
143
+ print(f"Model error: {model_error}")
144
+ return "❌ Model analysis failed. Please try another image.", image
145
 
146
  if not confident_predictions:
147
+ return "❌ No confident prediction. Try a clearer image with better lighting.", image
148
+
149
+ # Trier par score décroissant
150
+ confident_predictions.sort(key=lambda x: x['score'], reverse=True)
151
 
152
  best_pred = confident_predictions[0]
153
 
154
  # Formatage des résultats
155
+ result_text = f"🎯 **Main item**: {best_pred['label'].title()}\n"
156
  result_text += f"**Confidence**: {best_pred['score']*100:.1f}%\n\n"
157
 
158
  if len(confident_predictions) > 1:
159
  result_text += "**Other possibilities**:\n"
160
+ for i, pred in enumerate(confident_predictions[1:6], 1): # Top 5 seulement
161
+ result_text += f"{i}. {pred['label'].title()} ({pred['score']*100:.1f}%)\n"
162
 
163
+ # Conseils basés sur la confiance
164
+ if best_pred['score'] < 0.7:
165
+ result_text += f"\n💡 **Tip**: Low confidence. Try a clearer image with the item centered and good lighting."
166
+ else:
167
+ result_text += f"\n✅ **High confidence detection**: This is very likely a {best_pred['label']}."
168
 
169
  return result_text, image
170
 
 
186
  .header { text-align: center; margin-bottom: 30px; }
187
  .input-section { background: #f8f9fa; padding: 20px; border-radius: 10px; }
188
  .output-section { background: white; padding: 20px; border-radius: 10px; }
189
+ .success { color: green; }
190
+ .warning { color: orange; }
191
+ .error { color: red; }
192
  """
193
  ) as demo:
194
 
 
214
  lines=2
215
  )
216
 
217
+ gr.Markdown("""
218
+ **📝 Tips for better results:**
219
+ - Use clear, well-lit images
220
+ - Center the clothing item
221
+ - Use plain backgrounds when possible
222
+ - Avoid multiple items in one image
223
+ """)
224
+
225
  analyze_btn = gr.Button(
226
  "🔍 Analyze Item",
227
  variant="primary",
 
247
  - Make sure the clothing item is clearly visible
248
  - Well-lit images work best
249
  - Avoid busy backgrounds
250
+ - For best results, show one item at a time
251
  """)
252
 
253
  # Événement de click
 
280
 
281
  image_bytes = base64.b64decode(request.image_data)
282
  image = Image.open(BytesIO(image_bytes))
283
+ image = preprocess_image(image)
284
 
285
  # Analyse avec des inputs vides pour URL
286
  result_text, processed_image = analyze_fashion_item(image, "")