dzmu commited on
Commit
3ef3914
·
verified ·
1 Parent(s): 57bdc44

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -0
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import clip
4
+ import numpy as np
5
+ import random
6
+ import os
7
+ from PIL import Image
8
+ from ultralytics import YOLO
9
+ from gtts import gTTS
10
+ import uuid
11
+ import time
12
+ import tempfile
13
+
14
+ # ---- Model loading ----
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
17
+ yolo_model = YOLO('yolov8n.pt').to(device)
18
+ fashion_model = YOLO('best.pt').to(device) # Adjust the path to your custom model
19
+
20
+ # ---- Style prompts ----
21
+ style_prompts = {
22
+ 'drippy': [
23
+ "avant-garde streetwear",
24
+ "high-fashion designer outfit",
25
+ "trendsetting urban attire",
26
+ "luxury sneakers and chic accessories",
27
+ "cutting-edge, bold style"
28
+ ],
29
+ 'mid': [
30
+ "casual everyday outfit",
31
+ "modern minimalistic attire",
32
+ "comfortable yet stylish look",
33
+ "simple, relaxed streetwear",
34
+ "balanced, practical fashion"
35
+ ],
36
+ 'not_drippy': [
37
+ "disheveled outfit",
38
+ "poorly coordinated fashion",
39
+ "unfashionable, outdated attire",
40
+ "tacky, mismatched ensemble",
41
+ "sloppy, uninspired look"
42
+ ]
43
+ }
44
+
45
+ # ---- Clothing prompts + responses ----
46
+ clothing_prompts = [
47
+ "t-shirt", "dress shirt", "blouse", "hoodie", "jacket", "sweater", "coat",
48
+ "dress", "skirt", "pants", "jeans", "trousers", "shorts",
49
+ "sneakers", "boots", "heels", "sandals",
50
+ "cap", "hat", "scarf", "gloves", "bag", "accessory", "tank-top", "haircut"
51
+ ]
52
+
53
+ response_templates = {
54
+ 'drippy': [
55
+ "You're Drippy, bruh – fire {item}!",
56
+ "{item} goes crazy, on god!",
57
+ "Certified drippy with that {item}."
58
+ ],
59
+ 'mid': [
60
+ "Drop the {item} and you might get a text back.",
61
+ "It's alright, but I'd upgrade the {item}.",
62
+ "Mid fit alert. That {item} is holding you back."
63
+ ],
64
+ 'not_drippy': [
65
+ "Bro thought that {item} was tuff!",
66
+ "Oh hell nah! Burn that {item}!",
67
+ "Crimes against fashion, especially that {item}! Also… maybe get a haircut.",
68
+ "Never walk out the house again with that {item}."
69
+ ]
70
+ }
71
+
72
+ # Combine all prompts for CLIP processing
73
+ all_prompts = []
74
+ for cat_prompts in style_prompts.values():
75
+ all_prompts.extend(cat_prompts)
76
+ all_prompts.extend(clothing_prompts)
77
+
78
+ def get_top_clothing(probs, n=3):
79
+ """Retrieve top clothing items from CLIP probabilities."""
80
+ # clothing prompts are at the end of all_prompts
81
+ clothing_probs = probs[len(all_prompts) - len(clothing_prompts):]
82
+ top_indices = np.argsort(clothing_probs)[-n:]
83
+ return [clothing_prompts[i] for i in reversed(top_indices)]
84
+
85
+ # ---- The main function to analyze an uploaded image ----
86
+ def analyze_outfit(img: Image.Image):
87
+ # 1) YOLO detection to find the person region:
88
+ results = yolo_model(img)
89
+ result = results[0]
90
+ boxes = result.boxes.xyxy.cpu().numpy()
91
+ classes = result.boxes.cls.cpu().numpy()
92
+ confidences = result.boxes.conf.cpu().numpy()
93
+
94
+ # find person bounding box
95
+ person_indices = np.where(classes == 0)[0]
96
+ cropped_img = img
97
+ if len(person_indices) > 0:
98
+ max_conf_idx = np.argmax(confidences[person_indices])
99
+ x1, y1, x2, y2 = map(int, boxes[person_indices][max_conf_idx])
100
+ cropped_img = img.crop((x1, y1, x2, y2))
101
+
102
+ # 2) CLIP analysis
103
+ image_tensor = clip_preprocess(cropped_img).unsqueeze(0).to(device)
104
+ text_tokens = clip.tokenize(all_prompts).to(device)
105
+ with torch.no_grad():
106
+ logits, _ = clip_model(image_tensor, text_tokens)
107
+ probs = logits.softmax(dim=-1).cpu().numpy()[0]
108
+
109
+ # style classification
110
+ drip_len = len(style_prompts['drippy'])
111
+ mid_len = len(style_prompts['mid'])
112
+ not_len = len(style_prompts['not_drippy'])
113
+
114
+ drip_score = np.mean(probs[:drip_len])
115
+ mid_score = np.mean(probs[drip_len: drip_len + mid_len])
116
+ not_score = np.mean(probs[drip_len + mid_len: drip_len + mid_len + not_len])
117
+
118
+ if drip_score > mid_score and drip_score > not_score:
119
+ category = 'drippy'
120
+ elif mid_score > not_score:
121
+ category = 'mid'
122
+ else:
123
+ category = 'not_drippy'
124
+
125
+ # clothing items
126
+ clothing_items = get_top_clothing(probs)
127
+ clothing_item = clothing_items[0]
128
+
129
+ # response
130
+ response = random.choice(response_templates[category]).format(item=clothing_item)
131
+
132
+ # 3) (Optional) TTS: generate audio with gTTS
133
+ # Some hosting platforms won't play audio automatically;
134
+ # we'll just return an .mp3 link if you want to do that.
135
+ tts_path = os.path.join(tempfile.gettempdir(), f"drip_{uuid.uuid4().hex}.mp3")
136
+ tts = gTTS(response, lang="en")
137
+ tts.save(tts_path)
138
+
139
+ # Return text info. Gradio can handle audio outputs if needed:
140
+ # return response, tts_path
141
+ return response # Keep it simple and just return the text
142
+
143
+ # ---- Build the Gradio interface ----
144
+ demo = gr.Interface(
145
+ fn=analyze_outfit,
146
+ inputs=gr.Image(type='pil'),
147
+ outputs="text",
148
+ title="Drip Detective 3000",
149
+ description="Upload an image of your outfit to see if it's Drippy, Mid, or Not Drippy."
150
+ )
151
+
152
+ # ---- Launch if running locally ----
153
+ if __name__ == "__main__":
154
+ demo.launch(server_name="0.0.0.0", server_port=7860)