dzmu commited on
Commit
1a91629
·
verified ·
1 Parent(s): 57c11ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -100
app.py CHANGED
@@ -8,131 +8,258 @@ from PIL import Image
8
  from ultralytics import YOLO
9
  from gtts import gTTS
10
  import uuid
 
11
  import tempfile
12
 
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
- clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
15
- yolo_model = YOLO("yolov8n.pt").to(device)
16
- fashion_model = YOLO("best.pt").to(device) # Your trained model
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  style_prompts = {
19
- "drippy": [
20
- "avant-garde streetwear", "high-fashion designer outfit", "trendsetting urban attire",
21
- "luxury sneakers and chic accessories", "cutting-edge, bold style"
 
 
 
22
  ],
23
- "mid": [
24
- "casual everyday outfit", "modern minimalistic attire", "comfortable yet stylish look",
25
- "simple, relaxed streetwear", "balanced, practical fashion"
 
 
 
26
  ],
27
- "not_drippy": [
28
- "disheveled outfit", "poorly coordinated fashion", "unfashionable, outdated attire",
29
- "tacky, mismatched ensemble", "sloppy, uninspired look"
 
 
 
30
  ]
31
  }
32
 
33
- clothing_prompts = [
34
- "t-shirt", "dress shirt", "blouse", "hoodie", "jacket", "sweater", "coat", "dress", "skirt",
35
- "pants", "jeans", "trousers", "shorts", "sneakers", "boots", "heels", "sandals", "cap", "hat",
36
- "scarf", "gloves", "bag", "accessory", "tank-top", "haircut"
37
- ]
38
 
39
  response_templates = {
40
- "drippy": [
41
- "You're Drippy, bruh – fire {item}!", "{item} goes crazy, on god!",
 
42
  "Certified drippy with that {item}."
43
  ],
44
- "mid": [
45
- "Drop the {item} and you might get a text back.", "It's alright, but I'd upgrade the {item}.",
 
46
  "Mid fit alert. That {item} is holding you back."
47
  ],
48
- "not_drippy": [
49
- "Bro thought that {item} was tuff!", "Oh hell nah! Burn that {item}!",
 
50
  "Crimes against fashion, especially that {item}! Also… maybe get a haircut.",
51
  "Never walk out the house again with that {item}."
52
  ]
53
  }
54
 
55
- CATEGORY_LABEL_MAP = {"drippy": "drippy", "mid": "mid", "not_drippy": "trash"}
56
- all_prompts = [p for cat in style_prompts.values() for p in cat] + clothing_prompts
 
 
 
 
57
 
58
- def get_top_clothing(probs, n=3):
59
- clothing_probs = probs[len(all_prompts) - len(clothing_prompts):]
60
- top_indices = np.argsort(clothing_probs)[-n:]
61
- return [clothing_prompts[i] for i in reversed(top_indices)]
62
 
63
- def analyze_outfit(img):
64
- results = yolo_model(img)
65
- boxes = results[0].boxes.xyxy.cpu().numpy()
66
- classes = results[0].boxes.cls.cpu().numpy()
67
- confidences = results[0].boxes.conf.cpu().numpy()
68
 
 
 
 
 
 
 
 
69
  person_indices = np.where(classes == 0)[0]
70
- cropped = img
 
71
  if len(person_indices) > 0:
