dzmu commited on
Commit
76d5e1f
·
verified ·
1 Parent(s): aad2489

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -60
app.py CHANGED
@@ -25,7 +25,6 @@ try:
25
  except Exception as e:
26
  print(f"Error loading CLIP model: {e}")
27
  # Handle error
28
-
29
  try:
30
  yolo_person_model = YOLO(YOLO_PERSON_MODEL_PATH).to(DEVICE)
31
  print(f"YOLO person detection model ({YOLO_PERSON_MODEL_PATH}) loaded successfully.")
@@ -34,13 +33,6 @@ except Exception as e:
34
  # Handle error
35
 
36
  # REMOVED Fashion Model Loading
37
- # try:
38
- # fashion_model = YOLO(YOLO_FASHION_MODEL_PATH).to(DEVICE)
39
- # print(f"YOLO fashion model ({YOLO_FASHION_MODEL_PATH}) loaded successfully.")
40
- # if not hasattr(fashion_model, 'names') or not fashion_model.names:
41
- # print("Warning: Fashion model names not found.")
42
- # except Exception as e:
43
- # print(f"Error loading YOLO fashion model: {e}")
44
 
45
  # --- Prompts and Responses ---
46
  style_prompts = {
@@ -70,6 +62,7 @@ clothing_prompts = [
70
  all_prompts = []
71
  for cat_prompts in style_prompts.values():
72
  all_prompts.extend(cat_prompts)
 
73
  # Record end of style prompts before adding clothing prompts
74
  style_prompts_end_index = len(all_prompts)
75
  all_prompts.extend(clothing_prompts)
@@ -89,41 +82,31 @@ response_templates = {
89
  "Never walk out the house again with that {item}."
90
  ]
91
  }
92
-
93
  CATEGORY_LABEL_MAP = { "drippy": "drippy", "mid": "mid", "not_drippy": "trash" }
94
 
95
  # --- REINSTATED: Function to get top clothing items based on CLIP probabilities ---
96
  def get_top_clothing(probs, n=3):
97
  """Gets the top N clothing items based on CLIP probabilities."""
98
- # Calculate the start index of clothing probabilities in the combined 'probs' array
99
  clothing_probs_start_index = style_prompts_end_index
100
  clothing_probs = probs[clothing_probs_start_index:]
101
-
102
- # Ensure we don't request more items than available prompts
103
  actual_n = min(n, len(clothing_prompts))
104
  if actual_n <= 0:
105
- return ["item"] # Return default if no clothing prompts
106
-
107
- # Get indices of top N probabilities within the clothing_probs slice
108
  top_indices_in_slice = np.argsort(clothing_probs)[-actual_n:]
109
-
110
- # Return the corresponding clothing prompt names in descending order of probability
111
  return [clothing_prompts[i] for i in reversed(top_indices_in_slice)]
112
 
113
-
114
  # --- Core Logic ---
115
  def analyze_outfit(input_img: Image.Image):
116
  if input_img is None:
117
- return "Please upload an image.", None, "Error: No image provided."
 
118
 
119
  img = input_img.copy()
120
-
121
- # 1) YOLO Person Detection (Same as before)
122
  person_results = yolo_person_model(img, verbose=False)
123
  boxes = person_results[0].boxes.xyxy.cpu().numpy()
124
  classes = person_results[0].boxes.cls.cpu().numpy()
125
  confidences = person_results[0].boxes.conf.cpu().numpy()
126
-
127
  person_indices = np.where(classes == 0)[0]
128
  cropped_img = img
129
  if len(person_indices) > 0:
@@ -139,32 +122,23 @@ def analyze_outfit(input_img: Image.Image):
139
  cropped_img = img
140
  else:
141
  print("No person detected by yolo_person_model. Analyzing full image.")
142
- # Decide if you want to proceed or return an error
143
-
144
- # --- REMOVED: YOLO Fashion Detection ---
145
 
146
- # 2) CLIP Analysis (Using ALL prompts - Style + Clothing)
147
- detected_clothing_item = "look" # Default if something goes wrong
148
  try:
149
  image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(DEVICE)
150
- # --- Use all_prompts for tokenization ---
151
  text_tokens = clip.tokenize(all_prompts).to(DEVICE)
152
 
153
  with torch.no_grad():
154
  logits, _ = clip_model(image_tensor, text_tokens)
155
- # --- Probabilities for ALL prompts ---
156
  all_probs = logits.softmax(dim=-1).cpu().numpy()[0]
157
 
158
- # Calculate average scores for each style category based on their slices in all_probs
159
  drip_len = len(style_prompts['drippy'])
160
  mid_len = len(style_prompts['mid'])
161
- # not_len = len(style_prompts['not_drippy']) # Calculated implicitly below
162
-
163
  drip_score = np.mean(all_probs[0 : drip_len])
164
  mid_score = np.mean(all_probs[drip_len : drip_len + mid_len])
165
- not_score = np.mean(all_probs[drip_len + mid_len : style_prompts_end_index]) # Scores up to end of style prompts
166
 
167
- # Determine the category based on highest average score
168
  if drip_score > mid_score and drip_score > not_score:
169
  category_key = 'drippy'
170
  final_score = drip_score
@@ -179,63 +153,244 @@ def analyze_outfit(input_img: Image.Image):
179
  final_score_str = f"{final_score:.2f}"
180
  print(f"Style analysis: Category={category_label}, Score={final_score_str}")
181
 
182
- # --- REINSTATED: Get clothing item using CLIP probs ---
183
- clothing_items_detected_by_clip = get_top_clothing(all_probs, n=1) # Get top 1 item
184
  if clothing_items_detected_by_clip:
185
  detected_clothing_item = clothing_items_detected_by_clip[0]
186
  print(f"Top clothing item identified by CLIP: {detected_clothing_item}")
187
  else:
188
  print("Warning: CLIP did not identify a top clothing item.")
189
- detected_clothing_item = "fit" # Fallback if get_top_clothing fails
190
-
191
 
192
  except Exception as e:
193
  print(f"Error during CLIP analysis or clothing selection: {e}")
194
- return "Error during analysis.", None, f"Analysis Error: {e}"
 
195
 
196
- # 3) Generate Response and TTS (Same as before, but uses item from CLIP)
197
  try:
198
  response_text = random.choice(response_templates[category_key]).format(item=detected_clothing_item)
199
-
200
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
201
  tts = gTTS(text=response_text, lang='en', tld='com', slow=False)
202
  tts.save(tts_path)
203
  print(f"Generated TTS response: '{response_text}' saved to {tts_path}")
204
 
 
 
205
  category_html = f"""
206
- <div style='text-align: center; padding: 15px; border: 1px solid #eee; border-radius: 8px;'>
207
- <h2 style='color: #333; margin-bottom: 5px;'>Your fit is {category_label.upper()}!</h2>
208
- <p style='font-size: 1.1em; color: #555; margin-top: 0;'>Style Score: {final_score_str}</p>
209
  </div>
210
  """
211
  return category_html, tts_path, response_text
212
 
213
  except Exception as e:
214
  print(f"Error during response/TTS generation: {e}")
215
- category_html = f"<h2>Result: {category_label} (Score: {final_score_str})</h2>"
 
 
 
 
 
216
  return category_html, None, f"Analysis complete ({category_label}), but error generating audio/response."
217
 
218
 
219
- # --- Gradio Interface (Unchanged) ---
220
- with gr.Blocks(css=".gradio-container { max-width: 800px !important; margin: auto !important; } footer { display: none !important; }") as demo:
221
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>💧 DripAI: Rate Your Fit 💧</h1>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  with gr.Row():
223
- with gr.Column(scale=1):
224
  input_image = gr.Image(
225
- type='pil', label="Upload, Paste, or Use Webcam for your Outfit Photo",
226
- sources=['upload', 'webcam', 'clipboard'], height=400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  )
228
- analyze_button = gr.Button("Analyze Outfit", variant="primary", size="lg")
229
- with gr.Column(scale=1):
230
- gr.Markdown("### Analysis Result:")
231
- category_html = gr.HTML(label="Category & Score")
232
- audio_output = gr.Audio(autoplay=True, label="Audio Feedback", streaming=False)
233
- response_box = gr.Textbox(lines=4, label="Text Feedback", interactive=False)
 
234
  analyze_button.click(
235
- fn=analyze_outfit, inputs=[input_image], outputs=[category_html, audio_output, response_box]
 
 
236
  )
237
- gr.Markdown("<p style='text-align: center; color: grey; font-size: 0.9em;'>Upload an image of your outfit and click 'Analyze Outfit'. DripAI will rate your style and identify a key clothing item.</p>")
 
 
238
 
239
  # --- Launch App ---
240
  if __name__ == "__main__":
241
- demo.launch(debug=True) # Assumes debug is helpful on HF too, might remove later
 
25
  except Exception as e:
26
  print(f"Error loading CLIP model: {e}")
27
  # Handle error
 
28
  try:
29
  yolo_person_model = YOLO(YOLO_PERSON_MODEL_PATH).to(DEVICE)
30
  print(f"YOLO person detection model ({YOLO_PERSON_MODEL_PATH}) loaded successfully.")
 
33
  # Handle error
34
 
35
  # REMOVED Fashion Model Loading
 
 
 
 
 
 
 
36
 
37
  # --- Prompts and Responses ---
38
  style_prompts = {
 
62
  all_prompts = []
63
  for cat_prompts in style_prompts.values():
64
  all_prompts.extend(cat_prompts)
65
+
66
  # Record end of style prompts before adding clothing prompts
67
  style_prompts_end_index = len(all_prompts)
68
  all_prompts.extend(clothing_prompts)
 
82
  "Never walk out the house again with that {item}."
83
  ]
84
  }
 
