dzmu commited on
Commit
57c11ed
·
verified ·
1 Parent(s): 5b08891

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -68
app.py CHANGED
@@ -10,123 +10,129 @@ from gtts import gTTS
10
  import uuid
11
  import tempfile
12
 
13
- # Setup device and models
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
16
- yolo_model = YOLO('yolov8n.pt').to(device)
17
- fashion_model = YOLO('best.pt').to(device)
18
 
19
  style_prompts = {
20
- 'drippy': ["avant-garde streetwear", "high-fashion designer outfit", "trendsetting urban attire", "luxury sneakers and chic accessories", "cutting-edge, bold style"],
21
- 'mid': ["casual everyday outfit", "modern minimalistic attire", "comfortable yet stylish look", "simple, relaxed streetwear", "balanced, practical fashion"],
22
- 'not_drippy': ["disheveled outfit", "poorly coordinated fashion", "unfashionable, outdated attire", "tacky, mismatched ensemble", "sloppy, uninspired look"]
 
 
 
 
 
 
 
 
 
23
  }
24
 
25
  clothing_prompts = [
26
- "t-shirt", "dress shirt", "blouse", "hoodie", "jacket", "sweater", "coat", "dress", "skirt", "pants", "jeans", "trousers", "shorts", "sneakers", "boots", "heels", "sandals", "cap", "hat", "scarf", "gloves", "bag", "accessory", "tank-top", "haircut"
 
 
27
  ]
28
 
29
  response_templates = {
30
- 'drippy': ["You're Drippy, bruh – fire {item}!", "{item} goes crazy, on god!", "Certified drippy with that {item}."],
31
- 'mid': ["Drop the {item} and you might get a text back.", "It's alright, but I'd upgrade the {item}.", "Mid fit alert. That {item} is holding you back."],
32
- 'not_drippy': ["Bro thought that {item} was tuff!", "Oh hell nah! Burn that {item}!", "Crimes against fashion, especially that {item}! Also… maybe get a haircut.", "Never walk out the house again with that {item}."]
 
 
 
 
 
 
 
 
 
 
33
  }
34
 
35
  CATEGORY_LABEL_MAP = {"drippy": "drippy", "mid": "mid", "not_drippy": "trash"}
36
- all_prompts = [prompt for cat in style_prompts.values() for prompt in cat] + clothing_prompts
37
 
38
  def get_top_clothing(probs, n=3):
39
  clothing_probs = probs[len(all_prompts) - len(clothing_prompts):]
40
  top_indices = np.argsort(clothing_probs)[-n:]
41
  return [clothing_prompts[i] for i in reversed(top_indices)]
42
 
43
- def analyze_outfit(img: Image.Image):
44
  results = yolo_model(img)
45
  boxes = results[0].boxes.xyxy.cpu().numpy()
46
  classes = results[0].boxes.cls.cpu().numpy()
47
  confidences = results[0].boxes.conf.cpu().numpy()
48
 
49
  person_indices = np.where(classes == 0)[0]
50
- cropped_img = img
51
  if len(person_indices) > 0:
52
- max_conf_idx = np.argmax(confidences[person_indices])
53
- x1, y1, x2, y2 = map(int, boxes[person_indices][max_conf_idx])
54
- cropped_img = img.crop((x1, y1, x2, y2))
55
 
56
- image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(device)
 
 
 
 
 
 
57
  text_tokens = clip.tokenize([str(p) for p in all_prompts]).to(device)
58
  with torch.no_grad():
59
  logits, _ = clip_model(image_tensor, text_tokens)
60
  probs = logits.softmax(dim=-1).cpu().numpy()[0]
61
 
62
- drip_len = len(style_prompts['drippy'])
63
- mid_len = len(style_prompts['mid'])
64
- not_len = len(style_prompts['not_drippy'])
65
 
66
  drip_score = np.mean(probs[:drip_len])
67
  mid_score = np.mean(probs[drip_len:drip_len + mid_len])
68
- not_score = np.mean(probs[drip_len + mid_len:drip_len + mid_len + not_len])
69
 
70
  if drip_score > mid_score and drip_score > not_score:
71
- category_key = 'drippy'
72
  final_score = drip_score
73
  elif mid_score > not_score:
74
- category_key = 'mid'
75
  final_score = mid_score
76
  else:
77
- category_key = 'not_drippy'
78
  final_score = not_score
79
 
