dzmu commited on
Commit
aad2489
·
verified ·
1 Parent(s): 6640bd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -120
app.py CHANGED
@@ -5,7 +5,7 @@ import numpy as np
5
  import random
6
  import os
7
  from PIL import Image
8
- from ultralytics import YOLO
9
  from gtts import gTTS
10
  import uuid
11
  import time
@@ -14,7 +14,7 @@ import tempfile
14
  # --- Configuration ---
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  YOLO_PERSON_MODEL_PATH = 'yolov8n.pt' # Standard YOLOv8 for person detection
17
- YOLO_FASHION_MODEL_PATH = 'best.pt' # Your trained fashion model
18
  CLIP_MODEL_NAME = "ViT-B/32"
19
 
20
  # --- Load Models ---
@@ -24,7 +24,7 @@ try:
24
  print(f"CLIP model ({CLIP_MODEL_NAME}) loaded successfully.")
25
  except Exception as e:
26
  print(f"Error loading CLIP model: {e}")
27
- # Handle error appropriately, maybe exit or use a fallback
28
 
29
  try:
30
  yolo_person_model = YOLO(YOLO_PERSON_MODEL_PATH).to(DEVICE)
@@ -33,154 +33,136 @@ except Exception as e:
33
  print(f"Error loading YOLO person model: {e}")
34
  # Handle error
35
 
36
- try:
37
- fashion_model = YOLO(YOLO_FASHION_MODEL_PATH).to(DEVICE)
38
- print(f"YOLO fashion model ({YOLO_FASHION_MODEL_PATH}) loaded successfully.")
39
- # It's crucial that fashion_model.names is populated correctly after loading.
40
- # If it's not, you might need to load names from a corresponding .yaml file.
41
- if not hasattr(fashion_model, 'names') or not fashion_model.names:
42
- print("Warning: Fashion model names not found. Detection might not work correctly.")
43
- # Example: Manually assign if needed (replace with your actual class names)
44
- # fashion_model.names = {0: 't-shirt', 1: 'jeans', 2: 'sneakers', ...}
45
- except Exception as e:
46
- print(f"Error loading YOLO fashion model: {e}")
47
- # Handle error
48
 
49
  # --- Prompts and Responses ---
50
  style_prompts = {
51
  'drippy': [
52
- "avant-garde streetwear",
53
- "high-fashion designer outfit",
54
- "trendsetting urban attire",
55
- "luxury sneakers and chic accessories",
56
- "cutting-edge, bold style"
57
  ],
58
  'mid': [
59
- "casual everyday outfit",
60
- "modern minimalistic attire",
61
- "comfortable yet stylish look",
62
- "simple, relaxed streetwear",
63
- "balanced, practical fashion"
64
  ],
65
  'not_drippy': [
66
- "disheveled outfit",
67
- "poorly coordinated fashion",
68
- "unfashionable, outdated attire",
69
- "tacky, mismatched ensemble",
70
- "sloppy, uninspired look"
71
  ]
72
  }
73
 
74
- # Only style prompts are needed for CLIP now
75
- clip_style_texts = []
76
- for category in style_prompts:
77
- clip_style_texts.extend(style_prompts[category])
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  response_templates = {
80
  'drippy': [
81
- "You're Drippy, bruh – fire {item}!",
82
- "{item} goes crazy, on god!",
83
- "Certified drippy with that {item}."
84
  ],
85
  'mid': [
86
- "Drop the {item} and you might get a text back.",
87
- "It's alright, but I'd upgrade the {item}.",
88
  "Mid fit alert. That {item} is holding you back."
89
  ],
90
  'not_drippy': [
91
- "Bro thought that {item} was tuff!",
92
- "Oh hell nah! Burn that {item}!",
93
  "Crimes against fashion, especially that {item}! Also… maybe get a haircut.",
94
  "Never walk out the house again with that {item}."
95
  ]
96
  }
97
 
