""" ╔══════════════════════════════════════════════════════════════════╗ ║ PIXELSKITS — VIRTUAL TRY-ON API v3.0 ║ ║ ║ ║ PIPELINE: 4 fallback servers — tries each until one works ║ ║ 1. IDM-VTON (yisol/IDM-VTON) ║ ║ 2. Kolors (Kwai-Kolors/Kolors-Virtual-Try-On) ║ ║ 3. OOTDiff (levihsu/OOTDiffusion) ║ ║ 4. CatVTON (zhengchong/CatVTON) ║ ║ ║ ║ UPSCALE: Real-ESRGAN → Lanczos fallback ║ ╚══════════════════════════════════════════════════════════════════╝ """ import os, io, hashlib, time, logging, tempfile, random from typing import Optional from datetime import datetime import requests as http_requests from fastapi import FastAPI, File, UploadFile, Header, HTTPException, Request, Form from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response from PIL import Image from gradio_client import Client logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") log = logging.getLogger("pixelskits-tryon") MASTER_KEY = os.environ.get("PIXELSKITS_API_KEY", "") HF_TOKEN = os.environ.get("HF_TOKEN", "") RATE_LIMIT = 30 RATE_WINDOW = 3600 UPSCALE_URL = "https://api-inference.huggingface.co/models/ai-forever/Real-ESRGAN" RESOLUTIONS = { "720p": (1280, 720), "1080p": (1920, 1080), "2k": (2560, 1440), "4k": (3840, 2160), "8k": (7680, 4320), } # All 4 fallback servers TRYON_SERVERS = [ "yisol/IDM-VTON", "Kwai-Kolors/Kolors-Virtual-Try-On", "levihsu/OOTDiffusion", "zhengchong/CatVTON", ] app = FastAPI(title="Pixelskits Virtual Try-On API v3.0", version="3.0.0", docs_url="/docs") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["GET", "POST"], allow_headers=["*"]) _rate: dict = {} _ready_flag = False # ══════════════════════════════════════════════════════════ # STARTUP # ══════════════════════════════════════════════════════════ @app.on_event("startup") async def startup(): global _ready_flag log.info("✅ Pixelskits Try-On v3.0 online — 4 server fallback ready") _ready_flag = True def _ready(): return _ready_flag # ══════════════════════════════════════════════════════════ # RATE + AUTH # ══════════════════════════════════════════════════════════ def check_rate(ip: str): now = time.time() h = hashlib.sha256(ip.encode()).hexdigest()[:16] rec = _rate.get(h, {"count": 0, "start": now}) if now - rec["start"] > RATE_WINDOW: rec = {"count": 0, "start": now} rec["count"] += 1 _rate[h] = rec if rec["count"] > RATE_LIMIT: raise HTTPException(429, f"Rate limit: {RATE_LIMIT} req/hour.") def _auth(key: Optional[str]): if MASTER_KEY and key != MASTER_KEY: raise HTTPException(401, "Invalid API key.") def _get_ip(req: Request) -> str: return req.headers.get("x-forwarded-for", req.client.host or "unknown") # ══════════════════════════════════════════════════════════ # HELPERS # ══════════════════════════════════════════════════════════ def _save_temp(data: bytes, suffix: str = ".png") -> str: tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) tmp.write(data); tmp.flush(); tmp.close() return tmp.name def _pil_to_bytes(img: Image.Image, fmt: str = "PNG") -> bytes: buf = io.BytesIO(); img.save(buf, format=fmt); return buf.getvalue() def _validate_and_save(data: bytes, label: str) -> str: try: img = Image.open(io.BytesIO(data)); img.verify() except Exception: raise HTTPException(400, f"{label}: invalid image.") img = Image.open(io.BytesIO(data)).convert("RGB") return _save_temp(_pil_to_bytes(img)) def _open_result(result) -> Image.Image: """Extract image path from gradio result and open it.""" if isinstance(result, (list, tuple)): result = result[0] if isinstance(result, dict): result = result.get("image") or result.get("path") or result.get("url", "") if isinstance(result, str): return Image.open(result).convert("RGB") raise ValueError(f"Unknown result type: {type(result)}") # ══════════════════════════════════════════════════════════ # TRY-ON — attempts each server in order # ══════════════════════════════════════════════════════════ def try_idm_vton(person_path: str, garment_path: str, denoise_steps: int, seed: int) -> Image.Image: client = Client("yisol/IDM-VTON", hf_token=HF_TOKEN or None) result = client.predict( {"background": person_path, "layers": [], "composite": None}, garment_path, "Output", True, True, denoise_steps, seed, api_name="/tryon" ) return _open_result(result) def try_kolors(person_path: str, garment_path: str) -> Image.Image: client = Client("Kwai-Kolors/Kolors-Virtual-Try-On", hf_token=HF_TOKEN or None) result = client.predict( person_path, garment_path, api_name="/tryon" ) return _open_result(result) def try_ootdiffusion(person_path: str, garment_path: str) -> Image.Image: client = Client("levihsu/OOTDiffusion", hf_token=HF_TOKEN or None) result = client.predict( person_path, garment_path, "Upper body", 1, 1, api_name="/process_dc" ) return _open_result(result) def try_catvton(person_path: str, garment_path: str) -> Image.Image: client = Client("zhengchong/CatVTON", hf_token=HF_TOKEN or None) result = client.predict( person_path, garment_path, "upper", 50, 2.5, 42, api_name="/submit_tryon" ) return _open_result(result) def run_tryon_with_fallback(person_path: str, garment_path: str, denoise_steps: int, seed: int) -> tuple: """Try all 4 servers. Returns (Image, server_name).""" servers = [ ("IDM-VTON", lambda: try_idm_vton(person_path, garment_path, denoise_steps, seed)), ("Kolors", lambda: try_kolors(person_path, garment_path)), ("OOTDiffusion", lambda: try_ootdiffusion(person_path, garment_path)), ("CatVTON", lambda: try_catvton(person_path, garment_path)), ] last_error = None for name, fn in servers: try: log.info(f"Trying {name}…") result = fn() log.info(f"✅ {name} success!") return result, name except Exception as e: log.warning(f"❌ {name} failed: {e}") last_error = e continue raise HTTPException(503, f"All try-on servers unavailable. Last error: {last_error}") # ══════════════════════════════════════════════════════════ # UPSCALER # ══════════════════════════════════════════════════════════ def upscale_image(img: Image.Image, target_res: str) -> Image.Image: if target_res not in RESOLUTIONS: return img target_w, target_h = RESOLUTIONS[target_res] src_w, src_h = img.size scale_needed = max(target_w / src_w, target_h / src_h) log.info(f"Upscale: {src_w}×{src_h} → {target_res}") if scale_needed > 1.5 and HF_TOKEN: try: headers = {"Authorization": f"Bearer {HF_TOKEN}"} for attempt in range(3): resp = http_requests.post(UPSCALE_URL, headers=headers, data=_pil_to_bytes(img), timeout=120) if resp.status_code == 503: wait = min(float(resp.json().get("estimated_time", 15)), 25) log.info(f"Upscaler warming {wait}s… attempt {attempt+1}") time.sleep(wait) continue if resp.status_code == 200: upscaled = Image.open(io.BytesIO(resp.content)).convert("RGB") log.info(f"Real-ESRGAN ✅ {upscaled.size}") return upscaled.resize((target_w, target_h), Image.LANCZOS) log.warning(f"Upscaler {resp.status_code} — Lanczos fallback") break except Exception as e: log.warning(f"Upscaler error: {e} — Lanczos fallback") return img.resize((target_w, target_h), Image.LANCZOS) # ══════════════════════════════════════════════════════════ # ROUTES # ══════════════════════════════════════════════════════════ @app.get("/") async def root(): return { "api": "Pixelskits Virtual Try-On v3.0", "status": "online" if _ready() else "loading", "engines": TRYON_SERVERS, "resolutions": list(RESOLUTIONS.keys()), "usage": "POST /tryon — person_image + garment_image + resolution", } @app.get("/health") async def health(): return { "status": "ok" if _ready() else "loading", "hf_token": "set" if HF_TOKEN else "missing", "servers": len(TRYON_SERVERS), "ts": datetime.utcnow().isoformat() + "Z", } @app.get("/ping") async def ping(): return {"ping": "pong"} @app.post("/tryon") async def virtual_tryon( request: Request, person_image: UploadFile = File(...), garment_image: UploadFile = File(...), resolution: str = Form("1080p"), denoise_steps: int = Form(30), seed: int = Form(42), x_api_key: Optional[str] = Header(None, alias="X-API-Key"), ): _auth(x_api_key) check_rate(_get_ip(request)) if not _ready(): raise HTTPException(503, "API starting. Retry in 10s.") allowed = ("image/jpeg", "image/jpg", "image/png", "image/webp") if person_image.content_type not in allowed: raise HTTPException(400, "person_image: use JPEG, PNG, or WebP.") if garment_image.content_type not in allowed: raise HTTPException(400, "garment_image: use JPEG, PNG, or WebP.") if resolution not in RESOLUTIONS: raise HTTPException(400, f"resolution must be: {' | '.join(RESOLUTIONS.keys())}") denoise_steps = max(20, min(40, denoise_steps)) if seed == -1: seed = random.randint(0, 999999) person_bytes = await person_image.read() garment_bytes = await garment_image.read() if len(person_bytes) > 20*1024*1024: raise HTTPException(413, "person_image: max 20MB.") if len(garment_bytes) > 20*1024*1024: raise HTTPException(413, "garment_image: max 20MB.") t0 = time.time() person_path = _validate_and_save(person_bytes, "person_image") garment_path = _validate_and_save(garment_bytes, "garment_image") try: log.info(f"/tryon resolution={resolution} steps={denoise_steps}") tryon_img, server_used = run_tryon_with_fallback( person_path, garment_path, denoise_steps, seed ) final_img = upscale_image(tryon_img, resolution) except HTTPException: raise except Exception as e: log.error(f"Pipeline failed: {e}") raise HTTPException(500, f"Try-on failed: {str(e)}") finally: for p in [person_path, garment_path]: try: os.unlink(p) except: pass elapsed = round(time.time() - t0, 2) out_w, out_h = final_img.size log.info(f"/tryon done — {elapsed}s {out_w}×{out_h} via {server_used}") return Response( content = _pil_to_bytes(final_img), media_type = "image/png", headers = { "X-Processing-Time": f"{elapsed}s", "X-Engine": server_used, "X-Resolution": resolution, "X-Output-Size": f"{out_w}x{out_h}", }, )