import spaces import gradio as gr import torch import tempfile import os from PIL import Image from gradio_client import Client, handle_file from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler HF_TOKEN = os.environ.get("HF_TOKEN", "") _OOTD_HEADERS = {"Authorization": f"Bearer {HF_TOKEN}"} if HF_TOKEN else None # ─── Carga de modelos ───────────────────────────────────────────────────────── multiview_pipe = None def load_models(): global multiview_pipe multiview_pipe = DiffusionPipeline.from_pretrained( "sudo-ai/zero123plus-v1.2", custom_pipeline="sudo-ai/zero123plus-pipeline", torch_dtype=torch.float16, trust_remote_code=True, ) multiview_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( multiview_pipe.scheduler.config, timestep_spacing="trailing", ) load_models() # ─── Helpers ────────────────────────────────────────────────────────────────── def _preprocess_for_zero123(img: Image.Image, size: int = 320) -> Image.Image: """Centra y hace padding cuadrado con fondo blanco para Zero123++.""" img = img.convert("RGBA") bbox = img.getbbox() if bbox: img = img.crop(bbox) max_dim = max(img.size) padded = Image.new("RGBA", (max_dim, max_dim), (255, 255, 255, 255)) padded.paste(img, ((max_dim - img.width) // 2, (max_dim - img.height) // 2), img.split()[3]) return padded.convert("RGB").resize((size, size), Image.LANCZOS) # ─── Endpoint 1 — Try-on foto ───────────────────────────────────────────────── def tryon(person_image: Image.Image, garment_image: Image.Image) -> Image.Image: """ Recibe foto del usuario + imagen de prenda. Usa OOTDiffusion (open source) para aplicar la prenda sobre la persona. La foto del usuario debe ser de cuerpo completo, de frente. """ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f_person: person_image.save(f_person.name, format="JPEG") person_path = f_person.name with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as f_garment: garment_image.save(f_garment.name, format="JPEG") garment_path = f_garment.name try: client = Client("eduardo4547/OOTDiffusion", headers=_OOTD_HEADERS, verbose=False) result = client.predict( vton_img=handle_file(person_path), garm_img=handle_file(garment_path), n_samples=1, n_steps=20, image_scale=2.0, seed=42, api_name="/process_hd", ) # OOTDiffusion devuelve filepath directo (gr.Image output) output_path = result return Image.open(output_path).convert("RGB") finally: os.unlink(person_path) os.unlink(garment_path) # ─── Endpoint 2 — Multi-vista ───────────────────────────────────────────────── @spaces.GPU def multiview(front_image: Image.Image) -> list[Image.Image]: """ Genera 4 vistas (front, left, right, back) usando Zero123++. Recibe imagen del try-on con fondo limpio (rembg aplicado en el backend). """ multiview_pipe.to("cuda") processed = _preprocess_for_zero123(front_image) result = multiview_pipe(processed, num_inference_steps=75).images[0] w, h = result.size tile_w, tile_h = w // 3, h // 2 # Zero123++ v1.2: azimuths 30°, 90°, 150° (fila top) | 210°, 270°, 330° (fila bottom) views = { "front": result.crop((tile_w * 2, tile_h, w, h)), # 330° ≈ front-left "right": result.crop((tile_w, 0, tile_w * 2, tile_h)), # 90° = right profile "back": result.crop((0, tile_h, tile_w, h)), # 210° = back-left "left": result.crop((tile_w, tile_h, tile_w * 2, h)), # 270° = left profile } multiview_pipe.to("cpu") return [views["front"], views["left"], views["right"], views["back"]] # ─── Interfaz Gradio ────────────────────────────────────────────────────────── with gr.Blocks() as demo: with gr.Tab("Try-On"): gr.Interface( fn=tryon, inputs=[ gr.Image(type="pil", label="Foto del usuario (cuerpo completo, frontal)"), gr.Image(type="pil", label="Imagen de la prenda"), ], outputs=gr.Image(type="pil", label="Resultado"), api_name="tryon", ) with gr.Tab("Multi-Vista"): gr.Interface( fn=multiview, inputs=gr.Image(type="pil", label="Imagen frontal del try-on"), outputs=[ gr.Image(type="pil", label="Front"), gr.Image(type="pil", label="Left"), gr.Image(type="pil", label="Right"), gr.Image(type="pil", label="Back"), ], api_name="multiview", ) demo.launch(show_error=True)