dzmu commited on
Commit
ca2d7f0
·
verified ·
1 Parent(s): bade9fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -71
app.py CHANGED
@@ -10,37 +10,52 @@ from gtts import gTTS
10
  import uuid
11
  import tempfile
12
 
13
- # Device and model loading
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
16
  yolo_model = YOLO('yolov8n.pt').to(device)
17
  fashion_model = YOLO('best.pt').to(device)
18
 
19
- # Style prompts and templates
20
  style_prompts = {
21
- 'drippy': [...], # truncated for brevity
22
- 'mid': [...],
23
- 'not_drippy': [...]
 
 
 
 
 
 
 
 
 
24
  }
25
 
26
- clothing_prompts = [...]
 
 
 
 
27
 
28
  response_templates = {
29
- 'drippy': [...],
30
- 'mid': [...],
31
- 'not_drippy': [...]
 
 
 
 
 
 
 
 
 
 
32
  }
33
 
34
- CATEGORY_LABEL_MAP = {
35
- "drippy": "drippy",
36
- "mid": "mid",
37
- "not_drippy": "trash"
38
- }
39
-
40
- all_prompts = []
41
- for cat_prompts in style_prompts.values():
42
- all_prompts.extend(cat_prompts)
43
- all_prompts.extend(clothing_prompts)
44
 
45
  def get_top_clothing(probs, n=3):
46
  clothing_probs = probs[len(all_prompts) - len(clothing_prompts):]
@@ -61,20 +76,28 @@ def analyze_outfit(img: Image.Image):
61
  cropped_img = img.crop((x1, y1, x2, y2))
62
 
63
  image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(device)
64
- text_tokens = clip.tokenize(all_prompts).to(device)
65
  with torch.no_grad():
66
  logits, _ = clip_model(image_tensor, text_tokens)
67
  probs = logits.softmax(dim=-1).cpu().numpy()[0]
68
 
69
- drip_score = np.mean(probs[:len(style_prompts['drippy'])])
70
- mid_score = np.mean(probs[len(style_prompts['drippy']):len(style_prompts['drippy'])+len(style_prompts['mid'])])
71
- not_score = np.mean(probs[len(style_prompts['drippy'])+len(style_prompts['mid']):])
72
-
73
- category_key = max(['drippy', 'mid', 'not_drippy'], key=lambda k: np.mean(
74
- probs[:len(style_prompts[k])] if k == 'drippy' else
75
- probs[len(style_prompts['drippy']):len(style_prompts['drippy'])+len(style_prompts['mid'])] if k == 'mid' else
76
- probs[len(style_prompts['drippy'])+len(style_prompts['mid']):]
77
- ))
 
 
 
 
 
 
 
 
78
 
79
  category_label = CATEGORY_LABEL_MAP[category_key]
80
  clothing_items = get_top_clothing(probs)
@@ -83,60 +106,38 @@ def analyze_outfit(img: Image.Image):
83
 
84
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
85
  gTTS(response, lang="en").save(tts_path)
86
- final_score_str = f"{max(drip_score, mid_score, not_score):.2f}"
87
 
 
88
  category_html = f"""
89
- <div style='text-align: center;'>
90
- <h2 style='color: #1f04ff;'>Your fit is <b>{category_label.upper()}</b></h2>
91
- <p style='font-size: 18px;'>Drip Score: <strong>{final_score_str}</strong></p>
92
  </div>
93
  """
94
 
95
  return category_html, tts_path, response
96
 
97
- # Gradio interface with cleaner styling
98
- custom_css = """
99
- .container {
100
- max-width: 700px;
101
- margin: 0 auto;
102
- font-family: 'Arial', sans-serif;
103
- }
104
- button {
105
- background-color: #1f04ff;
106
- color: white;
107
- border-radius: 6px;
108
- padding: 10px 20px;
109
- font-size: 16px;
110
- }
111
- button:hover {
112
- background-color: #3c2fff;
113
- }
114
- .gradio-container {
115
- background: #f9f9f9;
116
- border-radius: 10px;
117
- padding: 20px;
118
- box-shadow: 0 4px 10px rgba(0,0,0,0.1);
119
- }
120
- """
121
-
122
- with gr.Blocks(css=custom_css) as demo:
123
- with gr.Column(elem_classes=["container"]):
124
- gr.Markdown("""
125
- # 👟 DripAI
126
- Upload your outfit to get judged by the algorithm.
127
- No bias. No mercy. Just drip.
128
- """)
129
- input_image = gr.Image(type='pil', label="Upload your outfit")
130
- analyze_button = gr.Button("Analyze My Fit")
131
 
132
  category_html = gr.HTML()
133
- audio_output = gr.Audio(autoplay=True, label="AI Feedback")
134
- response_box = gr.Textbox(lines=2, label="Generated Response")
135
 
136
  analyze_button.click(
137
  fn=analyze_outfit,
138
- inputs=[input_image],
139
  outputs=[category_html, audio_output, response_box],
140
  )
141
 
142
- demo.launch()
 
 
10
  import uuid
11
  import tempfile
12
 
13
+ # Setup device and models
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
16
  yolo_model = YOLO('yolov8n.pt').to(device)
17
  fashion_model = YOLO('best.pt').to(device)
18
 
