| import spaces |
| import gradio as gr |
| from PIL import Image |
| from src.tryon_pipeline import StableDiffusionXLInpaintPipeline as TryonPipeline |
| from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref |
| from src.unet_hacked_tryon import UNet2DConditionModel |
| from transformers import ( |
| CLIPImageProcessor, |
| CLIPVisionModelWithProjection, |
| CLIPTextModel, |
| CLIPTextModelWithProjection, |
| ) |
| from diffusers import DDPMScheduler,AutoencoderKL |
| from typing import List |
| import logging |
| import torch |
| import os |
| from transformers import AutoTokenizer |
| import numpy as np |
| from utils_mask import get_mask_location |
| from torchvision import transforms |
| import apply_net |
| from preprocess.humanparsing.run_parsing import Parsing |
| from preprocess.openpose.run_openpose import OpenPose |
| from detectron2.data.detection_utils import convert_PIL_to_numpy,_apply_exif_orientation |
| from torchvision.transforms.functional import to_pil_image |
| from PIL import Image, ImageDraw, ImageFont |
|
|
| class GenerationError(Exception): |
| """Custom exception for generation errors""" |
| pass |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| datefmt='%Y-%m-%d %H:%M:%S' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| def pil_to_binary_mask(pil_image, threshold=0): |
| np_image = np.array(pil_image) |
| grayscale_image = Image.fromarray(np_image).convert("L") |
| binary_mask = np.array(grayscale_image) > threshold |
| mask = np.zeros(binary_mask.shape, dtype=np.uint8) |
| for i in range(binary_mask.shape[0]): |
| for j in range(binary_mask.shape[1]): |
| if binary_mask[i,j] == True : |
| mask[i,j] = 1 |
| mask = (mask*255).astype(np.uint8) |
| output_mask = Image.fromarray(mask) |
| return output_mask |
|
|
|
|
| base_path = 'yisol/IDM-VTON' |
| example_path = os.path.join(os.path.dirname(__file__), 'example') |
|
|
| unet = UNet2DConditionModel.from_pretrained( |
| base_path, |
| subfolder="unet", |
| torch_dtype=torch.float16, |
| ) |
| unet.requires_grad_(False) |
| tokenizer_one = AutoTokenizer.from_pretrained( |
| base_path, |
| subfolder="tokenizer", |
| revision=None, |
| use_fast=False, |
| ) |
| tokenizer_two = AutoTokenizer.from_pretrained( |
| base_path, |
| subfolder="tokenizer_2", |
| revision=None, |
| use_fast=False, |
| ) |
| noise_scheduler = DDPMScheduler.from_pretrained(base_path, subfolder="scheduler") |
|
|
| text_encoder_one = CLIPTextModel.from_pretrained( |
| base_path, |
| subfolder="text_encoder", |
| torch_dtype=torch.float16, |
| ) |
| text_encoder_two = CLIPTextModelWithProjection.from_pretrained( |
| base_path, |
| subfolder="text_encoder_2", |
| torch_dtype=torch.float16, |
| ) |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
| base_path, |
| subfolder="image_encoder", |
| torch_dtype=torch.float16, |
| ) |
| vae = AutoencoderKL.from_pretrained(base_path, |
| subfolder="vae", |
| torch_dtype=torch.float16, |
| ) |
|
|
| |
| UNet_Encoder = UNet2DConditionModel_ref.from_pretrained( |
| base_path, |
| subfolder="unet_encoder", |
| torch_dtype=torch.float16, |
| ) |
|
|
| parsing_model = Parsing(0) |
| openpose_model = OpenPose(0) |
|
|
| UNet_Encoder.requires_grad_(False) |
| image_encoder.requires_grad_(False) |
| vae.requires_grad_(False) |
| unet.requires_grad_(False) |
| text_encoder_one.requires_grad_(False) |
| text_encoder_two.requires_grad_(False) |
| tensor_transfrom = transforms.Compose( |
| [ |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ] |
| ) |
|
|
| pipe = TryonPipeline.from_pretrained( |
| base_path, |
| unet=unet, |
| vae=vae, |
| feature_extractor= CLIPImageProcessor(), |
| text_encoder = text_encoder_one, |
| text_encoder_2 = text_encoder_two, |
| tokenizer = tokenizer_one, |
| tokenizer_2 = tokenizer_two, |
| scheduler = noise_scheduler, |
| image_encoder=image_encoder, |
| torch_dtype=torch.float16, |
| ) |
| pipe.unet_encoder = UNet_Encoder |
|
|
|
|
| |
| try: |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| logger.info("Loading NSFW detector...") |
| from transformers import AutoProcessor, AutoModelForImageClassification |
| nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection") |
| nsfw_model = AutoModelForImageClassification.from_pretrained( |
| "Falconsai/nsfw_image_detection" |
| ).to(device) |
| logger.info("NSFW detector loaded successfully.") |
| except Exception as e: |
| logger.error(f"Failed to load NSFW detector: {e}") |
| nsfw_model = None |
| nsfw_processor = None |
| |
|
|
| def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool: |
| """Returns True if image is NSFW""" |
| inputs = nsfw_processor(images=image, return_tensors="pt").to(device) |
| with torch.no_grad(): |
| outputs = nsfw_model(**inputs) |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| nsfw_score = probs[0][1].item() |
| return nsfw_score > threshold |
|
|
|
|
| progress=gr.Progress() |
|
|
| @spaces.GPU |
| def _infer(person,garment,denoise_steps,seed): |
| progress(0,desc="Starting") |
| device = "cuda" |
|
|
| try: |
| openpose_model.preprocessor.body_estimation.model.to(device) |
| pipe.to(device) |
| pipe.unet_encoder.to(device) |
|
|
| personRGB = person.convert("RGB") |
| crop_size = personRGB.size |
|
|
| human_img = personRGB.resize((768,1024)) |
| garm_img= garment.convert("RGB").resize((768,1024)) |
| |
| progress(0.1,desc="Mask generating") |
| |
| keypoints = openpose_model(human_img.resize((384,512))) |
| model_parse, _ = parsing_model(human_img.resize((384,512))) |
| mask, mask_gray = get_mask_location('hd', "upper_body", model_parse, keypoints) |
| mask = mask.resize((768,1024)) |
|
|
| mask_gray = (1-transforms.ToTensor()(mask)) * tensor_transfrom(human_img) |
| mask_gray = to_pil_image((mask_gray+1.0)/2.0) |
|
|
| progress(0.3,desc="DensePose processing") |
|
|
| human_img_arg = _apply_exif_orientation(human_img.resize((384,512))) |
| human_img_arg = convert_PIL_to_numpy(human_img_arg, format="BGR") |
|
|
| args = apply_net.create_argument_parser().parse_args(('show', './configs/densepose_rcnn_R_50_FPN_s1x.yaml', './ckpt/densepose/model_final_162be9.pkl', 'dp_segm', '-v', '--opts', 'MODEL.DEVICE', 'cuda')) |
| |
| pose_img = args.func(args,human_img_arg) |
| pose_img = pose_img[:,:,::-1] |
| pose_img = Image.fromarray(pose_img).resize((768,1024)) |
|
|
| progress(0.5,desc="Image generating") |
|
|
| def callback(pipe, step, timestep, callback_kwargs): |
| progress_value = 0.5 + ((step+1.0)/denoise_steps)*(0.5/1.0) |
| progress(progress_value, desc=f"Image generating, {step + 1}/{denoise_steps} steps") |
| return callback_kwargs |
| |
| with torch.no_grad(): |
| |
| with torch.cuda.amp.autocast(): |
| with torch.no_grad(): |
| prompt = "model is wearing clothing" |
| negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" |
| with torch.inference_mode(): |
| ( |
| prompt_embeds, |
| negative_prompt_embeds, |
| pooled_prompt_embeds, |
| negative_pooled_prompt_embeds, |
| ) = pipe.encode_prompt( |
| prompt, |
| num_images_per_prompt=1, |
| do_classifier_free_guidance=True, |
| negative_prompt=negative_prompt, |
| ) |
| |
| prompt = "a photo of clothing" |
| negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" |
| if not isinstance(prompt, List): |
| prompt = [prompt] * 1 |
| if not isinstance(negative_prompt, List): |
| negative_prompt = [negative_prompt] * 1 |
| with torch.inference_mode(): |
| ( |
| prompt_embeds_c, |
| _, |
| _, |
| _, |
| ) = pipe.encode_prompt( |
| prompt, |
| num_images_per_prompt=1, |
| do_classifier_free_guidance=False, |
| negative_prompt=negative_prompt, |
| ) |
|
|
| pose_img = tensor_transfrom(pose_img).unsqueeze(0).to(device,torch.float16) |
| garm_tensor = tensor_transfrom(garm_img).unsqueeze(0).to(device,torch.float16) |
| generator = torch.Generator(device).manual_seed(seed) if seed is not None else None |
| images = pipe( |
| prompt_embeds=prompt_embeds.to(device,torch.float16), |
| negative_prompt_embeds=negative_prompt_embeds.to(device,torch.float16), |
| pooled_prompt_embeds=pooled_prompt_embeds.to(device,torch.float16), |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.to(device,torch.float16), |
| num_inference_steps=denoise_steps, |
| generator=generator, |
| strength = 1.0, |
| pose_img = pose_img.to(device,torch.float16), |
| text_embeds_cloth=prompt_embeds_c.to(device,torch.float16), |
| cloth = garm_tensor.to(device,torch.float16), |
| mask_image=mask, |
| image=human_img, |
| height=1024, |
| width=768, |
| ip_adapter_image = garm_img.resize((768,1024)), |
| guidance_scale=2.0, |
| callback_on_step_end=callback |
| )[0] |
| out_img = images[0].resize(crop_size) |
| |
| |
| if nsfw_model and nsfw_processor: |
| if detect_nsfw(out_img): |
| msg = "Generated image contains NSFW content and cannot be displayed. Please provide a different image and try again." |
| raise Exception(msg) |
|
|
| info = { |
| "status": "success" |
| } |
| progress(1,desc="Complete") |
| return out_img, info |
| except GenerationError as e: |
| error_info = { |
| "error": str(e), |
| "status": "failed", |
| } |
| return None, error_info |
| except Exception as e: |
| error_info = { |
| "error": str(e), |
| "status": "failed", |
| } |
| return None, error_info |
|
|
|
|
| def infer(person,garment,denoise_steps,seed): |
| |
| out_img, info = _infer(person, garment,denoise_steps,seed) |
|
|
| |
| if info["status"] == "failed": |
| raise gr.Error(info["error"]) |
| |
| |
| return out_img |
|
|
|
|
| title = "## AI Clothes Changer" |
| description = "Step into the world of AI clothes swap and unlock style possibilities with [AI Clothes Changer](https://www.aiclotheschanger.org)" |
|
|
| example_path = os.path.join(os.path.dirname(__file__), 'example') |
| person_list = os.listdir(os.path.join(example_path,"human")) |
| person_images = [os.path.join(example_path,"human",person) for person in person_list] |
|
|
| garment_list = os.listdir(os.path.join(example_path,"cloth")) |
| garment_images = [os.path.join(example_path,"cloth",garment) for garment in garment_list] |
|
|
|
|
| with gr.Blocks().queue() as demo: |
| gr.Markdown(title) |
| gr.Markdown(description) |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### Person Image") |
| person_image = gr.Image( |
| sources=["upload"], |
| type="pil", |
| label="Person Image", |
| width=512, |
| height=512, |
| ) |
|
|
| gr.Examples( |
| inputs=person_image, |
| examples_per_page=20, |
| examples=person_images, |
| ) |
| with gr.Column(): |
| gr.Markdown("#### Garment Image") |
| garment_image = gr.Image( |
| sources=["upload"], |
| type="pil", |
| label="Garment Image", |
| width=512, |
| height=512, |
| ) |
|
|
| gr.Examples( |
| inputs=garment_image, |
| examples_per_page=20, |
| examples=garment_images, |
| ) |
| with gr.Column(): |
| gr.Markdown("#### Generated Image") |
|
|
| gen_image = gr.Image( |
| label="Generated Image", |
| width=512, |
| height=512, |
| ) |
|
|
| with gr.Row(): |
| gen_button = gr.Button("Generate") |
|
|
| with gr.Accordion("Advanced Options", open=False): |
| denoise_steps = gr.Number(label="Denoising Steps", minimum=20, maximum=40, value=30, step=1) |
| seed = gr.Number(label="Seed", minimum=-1, maximum=2147483647, step=1, value=42) |
| |
| gen_button.click( |
| fn=infer, |
| inputs=[person_image, garment_image, denoise_steps, seed], |
| outputs=[gen_image] |
| ) |
|
|
| demo.launch() |