85
  CATEGORY_LABEL_MAP = { "drippy": "drippy", "mid": "mid", "not_drippy": "trash" }
86
 
87
  # --- REINSTATED: Function to get top clothing items based on CLIP probabilities ---
88
  def get_top_clothing(probs, n=3):
89
  """Gets the top N clothing items based on CLIP probabilities."""
 
90
  clothing_probs_start_index = style_prompts_end_index
91
  clothing_probs = probs[clothing_probs_start_index:]
 
 
92
  actual_n = min(n, len(clothing_prompts))
93
  if actual_n <= 0:
94
+ return ["item"]
 
 
95
  top_indices_in_slice = np.argsort(clothing_probs)[-actual_n:]
 
 
96
  return [clothing_prompts[i] for i in reversed(top_indices_in_slice)]
97
 
 
98
  # --- Core Logic ---
99
  def analyze_outfit(input_img: Image.Image):
100
  if input_img is None:
101
+ return ("<p style='color: #FF5555; text-align: center;'>Please upload an image.</p>",
102
+ None, "Error: No image provided.")
103
 
104
  img = input_img.copy()
105
+ # 1) YOLO Person Detection
 
106
  person_results = yolo_person_model(img, verbose=False)
107
  boxes = person_results[0].boxes.xyxy.cpu().numpy()