98
- # Map internal category keys to user-facing labels
99
- CATEGORY_LABEL_MAP = {
100
- "drippy": "drippy",
101
- "mid": "mid",
102
- "not_drippy": "trash"
103
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # --- Core Logic ---
106
  def analyze_outfit(input_img: Image.Image):
107
  if input_img is None:
108
  return "Please upload an image.", None, "Error: No image provided."
109
 
110
- img = input_img.copy() # Work on a copy
111
 
112
- # 1) YOLO Person Detection
113
- person_results = yolo_person_model(img, verbose=False) # verbose=False suppresses console output
114
  boxes = person_results[0].boxes.xyxy.cpu().numpy()
115
  classes = person_results[0].boxes.cls.cpu().numpy()
116
  confidences = person_results[0].boxes.conf.cpu().numpy()
117
 
118
- # Find the most confident 'person' detection (class ID 0 for COCO)
119
  person_indices = np.where(classes == 0)[0]
120
- cropped_img = img # Default to full image if no person found
121
-
122
  if len(person_indices) > 0:
123
  max_conf_person_idx = person_indices[np.argmax(confidences[person_indices])]
124
  x1, y1, x2, y2 = map(int, boxes[max_conf_person_idx])
125
- # Ensure crop coordinates are valid
126
  x1, y1 = max(0, x1), max(0, y1)
127
  x2, y2 = min(img.width, x2), min(img.height, y2)
128
  if x1 < x2 and y1 < y2:
129
  cropped_img = img.crop((x1, y1, x2, y2))
 
130
  else:
131
  print("Warning: Invalid person bounding box after clipping. Using full image.")
132
  cropped_img = img
133
- print(f"Person detected and cropped: Box {x1, y1, x2, y2}")
134
  else:
135
  print("No person detected by yolo_person_model. Analyzing full image.")
136
- # Decide if you want to proceed without a person or return an error
137
- # return "Could not detect a person in the image.", None, "Error: Person not found."
138
 
139
- # 2) YOLO Fashion Detection (on the cropped image)
140
- detected_clothing_item = "fit" # Default item if no clothing detected
141
- try:
142
- fashion_results = fashion_model(cropped_img, conf=0.1, verbose=False)
143
- if len(fashion_results[0].boxes) > 0:
144
- fashion_boxes = fashion_results[0].boxes.xyxy.cpu().numpy()
145
- fashion_classes = fashion_results[0].boxes.cls.cpu().numpy()
146
- fashion_confidences = fashion_results[0].boxes.conf.cpu().numpy()
147
- fashion_names = fashion_results[0].names # Dictionary mapping class index to name
148
-
149
- # Get the most confident clothing detection
150
- max_conf_fashion_idx = np.argmax(fashion_confidences)
151
- detected_class_id = int(fashion_classes[max_conf_fashion_idx])
152
-
153
- if fashion_names and detected_class_id in fashion_names:
154
- detected_clothing_item = fashion_names[detected_class_id]
155
- print(f"Most confident clothing item detected: {detected_clothing_item} (Conf: {fashion_confidences[max_conf_fashion_idx]:.2f})")
156
- else:
157
- print(f"Warning: Detected clothing class ID {detected_class_id} not found in fashion model names.")
158
- detected_clothing_item = "clothing item" # Fallback if name mapping fails
159
- else:
160
- print("No clothing items detected by fashion_model on the cropped image.")
161
- detected_clothing_item = "style" # Fallback if nothing specific is found
162
- except Exception as e:
163
- print(f"Error during fashion detection: {e}")
164
- detected_clothing_item = "outfit" # General fallback on error
165
 
166
- # 3) CLIP Style Analysis (on the cropped image)
 
167
  try:
168
  image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(DEVICE)
169
- text_tokens = clip.tokenize(clip_style_texts).to(DEVICE)
 
170
 
171
  with torch.no_grad():
172
  logits, _ = clip_model(image_tensor, text_tokens)
173
- # Probabilities ONLY for the style prompts
174
- style_probs = logits.softmax(dim=-1).cpu().numpy()[0]
175
 
176
- # Calculate average scores for each style category
177
  drip_len = len(style_prompts['drippy'])
178
  mid_len = len(style_prompts['mid'])
179
- # not_len = len(style_prompts['not_drippy']) # Length of the last section
180
 
181
- drip_score = np.mean(style_probs[0 : drip_len])
182
- mid_score = np.mean(style_probs[drip_len : drip_len + mid_len])
183
- not_score = np.mean(style_probs[drip_len + mid_len :]) # Rest are 'not_drippy'
184
 
185
  # Determine the category based on highest average score
186
  if drip_score > mid_score and drip_score > not_score:
@@ -194,72 +176,66 @@ def analyze_outfit(input_img: Image.Image):
194
  final_score = not_score
195
 
196
  category_label = CATEGORY_LABEL_MAP[category_key]
197
- final_score_str = f"{final_score:.2f}" # Format score
198
  print(f"Style analysis: Category={category_label}, Score={final_score_str}")
199
 
 
 
 
 
 
 
 
 
 
 
200
  except Exception as e:
201
- print(f"Error during CLIP analysis: {e}")
202
- # Handle CLIP error - maybe return a default message
203
- return "Error during style analysis.", None, f"Analysis Error: {e}"
204
 
205
- # 4) Generate Response and TTS
206
  try:
207
- # Select a random response template for the determined category
208
  response_text = random.choice(response_templates[category_key]).format(item=detected_clothing_item)
209
 
210
- # Generate TTS audio
211
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
212
- tts = gTTS(text=response_text, lang='en', tld='com', slow=False) # Use tld='com' for a standard voice
213
  tts.save(tts_path)
214
  print(f"Generated TTS response: '{response_text}' saved to {tts_path}")
215
 
216
- # Output HTML for category + numeric score
217
  category_html = f"""
218
  <div style='text-align: center; padding: 15px; border: 1px solid #eee; border-radius: 8px;'>
219
  <h2 style='color: #333; margin-bottom: 5px;'>Your fit is {category_label.upper()}!</h2>
220
  <p style='font-size: 1.1em; color: #555; margin-top: 0;'>Style Score: {final_score_str}</p>
221
  </div>
222
  """
223
-
224
  return category_html, tts_path, response_text
225
 
226
  except Exception as e:
227
  print(f"Error during response/TTS generation: {e}")
228
- # Fallback if TTS or formatting fails
229
  category_html = f"<h2>Result: {category_label} (Score: {final_score_str})</h2>"
230
  return category_html, None, f"Analysis complete ({category_label}), but error generating audio/response."
231
 
232
 
233
- # --- Gradio Interface ---
234
  with gr.Blocks(css=".gradio-container { max-width: 800px !important; margin: auto !important; } footer { display: none !important; }") as demo:
235
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>💧 DripAI: Rate Your Fit 💧</h1>")
236
-
237
  with gr.Row():
238
  with gr.Column(scale=1):
239
  input_image = gr.Image(
240
- type='pil',
241
- label="Upload, Paste, or Use Webcam for your Outfit Photo",
242
- # Explicitly define sources for better UI clarity
243
- sources=['upload', 'webcam', 'clipboard'],
244
- height=400
245
  )
246
  analyze_button = gr.Button("Analyze Outfit", variant="primary", size="lg")
247
-
248
  with gr.Column(scale=1):
249
  gr.Markdown("### Analysis Result:")
250
- category_html = gr.HTML(label="Category & Score") # Displays HTML output
251
  audio_output = gr.Audio(autoplay=True, label="Audio Feedback", streaming=False)
252
- response_box = gr.Textbox(lines=4, label="Text Feedback", interactive=False) # Make textbox read-only
253
-
254
  analyze_button.click(
255
- fn=analyze_outfit,
256
- inputs=[input_image],
257
- outputs=[category_html, audio_output, response_box],
258
- # show_progress="full" # Optional: Show progress bar during processing
259
  )
260
-
261
  gr.Markdown("<p style='text-align: center; color: grey; font-size: 0.9em;'>Upload an image of your outfit and click 'Analyze Outfit'. DripAI will rate your style and identify a key clothing item.</p>")
262
 
263
  # --- Launch App ---
264
  if __name__ == "__main__":
265
- demo.launch(debug=True) # Enable debug for more detailed logs
 
5
  import random
6
  import os
7
  from PIL import Image
8
+ from ultralytics import YOLO # Still needed for person detection
9
  from gtts import gTTS
10
  import uuid
11
  import time
 
14
  # --- Configuration ---
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  YOLO_PERSON_MODEL_PATH = 'yolov8n.pt' # Standard YOLOv8 for person detection
17
+ # YOLO_FASHION_MODEL_PATH = 'best.pt' # REMOVED - Not using fashion model anymore
18
  CLIP_MODEL_NAME = "ViT-B/32"
19
 
20
  # --- Load Models ---
 
24
  print(f"CLIP model ({CLIP_MODEL_NAME}) loaded successfully.")
25
  except Exception as e:
26
  print(f"Error loading CLIP model: {e}")
27
+ # Handle error
28
 
29
  try:
30
  yolo_person_model = YOLO(YOLO_PERSON_MODEL_PATH).to(DEVICE)
 
33
  print(f"Error loading YOLO person model: {e}")
34
  # Handle error
35
 
36
+ # REMOVED Fashion Model Loading
37
+ # try:
38
+ # fashion_model = YOLO(YOLO_FASHION_MODEL_PATH).to(DEVICE)
39
+ # print(f"YOLO fashion model ({YOLO_FASHION_MODEL_PATH}) loaded successfully.")
40
+ # if not hasattr(fashion_model, 'names') or not fashion_model.names:
41
+ # print("Warning: Fashion model names not found.")
42
+ # except Exception as e:
43
+ # print(f"Error loading YOLO fashion model: {e}")
 
 
 
 
44
 
45
  # --- Prompts and Responses ---
46
  style_prompts = {
47
  'drippy': [
48
+ "avant-garde streetwear", "high-fashion designer outfit", "trendsetting urban attire",
49
+ "luxury sneakers and chic accessories", "cutting-edge, bold style"
 
 
 
50
  ],
51
  'mid': [
52
+ "casual everyday outfit", "modern minimalistic attire", "comfortable yet stylish look",
53
+ "simple, relaxed streetwear", "balanced, practical fashion"
 
 
 
54
  ],
55
  'not_drippy': [
56
+ "disheveled outfit", "poorly coordinated fashion", "unfashionable, outdated attire",
57
+ "tacky, mismatched ensemble", "sloppy, uninspired look"
 
 
 
58
  ]
59
  }
60
 
61
+ # --- REINSTATED: Clothing prompts for CLIP ---
62
+ clothing_prompts = [
63
+ "t-shirt", "dress shirt", "blouse", "hoodie", "jacket", "sweater", "coat",
64
+ "dress", "skirt", "pants", "jeans", "trousers", "shorts",
65
+ "sneakers", "boots", "heels", "sandals",
66
+ "cap", "hat", "scarf", "gloves", "bag", "accessory", "tank-top", "haircut"
67
+ ]
68
+
69
+ # --- REINSTATED: Combine all prompts for CLIP ---
70
+ all_prompts = []
71
+ for cat_prompts in style_prompts.values():
72
+ all_prompts.extend(cat_prompts)
73
+ # Record end of style prompts before adding clothing prompts
74
+ style_prompts_end_index = len(all_prompts)
75
+ all_prompts.extend(clothing_prompts)
76
+ print(f"Total prompts for CLIP: {len(all_prompts)}")
77
 
78
  response_templates = {
79
  'drippy': [
80
+ "You're Drippy, bruh – fire {item}!", "{item} goes crazy, on god!", "Certified drippy with that {item}."
 
 
81
  ],
82
  'mid': [
83
+ "Drop the {item} and you might get a text back.", "It's alright, but I'd upgrade the {item}.",
 
84
  "Mid fit alert. That {item} is holding you back."
85
  ],
86
  'not_drippy': [
87
+ "Bro thought that {item} was tuff!", "Oh hell nah! Burn that {item}!",
 
88
  "Crimes against fashion, especially that {item}! Also… maybe get a haircut.",
89
  "Never walk out the house again with that {item}."
90
  ]
91
  }
92
 
93
+ CATEGORY_LABEL_MAP = { "drippy": "drippy", "mid": "mid", "not_drippy": "trash" }
94
+
95
+ # --- REINSTATED: Function to get top clothing items based on CLIP probabilities ---
96
+ def get_top_clothing(probs, n=3):
97
+ """Gets the top N clothing items based on CLIP probabilities."""
98
+ # Calculate the start index of clothing probabilities in the combined 'probs' array
99
+ clothing_probs_start_index = style_prompts_end_index
100
+ clothing_probs = probs[clothing_probs_start_index:]
101
+
102
+ # Ensure we don't request more items than available prompts
103
+ actual_n = min(n, len(clothing_prompts))
104
+ if actual_n <= 0:
105
+ return ["item"] # Return default if no clothing prompts
106
+
107
+ # Get indices of top N probabilities within the clothing_probs slice
108
+ top_indices_in_slice = np.argsort(clothing_probs)[-actual_n:]
109
+
110
+ # Return the corresponding clothing prompt names in descending order of probability
111
+ return [clothing_prompts[i] for i in reversed(top_indices_in_slice)]
112
+
113
 
114
  # --- Core Logic ---
115
  def analyze_outfit(input_img: Image.Image):
116
  if input_img is None:
117
  return "Please upload an image.", None, "Error: No image provided."
118
 
119
+ img = input_img.copy()
120
 
121
+ # 1) YOLO Person Detection (Same as before)
122
+ person_results = yolo_person_model(img, verbose=False)
123
  boxes = person_results[0].boxes.xyxy.cpu().numpy()
124
  classes = person_results[0].boxes.cls.cpu().numpy()
125
  confidences = person_results[0].boxes.conf.cpu().numpy()
126
 
 
127
  person_indices = np.where(classes == 0)[0]
128
+ cropped_img = img
 
129
  if len(person_indices) > 0:
130
  max_conf_person_idx = person_indices[np.argmax(confidences[person_indices])]
131
  x1, y1, x2, y2 = map(int, boxes[max_conf_person_idx])
 
132
  x1, y1 = max(0, x1), max(0, y1)
133
  x2, y2 = min(img.width, x2), min(img.height, y2)
134
  if x1 < x2 and y1 < y2:
135
  cropped_img = img.crop((x1, y1, x2, y2))
136
+ print(f"Person detected and cropped: Box {x1, y1, x2, y2}")
137
  else:
138
  print("Warning: Invalid person bounding box after clipping. Using full image.")
139
  cropped_img = img
 
140
  else:
141
  print("No person detected by yolo_person_model. Analyzing full image.")
142
+ # Decide if you want to proceed or return an error
 
143
 
144
+ # --- REMOVED: YOLO Fashion Detection ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # 2) CLIP Analysis (Using ALL prompts - Style + Clothing)
147
+ detected_clothing_item = "look" # Default if something goes wrong
148
  try:
149
  image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(DEVICE)
150
+ # --- Use all_prompts for tokenization ---
151
+ text_tokens = clip.tokenize(all_prompts).to(DEVICE)
152
 
153
  with torch.no_grad():
154
  logits, _ = clip_model(image_tensor, text_tokens)
155
+ # --- Probabilities for ALL prompts ---
156
+ all_probs = logits.softmax(dim=-1).cpu().numpy()[0]
157
 
158
+ # Calculate average scores for each style category based on their slices in all_probs
159
  drip_len = len(style_prompts['drippy'])
160
  mid_len = len(style_prompts['mid'])
161
+ # not_len = len(style_prompts['not_drippy']) # Calculated implicitly below
162
 
163
+ drip_score = np.mean(all_probs[0 : drip_len])
164
+ mid_score = np.mean(all_probs[drip_len : drip_len + mid_len])
165
+ not_score = np.mean(all_probs[drip_len + mid_len : style_prompts_end_index]) # Scores up to end of style prompts
166
 
167
  # Determine the category based on highest average score
168
  if drip_score > mid_score and drip_score > not_score:
 
176
  final_score = not_score
177
 
178
  category_label = CATEGORY_LABEL_MAP[category_key]
179
+ final_score_str = f"{final_score:.2f}"
180
  print(f"Style analysis: Category={category_label}, Score={final_score_str}")
181
 
182
+ # --- REINSTATED: Get clothing item using CLIP probs ---
183
+ clothing_items_detected_by_clip = get_top_clothing(all_probs, n=1) # Get top 1 item
184
+ if clothing_items_detected_by_clip:
185
+ detected_clothing_item = clothing_items_detected_by_clip[0]
186
+ print(f"Top clothing item identified by CLIP: {detected_clothing_item}")
187
+ else:
188
+ print("Warning: CLIP did not identify a top clothing item.")
189
+ detected_clothing_item = "fit" # Fallback if get_top_clothing fails
190
+
191
+
192
  except Exception as e:
193
+ print(f"Error during CLIP analysis or clothing selection: {e}")
194
+ return "Error during analysis.", None, f"Analysis Error: {e}"
 
195
 
196
+ # 3) Generate Response and TTS (Same as before, but uses item from CLIP)
197
  try:
 
198
  response_text = random.choice(response_templates[category_key]).format(item=detected_clothing_item)
199
 
 
200
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
201
+ tts = gTTS(text=response_text, lang='en', tld='com', slow=False)
202
  tts.save(tts_path)
203
  print(f"Generated TTS response: '{response_text}' saved to {tts_path}")
204
 
 
205
  category_html = f"""
206
  <div style='text-align: center; padding: 15px; border: 1px solid #eee; border-radius: 8px;'>
207
  <h2 style='color: #333; margin-bottom: 5px;'>Your fit is {category_label.upper()}!</h2>
208
  <p style='font-size: 1.1em; color: #555; margin-top: 0;'>Style Score: {final_score_str}</p>
209
  </div>
210
  """
 
211
  return category_html, tts_path, response_text
212
 
213
  except Exception as e:
214
  print(f"Error during response/TTS generation: {e}")
 
215
  category_html = f"<h2>Result: {category_label} (Score: {final_score_str})</h2>"
216
  return category_html, None, f"Analysis complete ({category_label}), but error generating audio/response."
217
 
218
 
219
+ # --- Gradio Interface (Unchanged) ---
220
  with gr.Blocks(css=".gradio-container { max-width: 800px !important; margin: auto !important; } footer { display: none !important; }") as demo:
221
  gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>💧 DripAI: Rate Your Fit 💧</h1>")
 
222
  with gr.Row():
223
  with gr.Column(scale=1):
224
  input_image = gr.Image(
225
+ type='pil', label="Upload, Paste, or Use Webcam for your Outfit Photo",
226
+ sources=['upload', 'webcam', 'clipboard'], height=400
 
 
 
227
  )
228
  analyze_button = gr.Button("Analyze Outfit", variant="primary", size="lg")
 
229
  with gr.Column(scale=1):
230
  gr.Markdown("### Analysis Result:")
231
+ category_html = gr.HTML(label="Category & Score")
232
  audio_output = gr.Audio(autoplay=True, label="Audio Feedback", streaming=False)
233
+ response_box = gr.Textbox(lines=4, label="Text Feedback", interactive=False)
 
234
  analyze_button.click(
235
+ fn=analyze_outfit, inputs=[input_image], outputs=[category_html, audio_output, response_box]
 
 
 
236
  )
 
237
  gr.Markdown("<p style='text-align: center; color: grey; font-size: 0.9em;'>Upload an image of your outfit and click 'Analyze Outfit'. DripAI will rate your style and identify a key clothing item.</p>")
238
 
239
  # --- Launch App ---
240
  if __name__ == "__main__":
241
+ demo.launch(debug=True) # Assumes debug is helpful on HF too, might remove later