Spaces:
Running
Running
| """ | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β PIXELSKITS β VIRTUAL TRY-ON API v3.0 β | |
| β β | |
| β PIPELINE: 4 fallback servers β tries each until one works β | |
| β 1. IDM-VTON (yisol/IDM-VTON) β | |
| β 2. Kolors (Kwai-Kolors/Kolors-Virtual-Try-On) β | |
| β 3. OOTDiff (levihsu/OOTDiffusion) β | |
| β 4. CatVTON (zhengchong/CatVTON) β | |
| β β | |
| β UPSCALE: Real-ESRGAN β Lanczos fallback β | |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """ | |
| import os, io, hashlib, time, logging, tempfile, random | |
| from typing import Optional | |
| from datetime import datetime | |
| import requests as http_requests | |
| from fastapi import FastAPI, File, UploadFile, Header, HTTPException, Request, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import Response | |
| from PIL import Image | |
| from gradio_client import Client | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| log = logging.getLogger("pixelskits-tryon") | |
| MASTER_KEY = os.environ.get("PIXELSKITS_API_KEY", "") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| RATE_LIMIT = 30 | |
| RATE_WINDOW = 3600 | |
| UPSCALE_URL = "https://api-inference.huggingface.co/models/ai-forever/Real-ESRGAN" | |
| RESOLUTIONS = { | |
| "720p": (1280, 720), | |
| "1080p": (1920, 1080), | |
| "2k": (2560, 1440), | |
| "4k": (3840, 2160), | |
| "8k": (7680, 4320), | |
| } | |
| # All 4 fallback servers | |
| TRYON_SERVERS = [ | |
| "yisol/IDM-VTON", | |
| "Kwai-Kolors/Kolors-Virtual-Try-On", | |
| "levihsu/OOTDiffusion", | |
| "zhengchong/CatVTON", | |
| ] | |
| app = FastAPI(title="Pixelskits Virtual Try-On API v3.0", version="3.0.0", docs_url="/docs") | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], | |
| allow_methods=["GET", "POST"], allow_headers=["*"]) | |
| _rate: dict = {} | |
| _ready_flag = False | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # STARTUP | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def startup(): | |
| global _ready_flag | |
| log.info("β Pixelskits Try-On v3.0 online β 4 server fallback ready") | |
| _ready_flag = True | |
| def _ready(): return _ready_flag | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # RATE + AUTH | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def check_rate(ip: str): | |
| now = time.time() | |
| h = hashlib.sha256(ip.encode()).hexdigest()[:16] | |
| rec = _rate.get(h, {"count": 0, "start": now}) | |
| if now - rec["start"] > RATE_WINDOW: | |
| rec = {"count": 0, "start": now} | |
| rec["count"] += 1 | |
| _rate[h] = rec | |
| if rec["count"] > RATE_LIMIT: | |
| raise HTTPException(429, f"Rate limit: {RATE_LIMIT} req/hour.") | |
| def _auth(key: Optional[str]): | |
| if MASTER_KEY and key != MASTER_KEY: | |
| raise HTTPException(401, "Invalid API key.") | |
| def _get_ip(req: Request) -> str: | |
| return req.headers.get("x-forwarded-for", req.client.host or "unknown") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HELPERS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _save_temp(data: bytes, suffix: str = ".png") -> str: | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| tmp.write(data); tmp.flush(); tmp.close() | |
| return tmp.name | |
| def _pil_to_bytes(img: Image.Image, fmt: str = "PNG") -> bytes: | |
| buf = io.BytesIO(); img.save(buf, format=fmt); return buf.getvalue() | |
| def _validate_and_save(data: bytes, label: str) -> str: | |
| try: | |
| img = Image.open(io.BytesIO(data)); img.verify() | |
| except Exception: | |
| raise HTTPException(400, f"{label}: invalid image.") | |
| img = Image.open(io.BytesIO(data)).convert("RGB") | |
| return _save_temp(_pil_to_bytes(img)) | |
| def _open_result(result) -> Image.Image: | |
| """Extract image path from gradio result and open it.""" | |
| if isinstance(result, (list, tuple)): | |
| result = result[0] | |
| if isinstance(result, dict): | |
| result = result.get("image") or result.get("path") or result.get("url", "") | |
| if isinstance(result, str): | |
| return Image.open(result).convert("RGB") | |
| raise ValueError(f"Unknown result type: {type(result)}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # TRY-ON β attempts each server in order | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def try_idm_vton(person_path: str, garment_path: str, | |
| denoise_steps: int, seed: int) -> Image.Image: | |
| client = Client("yisol/IDM-VTON", hf_token=HF_TOKEN or None) | |
| result = client.predict( | |
| {"background": person_path, "layers": [], "composite": None}, | |
| garment_path, | |
| "Output", | |
| True, True, | |
| denoise_steps, seed, | |
| api_name="/tryon" | |
| ) | |
| return _open_result(result) | |
| def try_kolors(person_path: str, garment_path: str) -> Image.Image: | |
| client = Client("Kwai-Kolors/Kolors-Virtual-Try-On", hf_token=HF_TOKEN or None) | |
| result = client.predict( | |
| person_path, | |
| garment_path, | |
| api_name="/tryon" | |
| ) | |
| return _open_result(result) | |
| def try_ootdiffusion(person_path: str, garment_path: str) -> Image.Image: | |
| client = Client("levihsu/OOTDiffusion", hf_token=HF_TOKEN or None) | |
| result = client.predict( | |
| person_path, | |
| garment_path, | |
| "Upper body", | |
| 1, 1, | |
| api_name="/process_dc" | |
| ) | |
| return _open_result(result) | |
| def try_catvton(person_path: str, garment_path: str) -> Image.Image: | |
| client = Client("zhengchong/CatVTON", hf_token=HF_TOKEN or None) | |
| result = client.predict( | |
| person_path, | |
| garment_path, | |
| "upper", | |
| 50, 2.5, 42, | |
| api_name="/submit_tryon" | |
| ) | |
| return _open_result(result) | |
| def run_tryon_with_fallback(person_path: str, garment_path: str, | |
| denoise_steps: int, seed: int) -> tuple: | |
| """Try all 4 servers. Returns (Image, server_name).""" | |
| servers = [ | |
| ("IDM-VTON", lambda: try_idm_vton(person_path, garment_path, denoise_steps, seed)), | |
| ("Kolors", lambda: try_kolors(person_path, garment_path)), | |
| ("OOTDiffusion", lambda: try_ootdiffusion(person_path, garment_path)), | |
| ("CatVTON", lambda: try_catvton(person_path, garment_path)), | |
| ] | |
| last_error = None | |
| for name, fn in servers: | |
| try: | |
| log.info(f"Trying {name}β¦") | |
| result = fn() | |
| log.info(f"β {name} success!") | |
| return result, name | |
| except Exception as e: | |
| log.warning(f"β {name} failed: {e}") | |
| last_error = e | |
| continue | |
| raise HTTPException(503, f"All try-on servers unavailable. Last error: {last_error}") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UPSCALER | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def upscale_image(img: Image.Image, target_res: str) -> Image.Image: | |
| if target_res not in RESOLUTIONS: | |
| return img | |
| target_w, target_h = RESOLUTIONS[target_res] | |
| src_w, src_h = img.size | |
| scale_needed = max(target_w / src_w, target_h / src_h) | |
| log.info(f"Upscale: {src_w}Γ{src_h} β {target_res}") | |
| if scale_needed > 1.5 and HF_TOKEN: | |
| try: | |
| headers = {"Authorization": f"Bearer {HF_TOKEN}"} | |
| for attempt in range(3): | |
| resp = http_requests.post(UPSCALE_URL, headers=headers, | |
| data=_pil_to_bytes(img), timeout=120) | |
| if resp.status_code == 503: | |
| wait = min(float(resp.json().get("estimated_time", 15)), 25) | |
| log.info(f"Upscaler warming {wait}s⦠attempt {attempt+1}") | |
| time.sleep(wait) | |
| continue | |
| if resp.status_code == 200: | |
| upscaled = Image.open(io.BytesIO(resp.content)).convert("RGB") | |
| log.info(f"Real-ESRGAN β {upscaled.size}") | |
| return upscaled.resize((target_w, target_h), Image.LANCZOS) | |
| log.warning(f"Upscaler {resp.status_code} β Lanczos fallback") | |
| break | |
| except Exception as e: | |
| log.warning(f"Upscaler error: {e} β Lanczos fallback") | |
| return img.resize((target_w, target_h), Image.LANCZOS) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ROUTES | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def root(): | |
| return { | |
| "api": "Pixelskits Virtual Try-On v3.0", | |
| "status": "online" if _ready() else "loading", | |
| "engines": TRYON_SERVERS, | |
| "resolutions": list(RESOLUTIONS.keys()), | |
| "usage": "POST /tryon β person_image + garment_image + resolution", | |
| } | |
| async def health(): | |
| return { | |
| "status": "ok" if _ready() else "loading", | |
| "hf_token": "set" if HF_TOKEN else "missing", | |
| "servers": len(TRYON_SERVERS), | |
| "ts": datetime.utcnow().isoformat() + "Z", | |
| } | |
| async def ping(): | |
| return {"ping": "pong"} | |
| async def virtual_tryon( | |
| request: Request, | |
| person_image: UploadFile = File(...), | |
| garment_image: UploadFile = File(...), | |
| resolution: str = Form("1080p"), | |
| denoise_steps: int = Form(30), | |
| seed: int = Form(42), | |
| x_api_key: Optional[str] = Header(None, alias="X-API-Key"), | |
| ): | |
| _auth(x_api_key) | |
| check_rate(_get_ip(request)) | |
| if not _ready(): | |
| raise HTTPException(503, "API starting. Retry in 10s.") | |
| allowed = ("image/jpeg", "image/jpg", "image/png", "image/webp") | |
| if person_image.content_type not in allowed: | |
| raise HTTPException(400, "person_image: use JPEG, PNG, or WebP.") | |
| if garment_image.content_type not in allowed: | |
| raise HTTPException(400, "garment_image: use JPEG, PNG, or WebP.") | |
| if resolution not in RESOLUTIONS: | |
| raise HTTPException(400, f"resolution must be: {' | '.join(RESOLUTIONS.keys())}") | |
| denoise_steps = max(20, min(40, denoise_steps)) | |
| if seed == -1: | |
| seed = random.randint(0, 999999) | |
| person_bytes = await person_image.read() | |
| garment_bytes = await garment_image.read() | |
| if len(person_bytes) > 20*1024*1024: raise HTTPException(413, "person_image: max 20MB.") | |
| if len(garment_bytes) > 20*1024*1024: raise HTTPException(413, "garment_image: max 20MB.") | |
| t0 = time.time() | |
| person_path = _validate_and_save(person_bytes, "person_image") | |
| garment_path = _validate_and_save(garment_bytes, "garment_image") | |
| try: | |
| log.info(f"/tryon resolution={resolution} steps={denoise_steps}") | |
| tryon_img, server_used = run_tryon_with_fallback( | |
| person_path, garment_path, denoise_steps, seed | |
| ) | |
| final_img = upscale_image(tryon_img, resolution) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| log.error(f"Pipeline failed: {e}") | |
| raise HTTPException(500, f"Try-on failed: {str(e)}") | |
| finally: | |
| for p in [person_path, garment_path]: | |
| try: os.unlink(p) | |
| except: pass | |
| elapsed = round(time.time() - t0, 2) | |
| out_w, out_h = final_img.size | |
| log.info(f"/tryon done β {elapsed}s {out_w}Γ{out_h} via {server_used}") | |
| return Response( | |
| content = _pil_to_bytes(final_img), | |
| media_type = "image/png", | |
| headers = { | |
| "X-Processing-Time": f"{elapsed}s", | |
| "X-Engine": server_used, | |
| "X-Resolution": resolution, | |
| "X-Output-Size": f"{out_w}x{out_h}", | |
| }, | |
| ) | |