108
  classes = person_results[0].boxes.cls.cpu().numpy()
109
  confidences = person_results[0].boxes.conf.cpu().numpy()
 
110
  person_indices = np.where(classes == 0)[0]
111
  cropped_img = img
112
  if len(person_indices) > 0:
 
122
  cropped_img = img
123
  else:
124
  print("No person detected by yolo_person_model. Analyzing full image.")
 
 
 
125
 
126
+ # 2) CLIP Analysis
127
+ detected_clothing_item = "look"
128
  try:
129
  image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(DEVICE)
 
130
  text_tokens = clip.tokenize(all_prompts).to(DEVICE)
131
 
132
  with torch.no_grad():
133
  logits, _ = clip_model(image_tensor, text_tokens)
 
134
  all_probs = logits.softmax(dim=-1).cpu().numpy()[0]
135
 
 
136
  drip_len = len(style_prompts['drippy'])
137
  mid_len = len(style_prompts['mid'])
 
 
138
  drip_score = np.mean(all_probs[0 : drip_len])
139
  mid_score = np.mean(all_probs[drip_len : drip_len + mid_len])
140
+ not_score = np.mean(all_probs[drip_len + mid_len : style_prompts_end_index])
141
 
 
142
  if drip_score > mid_score and drip_score > not_score:
