Aadharsh commited on
Commit
4a0d425
·
verified ·
1 Parent(s): b50aac3

created app.py

Browse files
Files changed (1) hide show
  1. app.py +414 -0
app.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import cv2
7
+ import base64
8
+ from PIL import Image
9
+ from openai import OpenAI
10
+ import json
11
+ import ast
12
+ import io
13
+ import gradio as gr
14
+ import numpy as np
15
+ import gradio as gr
16
+
17
+ api_token = os.getenv("openai_key")
18
+
19
+
20
+ openai_apikey =api_token
21
+ client = OpenAI(api_key=openai_apikey)
22
+
23
+
24
+ labels = {
25
+ 0: "Background", 1: "Hat", 2: "Hair", 3: "Sunglasses", 4: "Upper-clothes",
26
+ 5: "Skirt", 6: "Pants", 7: "Dress", 8: "Belt", 9: "Left-shoe",
27
+ 10: "Right-shoe", 11: "Face", 12: "Left-leg", 13: "Right-leg",
28
+ 14: "Left-arm", 15: "Right-arm", 16: "Bag", 17: "Scarf"
29
+ }
30
+
31
+
32
+
33
+ seg_processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
34
+ seg_model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
35
+
36
+
37
+
38
+ def encode_image(image_path):
39
+ with open(image_path, "rb") as image_file:
40
+ return base64.b64encode(image_file.read()).decode("utf-8")
41
+
42
+
43
+ def encode_pil_image(image):
44
+ buffered = io.BytesIO()
45
+ format = image.format if image.format else "PNG"
46
+ image.save(buffered, format=format)
47
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
48
+
49
+
50
+
51
+
52
+ def get_gpt_response(prompt,image=None,model=None,JSON=True):
53
+
54
+ if model==None:
55
+ model = "gpt-3.5-turbo"
56
+
57
+ content_list = [ { "type": "text", "text": prompt}]
58
+ if image is not None:
59
+ content_list.append( {"type": "image_url","image_url": {"url": f"data:image/jpeg;base64,{image}", }})
60
+
61
+ completion = client.chat.completions.create(
62
+ model=model,
63
+ response_format= {
64
+ "type": "json_object"
65
+ },
66
+ messages=[{ "role": "user","content": content_list }]
67
+ )
68
+
69
+ return completion.choices[0].message.content
70
+
71
+
72
+
73
+ def get_segmentation(img):
74
+
75
+ inputs = seg_processor(images=img, return_tensors="pt")
76
+ outputs = seg_model(**inputs)
77
+ logits = outputs.logits.cpu()
78
+
79
+ upsampled_logits = nn.functional.interpolate(
80
+ logits,
81
+ size=img.size[::-1],
82
+ mode="bilinear",
83
+ align_corners=False,
84
+ )
85
+
86
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
87
+ pred_seg = pred_seg.cpu().numpy()
88
+
89
+ #Classify any "skin" below upper clothes as upper clothes.
90
+ upper_clothes_mask = (pred_seg == 4)
91
+ min_upper_clothes_row = np.argmax(upper_clothes_mask, axis=0)
92
+ face_mask = (pred_seg == 11)
93
+ rows, cols = np.where(face_mask)
94
+ rows_below_upper_clothes = rows > min_upper_clothes_row[cols]
95
+ pred_seg[rows[rows_below_upper_clothes], cols[rows_below_upper_clothes]] = 4
96
+
97
+
98
+ #get active labels
99
+ active_labels = {key: labels[key] for key in np.unique(pred_seg) if key in labels}
100
+ active_labels
101
+ l = ""
102
+ for key,value in active_labels.items():
103
+ l+= f"{value}: {key} \n"
104
+
105
+ return pred_seg, l
106
+
107
+
108
+
109
+ def erase_regions(model_image, pred_seg, parsed_erasure_labels):
110
+
111
+ image_array = np.array(model_image)
112
+
113
+ resized_pred_seg = cv2.resize(pred_seg, (image_array.shape[1], image_array.shape[0]), interpolation=cv2.INTER_NEAREST)
114
+ erasure_mask = np.isin(resized_pred_seg, parsed_erasure_labels)
115
+
116
+ image_array[erasure_mask] = [128, 128, 128]
117
+ output_image = Image.fromarray(image_array)
118
+
119
+ return output_image
120
+
121
+
122
+
123
+ def _get_detect_prompt(labels,garment_desc):
124
+ detect_prompt = f"""
125
+ Analyze the provided garment description and the human model’s segmentation to determine which regions should be blacked out for a virtual try on task.
126
+
127
+ #### Segmentation Labels
128
+ {labels}
129
+
130
+ #### Rules for Blacking Out Regions
131
+ - **Upper-body garments** (shirts, blouses, jackets): Black out Upper-clothes (4). If it has sleeves (short/long), also black out Left-arm (14) and Right-arm (15).
132
+ - **Lower-body garments** (pants, skirts, shorts): Black out Pants (6). If full-length, also black out Left-leg (12) and Right-leg (13).
133
+ - **Dresses/Jumpsuits**: Black out Upper-clothes (4) and Pants (6). If long-sleeved, add Left-arm (14) and Right-arm (15). If full-length, add Left-leg (12) and Right-leg (13).
134
+ - **Shoes**: Always black out Left-shoe (9) and Right-shoe (10).
135
+ - **Additional rules**:
136
+ - Sleeveless garments: Only black out Upper-clothes (4); keep arms visible.
137
+ - Shorts/Mini-skirts: Only black out Pants (6); keep legs visible.
138
+ - Transparency: Ignore; follow standard rules.
139
+ - Overlapping items: Prioritize the visible garment (e.g., if a dress is worn, black out Pants).
140
+ - Garments with sleeves: Remove the corresponding hand.
141
+ - **Never modify**: Background (0), Hair (2), Face (11).
142
+ - If it's a full-body garment remove upper-clothes (4) and left-leg (12) and right-leg (13) and pants (6)
143
+ - Always check to include pants or not in your reasoning
144
+ - Black out Arms even if its short sleeve
145
+
146
+ Follow these rules precisely and return only the required segmentation labels.
147
+ #### Output Format
148
+ STRICTLY GIVE JSON ONLY, with the follownig schema:
149
+ {{
150
+ - **reasoning**: The reasoning for the decision.
151
+ - **remove_decision**: A comma-separated list of binary values (0 or 1) indicating whether the corresponding segmentation labels should be blacked out or not.
152
+ }}
153
+
154
+
155
+
156
+ #### Examples
157
+
158
+ #### **Long-Sleeve Shirt**
159
+ {{'Garment Type': 'shirt', 'Garment Type Category': 'upper body garment', 'Coverage Areas': ['torso', 'arms'], 'Sleeves': 'long', 'Leg Coverage': 'none', 'Special Features': {{'hood': 'no', 'gloves': 'no', 'transparency': 'no', 'cut-outs': 'no'}}, 'Description': 'A long-sleeve shirt that covers the torso and arms.'}}
160
+
161
+ - **Reasoning**: Black out Upper-clothes (4) because it covers the torso. Since it has long sleeves, both arms (14, 15) are also blacked out.
162
+ - **Output**: {{"reasoning": "Black out Upper-clothes (4) for torso. Long sleeves, so black out Left-arm (14) and Right-arm (15).", "remove_decision": [0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,1,0,0]}}
163
+
164
+ #### **Full-Body Jumpsuit with Hood & Gloves**
165
+ {{'Garment Type': 'jumpsuit', 'Garment Type Category': 'full body garment', 'Coverage Areas': ['torso', 'arms', 'legs', 'hands', 'head'], 'Sleeves': 'long', 'Leg Coverage': 'full-length', 'Special Features': {{'hood': 'yes', 'gloves': 'yes', 'transparency': 'no', 'cut-outs': 'no'}}, 'Description': 'A full-body jumpsuit with a hood and gloves.'}}
166
+
167
+ - **Reasoning**: Black out Upper-clothes (4) for the torso, Pants (6) for the leg portion, and full-length leg coverage requires blacking out Left-leg (12) and Right-leg (13). Since the sleeves are long, black out both arms (14, 15).
168
+ - **Output**: {{"reasoning": "Black out Upper-clothes (4) for torso, Pants (6) for legs, Left-leg (12), Right-leg (13) for full-length legs, Left-arm (14) and Right-arm (15) for long sleeves.", "remove_Decision": [0,0,0,0,1,0,1,0,0,0,0,0,1,1,1,1,0,0]}}
169
+
170
+ #### **Short-Sleeve Crop Top**
171
+ {{'Garment Type': 'crop top', 'Garment Type Category': 'upper body garment', 'Coverage Areas': ['torso'], 'Sleeves': 'short', 'Leg Coverage': 'none', 'Special Features': {{'hood': 'no', 'gloves': 'no', 'transparency': 'no', 'cut-outs': 'no'}}, 'Description': 'A pink sequined crop top featuring a sleeveless design that covers the upper torso.'}}
172
+
173
+ - **Reasoning**: Black out Upper-clothes (4) because it covers the upper torso. Since the top has short sleeves, both arms (14, 15) are blacked out.
174
+ - **Output**: {{"reasoning": "Black out Upper-clothes (4) for torso. Keep arms visible since it's short-sleeve.", "remove_decision": [0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,1,0,0]}}
175
+
176
+ #### **Ankle-Length Dress with Long Sleeves**
177
+ {{'Garment Type': 'dress', 'Garment Type Category': 'full body garment', 'Coverage Areas': ['torso', 'arms', 'legs'], 'Sleeves': 'long', 'Leg Coverage': 'full-length', 'Special Features': {{'hood': 'no', 'gloves': 'no', 'transparency': 'no', 'cut-outs': 'no'}}, 'Description': 'An ankle-length dress with long sleeves that covers the torso and legs.'}}
178
+
179
+ - **Reasoning**: Black out Upper-clothes (4) and Pants (6) because the dress covers both the torso and legs. Since it has long sleeves, black out Left-arm (14) and Right-arm (15). Since it's ankle-length, also black out Left-leg (12) and Right-leg (13).
180
+ - **Output**: {{"reasoning": "Black out Upper-clothes (4) and Pants (6) for torso and legs, Left-leg (12), Right-leg (13) for full-length legs, Left-arm (14) and Right-arm (15) for long sleeves.", "remove_Decision": [0,0,0,0,1,0,1,0,0,0,0,0,1,1,1,1,0,0]}}
181
+
182
+ #### **Mini Skirt**
183
+ {{'Garment Type': 'skirt', 'Garment Type Category': 'lower body garment', 'Coverage Areas': ['lower torso'], 'Sleeves': 'none', 'Leg Coverage': 'short', 'Special Features': {{'hood': 'no', 'gloves': 'no', 'transparency': 'no', 'cut-outs': 'no'}}, 'Description': 'A mini skirt that covers the lower torso.'}}
184
+
185
+ - **Reasoning**: Black out Pants (6) because the mini skirt covers the lower torso. Since it's short in length, the legs remain visible and are not blacked out.
186
+ - **Output**: {{"reasoning": "Black out Pants (6) for lower torso. Keep legs visible since it's short length.", "remove_decision": [0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0]}}
187
+
188
+ #### **Sneakers**
189
+ {{'Garment Type': 'shoes', 'Garment Type Category': 'footwear', 'Coverage Areas': ['feet'], 'Sleeves': 'none', 'Leg Coverage': 'none', 'Special Features': {{'hood': 'no', 'gloves': 'no', 'transparency': 'no', 'cut-outs': 'no'}}, 'Description': 'A pair of sneakers for footwear.'}}
190
+
191
+ - **Reasoning**: Always black out Left-shoe (9) and Right-shoe (10) since they are footwear.
192
+ - **Output**: {{"reasoning": "Black out Left-shoe (9) and Right-shoe (10) for footwear.", "remove_decision": [0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0]}}
193
+
194
+ ### Slip Dress
195
+ {{
196
+ "Garment Type": "slip dress", "Garment Type Category": "dress", "Coverage Areas": ["torso", "upper thighs"], "Sleeves": "none", "Leg Coverage": "partial (upper thighs)",
197
+ "Special Features": {{
198
+ "hood": "no",
199
+ "gloves": "no",
200
+ "transparency": "no",
201
+ "cut-outs": "no"
202
+ }},
203
+ "Description": "A sleeveless slip dress with a red bodice and a pink sequined skirt. It features a plunging neckline and a fitted design that extends to the upper thighs."
204
+ }}
205
+ Output: {{ "reasoning": "Black out Upper-clothes (4) and Pants (6) because the dress covers both the torso and upper thighs. Since it is sleeveless, arms (14, 15) remain visible. Legs (12, 13) are not blacked out because the dress is short and does not cover them fully.",
206
+ "remove_decision": [0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0]}}
207
+
208
+
209
+ {garment_desc}
210
+
211
+ - **Reasoning**:
212
+ - **Output**:
213
+ """
214
+
215
+ return detect_prompt
216
+
217
+
218
+
219
+
220
+ def _get_garment_prompt():
221
+
222
+ garment_prompt = """
223
+ Analyze the given garment image and provide a **detailed structured description**. Always give output only in JSON. Focus on:
224
+
225
+ 1. **Garment Type**: Clearly state what type of garment it is. Possible types:
226
+ - "shirt"
227
+ - "dress"
228
+ - "jacket"
229
+ - "pants"
230
+ - "gloves"
231
+ - "sweater"
232
+ - "skirt"
233
+ - "shorts"
234
+ - "vest"
235
+ - "jumpsuit"
236
+ - "coat"
237
+ - "blouse"
238
+ - "t-shirt"
239
+ - "crop top"
240
+
241
+ 2. **Garment Type Category**: Specify the category of the garment. Possible categories:
242
+ - "upper body garment"
243
+ - "lower body garment"
244
+ - "full body garment"
245
+
246
+ 3. **Coverage Areas**: Specify which body parts the garment covers. Possible areas:
247
+ - "torso"
248
+ - "arms"
249
+ - "legs"
250
+ - "hands"
251
+ - "head"
252
+
253
+ 4. **Sleeves & Length**: If the garment has sleeves, specify if they are:
254
+ - "short"
255
+ - "long"
256
+ - "none"
257
+
258
+ 5. **Leg Coverage**: If the garment covers the legs, specify if it's:
259
+ - "full-length"
260
+ - "knee-length"
261
+ - "short"
262
+ - "none"
263
+
264
+ 6. **Special Features**: Mention any additional details such as:
265
+ - **Hood** → If the garment includes a hood, covering the head. (Possible values: "yes", "no")
266
+ - **Gloves** → If the garment has built-in gloves, covering hands. (Possible values: "yes", "no")
267
+ - **Transparency** → If any part of the garment is see-through (e.g., mesh, lace). (Possible values: "yes", "no")
268
+ - **Cut-outs** → If the garment has openings exposing skin (e.g., backless, ripped areas). (Possible values: "yes", "no")
269
+
270
+ 7. **Description**: Provide a short textual description of the garment, summarizing its appearance, coverage, type, length, style, and key features.
271
+
272
+
273
+ ### **Example Outputs:**
274
+
275
+ #### **Long-Sleeve Shirt**
276
+ {
277
+ "Garment Type": "shirt",
278
+ "Garment Type Category": "upper body garment",
279
+ "Coverage Areas": ["torso", "arms"],
280
+ "Sleeves": "long",
281
+ "Leg Coverage": "none",
282
+ "Special Features": {
283
+ "hood": "no",
284
+ "gloves": "no",
285
+ "transparency": "no",
286
+ "cut-outs": "no"
287
+ },
288
+ Description": "A long-sleeve shirt made of cotton, providing full coverage for the torso and arms. It has a classic design with no additional features."
289
+ }
290
+
291
+ ### **Full-Body Jumpsuit with Hood & Gloves**
292
+ {
293
+ "Garment Type": "jumpsuit",
294
+ "Garment Type Category": "full body garment",
295
+ "Coverage Areas": ["torso", "arms", "legs", "hands", "head"],
296
+ "Sleeves": "long",
297
+ "Leg Coverage": "full-length",
298
+ "Special Features": {
299
+ "hood": "yes",
300
+ "gloves": "yes",
301
+ "transparency": "no",
302
+ "cut-outs": "no"
303
+ },
304
+ "Description": "A full-body jumpsuit with a hood and built-in gloves. It provides full coverage for the torso, arms, legs, hands, and head."
305
+
306
+ }
307
+
308
+ ### **Short-Sleeve Crop Top**
309
+ {
310
+ "Garment Type": "crop top",
311
+ "Garment Type Category": "upper body garment",
312
+ "Coverage Areas": ["torso"],
313
+ "Sleeves": "short",
314
+ "Leg Coverage": "none",
315
+ "Special Features": {
316
+ "hood": "no",
317
+ "gloves": "no",
318
+ "transparency": "no",
319
+ "cut-outs": "no"
320
+ },
321
+ "Description": "A casual short-sleeve crop top that covers the upper torso."
322
+ }
323
+
324
+ Output only JSON.
325
+ """
326
+ return garment_prompt
327
+
328
+
329
+ def _get_erasure_prompt(labels,reasoning):
330
+ prompt = f"""
331
+ Here are the labels:
332
+
333
+ {labels}
334
+
335
+ and Here is a reasoning:
336
+ {reasoning}
337
+
338
+ Based on the labels and reasoning I want you to output a list containing elements that are to be reomved. For example output [3,4,5,6].
339
+
340
+ output only JSON that contains a list: {{"erasure_labels":[indices]}}
341
+ """
342
+
343
+ return prompt
344
+
345
+
346
+ def fashion_masking(imgA,imgB, max_retries = 3):
347
+
348
+ tries=0
349
+
350
+ while tries <3:
351
+ try:
352
+
353
+ model_image = Image.fromarray(imgA)
354
+ dress_image = Image.fromarray(imgB)
355
+
356
+ pred_seg, l = get_segmentation(model_image)
357
+
358
+ # garment_image_path = "dress.png"
359
+
360
+ # model_image = Image.open("model_2.png").convert('RGB')
361
+ # dress_image = Image.open("dress.png")
362
+
363
+ base64_garment_image = encode_pil_image(dress_image)
364
+
365
+ #base64_garment_image = encode_image(garment_image_path)
366
+
367
+ garment_prompt = _get_garment_prompt()
368
+
369
+ garment_desc = get_gpt_response(garment_prompt,base64_garment_image, model = "gpt-4o-mini")
370
+
371
+ detect_prompt = _get_detect_prompt(garment_desc,labels)
372
+
373
+ res = get_gpt_response(detect_prompt, model = "gpt-4o-mini")
374
+ reasoning = json.loads(res)['reasoning']
375
+
376
+
377
+ erasure_labels = get_gpt_response(_get_erasure_prompt(labels,reasoning))
378
+
379
+ parsed_erasure_labels = json.loads(erasure_labels)["erasure_labels"]
380
+
381
+ erased_img = erase_regions(model_image, pred_seg, parsed_erasure_labels)
382
+
383
+ return_text = f"""Garment Description: {garment_desc} \n Reasoning : {reasoning} """
384
+
385
+ return erased_img, return_text
386
+
387
+ except Exception as e :
388
+
389
+ tries = tries+1
390
+
391
+ print(e)
392
+
393
+
394
+
395
+
396
+
397
+
398
+ demo = gr.Interface(
399
+ fn=fashion_masking, # Ensure this function returns (image, text)
400
+ inputs=[
401
+ gr.Image(label="Image A (person)", image_mode="RGB", type="numpy"),
402
+ gr.Image(label="Image B (garment)", image_mode="RGB", type="numpy"),
403
+ ],
404
+ outputs=[
405
+ gr.Image(label="Masked Output", image_mode="RGB", type="numpy"),
406
+ gr.Textbox(label="Output Description")
407
+ ],
408
+ flagging_mode='never',
409
+ )
410
+
411
+ demo.launch(share=True)
412
+
413
+
414
+