dzmu commited on
Commit
bade9fb
·
verified ·
1 Parent(s): 56b01f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -96
app.py CHANGED
@@ -8,73 +8,35 @@ from PIL import Image
8
  from ultralytics import YOLO
9
  from gtts import gTTS
10
  import uuid
11
- import time
12
  import tempfile
13
 
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
16
-
17
  yolo_model = YOLO('yolov8n.pt').to(device)
18
- fashion_model = YOLO('best.pt').to(device) # If needed
19
 
 
20
  style_prompts = {
21
- 'drippy': [
22
- "avant-garde streetwear",
23
- "high-fashion designer outfit",
24
- "trendsetting urban attire",
25
- "luxury sneakers and chic accessories",
26
- "cutting-edge, bold style"
27
- ],
28
- 'mid': [
29
- "casual everyday outfit",
30
- "modern minimalistic attire",
31
- "comfortable yet stylish look",
32
- "simple, relaxed streetwear",
33
- "balanced, practical fashion"
34
- ],
35
- 'not_drippy': [
36
- "disheveled outfit",
37
- "poorly coordinated fashion",
38
- "unfashionable, outdated attire",
39
- "tacky, mismatched ensemble",
40
- "sloppy, uninspired look"
41
- ]
42
  }
43
 
44
- clothing_prompts = [
45
- "t-shirt", "dress shirt", "blouse", "hoodie", "jacket", "sweater", "coat",
46
- "dress", "skirt", "pants", "jeans", "trousers", "shorts",
47
- "sneakers", "boots", "heels", "sandals",
48
- "cap", "hat", "scarf", "gloves", "bag", "accessory", "tank-top", "haircut"
49
- ]
50
 
51
  response_templates = {
52
- 'drippy': [
53
- "You're Drippy, bruh – fire {item}!",
54
- "{item} goes crazy, on god!",
55
- "Certified drippy with that {item}."
56
- ],
57
- 'mid': [
58
- "Drop the {item} and you might get a text back.",
59
- "It's alright, but I'd upgrade the {item}.",
60
- "Mid fit alert. That {item} is holding you back."
61
- ],
62
- 'not_drippy': [
63
- "Bro thought that {item} was tuff!",
64
- "Oh hell nah! Burn that {item}!",
65
- "Crimes against fashion, especially that {item}! Also… maybe get a haircut.",
66
- "Never walk out the house again with that {item}."
67
- ]
68
  }
69
 
70
- # Map "not_drippy" => "trash" in user-facing output
71
  CATEGORY_LABEL_MAP = {
72
  "drippy": "drippy",
73
  "mid": "mid",
74
  "not_drippy": "trash"
75
  }
76
 
77
- # Combine all prompts for CLIP
78
  all_prompts = []
79
  for cat_prompts in style_prompts.values():
80
  all_prompts.extend(cat_prompts)
@@ -86,13 +48,11 @@ def get_top_clothing(probs, n=3):
86
  return [clothing_prompts[i] for i in reversed(top_indices)]
87
 
88
  def analyze_outfit(img: Image.Image):
89
- # 1) YOLO detection
90
  results = yolo_model(img)
91
  boxes = results[0].boxes.xyxy.cpu().numpy()
92
  classes = results[0].boxes.cls.cpu().numpy()
93
  confidences = results[0].boxes.conf.cpu().numpy()
94
 
95
- # Crop if person is found
96
  person_indices = np.where(classes == 0)[0]
97
  cropped_img = img
98
  if len(person_indices) > 0:
@@ -100,73 +60,78 @@ def analyze_outfit(img: Image.Image):
100
  x1, y1, x2, y2 = map(int, boxes[person_indices][max_conf_idx])
101
  cropped_img = img.crop((x1, y1, x2, y2))
102
 
103
- # 2) CLIP analysis
104
  image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(device)
105
  text_tokens = clip.tokenize(all_prompts).to(device)