143
  category_key = 'drippy'
144
  final_score = drip_score
 
153
  final_score_str = f"{final_score:.2f}"
154
  print(f"Style analysis: Category={category_label}, Score={final_score_str}")
155
 
156
+ clothing_items_detected_by_clip = get_top_clothing(all_probs, n=1)
 
157
  if clothing_items_detected_by_clip:
158
  detected_clothing_item = clothing_items_detected_by_clip[0]
159
  print(f"Top clothing item identified by CLIP: {detected_clothing_item}")
160
  else:
161
  print("Warning: CLIP did not identify a top clothing item.")
162
+ detected_clothing_item = "fit"
 
163
 
164
  except Exception as e:
165
  print(f"Error during CLIP analysis or clothing selection: {e}")
166
+ return ("<p style='color: #FF5555;'>Error during analysis.</p>",
167
+ None, f"Analysis Error: {e}")
168
 
169
+ # 3) Generate Response and TTS
170
  try:
171
  response_text = random.choice(response_templates[category_key]).format(item=detected_clothing_item)
 
172
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
173
  tts = gTTS(text=response_text, lang='en', tld='com', slow=False)
174
  tts.save(tts_path)
175
  print(f"Generated TTS response: '{response_text}' saved to {tts_path}")
176
 
177
+ # --- Updated HTML Output ---
178
+ # Simpler structure, relies more on CSS for styling defined below
179
  category_html = f"""
180
+ <div class='results-container'>
181
+ <h2 class='result-category'>RATING: {category_label.upper()}</h2>
182
+ <p class='result-score'>Style Score: {final_score_str}</p>
183
  </div>
184
  """
185
  return category_html, tts_path, response_text
186
 
187
  except Exception as e:
188
  print(f"Error during response/TTS generation: {e}")
189
+ category_html = f"""
190
+ <div class='results-container'>
191
+ <h2 class='result-category'>Result: {category_label.upper()} (Score: {final_score_str})</h2>
192
+ <p class='result-score' style='color: #FFAAAA;'>Error generating audio/full response.</p>
193
+ </div>
194
+ """
195
  return category_html, None, f"Analysis complete ({category_label}), but error generating audio/response."
196
 
197
 