72
- idx = np.argmax(confidences[person_indices])
73
- x1, y1, x2, y2 = map(int, boxes[person_indices][idx])
74
- cropped = img.crop((x1, y1, x2, y2))
75
-
76
- # Run fashion model to get top class label
77
- fashion_results = fashion_model(cropped, verbose=False)
78
- top_item_idx = fashion_results[0].probs.top1
79
- top_item_name = fashion_results[0].names[int(top_item_idx)]
80
-
81
- # CLIP classification
82
- image_tensor = clip_preprocess(cropped).unsqueeze(0).to(device)
83
- text_tokens = clip.tokenize([str(p) for p in all_prompts]).to(device)
84
- with torch.no_grad():
85
- logits, _ = clip_model(image_tensor, text_tokens)
86
- probs = logits.softmax(dim=-1).cpu().numpy()[0]
87
-
88
- drip_len = len(style_prompts["drippy"])
89
- mid_len = len(style_prompts["mid"])
90
- not_len = len(style_prompts["not_drippy"])
91
-
92
- drip_score = np.mean(probs[:drip_len])
93
- mid_score = np.mean(probs[drip_len:drip_len + mid_len])
94
- not_score = np.mean(probs[drip_len + mid_len:])
95
-
96
- if drip_score > mid_score and drip_score > not_score:
97
- cat = "drippy"
98
- final_score = drip_score
99
- elif mid_score > not_score:
100
- cat = "mid"
101
- final_score = mid_score
102
  else:
103
- cat = "not_drippy"
104
- final_score = not_score
105
-
106
- label = CATEGORY_LABEL_MAP[cat]
107
- response = random.choice(response_templates[cat]).format(item=top_item_name)
108
- tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
109
- gTTS(response, lang="en").save(tts_path)
110
-
111
- html = f"""
112
- <div style='padding:1rem; text-align:center;'>
113
- <h2>Your fit is <span style='color:#1f04ff'>{label}</span>!</h2>
114
- <p>Drip Score: <strong>{final_score:.2f}</strong></p>
115
- </div>
116
- """
117
- return html, tts_path, response
118
-
119
- # Gradio UI
120
- with gr.Blocks(css=".container { max-width: 600px; margin: auto; padding: 2rem; }") as demo:
121
- with gr.Group(elem_classes=["container"]):
122
- method = gr.Radio(["Upload Image", "Use Webcam"], label="Choose how to submit your fit", value="Upload Image")
123
- upload = gr.Image(type="pil", label="Upload Image", visible=True)
124
- webcam = gr.Image(type="pil", label="Take Photo", visible=False, sources=["webcam"])
125
- analyze = gr.Button("🔥 Analyze My Fit")
126
- html = gr.HTML()
127
- audio = gr.Audio(autoplay=True, label="")
128
- textbox = gr.Textbox(label="Response", interactive=False, lines=2)
129
-
130
- def toggle_inputs(method):
131
- return gr.update(visible=method == "Upload Image"), gr.update(visible=method == "Use Webcam")
132
-
133
- method.change(toggle_inputs, inputs=method, outputs=[upload, webcam])
134
- analyze.click(fn=analyze_outfit, inputs=upload, outputs=[html, audio, textbox])
135
- analyze.click(fn=analyze_outfit, inputs=webcam, outputs=[html, audio, textbox])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
 
137
  if __name__ == "__main__":
138
- demo.launch()
 
8
  from ultralytics import YOLO
9
  from gtts import gTTS
10
  import uuid
11
+ import time
12
  import tempfile
13
 
14
+ # --- Configuration ---
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ YOLO_PERSON_MODEL_PATH = 'yolov8n.pt' # Standard YOLOv8 for person detection
17
+ YOLO_FASHION_MODEL_PATH = 'best.pt' # Your trained fashion model
18
+ CLIP_MODEL_NAME = "ViT-B/32"
19
 