19
+ # Style prompts and categories
20
  style_prompts = {
21
+ 'drippy': [
22
+ "avant-garde streetwear", "high-fashion designer outfit", "trendsetting urban attire",
23
+ "luxury sneakers and chic accessories", "cutting-edge, bold style"
24
+ ],
25
+ 'mid': [
26
+ "casual everyday outfit", "modern minimalistic attire", "comfortable yet stylish look",
27
+ "simple, relaxed streetwear", "balanced, practical fashion"
28
+ ],
29
+ 'not_drippy': [
30
+ "disheveled outfit", "poorly coordinated fashion", "unfashionable, outdated attire",
31
+ "tacky, mismatched ensemble", "sloppy, uninspired look"
32
+ ]
33
  }
34
 
35
+ clothing_prompts = [
36
+ "t-shirt", "dress shirt", "blouse", "hoodie", "jacket", "sweater", "coat", "dress", "skirt",
37
+ "pants", "jeans", "trousers", "shorts", "sneakers", "boots", "heels", "sandals", "cap", "hat",
38
+ "scarf", "gloves", "bag", "accessory", "tank-top", "haircut"
39
+ ]
40
 
41
  response_templates = {
42
+ 'drippy': [
43
+ "You're Drippy, bruh – fire {item}!", "{item} goes crazy, on god!",
44
+ "Certified drippy with that {item}."
45
+ ],
46
+ 'mid': [
47
+ "Drop the {item} and you might get a text back.", "It's alright, but I'd upgrade the {item}.",
48
+ "Mid fit alert. That {item} is holding you back."
49
+ ],
50
+ 'not_drippy': [
51
+ "Bro thought that {item} was tuff!", "Oh hell nah! Burn that {item}!",
52
+ "Crimes against fashion, especially that {item}! Also… maybe get a haircut.",
53
+ "Never walk out the house again with that {item}."
54
+ ]
55
  }
56
 
57
+ CATEGORY_LABEL_MAP = {"drippy": "drippy", "mid": "mid", "not_drippy": "trash"}
58
+ all_prompts = [prompt for cat in style_prompts.values() for prompt in cat] + clothing_prompts
 
 
 
 
 
 
 
 
59
 
60
  def get_top_clothing(probs, n=3):
61
  clothing_probs = probs[len(all_prompts) - len(clothing_prompts):]
 
76
  cropped_img = img.crop((x1, y1, x2, y2))
77
 
78
  image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(device)
79
+ text_tokens = clip.tokenize([str(p) for p in all_prompts]).to(device)
80
  with torch.no_grad():
81
  logits, _ = clip_model(image_tensor, text_tokens)
82
  probs = logits.softmax(dim=-1).cpu().numpy()[0]
83
 
84
+ drip_len = len(style_prompts['drippy'])
85
+ mid_len = len(style_prompts['mid'])
86
+ not_len = len(style_prompts['not_drippy'])
87
+
88
+ drip_score = np.mean(probs[:drip_len])
89
+ mid_score = np.mean(probs[drip_len:drip_len + mid_len])
90
+ not_score = np.mean(probs[drip_len + mid_len:drip_len + mid_len + not_len])
91
+
92
+ if drip_score > mid_score and drip_score > not_score:
93
+ category_key = 'drippy'
94
+ final_score = drip_score
95
+ elif mid_score > not_score:
96
+ category_key = 'mid'
97
+ final_score = mid_score
98
+ else:
99
+ category_key = 'not_drippy'
100
+ final_score = not_score
101
 
102
  category_label = CATEGORY_LABEL_MAP[category_key]
103
  clothing_items = get_top_clothing(probs)
 
106
 
107
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
108
  gTTS(response, lang="en").save(tts_path)
 
109
 
110
+ final_score_str = f"{final_score:.2f}"
111
  category_html = f"""
112
+ <div style='padding:1rem; text-align:center;'>
113
+ <h2 style='margin-bottom:0.5rem;'>Your fit is <span style='color:#1f04ff'>{category_label}</span>!</h2>
114
+ <p style='font-size:1.1rem;'>Drip Score: <strong>{final_score_str}</strong></p>
115
  </div>
116
  """
117
 
118
  return category_html, tts_path, response
119
 
120
+ # Gradio UI layout
121
+ with gr.Blocks(css="""
122
+ .container { max-width: 600px; margin: 0 auto; padding: 2rem; }
123
+ button { background-color: #1f04ff; color: white; border: none; padding: 0.75rem 1.5rem; border-radius: 6px; cursor: pointer; font-weight: bold; }
124
+ button:hover { background-color: #1500cc; }
125
+ #resultbox { border: 1px solid #e3e3e3; border-radius: 10px; padding: 1rem; background: #fafafa; }
126
+ """) as demo:
127
+
128
+ with gr.Group(elem_classes=["container"]):
129
+ input_image = gr.Image(type='pil', sources=['upload', 'webcam', 'clipboard'], label="Upload or Snap Your Fit")
130
+ analyze_button = gr.Button("🔥 Analyze My Fit")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  category_html = gr.HTML()
133
+ audio_output = gr.Audio(autoplay=True, label="")
134
+ response_box = gr.Textbox(label="Response", lines=2, interactive=False)
135
 
136
  analyze_button.click(
137
  fn=analyze_outfit,
138
+ inputs=input_image,
139
  outputs=[category_html, audio_output, response_box],
140
  )
141
 
142
+ if __name__ == '__main__':
143
+ demo.launch()