198
+ # --- Elite Fashion / Techno CSS ---
199
+ custom_css = """
200
+ :root {
201
+ --primary-bg-color: #000000;
202
+ --secondary-bg-color: #1A1A1A;
203
+ --text-color: #FFFFFF;
204
+ --accent-color: #1F04FF;
205
+ --border-color: #333333; /* Slightly lighter than secondary bg for subtle definition */
206
+ --input-bg-color: #1A1A1A;
207
+ --button-text-color: #FFFFFF;
208
+ --body-text-size: 16px; /* Base text size */
209
+ }
210
+
211
+ /* --- Global Styles --- */
212
+ body, .gradio-container {
213
+ background-color: var(--primary-bg-color) !important;
214
+ color: var(--text-color) !important;
215
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Open Sans', 'Helvetica Neue', sans-serif; /* Modern font stack */
216
+ font-size: var(--body-text-size);
217
+ }
218
+
219
+ /* Hide default Gradio footer */
220
+ footer { display: none !important; }
221
+
222
+ /* --- Component Styling --- */
223
+ .gr-block { /* General block container */
224
+ background-color: var(--secondary-bg-color) !important;
225
+ border: 1px solid var(--border-color) !important;
226
+ border-radius: 8px !important; /* Slightly rounded corners */
227
+ padding: 15px !important;
228
+ box-shadow: none !important; /* Remove default shadows */
229
+ }
230
+
231
+ /* Input/Output Text Areas & General inputs */
232
+ .gr-input, .gr-output, .gr-textbox textarea, .gr-dropdown select, .gr-checkboxgroup input {
233
+ background-color: var(--input-bg-color) !important;
234
+ color: var(--text-color) !important;
235
+ border: 1px solid var(--border-color) !important;
236
+ border-radius: 5px !important;
237
+ }
238
+ .gr-textbox textarea::placeholder { /* Style placeholder text if needed */
239
+ color: #888888 !important;
240
+ }
241
+
242
+ /* Component Labels */
243
+ .gr-label span, .gr-label .label-text {
244
+ color: var(--text-color) !important;
245
+ font-weight: 500 !important; /* Slightly bolder labels */
246
+ font-size: 0.95em !important;
247
+ margin-bottom: 8px !important; /* Space below label */
248
+ }
249
+
250
+ /* Image Input/Output */
251
+ .gr-image {
252
+ background-color: var(--primary-bg-color) !important; /* Match main background */
253
+ border: 1px dashed var(--border-color) !important; /* Dashed border for drop zone */
254
+ border-radius: 8px !important;
255
+ overflow: hidden; /* Ensure image stays within bounds */
256
+ }
257
+ .gr-image img {
258
+ border-radius: 6px !important; /* Slightly round image corners */
259
+ object-fit: contain; /* Ensure image fits well */
260
+ }
261
+ .gr-image .no-image, .gr-image .upload-button { /* Placeholder text/button inside image component */
262
+ color: #AAAAAA !important;
263
+ }
264
+
265
+ /* Audio Component */
266
+ .gr-audio > div:first-of-type { /* Target the container around the audio player */
267
+ border: 1px solid var(--border-color) !important;
268
+ background-color: var(--secondary-bg-color) !important;
269
+ border-radius: 5px !important;
270
+ padding: 10px !important;
271
+ }
272
+ .gr-audio audio { /* Style the audio player itself */
273
+ width: 100%; /* Make player responsive */
274
+ filter: invert(1) hue-rotate(180deg); /* Basic dark theme for player controls */
275
+ }
276
+
277
+ /* --- Button Styling --- */
278
+ .gr-button { /* General button style reset */
279
+ border: none !important;
280
+ border-radius: 5px !important;
281
+ transition: background-color 0.2s ease, transform 0.1s ease;
282
+ font-weight: 600 !important;
283
+ }
284
+ .gr-button-primary { /* Specific styling for the primary Analyze button */
285
+ background-color: var(--accent-color) !important;
286
+ color: var(--button-text-color) !important;
287
+ font-size: 1.1em !important; /* Make primary button slightly larger */
288
+ padding: 12px 20px !important; /* Adjust padding */
289
+ }
290
+ .gr-button-primary:hover {
291
+ background-color: #482FFF !important; /* Slightly lighter blue on hover */
292
+ transform: scale(1.02); /* Subtle scale effect */
293
+ box-shadow: 0 0 10px var(--accent-color); /* Add a glow effect */
294
+ }
295
+ .gr-button-primary:active {
296
+ transform: scale(0.98); /* Press down effect */
297
+ }
298
+
299
+ /* --- Typography & Content --- */
300
+ h1, h2, h3 {
301
+ color: var(--text-color) !important;
302
+ font-weight: 600; /* Bold headings */
303
+ letter-spacing: 0.5px; /* Add slight letter spacing */
304
+ }
305
+ .prose h1 { /* Target Markdown H1 specifically if needed */
306
+ text-align: center;
307
+ margin-bottom: 25px !important;
308
+ font-size: 2em !important; /* Larger title */
309
+ text-transform: uppercase; /* Uppercase for impact */
310
+ letter-spacing: 1.5px;
311
+ }
312
+ .prose p { /* Target Markdown Paragraph */
313
+ color: #CCCCCC !important; /* Slightly dimmer text for descriptions */
314
+ font-size: 0.95em;
315
+ text-align: center;
316
+ }
317
+
318
+ /* Custom styling for the results HTML block */
319
+ .results-container {
320
+ text-align: center;
321
+ padding: 20px;
322
+ border: 1px solid var(--accent-color); /* Use accent color for border */
323
+ border-radius: 8px;
324
+ background: linear-gradient(145deg, var(--secondary-bg-color), #2a2a2a); /* Subtle gradient */
325
+ }
326
+ .result-category {
327
+ color: var(--accent-color) !important; /* Use accent color for category */
328
+ font-size: 1.5em;
329
+ margin-bottom: 5px;
330
+ font-weight: 700;
331
+ text-transform: uppercase;
332
+ }
333
+ .result-score {
334
+ color: var(--text-color) !important;
335
+ font-size: 1.1em;
336
+ margin-top: 0;
337
+ }
338
+
339
+ /* --- Layout Adjustments --- */
340
+ .gradio-container {
341
+ max-width: 850px !important; /* Slightly wider max-width */
342
+ margin: auto !important;
343
+ padding-top: 30px; /* Add some space at the top */
344
+ }
345
+ .gr-row {
346
+ gap: 25px !important; /* Increase gap between columns */
347
+ }
348
+ """
349
+
350
+
351
+ # --- Gradio Interface (Now using the custom CSS) ---
352
+ with gr.Blocks(css=custom_css, theme=gr.themes.Base(primary_hue="neutral", secondary_hue="neutral", text_size=gr.themes.sizes.text_lg)) as demo: # Use Base theme to minimize default styles
353
+ # Title using Markdown (styled by CSS)
354
+ gr.Markdown("<h1>💧 DripAI: Rate Your Fit 💧</h1>")
355
+
356
  with gr.Row():
