ribonpatil's picture
Rename app-20.py to app.py
b736b0f verified
"""
╔══════════════════════════════════════════════════════════════════╗
β•‘ PIXELSKITS β€” VIRTUAL TRY-ON API v3.0 β•‘
β•‘ β•‘
β•‘ PIPELINE: 4 fallback servers β€” tries each until one works β•‘
β•‘ 1. IDM-VTON (yisol/IDM-VTON) β•‘
β•‘ 2. Kolors (Kwai-Kolors/Kolors-Virtual-Try-On) β•‘
β•‘ 3. OOTDiff (levihsu/OOTDiffusion) β•‘
β•‘ 4. CatVTON (zhengchong/CatVTON) β•‘
β•‘ β•‘
β•‘ UPSCALE: Real-ESRGAN β†’ Lanczos fallback β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
"""
import os, io, hashlib, time, logging, tempfile, random
from typing import Optional
from datetime import datetime
import requests as http_requests
from fastapi import FastAPI, File, UploadFile, Header, HTTPException, Request, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from PIL import Image
from gradio_client import Client
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
log = logging.getLogger("pixelskits-tryon")
MASTER_KEY = os.environ.get("PIXELSKITS_API_KEY", "")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
RATE_LIMIT = 30
RATE_WINDOW = 3600
UPSCALE_URL = "https://api-inference.huggingface.co/models/ai-forever/Real-ESRGAN"
RESOLUTIONS = {
"720p": (1280, 720),
"1080p": (1920, 1080),
"2k": (2560, 1440),
"4k": (3840, 2160),
"8k": (7680, 4320),
}
# All 4 fallback servers
TRYON_SERVERS = [
"yisol/IDM-VTON",
"Kwai-Kolors/Kolors-Virtual-Try-On",
"levihsu/OOTDiffusion",
"zhengchong/CatVTON",
]
app = FastAPI(title="Pixelskits Virtual Try-On API v3.0", version="3.0.0", docs_url="/docs")
app.add_middleware(CORSMiddleware, allow_origins=["*"],
allow_methods=["GET", "POST"], allow_headers=["*"])
_rate: dict = {}
_ready_flag = False
# ══════════════════════════════════════════════════════════
# STARTUP
# ══════════════════════════════════════════════════════════
@app.on_event("startup")
async def startup():
global _ready_flag
log.info("βœ… Pixelskits Try-On v3.0 online β€” 4 server fallback ready")
_ready_flag = True
def _ready(): return _ready_flag
# ══════════════════════════════════════════════════════════
# RATE + AUTH
# ══════════════════════════════════════════════════════════
def check_rate(ip: str):
now = time.time()
h = hashlib.sha256(ip.encode()).hexdigest()[:16]
rec = _rate.get(h, {"count": 0, "start": now})
if now - rec["start"] > RATE_WINDOW:
rec = {"count": 0, "start": now}
rec["count"] += 1
_rate[h] = rec
if rec["count"] > RATE_LIMIT:
raise HTTPException(429, f"Rate limit: {RATE_LIMIT} req/hour.")
def _auth(key: Optional[str]):
if MASTER_KEY and key != MASTER_KEY:
raise HTTPException(401, "Invalid API key.")
def _get_ip(req: Request) -> str:
return req.headers.get("x-forwarded-for", req.client.host or "unknown")
# ══════════════════════════════════════════════════════════
# HELPERS
# ══════════════════════════════════════════════════════════
def _save_temp(data: bytes, suffix: str = ".png") -> str:
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
tmp.write(data); tmp.flush(); tmp.close()
return tmp.name
def _pil_to_bytes(img: Image.Image, fmt: str = "PNG") -> bytes:
buf = io.BytesIO(); img.save(buf, format=fmt); return buf.getvalue()
def _validate_and_save(data: bytes, label: str) -> str:
try:
img = Image.open(io.BytesIO(data)); img.verify()
except Exception:
raise HTTPException(400, f"{label}: invalid image.")
img = Image.open(io.BytesIO(data)).convert("RGB")
return _save_temp(_pil_to_bytes(img))
def _open_result(result) -> Image.Image:
"""Extract image path from gradio result and open it."""
if isinstance(result, (list, tuple)):
result = result[0]
if isinstance(result, dict):
result = result.get("image") or result.get("path") or result.get("url", "")
if isinstance(result, str):
return Image.open(result).convert("RGB")
raise ValueError(f"Unknown result type: {type(result)}")
# ══════════════════════════════════════════════════════════
# TRY-ON β€” attempts each server in order
# ══════════════════════════════════════════════════════════
def try_idm_vton(person_path: str, garment_path: str,
denoise_steps: int, seed: int) -> Image.Image:
client = Client("yisol/IDM-VTON", hf_token=HF_TOKEN or None)
result = client.predict(
{"background": person_path, "layers": [], "composite": None},
garment_path,
"Output",
True, True,
denoise_steps, seed,
api_name="/tryon"
)
return _open_result(result)
def try_kolors(person_path: str, garment_path: str) -> Image.Image:
client = Client("Kwai-Kolors/Kolors-Virtual-Try-On", hf_token=HF_TOKEN or None)
result = client.predict(
person_path,
garment_path,
api_name="/tryon"
)
return _open_result(result)
def try_ootdiffusion(person_path: str, garment_path: str) -> Image.Image:
client = Client("levihsu/OOTDiffusion", hf_token=HF_TOKEN or None)
result = client.predict(
person_path,
garment_path,
"Upper body",
1, 1,
api_name="/process_dc"
)
return _open_result(result)
def try_catvton(person_path: str, garment_path: str) -> Image.Image:
client = Client("zhengchong/CatVTON", hf_token=HF_TOKEN or None)
result = client.predict(
person_path,
garment_path,
"upper",
50, 2.5, 42,
api_name="/submit_tryon"
)
return _open_result(result)
def run_tryon_with_fallback(person_path: str, garment_path: str,
denoise_steps: int, seed: int) -> tuple:
"""Try all 4 servers. Returns (Image, server_name)."""
servers = [
("IDM-VTON", lambda: try_idm_vton(person_path, garment_path, denoise_steps, seed)),
("Kolors", lambda: try_kolors(person_path, garment_path)),
("OOTDiffusion", lambda: try_ootdiffusion(person_path, garment_path)),
("CatVTON", lambda: try_catvton(person_path, garment_path)),
]
last_error = None
for name, fn in servers:
try:
log.info(f"Trying {name}…")
result = fn()
log.info(f"βœ… {name} success!")
return result, name
except Exception as e:
log.warning(f"❌ {name} failed: {e}")
last_error = e
continue
raise HTTPException(503, f"All try-on servers unavailable. Last error: {last_error}")
# ══════════════════════════════════════════════════════════
# UPSCALER
# ══════════════════════════════════════════════════════════
def upscale_image(img: Image.Image, target_res: str) -> Image.Image:
if target_res not in RESOLUTIONS:
return img
target_w, target_h = RESOLUTIONS[target_res]
src_w, src_h = img.size
scale_needed = max(target_w / src_w, target_h / src_h)
log.info(f"Upscale: {src_w}Γ—{src_h} β†’ {target_res}")
if scale_needed > 1.5 and HF_TOKEN:
try:
headers = {"Authorization": f"Bearer {HF_TOKEN}"}
for attempt in range(3):
resp = http_requests.post(UPSCALE_URL, headers=headers,
data=_pil_to_bytes(img), timeout=120)
if resp.status_code == 503:
wait = min(float(resp.json().get("estimated_time", 15)), 25)
log.info(f"Upscaler warming {wait}s… attempt {attempt+1}")
time.sleep(wait)
continue
if resp.status_code == 200:
upscaled = Image.open(io.BytesIO(resp.content)).convert("RGB")
log.info(f"Real-ESRGAN βœ… {upscaled.size}")
return upscaled.resize((target_w, target_h), Image.LANCZOS)
log.warning(f"Upscaler {resp.status_code} β€” Lanczos fallback")
break
except Exception as e:
log.warning(f"Upscaler error: {e} β€” Lanczos fallback")
return img.resize((target_w, target_h), Image.LANCZOS)
# ══════════════════════════════════════════════════════════
# ROUTES
# ══════════════════════════════════════════════════════════
@app.get("/")
async def root():
return {
"api": "Pixelskits Virtual Try-On v3.0",
"status": "online" if _ready() else "loading",
"engines": TRYON_SERVERS,
"resolutions": list(RESOLUTIONS.keys()),
"usage": "POST /tryon β€” person_image + garment_image + resolution",
}
@app.get("/health")
async def health():
return {
"status": "ok" if _ready() else "loading",
"hf_token": "set" if HF_TOKEN else "missing",
"servers": len(TRYON_SERVERS),
"ts": datetime.utcnow().isoformat() + "Z",
}
@app.get("/ping")
async def ping():
return {"ping": "pong"}
@app.post("/tryon")
async def virtual_tryon(
request: Request,
person_image: UploadFile = File(...),
garment_image: UploadFile = File(...),
resolution: str = Form("1080p"),
denoise_steps: int = Form(30),
seed: int = Form(42),
x_api_key: Optional[str] = Header(None, alias="X-API-Key"),
):
_auth(x_api_key)
check_rate(_get_ip(request))
if not _ready():
raise HTTPException(503, "API starting. Retry in 10s.")
allowed = ("image/jpeg", "image/jpg", "image/png", "image/webp")
if person_image.content_type not in allowed:
raise HTTPException(400, "person_image: use JPEG, PNG, or WebP.")
if garment_image.content_type not in allowed:
raise HTTPException(400, "garment_image: use JPEG, PNG, or WebP.")
if resolution not in RESOLUTIONS:
raise HTTPException(400, f"resolution must be: {' | '.join(RESOLUTIONS.keys())}")
denoise_steps = max(20, min(40, denoise_steps))
if seed == -1:
seed = random.randint(0, 999999)
person_bytes = await person_image.read()
garment_bytes = await garment_image.read()
if len(person_bytes) > 20*1024*1024: raise HTTPException(413, "person_image: max 20MB.")
if len(garment_bytes) > 20*1024*1024: raise HTTPException(413, "garment_image: max 20MB.")
t0 = time.time()
person_path = _validate_and_save(person_bytes, "person_image")
garment_path = _validate_and_save(garment_bytes, "garment_image")
try:
log.info(f"/tryon resolution={resolution} steps={denoise_steps}")
tryon_img, server_used = run_tryon_with_fallback(
person_path, garment_path, denoise_steps, seed
)
final_img = upscale_image(tryon_img, resolution)
except HTTPException:
raise
except Exception as e:
log.error(f"Pipeline failed: {e}")
raise HTTPException(500, f"Try-on failed: {str(e)}")
finally:
for p in [person_path, garment_path]:
try: os.unlink(p)
except: pass
elapsed = round(time.time() - t0, 2)
out_w, out_h = final_img.size
log.info(f"/tryon done β€” {elapsed}s {out_w}Γ—{out_h} via {server_used}")
return Response(
content = _pil_to_bytes(final_img),
media_type = "image/png",
headers = {
"X-Processing-Time": f"{elapsed}s",
"X-Engine": server_used,
"X-Resolution": resolution,
"X-Output-Size": f"{out_w}x{out_h}",
},
)