Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| #import random | |
| import spaces #[uncomment to use ZeroGPU] | |
| #from diffusers import DiffusionPipeline | |
| import torch | |
| from diffusers import AutoPipelineForInpainting | |
| from diffusers.utils import load_image | |
| from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline | |
| #import cv2 | |
| #import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import os | |
| import gc | |
| import glob | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| GDINO_MODEL_NAME="IDEA-Research/grounding-dino-tiny" | |
| SAM_MODEL_NAME="facebook/sam-vit-base" | |
| GDINO=pipeline(model=GDINO_MODEL_NAME, task="zero-shot-object-detection", device=DEVICE) | |
| SAM=AutoModelForMaskGeneration.from_pretrained(SAM_MODEL_NAME).to(DEVICE) | |
| SAM_PROCESSOR=AutoProcessor.from_pretrained(SAM_MODEL_NAME) | |
| SD_MODEL="diffusers/stable-diffusion-xl-1.0-inpainting-0.1" | |
| SD_PIPLINE = AutoPipelineForInpainting.from_pretrained(SD_MODEL, torch_dtype=torch.float16).to(DEVICE) | |
| IP_ADAPTER="h94/IP-Adapter" | |
| SUB_FOLDER="sdxl_models" | |
| IP_WEIGHT_NAME="ip-adapter_sdxl.bin" | |
| SD_PIPLINE.load_ip_adapter(IP_ADAPTER, subfolder=SUB_FOLDER, weight_name=IP_WEIGHT_NAME) | |
| IP_SCALE=0.6 | |
| SD_PIPLINE.set_ip_adapter_scale(IP_SCALE) | |
| GEN_STEPS=100 | |
| def refine_masks(masks: torch.BoolTensor)->np.array: | |
| masks = masks.permute(0, 2, 3, 1) | |
| masks = masks.float().mean(axis=-1) | |
| return masks.cpu().numpy() | |
| def get_boxes(detections:list)-> list: | |
| boxes = [] | |
| for det in detections: | |
| boxes.append([det['box']['xmin'], det['box']['ymin'], | |
| det['box']['xmax'], det['box']['ymax']]) | |
| return [boxes] | |
| def get_mask(img:Image, prompt:str, d_model:pipeline, s_model:AutoModelForMaskGeneration, | |
| s_processor:AutoProcessor, device:str, threshold:float = 0.3)-> np.array: | |
| labels = [label if label.endswith(".") else label+"." for label in ['face', prompt]] | |
| dets=d_model(img, candidate_labels=labels, threshold=threshold) | |
| boxes = get_boxes(dets) | |
| inputs=s_processor(images=img, input_boxes=boxes, return_tensors="pt").to(DEVICE) | |
| outputs = s_model(**inputs) | |
| masks = s_processor.post_process_masks( | |
| masks=outputs.pred_masks, | |
| original_sizes=inputs.original_sizes, | |
| reshaped_input_sizes=inputs.reshaped_input_sizes | |
| )[0] | |
| return refine_masks(masks) | |
| def generate_result(model_img:str, cloth_img:str, | |
| masks: np.array, prompt:str, sd_pipline:AutoPipelineForInpainting, n_steps:int=100)->Image: | |
| width, height = model_img.size | |
| cloth_mask=masks[1] #np.array(masks[1],dtype=np.float32) | |
| generator = torch.Generator(device="cpu").manual_seed(4) | |
| images = sd_pipline( | |
| prompt=prompt, | |
| image=model_img, | |
| mask_image=cloth_mask, | |
| ip_adapter_image=cloth_img, | |
| generator=generator, | |
| num_inference_steps=n_steps, | |
| ).images | |
| return images[0].resize((width, height)) | |
| def run(model_img:Image, cloth_img:Image, cloth_class:str, close_description:str)->Image: | |
| masks = get_mask(model_img, cloth_class, GDINO, SAM, SAM_PROCESSOR, DEVICE) #GSAM2) | |
| result = generate_result(model_img, cloth_img, masks, close_description, SD_PIPLINE, GEN_STEPS) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return result | |
| gr.Interface( | |
| run, | |
| title = 'Virtual Try-On', | |
| inputs=[ | |
| gr.Image(sources = 'upload', label='Model image', type = 'pil'), | |
| gr.Image(sources = 'upload', label='Cloth image', type = 'pil'), | |
| gr.Textbox(label = 'Cloth class'), | |
| gr.Textbox(label = 'Close description') | |
| ], | |
| outputs = [ | |
| gr.Image() | |
| ], | |
| examples=[ | |
| ["./examples/models/girl1.jpg", "./examples/clothes/t_short.jpg", "shirt", "black shirt"], | |
| ] | |
| ).launch(debug=True,share=True) |