357
+ with gr.Column(scale=1, min_width=350): # Assign min width for better responsiveness
358
  input_image = gr.Image(
359
+ type='pil',
360
+ label="Upload Your Outfit", # Simpler label
361
+ sources=['upload', 'webcam', 'clipboard'],
362
+ height=450 # Slightly taller image area
363
+ )
364
+ analyze_button = gr.Button(
365
+ "Analyze Outfit",
366
+ variant="primary",
367
+ # size="lg" removed, controlled by CSS
368
+ )
369
+
370
+ with gr.Column(scale=1, min_width=350): # Assign min width
371
+ gr.Markdown("### ANALYSIS RESULTS") # Simple heading
372
+ category_html = gr.HTML(label="Rating & Score") # Label for screen readers/context
373
+ response_box = gr.Textbox(
374
+ lines=3,
375
+ label="Verbal Feedback", # Updated label
376
+ interactive=False
377
  )
378
+ audio_output = gr.Audio(
379
+ autoplay=False, # Changed default to false, user can click play
380
+ label="Audio Feedback",
381
+ streaming=False
382
+ )
383
+
384
+ # Bind the analysis function to the button click
385
  analyze_button.click(
386
+ fn=analyze_outfit,
387
+ inputs=[input_image],
388
+ outputs=[category_html, audio_output, response_box]
389
  )
390
+
391
+ # Footer description text
392
+ gr.Markdown("<p>Upload, paste, or use your webcam to capture your outfit. DripAI evaluates your style.</p>")
393
 
394
  # --- Launch App ---
395
  if __name__ == "__main__":
396
+ demo.launch(debug=True)