80
- category_label = CATEGORY_LABEL_MAP[category_key]
81
- clothing_items = get_top_clothing(probs)
82
- clothing_item = clothing_items[0]
83
- response = random.choice(response_templates[category_key]).format(item=clothing_item)
84
-
85
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
86
  gTTS(response, lang="en").save(tts_path)
87
 
88
- final_score_str = f"{final_score:.2f}"
89
- category_html = f"""
90
  <div style='padding:1rem; text-align:center;'>
91
- <h2 style='margin-bottom:0.5rem;'>Your fit is <span style='color:#1f04ff'>{category_label}</span>!</h2>
92
- <p style='font-size:1.1rem;'>Drip Score: <strong>{final_score_str}</strong></p>
93
  </div>
94
  """
 
95
 
96
- return category_html, tts_path, response
97
-
98
- with gr.Blocks(css="""
99
- .container { max-width: 600px; margin: 0 auto; padding: 2rem; }
100
- button { background-color: #1f04ff; color: white; border: none; padding: 0.75rem 1.5rem; border-radius: 6px; cursor: pointer; font-weight: bold; }
101
- button:hover { background-color: #1500cc; }
102
- #resultbox { border: 1px solid #e3e3e3; border-radius: 10px; padding: 1rem; background: #fafafa; }
103
- """) as demo:
104
-
105
  with gr.Group(elem_classes=["container"]):
106
- gr.Markdown("### Choose an input method:")
107
-
108
- with gr.Row():
109
- webcam_input = gr.Image(label="📸 Take a Photo", image_mode="RGB", type="pil", show_label=False, show_download_button=False, sources=["webcam"])
110
- upload_input = gr.Image(label="🖼️ Upload Photo", image_mode="RGB", type="pil", show_label=False, show_download_button=False, sources=["upload"])
111
-
112
- analyze_button = gr.Button("🔥 Analyze My Fit")
113
-
114
- # Placeholder outputs
115
- category_html = gr.HTML()
116
- audio_output = gr.Audio(autoplay=True, label="")
117
- response_box = gr.Textbox(label="Response", lines=2, interactive=False)
118
-
119
- # Merge the two image inputs into one workflow
120
- def merged_input(image1, image2):
121
- return image1 if image1 is not None else image2
122
-
123
- merged = gr.State()
124
-
125
- webcam_input.change(fn=merged_input, inputs=[webcam_input, upload_input], outputs=merged)
126
- upload_input.change(fn=merged_input, inputs=[webcam_input, upload_input], outputs=merged)
127
-
128
- analyze_button.click(fn=analyze_outfit, inputs=[merged], outputs=[category_html, audio_output, response_box])
129
-
130
- if __name__ == '__main__':
131
  demo.launch()
132
-
 
10
  import uuid
11
  import tempfile
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
15
+ yolo_model = YOLO("yolov8n.pt").to(device)
16
+ fashion_model = YOLO("best.pt").to(device) # Your trained model
17
 
18
  style_prompts = {
19
+ "drippy": [
20
+ "avant-garde streetwear", "high-fashion designer outfit", "trendsetting urban attire",
21
+ "luxury sneakers and chic accessories", "cutting-edge, bold style"
22
+ ],
23
+ "mid": [
24
+ "casual everyday outfit", "modern minimalistic attire", "comfortable yet stylish look",
25
+ "simple, relaxed streetwear", "balanced, practical fashion"
26
+ ],
27
+ "not_drippy": [
28
+ "disheveled outfit", "poorly coordinated fashion", "unfashionable, outdated attire",
29
+ "tacky, mismatched ensemble", "sloppy, uninspired look"
30
+ ]
31
  }
32
 
33
  clothing_prompts = [
34
+ "t-shirt", "dress shirt", "blouse", "hoodie", "jacket", "sweater", "coat", "dress", "skirt",
35
+ "pants", "jeans", "trousers", "shorts", "sneakers", "boots", "heels", "sandals", "cap", "hat",
36
+ "scarf", "gloves", "bag", "accessory", "tank-top", "haircut"
37
  ]
38
 
39
  response_templates = {
40
+ "drippy": [
41
+ "You're Drippy, bruh fire {item}!", "{item} goes crazy, on god!",
42
+ "Certified drippy with that {item}."
43
+ ],
44
+ "mid": [
45
+ "Drop the {item} and you might get a text back.", "It's alright, but I'd upgrade the {item}.",
46
+ "Mid fit alert. That {item} is holding you back."
47
+ ],
48
+ "not_drippy": [
49
+ "Bro thought that {item} was tuff!", "Oh hell nah! Burn that {item}!",
50
+ "Crimes against fashion, especially that {item}! Also… maybe get a haircut.",
51
+ "Never walk out the house again with that {item}."
52
+ ]
53
  }
