ThatITGuy commited on
Commit
612f4c4
·
verified ·
1 Parent(s): b3dac1f

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +198 -0
  2. requirements.txt +6 -0
  3. try_on_diffusion_client.py +130 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import cv2
2
+ # import gradio as gr
3
+ # import numpy as np
4
+ # import os
5
+ # from PIL import Image
6
+ # from try_on_diffusion_client import TryOnDiffusionClient # <‑ the file you just pulled
7
+
8
+ # # --- 1. Initialise the API client -----------------------------------------
9
+ # API_URL = "https://try-on-diffusion.p.rapidapi.com" # RapidAPI endpoint
10
+ # API_KEY = os.getenv("RAPIDAPI_KEY") # put your key in an env var
11
+ # client = TryOnDiffusionClient(base_url=API_URL, api_key=API_KEY)
12
+
13
+ # # --- 2. The Gradio callback ------------------------------------------------
14
+
15
+ # def try_on(user_img, outfit_img, height, chest, waist, sleeve):
16
+ # # 1️⃣ Force 3‑channel RGB and uint8 dtype
17
+ # user_rgb = user_img.convert("RGB")
18
+ # outfit_rgb = outfit_img.convert("RGB")
19
+
20
+ # avatar_bgr = cv2.cvtColor(np.array(user_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
21
+ # clothing_bgr = cv2.cvtColor(np.array(outfit_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
22
+
23
+ # # 2️⃣ Call the API (unchanged)
24
+ # resp = client.try_on_file(
25
+ # clothing_image = clothing_bgr,
26
+ # avatar_image = avatar_bgr,
27
+ # seed = -1
28
+ # )
29
+
30
+ # if resp.status_code != 200 or resp.image is None:
31
+ # return None, f"API error {resp.status_code}: {resp.error_details}"
32
+
33
+ # result_rgb = cv2.cvtColor(resp.image, cv2.COLOR_BGR2RGB)
34
+
35
+ # # 3️⃣ Simple fit message
36
+ # fit_msg = "✅ Fit looks OK"
37
+ # if chest < 80:
38
+ # fit_msg = "⚠️ Chest measurement suggests a tight fit."
39
+
40
+ # return Image.fromarray(result_rgb), fit_msg
41
+
42
+
43
+ # # --- 3. Create the Gradio Interface ------------------------------------------
44
+ # demo = gr.Interface(
45
+ # fn=try_on,
46
+ # inputs=[
47
+ # gr.Image(type="pil", label="Upload your image"),
48
+ # gr.Image(type="pil", label="Upload outfit image"),
49
+ # gr.Number(label="Height (cm)"),
50
+ # gr.Number(label="Chest (cm)"),
51
+ # gr.Number(label="Waist (cm)"),
52
+ # gr.Number(label="Sleeve length (cm)")
53
+ # ],
54
+ # outputs=[
55
+ # gr.Image(type="pil", label="Try-On Result"),
56
+ # gr.Text(label="Fit Advice")
57
+ # ],
58
+ # title="👕 Virtual Try-On",
59
+ # description="Upload your image and an outfit to see how it might look on you!"
60
+ # )
61
+
62
+ # # --- 4. Launch the app -------------------------------------------------------
63
+ # if __name__ == "__main__":
64
+ # print("Launching Virtual Try-On App...")
65
+ # demo.launch()
66
+
67
+
68
+ import cv2
69
+ import gradio as gr
70
+ import numpy as np
71
+ import os
72
+ from PIL import Image
73
+ from try_on_diffusion_client import TryOnDiffusionClient
74
+
75
+ # --- 1. Initialize the API client -----------------------------------------
76
+ API_URL = "https://try-on-diffusion.p.rapidapi.com"
77
+ API_KEY = os.getenv("RAPIDAPI_KEY") # This will be set as a secret in HF Spaces
78
+
79
+ if not API_KEY:
80
+ print("⚠️ RAPIDAPI_KEY not found. Please set it in Hugging Face Spaces secrets.")
81
+
82
+ client = TryOnDiffusionClient(base_url=API_URL, api_key=API_KEY)
83
+
84
+ # --- 2. The Gradio callback ------------------------------------------------
85
+ def try_on(user_img, outfit_img, height, chest, waist, sleeve):
86
+ """
87
+ Process the virtual try-on request
88
+ """
89
+ # Check if API key is available
90
+ if not API_KEY:
91
+ return None, "❌ API key not configured. Please contact the app administrator."
92
+
93
+ # Validate inputs
94
+ if user_img is None or outfit_img is None:
95
+ return None, "❌ Please upload both user image and outfit image."
96
+
97
+ try:
98
+ # Convert images to RGB and ensure uint8 dtype
99
+ user_rgb = user_img.convert("RGB")
100
+ outfit_rgb = outfit_img.convert("RGB")
101
+
102
+ avatar_bgr = cv2.cvtColor(np.array(user_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
103
+ clothing_bgr = cv2.cvtColor(np.array(outfit_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
104
+
105
+ # Call the API
106
+ resp = client.try_on_file(
107
+ clothing_image=clothing_bgr,
108
+ avatar_image=avatar_bgr,
109
+ seed=-1
110
+ )
111
+
112
+ if resp.status_code != 200 or resp.image is None:
113
+ error_msg = resp.error_details if resp.error_details else f"API error {resp.status_code}"
114
+ return None, f"❌ {error_msg}"
115
+
116
+ # Convert result back to RGB
117
+ result_rgb = cv2.cvtColor(resp.image, cv2.COLOR_BGR2RGB)
118
+
119
+ # Generate fit advice
120
+ fit_msg = generate_fit_advice(height, chest, waist, sleeve)
121
+
122
+ return Image.fromarray(result_rgb), fit_msg
123
+
124
+ except Exception as e:
125
+ return None, f"❌ Error processing request: {str(e)}"
126
+
127
+
128
+ def generate_fit_advice(height, chest, waist, sleeve):
129
+ """
130
+ Generate basic fit advice based on measurements
131
+ """
132
+ advice = ["✅ Fit Analysis:"]
133
+
134
+ if chest and chest < 80:
135
+ advice.append("⚠️ Chest measurement suggests the item might be tight.")
136
+ elif chest and chest > 120:
137
+ advice.append("ℹ️ Chest measurement suggests a loose fit.")
138
+ else:
139
+ advice.append("✅ Chest measurement looks good.")
140
+
141
+ if waist and chest and abs(chest - waist) < 10:
142
+ advice.append("ℹ️ Similar chest and waist measurements - consider fitted styles.")
143
+
144
+ if height and height < 160:
145
+ advice.append("ℹ️ For shorter stature, consider checking garment length.")
146
+ elif height and height > 185:
147
+ advice.append("ℹ️ For taller stature, ensure adequate garment length.")
148
+
149
+ return "\n".join(advice)
150
+
151
+
152
+ # --- 3. Create the Gradio Interface ------------------------------------------
153
+ with gr.Blocks(theme=gr.themes.Soft(), title="Virtual Try-On") as demo:
154
+ gr.Markdown(
155
+ """
156
+ # 👕 Virtual Try-On App
157
+
158
+ Upload your photo and a clothing item to see how it would look on you!
159
+
160
+ **Tips for best results:**
161
+ - Use clear, well-lit photos
162
+ - Person should be facing forward
163
+ - Clothing item should be clearly visible
164
+ """
165
+ )
166
+
167
+ with gr.Row():
168
+ with gr.Column():
169
+ user_img = gr.Image(type="pil", label="Your Photo", height=300)
170
+ outfit_img = gr.Image(type="pil", label="Clothing Item", height=300)
171
+
172
+ with gr.Column():
173
+ gr.Markdown("### 📏 Your Measurements (Optional)")
174
+ height = gr.Number(label="Height (cm)", value=170, minimum=140, maximum=220)
175
+ chest = gr.Number(label="Chest (cm)", value=90, minimum=70, maximum=140)
176
+ waist = gr.Number(label="Waist (cm)", value=80, minimum=60, maximum=130)
177
+ sleeve = gr.Number(label="Sleeve Length (cm)", value=60, minimum=50, maximum=80)
178
+
179
+ submit_btn = gr.Button("✨ Try On Outfit", variant="primary", size="lg")
180
+
181
+ with gr.Row():
182
+ result_img = gr.Image(type="pil", label="Try-On Result", height=400)
183
+ fit_advice = gr.Textbox(label="Fit Advice", lines=6, max_lines=10)
184
+
185
+ # Examples section
186
+ gr.Markdown("### 📸 Example Images")
187
+ gr.Markdown("*Upload your own images or try with sample photos*")
188
+
189
+ submit_btn.click(
190
+ fn=try_on,
191
+ inputs=[user_img, outfit_img, height, chest, waist, sleeve],
192
+ outputs=[result_img, fit_advice]
193
+ )
194
+
195
+ # --- 4. Launch the app -------------------------------------------------------
196
+ if __name__ == "__main__":
197
+ print("🚀 Launching Virtual Try-On App...")
198
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ requests
3
+ pillow
4
+ numpy
5
+ opencv-python-headless
6
+ try-on-diffusion-client
try_on_diffusion_client.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import requests
4
+ from requests_toolbelt.multipart.encoder import MultipartEncoder
5
+ from urllib.parse import urlparse
6
+ import logging
7
+ import json
8
+ from io import BytesIO
9
+ from dataclasses import dataclass
10
+
11
+
12
+ @dataclass
13
+ class TryOnDiffusionAPIResponse:
14
+ status_code: int
15
+ image: np.ndarray = None
16
+ response_data: bytes = None
17
+ error_details: str = None
18
+ seed: int = None
19
+
20
+
21
+ class TryOnDiffusionClient:
22
+ def __init__(self, base_url: str = "http://localhost:8000/", api_key: str = ""):
23
+ self._logger = logging.getLogger("try_on_diffusion_client")
24
+ self._base_url = base_url
25
+ self._api_key = api_key
26
+
27
+ if self._base_url[-1] == "/":
28
+ self._base_url = self._base_url[:-1]
29
+
30
+ parsed_url = urlparse(self._base_url)
31
+
32
+ self._rapidapi_host = parsed_url.netloc if parsed_url.netloc.endswith(".rapidapi.com") else None
33
+
34
+ if self._rapidapi_host is not None:
35
+ self._logger.info(f"Using RapidAPI proxy: {self._rapidapi_host}")
36
+
37
+ @staticmethod
38
+ def _image_to_upload_file(image: np.ndarray) -> tuple:
39
+ _, jpeg_data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 99])
40
+ jpeg_data = jpeg_data.tobytes()
41
+
42
+ fp = BytesIO(jpeg_data)
43
+
44
+ return "image.jpg", fp, "image/jpeg"
45
+
46
+ def try_on_file(
47
+ self,
48
+ clothing_image: np.ndarray = None,
49
+ clothing_prompt: str = None,
50
+ avatar_image: np.ndarray = None,
51
+ avatar_prompt: str = None,
52
+ avatar_sex: str = None,
53
+ background_image: np.ndarray = None,
54
+ background_prompt: str = None,
55
+ seed: int = -1,
56
+ raw_response: bool = False,
57
+ ) -> TryOnDiffusionAPIResponse:
58
+ url = self._base_url + "/try-on-file"
59
+
60
+ request_data = {"seed": str(seed)}
61
+
62
+ if clothing_image is not None:
63
+ request_data["clothing_image"] = self._image_to_upload_file(clothing_image)
64
+
65
+ if clothing_prompt is not None:
66
+ request_data["clothing_prompt"] = clothing_prompt
67
+
68
+ if avatar_image is not None:
69
+ request_data["avatar_image"] = self._image_to_upload_file(avatar_image)
70
+
71
+ if avatar_prompt is not None:
72
+ request_data["avatar_prompt"] = avatar_prompt
73
+
74
+ if avatar_sex is not None:
75
+ request_data["avatar_sex"] = avatar_sex
76
+
77
+ if background_image is not None:
78
+ request_data["background_image"] = self._image_to_upload_file(background_image)
79
+
80
+ if background_prompt is not None:
81
+ request_data["background_prompt"] = background_prompt
82
+
83
+ multipart_data = MultipartEncoder(fields=request_data)
84
+
85
+ headers = {"Content-Type": multipart_data.content_type}
86
+
87
+ if self._rapidapi_host is not None:
88
+ headers["X-RapidAPI-Key"] = self._api_key
89
+ headers["X-RapidAPI-Host"] = self._rapidapi_host
90
+ else:
91
+ headers["X-API-Key"] = self._api_key
92
+
93
+ try:
94
+ response = requests.post(
95
+ url,
96
+ data=multipart_data,
97
+ headers=headers,
98
+ )
99
+ except Exception as e:
100
+ self._logger.error(e, exc_info=True)
101
+ return TryOnDiffusionAPIResponse(status_code=0)
102
+
103
+ if response.status_code != 200:
104
+ self._logger.warning(f"Request failed, status code: {response.status_code}, response: {response.content}")
105
+
106
+ result = TryOnDiffusionAPIResponse(status_code=response.status_code)
107
+
108
+ if not raw_response and response.status_code == 200:
109
+ try:
110
+ result.image = cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR)
111
+ except:
112
+ result.image = None
113
+ else:
114
+ result.response_data = response.content
115
+
116
+ if result.status_code == 200:
117
+ if "X-Seed" in response.headers:
118
+ result.seed = int(response.headers["X-Seed"])
119
+ else:
120
+ try:
121
+ response_json = (
122
+ json.loads(result.response_data.decode("utf-8")) if result.response_data is not None else None
123
+ )
124
+
125
+ if response_json is not None and "detail" in response_json:
126
+ result.error_details = response_json["detail"]
127
+ except:
128
+ result.error_details = None
129
+
130
+ return result