Spaces:
Paused
Paused
| import patch_gradio | |
| import os | |
| import sys | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import gc | |
| import traceback | |
| from huggingface_hub import snapshot_download | |
| sys.path.insert(0, '/app/CatVTON') | |
| from model.pipeline import CatVTONPipeline | |
| from model.cloth_masker import AutoMasker | |
| from utils import init_weight_dtype, resize_and_crop, resize_and_padding | |
| pipeline = None | |
| automasker = None | |
| def load_models(): | |
| global pipeline, automasker | |
| if pipeline is not None and automasker is not None: | |
| return | |
| print("π Loading models...", file=sys.stderr) | |
| try: | |
| repo_path = snapshot_download( | |
| repo_id="zhengchong/CatVTON", | |
| cache_dir="/tmp/models" | |
| ) | |
| nsfw_path = "/tmp/NSFW.jpg" | |
| if not os.path.exists(nsfw_path): | |
| Image.new('RGB', (512, 512), color='black').save(nsfw_path) | |
| pipeline = CatVTONPipeline( | |
| base_ckpt="booksforcharlie/stable-diffusion-inpainting", | |
| attn_ckpt=repo_path, | |
| attn_ckpt_version="mix", | |
| weight_dtype=torch.float16, | |
| use_tf32=True, | |
| device='cuda' | |
| ) | |
| automasker = AutoMasker( | |
| densepose_ckpt=os.path.join(repo_path, "DensePose"), | |
| schp_ckpt=os.path.join(repo_path, "SCHP"), | |
| device='cpu' | |
| ) | |
| print("β Models loaded!", file=sys.stderr) | |
| except Exception as e: | |
| print(f"β Error: {e}", file=sys.stderr) | |
| traceback.print_exc() | |
| raise | |
| def generate_tryon(person_img, cloth_img, garment_category): # Added parameter | |
| print("="*50, file=sys.stderr) | |
| print(f"Received - Person: {type(person_img)}, Cloth: {type(cloth_img)}, Category: {garment_category}", file=sys.stderr) | |
| if person_img is None or cloth_img is None: | |
| raise gr.Error("Both images required!") | |
| try: | |
| print("Images received as PIL", file=sys.stderr) | |
| load_models() | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| person_img = person_img.copy() | |
| cloth_img = cloth_img.copy() | |
| target_height = 1024 | |
| target_width = 768 | |
| person_img = resize_and_crop(person_img, (target_width, target_height)) | |
| cloth_img = resize_and_padding(cloth_img, (target_width, target_height)) | |
| # CHANGED: Use selected category instead of hardcoded "upper" | |
| mask = automasker(person_img, garment_category)['mask'] | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| result = pipeline( | |
| image=person_img, | |
| condition_image=cloth_img, | |
| mask=mask, | |
| num_inference_steps=50, | |
| guidance_scale=2.5, | |
| seed=None, | |
| height=target_height, | |
| width=target_width | |
| )[0] | |
| print("β Success!", file=sys.stderr) | |
| return result | |
| except Exception as e: | |
| print(f"β Error: {e}", file=sys.stderr) | |
| traceback.print_exc() | |
| raise gr.Error(str(e)) | |
| # Updated Interface with category dropdown | |
| demo = gr.Interface( | |
| fn=generate_tryon, | |
| inputs=[ | |
| gr.Image(label="Person Image", type="pil"), | |
| gr.Image(label="Garment Image", type="pil"), | |
| gr.Dropdown( | |
| choices=["upper", "lower", "overall"], | |
| value="upper", | |
| label="Garment Category", | |
| info="Select: upper (tops), lower (pants/skirts), overall (dresses/full outfits)" | |
| ) | |
| ], | |
| outputs=gr.Image(label="Result", type="pil"), | |
| title="Try-Space Virtual Try-On", | |
| description="Upload person and garment images. Select category. Processing takes 2-3 minutes on GPU T4.", | |
| api_name="generate_tryon", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| print("π Starting...", file=sys.stderr) | |
| try: | |
| load_models() | |
| except: | |
| pass | |
| demo.queue().launch(show_error=True) |