54
 
55
  CATEGORY_LABEL_MAP = {"drippy": "drippy", "mid": "mid", "not_drippy": "trash"}
56
+ all_prompts = [p for cat in style_prompts.values() for p in cat] + clothing_prompts
57
 
58
  def get_top_clothing(probs, n=3):
59
  clothing_probs = probs[len(all_prompts) - len(clothing_prompts):]
60
  top_indices = np.argsort(clothing_probs)[-n:]
61
  return [clothing_prompts[i] for i in reversed(top_indices)]
62
 
63
+ def analyze_outfit(img):
64
  results = yolo_model(img)
65
  boxes = results[0].boxes.xyxy.cpu().numpy()
66
  classes = results[0].boxes.cls.cpu().numpy()
67
  confidences = results[0].boxes.conf.cpu().numpy()
68
 
69
  person_indices = np.where(classes == 0)[0]
70
+ cropped = img
71
  if len(person_indices) > 0:
72
+ idx = np.argmax(confidences[person_indices])
73
+ x1, y1, x2, y2 = map(int, boxes[person_indices][idx])
74
+ cropped = img.crop((x1, y1, x2, y2))
75
 
76
+ # Run fashion model to get top class label
77
+ fashion_results = fashion_model(cropped, verbose=False)
78
+ top_item_idx = fashion_results[0].probs.top1
79
+ top_item_name = fashion_results[0].names[int(top_item_idx)]
80
+
81
+ # CLIP classification
82
+ image_tensor = clip_preprocess(cropped).unsqueeze(0).to(device)
83
  text_tokens = clip.tokenize([str(p) for p in all_prompts]).to(device)
84
  with torch.no_grad():
85
  logits, _ = clip_model(image_tensor, text_tokens)
86
  probs = logits.softmax(dim=-1).cpu().numpy()[0]
87
 
88
+ drip_len = len(style_prompts["drippy"])
89
+ mid_len = len(style_prompts["mid"])
90
+ not_len = len(style_prompts["not_drippy"])
91
 
92
  drip_score = np.mean(probs[:drip_len])
93
  mid_score = np.mean(probs[drip_len:drip_len + mid_len])
94
+ not_score = np.mean(probs[drip_len + mid_len:])
95
 
96
  if drip_score > mid_score and drip_score > not_score:
97
+ cat = "drippy"
98
  final_score = drip_score
99
  elif mid_score > not_score:
100
+ cat = "mid"
101
  final_score = mid_score
102
  else:
103
+ cat = "not_drippy"
104
  final_score = not_score
105
 
106
+ label = CATEGORY_LABEL_MAP[cat]
107
+ response = random.choice(response_templates[cat]).format(item=top_item_name)
 
 
 
108
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
109
  gTTS(response, lang="en").save(tts_path)
110
 
111
+ html = f"""
 
112
  <div style='padding:1rem; text-align:center;'>
113
+ <h2>Your fit is <span style='color:#1f04ff'>{label}</span>!</h2>
114
+ <p>Drip Score: <strong>{final_score:.2f}</strong></p>
115
  </div>
116
  """
117
+ return html, tts_path, response
118
 
119
+ # Gradio UI
120
+ with gr.Blocks(css=".container { max-width: 600px; margin: auto; padding: 2rem; }") as demo:
 
 
 
 
 
 
 
121
  with gr.Group(elem_classes=["container"]):
122
+ method = gr.Radio(["Upload Image", "Use Webcam"], label="Choose how to submit your fit", value="Upload Image")
123
+ upload = gr.Image(type="pil", label="Upload Image", visible=True)
124
+ webcam = gr.Image(type="pil", label="Take Photo", visible=False, sources=["webcam"])
125
+ analyze = gr.Button("🔥 Analyze My Fit")
126
+ html = gr.HTML()
127
+ audio = gr.Audio(autoplay=True, label="")
128
+ textbox = gr.Textbox(label="Response", interactive=False, lines=2)
129
+
130
+ def toggle_inputs(method):
131
+ return gr.update(visible=method == "Upload Image"), gr.update(visible=method == "Use Webcam")
132
+
133
+ method.change(toggle_inputs, inputs=method, outputs=[upload, webcam])
134
+ analyze.click(fn=analyze_outfit, inputs=upload, outputs=[html, audio, textbox])
135
+ analyze.click(fn=analyze_outfit, inputs=webcam, outputs=[html, audio, textbox])
136
+
137
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
138
  demo.launch()