106
  with torch.no_grad():
107
  logits, _ = clip_model(image_tensor, text_tokens)
108
  probs = logits.softmax(dim=-1).cpu().numpy()[0]
109
 
110
- # Style classification
111
- drip_len = len(style_prompts['drippy'])
112
- mid_len = len(style_prompts['mid'])
113
- not_len = len(style_prompts['not_drippy'])
114
-
115
- drip_score = np.mean(probs[:drip_len])
116
- mid_score = np.mean(probs[drip_len : drip_len + mid_len])
117
- not_score = np.mean(probs[drip_len + mid_len : drip_len + mid_len + not_len])
118
-
119
- if drip_score > mid_score and drip_score > not_score:
120
- category_key = 'drippy'
121
- final_score = drip_score
122
- elif mid_score > not_score:
123
- category_key = 'mid'
124
- final_score = mid_score
125
- else:
126
- category_key = 'not_drippy'
127
- final_score = not_score
128
 
129
- category_label = CATEGORY_LABEL_MAP[category_key]
 
 
 
 
130
 
131
- # Clothing item
132
  clothing_items = get_top_clothing(probs)
133
  clothing_item = clothing_items[0]
134
-
135
- # Random response
136
  response = random.choice(response_templates[category_key]).format(item=clothing_item)
137
 
138
- # TTS MP3
139
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
140
- tts = gTTS(response, lang="en")
141
- tts.save(tts_path)
142
-
143
- # Round the score
144
- final_score_str = f"{final_score:.2f}"
145
 
146
- # Output HTML for category + numeric score
147
  category_html = f"""
148
- <h2>Your fit is {category_label}!</h2>
149
- <p>Drip Score: {final_score_str}</p>
 
 
150
  """
151
 
152
  return category_html, tts_path, response
153
 
154
- ###############################################################################
155
- # Custom Layout with Blocks
156
- ###############################################################################
157
- with gr.Blocks(css=".container {max-width: 800px; margin: 0 auto;}") as demo:
158
- gr.Markdown("## DripAI")
159
- with gr.Group(elem_classes=["container"]):
160
- input_image = gr.Image(
161
- type='pil',
162
- label="Upload your outfit"
163
- )
164
- analyze_button = gr.Button("Analyze Outfit")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- # Output components
167
  category_html = gr.HTML()
168
- audio_output = gr.Audio(autoplay=True, label="Audio Feedback")
169
- response_box = gr.Textbox(lines=3, label="Response")
170
 
