hyperfit-tryon / app.py
kevtico20
fix: pasar HF_TOKEN como auth header al llamar OOTDiffusion (cuota ZeroGPU)
072a251
import spaces
import gradio as gr
import torch
import tempfile
import os
from PIL import Image
from gradio_client import Client, handle_file
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
HF_TOKEN = os.environ.get("HF_TOKEN", "")
_OOTD_HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else None
# ─── Carga de modelos ─────────────────────────────────────────────────────────
multiview_pipe = None
def load_models():
global multiview_pipe
multiview_pipe = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.2",
custom_pipeline="sudo-ai/zero123plus-pipeline",
torch_dtype=torch.float16,
trust_remote_code=True,
)
multiview_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
multiview_pipe.scheduler.config,
timestep_spacing="trailing",
)
load_models()
# ─── Helpers ──────────────────────────────────────────────────────────────────
def _preprocess_for_zero123(img: Image.Image, size: int = 320) -> Image.Image:
"""Centra y hace padding cuadrado con fondo blanco para Zero123++."""
img = img.convert("RGBA")
bbox = img.getbbox()
if bbox:
img = img.crop(bbox)
max_dim = max(img.size)
padded = Image.new("RGBA", (max_dim, max_dim), (255, 255, 255, 255))
padded.paste(img, ((max_dim - img.width) // 2, (max_dim - img.height) // 2), img.split()[3])
return padded.convert("RGB").resize((size, size), Image.LANCZOS)
# ─── Endpoint 1 β€” Try-on foto ─────────────────────────────────────────────────
def tryon(person_image: Image.Image, garment_image: Image.Image) -> Image.Image:
"""
Recibe foto del usuario + imagen de prenda.
Usa OOTDiffusion (open source) para aplicar la prenda sobre la persona.
La foto del usuario debe ser de cuerpo completo, de frente.
"""
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f_person:
person_image.save(f_person.name, format="JPEG")
person_path = f_person.name
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f_garment:
garment_image.save(f_garment.name, format="JPEG")
garment_path = f_garment.name
try:
client = Client("eduardo4547/OOTDiffusion", headers=_OOTD_HEADERS, verbose=False)
result = client.predict(
vton_img=handle_file(person_path),
garm_img=handle_file(garment_path),
n_samples=1,
n_steps=20,
image_scale=2.0,
seed=42,
api_name="/process_hd",
)
# OOTDiffusion devuelve filepath directo (gr.Image output)
output_path = result
return Image.open(output_path).convert("RGB")
finally:
os.unlink(person_path)
os.unlink(garment_path)
# ─── Endpoint 2 β€” Multi-vista ─────────────────────────────────────────────────
@spaces.GPU
def multiview(front_image: Image.Image) -> list[Image.Image]:
"""
Genera 4 vistas (front, left, right, back) usando Zero123++.
Recibe imagen del try-on con fondo limpio (rembg aplicado en el backend).
"""
multiview_pipe.to("cuda")
processed = _preprocess_for_zero123(front_image)
result = multiview_pipe(processed, num_inference_steps=75).images[0]
w, h = result.size
tile_w, tile_h = w // 3, h // 2
# Zero123++ v1.2: azimuths 30Β°, 90Β°, 150Β° (fila top) | 210Β°, 270Β°, 330Β° (fila bottom)
views = {
"front": result.crop((tile_w * 2, tile_h, w, h)), # 330Β° β‰ˆ front-left
"right": result.crop((tile_w, 0, tile_w * 2, tile_h)), # 90Β° = right profile
"back": result.crop((0, tile_h, tile_w, h)), # 210Β° = back-left
"left": result.crop((tile_w, tile_h, tile_w * 2, h)), # 270Β° = left profile
}
multiview_pipe.to("cpu")
return [views["front"], views["left"], views["right"], views["back"]]
# ─── Interfaz Gradio ──────────────────────────────────────────────────────────
with gr.Blocks() as demo:
with gr.Tab("Try-On"):
gr.Interface(
fn=tryon,
inputs=[
gr.Image(type="pil", label="Foto del usuario (cuerpo completo, frontal)"),
gr.Image(type="pil", label="Imagen de la prenda"),
],
outputs=gr.Image(type="pil", label="Resultado"),
api_name="tryon",
)
with gr.Tab("Multi-Vista"):
gr.Interface(
fn=multiview,
inputs=gr.Image(type="pil", label="Imagen frontal del try-on"),
outputs=[
gr.Image(type="pil", label="Front"),
gr.Image(type="pil", label="Left"),
gr.Image(type="pil", label="Right"),
gr.Image(type="pil", label="Back"),
],
api_name="multiview",
)
demo.launch(show_error=True)