20
+ # --- Load Models ---
21
+ print(f"Using device: {DEVICE}")
22
+ try:
23
+ clip_model, clip_preprocess = clip.load(CLIP_MODEL_NAME, device=DEVICE)
24
+ print(f"CLIP model ({CLIP_MODEL_NAME}) loaded successfully.")
25
+ except Exception as e:
26
+ print(f"Error loading CLIP model: {e}")
27
+ # Handle error appropriately, maybe exit or use a fallback
28
+
29
+ try:
30
+ yolo_person_model = YOLO(YOLO_PERSON_MODEL_PATH).to(DEVICE)
31
+ print(f"YOLO person detection model ({YOLO_PERSON_MODEL_PATH}) loaded successfully.")
32
+ except Exception as e:
33
+ print(f"Error loading YOLO person model: {e}")
34
+ # Handle error
35
+
36
+ try:
37
+ fashion_model = YOLO(YOLO_FASHION_MODEL_PATH).to(DEVICE)
38
+ print(f"YOLO fashion model ({YOLO_FASHION_MODEL_PATH}) loaded successfully.")
39
+ # It's crucial that fashion_model.names is populated correctly after loading.
40
+ # If it's not, you might need to load names from a corresponding .yaml file.
41
+ if not hasattr(fashion_model, 'names') or not fashion_model.names:
42
+ print("Warning: Fashion model names not found. Detection might not work correctly.")
43
+ # Example: Manually assign if needed (replace with your actual class names)
44
+ # fashion_model.names = {0: 't-shirt', 1: 'jeans', 2: 'sneakers', ...}
45
+ except Exception as e:
46
+ print(f"Error loading YOLO fashion model: {e}")
47
+ # Handle error
48
+
49
+ # --- Prompts and Responses ---
50
  style_prompts = {
51
+ 'drippy': [
52
+ "avant-garde streetwear",
53
+ "high-fashion designer outfit",
54
+ "trendsetting urban attire",
55
+ "luxury sneakers and chic accessories",
56
+ "cutting-edge, bold style"
57
  ],
58
+ 'mid': [
59
+ "casual everyday outfit",
60
+ "modern minimalistic attire",
61
+ "comfortable yet stylish look",
62
+ "simple, relaxed streetwear",
63
+ "balanced, practical fashion"
64
  ],
65
+ 'not_drippy': [
66
+ "disheveled outfit",
67
+ "poorly coordinated fashion",
68
+ "unfashionable, outdated attire",
69
+ "tacky, mismatched ensemble",
70
+ "sloppy, uninspired look"
71
  ]
72
  }
73
 
74
+ # Only style prompts are needed for CLIP now
75
+ clip_style_texts = []
76
+ for category in style_prompts:
77
+ clip_style_texts.extend(style_prompts[category])
 
78
 
79
  response_templates = {
80
+ 'drippy': [
81
+ "You're Drippy, bruh – fire {item}!",
82
+ "{item} goes crazy, on god!",
83
  "Certified drippy with that {item}."
84
  ],
85
+ 'mid': [
86
+ "Drop the {item} and you might get a text back.",
87
+ "It's alright, but I'd upgrade the {item}.",
88
  "Mid fit alert. That {item} is holding you back."
89
  ],
90
+ 'not_drippy': [
91
+ "Bro thought that {item} was tuff!",
92
+ "Oh hell nah! Burn that {item}!",
93
  "Crimes against fashion, especially that {item}! Also… maybe get a haircut.",
94
  "Never walk out the house again with that {item}."
95
  ]
96
  }
97
 
98
+ # Map internal category keys to user-facing labels
99
+ CATEGORY_LABEL_MAP = {
100
+ "drippy": "drippy",
101
+ "mid": "mid",
102
+ "not_drippy": "trash"
103
+ }
104
 
105
+ # --- Core Logic ---
106
+ def analyze_outfit(input_img: Image.Image):
107
+ if input_img is None:
108
+ return "Please upload an image.", None, "Error: No image provided."
109
 
110
+ img = input_img.copy() # Work on a copy
 
 
 
 
111
 
112
+ # 1) YOLO Person Detection
113
+ person_results = yolo_person_model(img, verbose=False) # verbose=False suppresses console output
114
+ boxes = person_results[0].boxes.xyxy.cpu().numpy()
115
+ classes = person_results[0].boxes.cls.cpu().numpy()
116
+ confidences = person_results[0].boxes.conf.cpu().numpy()
117
+
118
+ # Find the most confident 'person' detection (class ID 0 for COCO)
119
  person_indices = np.where(classes == 0)[0]