171
  analyze_button.click(
172
  fn=analyze_outfit,
 
8
  from ultralytics import YOLO
9
  from gtts import gTTS
10
  import uuid
 
11
  import tempfile
12
 
13
+ # Device and model loading
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
 
16
  yolo_model = YOLO('yolov8n.pt').to(device)
17
+ fashion_model = YOLO('best.pt').to(device)
18
 
19
+ # Style prompts and templates
20
  style_prompts = {
21
+ 'drippy': [...], # truncated for brevity
22
+ 'mid': [...],
23
+ 'not_drippy': [...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  }
25
 
26
+ clothing_prompts = [...]
 
 
 
 
 
27
 
28
  response_templates = {
29
+ 'drippy': [...],
30
+ 'mid': [...],
31
+ 'not_drippy': [...]
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  }
33
 
 
34
  CATEGORY_LABEL_MAP = {
35
  "drippy": "drippy",
36
  "mid": "mid",
37
  "not_drippy": "trash"
38
  }
39
 
 
40
  all_prompts = []
41
  for cat_prompts in style_prompts.values():
42
  all_prompts.extend(cat_prompts)
 
48
  return [clothing_prompts[i] for i in reversed(top_indices)]
49
 
50
  def analyze_outfit(img: Image.Image):
 
51
  results = yolo_model(img)
52
  boxes = results[0].boxes.xyxy.cpu().numpy()
53
  classes = results[0].boxes.cls.cpu().numpy()
54
  confidences = results[0].boxes.conf.cpu().numpy()
55
 
 
56
  person_indices = np.where(classes == 0)[0]
57
  cropped_img = img
58
  if len(person_indices) > 0:
 
60
  x1, y1, x2, y2 = map(int, boxes[person_indices][max_conf_idx])
61
  cropped_img = img.crop((x1, y1, x2, y2))
62
 
 
63
  image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(device)
64
  text_tokens = clip.tokenize(all_prompts).to(device)
65
  with torch.no_grad():
66
  logits, _ = clip_model(image_tensor, text_tokens)
67
  probs = logits.softmax(dim=-1).cpu().numpy()[0]
68
 
69
+ drip_score = np.mean(probs[:len(style_prompts['drippy'])])
70
+ mid_score = np.mean(probs[len(style_prompts['drippy']):len(style_prompts['drippy'])+len(style_prompts['mid'])])
71
+ not_score = np.mean(probs[len(style_prompts['drippy'])+len(style_prompts['mid']):])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ category_key = max(['drippy', 'mid', 'not_drippy'], key=lambda k: np.mean(
74
+ probs[:len(style_prompts[k])] if k == 'drippy' else
75
+ probs[len(style_prompts['drippy']):len(style_prompts['drippy'])+len(style_prompts['mid'])] if k == 'mid' else
76
+ probs[len(style_prompts['drippy'])+len(style_prompts['mid']):]
77
+ ))
78
 
79
+ category_label = CATEGORY_LABEL_MAP[category_key]
80
  clothing_items = get_top_clothing(probs)
81
  clothing_item = clothing_items[0]
 
 
82
  response = random.choice(response_templates[category_key]).format(item=clothing_item)
83
 
 
84
  tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
85
+ gTTS(response, lang="en").save(tts_path)
86
+ final_score_str = f"{max(drip_score, mid_score, not_score):.2f}"
 
 
 
87
 
 
88
  category_html = f"""
89
+ <div style='text-align: center;'>
90
+ <h2 style='color: #1f04ff;'>Your fit is <b>{category_label.upper()}</b></h2>
91
+ <p style='font-size: 18px;'>Drip Score: <strong>{final_score_str}</strong></p>
92
+ </div>
93
  """
94
 
95
  return category_html, tts_path, response
96
 
97
+ # Gradio interface with cleaner styling
98
+ custom_css = """
99
+ .container {
100
+ max-width: 700px;
101
+ margin: 0 auto;
102
+ font-family: 'Arial', sans-serif;
103
+ }
104
+ button {
105
+ background-color: #1f04ff;
106
+ color: white;
107
+ border-radius: 6px;
108
+ padding: 10px 20px;
109
+ font-size: 16px;
110
+ }
111
+ button:hover {
112
+ background-color: #3c2fff;
113
+ }
114
+ .gradio-container {
115
+ background: #f9f9f9;
116
+ border-radius: 10px;
117
+ padding: 20px;
118
+ box-shadow: 0 4px 10px rgba(0,0,0,0.1);
119
+ }
120
+ """
121
+
122
+ with gr.Blocks(css=custom_css) as demo:
123
+ with gr.Column(elem_classes=["container"]):
124
+ gr.Markdown("""
125
+ # 👟 DripAI
126
+ Upload your outfit to get judged by the algorithm.
127
+ No bias. No mercy. Just drip.
128
+ """)
129
+ input_image = gr.Image(type='pil', label="Upload your outfit")
130
+ analyze_button = gr.Button("Analyze My Fit")
131
 
 
132
  category_html = gr.HTML()
133
+ audio_output = gr.Audio(autoplay=True, label="AI Feedback")
134
+ response_box = gr.Textbox(lines=2, label="Generated Response")
135
 
136
  analyze_button.click(
137
  fn=analyze_outfit,