Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files- app.py +198 -0
- requirements.txt +6 -0
- try_on_diffusion_client.py +130 -0
app.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# import cv2
|
| 2 |
+
# import gradio as gr
|
| 3 |
+
# import numpy as np
|
| 4 |
+
# import os
|
| 5 |
+
# from PIL import Image
|
| 6 |
+
# from try_on_diffusion_client import TryOnDiffusionClient # <‑ the file you just pulled
|
| 7 |
+
|
| 8 |
+
# # --- 1. Initialise the API client -----------------------------------------
|
| 9 |
+
# API_URL = "https://try-on-diffusion.p.rapidapi.com" # RapidAPI endpoint
|
| 10 |
+
# API_KEY = os.getenv("RAPIDAPI_KEY") # put your key in an env var
|
| 11 |
+
# client = TryOnDiffusionClient(base_url=API_URL, api_key=API_KEY)
|
| 12 |
+
|
| 13 |
+
# # --- 2. The Gradio callback ------------------------------------------------
|
| 14 |
+
|
| 15 |
+
# def try_on(user_img, outfit_img, height, chest, waist, sleeve):
|
| 16 |
+
# # 1️⃣ Force 3‑channel RGB and uint8 dtype
|
| 17 |
+
# user_rgb = user_img.convert("RGB")
|
| 18 |
+
# outfit_rgb = outfit_img.convert("RGB")
|
| 19 |
+
|
| 20 |
+
# avatar_bgr = cv2.cvtColor(np.array(user_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
| 21 |
+
# clothing_bgr = cv2.cvtColor(np.array(outfit_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
| 22 |
+
|
| 23 |
+
# # 2️⃣ Call the API (unchanged)
|
| 24 |
+
# resp = client.try_on_file(
|
| 25 |
+
# clothing_image = clothing_bgr,
|
| 26 |
+
# avatar_image = avatar_bgr,
|
| 27 |
+
# seed = -1
|
| 28 |
+
# )
|
| 29 |
+
|
| 30 |
+
# if resp.status_code != 200 or resp.image is None:
|
| 31 |
+
# return None, f"API error {resp.status_code}: {resp.error_details}"
|
| 32 |
+
|
| 33 |
+
# result_rgb = cv2.cvtColor(resp.image, cv2.COLOR_BGR2RGB)
|
| 34 |
+
|
| 35 |
+
# # 3️⃣ Simple fit message
|
| 36 |
+
# fit_msg = "✅ Fit looks OK"
|
| 37 |
+
# if chest < 80:
|
| 38 |
+
# fit_msg = "⚠️ Chest measurement suggests a tight fit."
|
| 39 |
+
|
| 40 |
+
# return Image.fromarray(result_rgb), fit_msg
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# # --- 3. Create the Gradio Interface ------------------------------------------
|
| 44 |
+
# demo = gr.Interface(
|
| 45 |
+
# fn=try_on,
|
| 46 |
+
# inputs=[
|
| 47 |
+
# gr.Image(type="pil", label="Upload your image"),
|
| 48 |
+
# gr.Image(type="pil", label="Upload outfit image"),
|
| 49 |
+
# gr.Number(label="Height (cm)"),
|
| 50 |
+
# gr.Number(label="Chest (cm)"),
|
| 51 |
+
# gr.Number(label="Waist (cm)"),
|
| 52 |
+
# gr.Number(label="Sleeve length (cm)")
|
| 53 |
+
# ],
|
| 54 |
+
# outputs=[
|
| 55 |
+
# gr.Image(type="pil", label="Try-On Result"),
|
| 56 |
+
# gr.Text(label="Fit Advice")
|
| 57 |
+
# ],
|
| 58 |
+
# title="👕 Virtual Try-On",
|
| 59 |
+
# description="Upload your image and an outfit to see how it might look on you!"
|
| 60 |
+
# )
|
| 61 |
+
|
| 62 |
+
# # --- 4. Launch the app -------------------------------------------------------
|
| 63 |
+
# if __name__ == "__main__":
|
| 64 |
+
# print("Launching Virtual Try-On App...")
|
| 65 |
+
# demo.launch()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
import cv2
|
| 69 |
+
import gradio as gr
|
| 70 |
+
import numpy as np
|
| 71 |
+
import os
|
| 72 |
+
from PIL import Image
|
| 73 |
+
from try_on_diffusion_client import TryOnDiffusionClient
|
| 74 |
+
|
| 75 |
+
# --- 1. Initialize the API client -----------------------------------------
|
| 76 |
+
API_URL = "https://try-on-diffusion.p.rapidapi.com"
|
| 77 |
+
API_KEY = os.getenv("RAPIDAPI_KEY") # This will be set as a secret in HF Spaces
|
| 78 |
+
|
| 79 |
+
if not API_KEY:
|
| 80 |
+
print("⚠️ RAPIDAPI_KEY not found. Please set it in Hugging Face Spaces secrets.")
|
| 81 |
+
|
| 82 |
+
client = TryOnDiffusionClient(base_url=API_URL, api_key=API_KEY)
|
| 83 |
+
|
| 84 |
+
# --- 2. The Gradio callback ------------------------------------------------
|
| 85 |
+
def try_on(user_img, outfit_img, height, chest, waist, sleeve):
|
| 86 |
+
"""
|
| 87 |
+
Process the virtual try-on request
|
| 88 |
+
"""
|
| 89 |
+
# Check if API key is available
|
| 90 |
+
if not API_KEY:
|
| 91 |
+
return None, "❌ API key not configured. Please contact the app administrator."
|
| 92 |
+
|
| 93 |
+
# Validate inputs
|
| 94 |
+
if user_img is None or outfit_img is None:
|
| 95 |
+
return None, "❌ Please upload both user image and outfit image."
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
# Convert images to RGB and ensure uint8 dtype
|
| 99 |
+
user_rgb = user_img.convert("RGB")
|
| 100 |
+
outfit_rgb = outfit_img.convert("RGB")
|
| 101 |
+
|
| 102 |
+
avatar_bgr = cv2.cvtColor(np.array(user_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
| 103 |
+
clothing_bgr = cv2.cvtColor(np.array(outfit_rgb, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
| 104 |
+
|
| 105 |
+
# Call the API
|
| 106 |
+
resp = client.try_on_file(
|
| 107 |
+
clothing_image=clothing_bgr,
|
| 108 |
+
avatar_image=avatar_bgr,
|
| 109 |
+
seed=-1
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if resp.status_code != 200 or resp.image is None:
|
| 113 |
+
error_msg = resp.error_details if resp.error_details else f"API error {resp.status_code}"
|
| 114 |
+
return None, f"❌ {error_msg}"
|
| 115 |
+
|
| 116 |
+
# Convert result back to RGB
|
| 117 |
+
result_rgb = cv2.cvtColor(resp.image, cv2.COLOR_BGR2RGB)
|
| 118 |
+
|
| 119 |
+
# Generate fit advice
|
| 120 |
+
fit_msg = generate_fit_advice(height, chest, waist, sleeve)
|
| 121 |
+
|
| 122 |
+
return Image.fromarray(result_rgb), fit_msg
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
return None, f"❌ Error processing request: {str(e)}"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def generate_fit_advice(height, chest, waist, sleeve):
|
| 129 |
+
"""
|
| 130 |
+
Generate basic fit advice based on measurements
|
| 131 |
+
"""
|
| 132 |
+
advice = ["✅ Fit Analysis:"]
|
| 133 |
+
|
| 134 |
+
if chest and chest < 80:
|
| 135 |
+
advice.append("⚠️ Chest measurement suggests the item might be tight.")
|
| 136 |
+
elif chest and chest > 120:
|
| 137 |
+
advice.append("ℹ️ Chest measurement suggests a loose fit.")
|
| 138 |
+
else:
|
| 139 |
+
advice.append("✅ Chest measurement looks good.")
|
| 140 |
+
|
| 141 |
+
if waist and chest and abs(chest - waist) < 10:
|
| 142 |
+
advice.append("ℹ️ Similar chest and waist measurements - consider fitted styles.")
|
| 143 |
+
|
| 144 |
+
if height and height < 160:
|
| 145 |
+
advice.append("ℹ️ For shorter stature, consider checking garment length.")
|
| 146 |
+
elif height and height > 185:
|
| 147 |
+
advice.append("ℹ️ For taller stature, ensure adequate garment length.")
|
| 148 |
+
|
| 149 |
+
return "\n".join(advice)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# --- 3. Create the Gradio Interface ------------------------------------------
|
| 153 |
+
with gr.Blocks(theme=gr.themes.Soft(), title="Virtual Try-On") as demo:
|
| 154 |
+
gr.Markdown(
|
| 155 |
+
"""
|
| 156 |
+
# 👕 Virtual Try-On App
|
| 157 |
+
|
| 158 |
+
Upload your photo and a clothing item to see how it would look on you!
|
| 159 |
+
|
| 160 |
+
**Tips for best results:**
|
| 161 |
+
- Use clear, well-lit photos
|
| 162 |
+
- Person should be facing forward
|
| 163 |
+
- Clothing item should be clearly visible
|
| 164 |
+
"""
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
with gr.Row():
|
| 168 |
+
with gr.Column():
|
| 169 |
+
user_img = gr.Image(type="pil", label="Your Photo", height=300)
|
| 170 |
+
outfit_img = gr.Image(type="pil", label="Clothing Item", height=300)
|
| 171 |
+
|
| 172 |
+
with gr.Column():
|
| 173 |
+
gr.Markdown("### 📏 Your Measurements (Optional)")
|
| 174 |
+
height = gr.Number(label="Height (cm)", value=170, minimum=140, maximum=220)
|
| 175 |
+
chest = gr.Number(label="Chest (cm)", value=90, minimum=70, maximum=140)
|
| 176 |
+
waist = gr.Number(label="Waist (cm)", value=80, minimum=60, maximum=130)
|
| 177 |
+
sleeve = gr.Number(label="Sleeve Length (cm)", value=60, minimum=50, maximum=80)
|
| 178 |
+
|
| 179 |
+
submit_btn = gr.Button("✨ Try On Outfit", variant="primary", size="lg")
|
| 180 |
+
|
| 181 |
+
with gr.Row():
|
| 182 |
+
result_img = gr.Image(type="pil", label="Try-On Result", height=400)
|
| 183 |
+
fit_advice = gr.Textbox(label="Fit Advice", lines=6, max_lines=10)
|
| 184 |
+
|
| 185 |
+
# Examples section
|
| 186 |
+
gr.Markdown("### 📸 Example Images")
|
| 187 |
+
gr.Markdown("*Upload your own images or try with sample photos*")
|
| 188 |
+
|
| 189 |
+
submit_btn.click(
|
| 190 |
+
fn=try_on,
|
| 191 |
+
inputs=[user_img, outfit_img, height, chest, waist, sleeve],
|
| 192 |
+
outputs=[result_img, fit_advice]
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# --- 4. Launch the app -------------------------------------------------------
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
print("🚀 Launching Virtual Try-On App...")
|
| 198 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
requests
|
| 3 |
+
pillow
|
| 4 |
+
numpy
|
| 5 |
+
opencv-python-headless
|
| 6 |
+
try-on-diffusion-client
|
try_on_diffusion_client.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import requests
|
| 4 |
+
from requests_toolbelt.multipart.encoder import MultipartEncoder
|
| 5 |
+
from urllib.parse import urlparse
|
| 6 |
+
import logging
|
| 7 |
+
import json
|
| 8 |
+
from io import BytesIO
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class TryOnDiffusionAPIResponse:
|
| 14 |
+
status_code: int
|
| 15 |
+
image: np.ndarray = None
|
| 16 |
+
response_data: bytes = None
|
| 17 |
+
error_details: str = None
|
| 18 |
+
seed: int = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TryOnDiffusionClient:
|
| 22 |
+
def __init__(self, base_url: str = "http://localhost:8000/", api_key: str = ""):
|
| 23 |
+
self._logger = logging.getLogger("try_on_diffusion_client")
|
| 24 |
+
self._base_url = base_url
|
| 25 |
+
self._api_key = api_key
|
| 26 |
+
|
| 27 |
+
if self._base_url[-1] == "/":
|
| 28 |
+
self._base_url = self._base_url[:-1]
|
| 29 |
+
|
| 30 |
+
parsed_url = urlparse(self._base_url)
|
| 31 |
+
|
| 32 |
+
self._rapidapi_host = parsed_url.netloc if parsed_url.netloc.endswith(".rapidapi.com") else None
|
| 33 |
+
|
| 34 |
+
if self._rapidapi_host is not None:
|
| 35 |
+
self._logger.info(f"Using RapidAPI proxy: {self._rapidapi_host}")
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def _image_to_upload_file(image: np.ndarray) -> tuple:
|
| 39 |
+
_, jpeg_data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), 99])
|
| 40 |
+
jpeg_data = jpeg_data.tobytes()
|
| 41 |
+
|
| 42 |
+
fp = BytesIO(jpeg_data)
|
| 43 |
+
|
| 44 |
+
return "image.jpg", fp, "image/jpeg"
|
| 45 |
+
|
| 46 |
+
def try_on_file(
|
| 47 |
+
self,
|
| 48 |
+
clothing_image: np.ndarray = None,
|
| 49 |
+
clothing_prompt: str = None,
|
| 50 |
+
avatar_image: np.ndarray = None,
|
| 51 |
+
avatar_prompt: str = None,
|
| 52 |
+
avatar_sex: str = None,
|
| 53 |
+
background_image: np.ndarray = None,
|
| 54 |
+
background_prompt: str = None,
|
| 55 |
+
seed: int = -1,
|
| 56 |
+
raw_response: bool = False,
|
| 57 |
+
) -> TryOnDiffusionAPIResponse:
|
| 58 |
+
url = self._base_url + "/try-on-file"
|
| 59 |
+
|
| 60 |
+
request_data = {"seed": str(seed)}
|
| 61 |
+
|
| 62 |
+
if clothing_image is not None:
|
| 63 |
+
request_data["clothing_image"] = self._image_to_upload_file(clothing_image)
|
| 64 |
+
|
| 65 |
+
if clothing_prompt is not None:
|
| 66 |
+
request_data["clothing_prompt"] = clothing_prompt
|
| 67 |
+
|
| 68 |
+
if avatar_image is not None:
|
| 69 |
+
request_data["avatar_image"] = self._image_to_upload_file(avatar_image)
|
| 70 |
+
|
| 71 |
+
if avatar_prompt is not None:
|
| 72 |
+
request_data["avatar_prompt"] = avatar_prompt
|
| 73 |
+
|
| 74 |
+
if avatar_sex is not None:
|
| 75 |
+
request_data["avatar_sex"] = avatar_sex
|
| 76 |
+
|
| 77 |
+
if background_image is not None:
|
| 78 |
+
request_data["background_image"] = self._image_to_upload_file(background_image)
|
| 79 |
+
|
| 80 |
+
if background_prompt is not None:
|
| 81 |
+
request_data["background_prompt"] = background_prompt
|
| 82 |
+
|
| 83 |
+
multipart_data = MultipartEncoder(fields=request_data)
|
| 84 |
+
|
| 85 |
+
headers = {"Content-Type": multipart_data.content_type}
|
| 86 |
+
|
| 87 |
+
if self._rapidapi_host is not None:
|
| 88 |
+
headers["X-RapidAPI-Key"] = self._api_key
|
| 89 |
+
headers["X-RapidAPI-Host"] = self._rapidapi_host
|
| 90 |
+
else:
|
| 91 |
+
headers["X-API-Key"] = self._api_key
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
response = requests.post(
|
| 95 |
+
url,
|
| 96 |
+
data=multipart_data,
|
| 97 |
+
headers=headers,
|
| 98 |
+
)
|
| 99 |
+
except Exception as e:
|
| 100 |
+
self._logger.error(e, exc_info=True)
|
| 101 |
+
return TryOnDiffusionAPIResponse(status_code=0)
|
| 102 |
+
|
| 103 |
+
if response.status_code != 200:
|
| 104 |
+
self._logger.warning(f"Request failed, status code: {response.status_code}, response: {response.content}")
|
| 105 |
+
|
| 106 |
+
result = TryOnDiffusionAPIResponse(status_code=response.status_code)
|
| 107 |
+
|
| 108 |
+
if not raw_response and response.status_code == 200:
|
| 109 |
+
try:
|
| 110 |
+
result.image = cv2.imdecode(np.frombuffer(response.content, np.uint8), cv2.IMREAD_COLOR)
|
| 111 |
+
except:
|
| 112 |
+
result.image = None
|
| 113 |
+
else:
|
| 114 |
+
result.response_data = response.content
|
| 115 |
+
|
| 116 |
+
if result.status_code == 200:
|
| 117 |
+
if "X-Seed" in response.headers:
|
| 118 |
+
result.seed = int(response.headers["X-Seed"])
|
| 119 |
+
else:
|
| 120 |
+
try:
|
| 121 |
+
response_json = (
|
| 122 |
+
json.loads(result.response_data.decode("utf-8")) if result.response_data is not None else None
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if response_json is not None and "detail" in response_json:
|
| 126 |
+
result.error_details = response_json["detail"]
|
| 127 |
+
except:
|
| 128 |
+
result.error_details = None
|
| 129 |
+
|
| 130 |
+
return result
|