120
+ cropped_img = img # Default to full image if no person found
121
+
122
  if len(person_indices) > 0:
123
+ max_conf_person_idx = person_indices[np.argmax(confidences[person_indices])]
124
+ x1, y1, x2, y2 = map(int, boxes[max_conf_person_idx])
125
+ # Ensure crop coordinates are valid
126
+ x1, y1 = max(0, x1), max(0, y1)
127
+ x2, y2 = min(img.width, x2), min(img.height, y2)
128
+ if x1 < x2 and y1 < y2:
129
+ cropped_img = img.crop((x1, y1, x2, y2))
130
+ else:
131
+ print("Warning: Invalid person bounding box after clipping. Using full image.")
132
+ cropped_img = img
133
+ print(f"Person detected and cropped: Box {x1, y1, x2, y2}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  else:
135
+ print("No person detected by yolo_person_model. Analyzing full image.")
136
+ # Decide if you want to proceed without a person or return an error
137
+ # return "Could not detect a person in the image.", None, "Error: Person not found."
138
+
139
+ # 2) YOLO Fashion Detection (on the cropped image)
140
+ detected_clothing_item = "fit" # Default item if no clothing detected
141
+ try:
142
+ fashion_results = fashion_model(cropped_img, verbose=False)
143
+ if len(fashion_results[0].boxes) > 0:
144
+ fashion_boxes = fashion_results[0].boxes.xyxy.cpu().numpy()
145
+ fashion_classes = fashion_results[0].boxes.cls.cpu().numpy()
146
+ fashion_confidences = fashion_results[0].boxes.conf.cpu().numpy()
147
+ fashion_names = fashion_results[0].names # Dictionary mapping class index to name
148
+
149
+ # Get the most confident clothing detection
150
+ max_conf_fashion_idx = np.argmax(fashion_confidences)
151
+ detected_class_id = int(fashion_classes[max_conf_fashion_idx])
152
+
153
+ if fashion_names and detected_class_id in fashion_names:
154
+ detected_clothing_item = fashion_names[detected_class_id]
155
+ print(f"Most confident clothing item detected: {detected_clothing_item} (Conf: {fashion_confidences[max_conf_fashion_idx]:.2f})")
156
+ else:
157
+ print(f"Warning: Detected clothing class ID {detected_class_id} not found in fashion model names.")
158
+ detected_clothing_item = "clothing item" # Fallback if name mapping fails
159
+ else:
160
+ print("No clothing items detected by fashion_model on the cropped image.")
161
+ detected_clothing_item = "style" # Fallback if nothing specific is found
162
+ except Exception as e:
163
+ print(f"Error during fashion detection: {e}")
164
+ detected_clothing_item = "outfit" # General fallback on error
165
+
166
+ # 3) CLIP Style Analysis (on the cropped image)
167
+ try:
168
+ image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(DEVICE)
169
+ text_tokens = clip.tokenize(clip_style_texts).to(DEVICE)
170
+
171
+ with torch.no_grad():
172
+ logits, _ = clip_model(image_tensor, text_tokens)
173
+ # Probabilities ONLY for the style prompts
174
+ style_probs = logits.softmax(dim=-1).cpu().numpy()[0]
175
+
176
+ # Calculate average scores for each style category
177
+ drip_len = len(style_prompts['drippy'])
178
+ mid_len = len(style_prompts['mid'])
179
+ # not_len = len(style_prompts['not_drippy']) # Length of the last section
180
+
181
+ drip_score = np.mean(style_probs[0 : drip_len])
182
+ mid_score = np.mean(style_probs[drip_len : drip_len + mid_len])
183
+ not_score = np.mean(style_probs[drip_len + mid_len :]) # Rest are 'not_drippy'
184
+
185
+ # Determine the category based on highest average score
186
+ if drip_score > mid_score and drip_score > not_score:
187
+ category_key = 'drippy'
188
+ final_score = drip_score
189
+ elif mid_score > not_score:
190
+ category_key = 'mid'
191
+ final_score = mid_score
192
+ else:
193
+ category_key = 'not_drippy'
194
+ final_score = not_score
195
+
196
+ category_label = CATEGORY_LABEL_MAP[category_key]
197
+ final_score_str = f"{final_score:.2f}" # Format score
198
+ print(f"Style analysis: Category={category_label}, Score={final_score_str}")
199
+
200
+ except Exception as e:
201
+ print(f"Error during CLIP analysis: {e}")
202
+ # Handle CLIP error - maybe return a default message
203
+ return "Error during style analysis.", None, f"Analysis Error: {e}"
204
+
205
+ # 4) Generate Response and TTS
206
+ try:
207
+ # Select a random response template for the determined category
208
+ response_text = random.choice(response_templates[category_key]).format(item=detected_clothing_item)
209
+
210
+ # Generate TTS audio
211
+ tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
212
+ tts = gTTS(text=response_text, lang='en', tld='com', slow=False) # Use tld='com' for a standard voice
213
+ tts.save(tts_path)
214
+ print(f"Generated TTS response: '{response_text}' saved to {tts_path}")
215
+
216
+ # Output HTML for category + numeric score
217
+ category_html = f"""
218
+ <div style='text-align: center; padding: 15px; border: 1px solid #eee; border-radius: 8px;'>
219
+ <h2 style='color: #333; margin-bottom: 5px;'>Your fit is {category_label.upper()}!</h2>
220
+ <p style='font-size: 1.1em; color: #555; margin-top: 0;'>Style Score: {final_score_str}</p>
221
+ </div>
222
+ """
223
+
224
+ return category_html, tts_path, response_text
225
+
226
+ except Exception as e:
227
+ print(f"Error during response/TTS generation: {e}")
228
+ # Fallback if TTS or formatting fails
229
+ category_html = f"<h2>Result: {category_label} (Score: {final_score_str})</h2>"
230
+ return category_html, None, f"Analysis complete ({category_label}), but error generating audio/response."
231
+
232
+
233
+ # --- Gradio Interface ---
234
+ with gr.Blocks(css=".gradio-container { max-width: 800px !important; margin: auto !important; } footer { display: none !important; }") as demo:
235
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 20px;'>💧 DripAI: Rate Your Fit 💧</h1>")
236
+
237
+ with gr.Row():
238
+ with gr.Column(scale=1):
239
+ input_image = gr.Image(
240
+ type='pil',
241
+ label="Upload, Paste, or Use Webcam for your Outfit Photo",
242
+ # Explicitly define sources for better UI clarity
243
+ sources=['upload', 'webcam', 'clipboard'],
244
+ height=400
245
+ )
246
+ analyze_button = gr.Button("Analyze Outfit", variant="primary", size="lg")
247
+
248
+ with gr.Column(scale=1):
249
+ gr.Markdown("### Analysis Result:")
250
+ category_html = gr.HTML(label="Category & Score") # Displays HTML output
251
+ audio_output = gr.Audio(autoplay=True, label="Audio Feedback", streaming=False)
252
+ response_box = gr.Textbox(lines=4, label="Text Feedback", interactive=False) # Make textbox read-only
253
+
254
+ analyze_button.click(
255
+ fn=analyze_outfit,
256
+ inputs=[input_image],
257
+ outputs=[category_html, audio_output, response_box],
258
+ # show_progress="full" # Optional: Show progress bar during processing
259
+ )
260
+
261
+ gr.Markdown("<p style='text-align: center; color: grey; font-size: 0.9em;'>Upload an image of your outfit and click 'Analyze Outfit'. DripAI will rate your style and identify a key clothing item.</p>")
262
 
263
+ # --- Launch App ---
264
  if __name__ == "__main__":
265
+ demo.launch(debug=True) # Enable debug for more detailed logs