Spaces:
Paused
Paused
| import os | |
| import io | |
| # external libraries | |
| import torch | |
| import torch.utils.checkpoint | |
| import torch.utils.checkpoint | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from diffusers import AutoencoderKL, DDIMScheduler | |
| from diffusers.utils import check_min_version | |
| from diffusers.utils.import_utils import is_xformers_available | |
| from diffusers import UNet2DConditionModel | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| # custom imports | |
| from model.src.datasets.dresscode import DressCodeDataset | |
| from model.src.datasets.vitonhd import VitonHDDataset | |
| from model.src.mgd_pipelines.mgd_pipe import MGDPipe | |
| from model.src.mgd_pipelines.mgd_pipe_disentangled import MGDPipeDisentangled | |
| from model.src.utils.arg_parser import eval_parse_args | |
| from model.src.utils.image_from_pipe import generate_images_from_mgd_pipe | |
| from model.src.utils.set_seeds import set_seed | |
| from PIL import Image | |
| from huggingface_hub import HfApi, HfFolder | |
| # Will error if the minimal version of diffusers is not installed. Remove at your own risks. | |
| check_min_version("0.10.0.dev0") | |
| logger = get_logger(__name__, log_level="INFO") | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| os.environ["WANDB_START_METHOD"] = "thread" | |
| hf_token = os.getenv("HF_TOKEN") | |
| api = HfApi() | |
| HfFolder.save_token(hf_token) | |
| def main(json_from_req: dict) -> None: | |
| args = eval_parse_args() | |
| accelerator = Accelerator( | |
| mixed_precision=args.mixed_precision, | |
| ) | |
| device = accelerator.device | |
| # If passed along, set the training seed now. | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| # Load scheduler, tokenizer and models. | |
| val_scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | |
| val_scheduler.set_timesteps(50, device=device) | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision | |
| ) | |
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) | |
| unet = load_mgd_model(dataset=args.dataset, pretrained=True) | |
| #unet = torch.hub.load(dataset=args.dataset, repo_or_dir='aimagelab/multimodal-garment-designer', source='github', | |
| #model='mgd', pretrained=True) | |
| # Freeze vae and text_encoder | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| # Enable memory efficient attention if requested | |
| if args.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| if args.category: | |
| category = [args.category] | |
| else: | |
| category = ['dresses', 'upper_body', 'lower_body'] | |
| if args.dataset == "dresscode": | |
| test_dataset = DressCodeDataset( | |
| dataroot_path=args.dataset_path, | |
| phase='test', | |
| order=args.test_order, | |
| radius=5, | |
| sketch_threshold_range=(20, 20), | |
| tokenizer=tokenizer, | |
| category=category, | |
| size=(512, 384), | |
| json_from_req=json_from_req | |
| ) | |
| elif args.dataset == "vitonhd": | |
| test_dataset = VitonHDDataset( | |
| dataroot_path=args.dataset_path, | |
| phase='test', | |
| order=args.test_order, | |
| sketch_threshold_range=(20, 20), | |
| radius=5, | |
| tokenizer=tokenizer, | |
| size=(512, 384), | |
| json_from_req=json_from_req | |
| ) | |
| else: | |
| raise NotImplementedError | |
| test_dataloader = torch.utils.data.DataLoader( | |
| test_dataset, | |
| shuffle=False, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers_test, | |
| ) | |
| # For mixed precision training we cast the text_encoder and vae weights to half-precision | |
| # as these models are only used for inference, keeping weights in full precision is not required. | |
| weight_dtype = torch.float32 | |
| if args.mixed_precision == 'fp16': | |
| weight_dtype = torch.float16 | |
| # Move text_encode and vae to gpu and cast to weight_dtype | |
| text_encoder.to(device, dtype=weight_dtype) | |
| vae.to(device, dtype=weight_dtype) | |
| unet.eval() | |
| # Select fast classifier free guidance or disentagle classifier free guidance according to the disentagle parameter in args | |
| with torch.inference_mode(): | |
| if args.disentagle: | |
| val_pipe = MGDPipeDisentangled( | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| unet=unet.to(vae.dtype), | |
| tokenizer=tokenizer, | |
| scheduler=val_scheduler, | |
| ).to(device) | |
| else: | |
| val_pipe = MGDPipe( | |
| text_encoder=text_encoder, | |
| vae=vae, | |
| unet=unet.to(vae.dtype), | |
| tokenizer=tokenizer, | |
| scheduler=val_scheduler, | |
| ).to(device) | |
| val_pipe.enable_attention_slicing() | |
| test_dataloader = accelerator.prepare(test_dataloader) | |
| final_image = generate_images_from_mgd_pipe( | |
| test_order=args.test_order, | |
| pipe=val_pipe, | |
| test_dataloader=test_dataloader, | |
| save_name=args.save_name, | |
| dataset=args.dataset, | |
| output_dir=args.output_dir, | |
| guidance_scale=args.guidance_scale, | |
| guidance_scale_pose=args.guidance_scale_pose, | |
| guidance_scale_sketch=args.guidance_scale_sketch, | |
| sketch_cond_rate=args.sketch_cond_rate, | |
| start_cond_rate=args.start_cond_rate, | |
| no_pose=False, | |
| disentagle=False, | |
| seed=args.seed, | |
| ) | |
| return final_image # Now returning the generated image | |
| def load_mgd_model(dataset: str, pretrained: bool = True) -> UNet2DConditionModel: | |
| """ | |
| MGD model | |
| pretrained (bool): load pretrained weights into the model | |
| """ | |
| config = UNet2DConditionModel.load_config("benjamin-paine/stable-diffusion-v1-5-inpainting", subfolder="unet") | |
| config['in_channels'] = 28 | |
| unet = UNet2DConditionModel.from_config(config) | |
| if pretrained: | |
| checkpoint = f"https://github.com/aimagelab/multimodal-garment-designer/releases/download/weights/{dataset}.pth" | |
| unet.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True)) | |
| return unet | |
| if __name__ == "__main__": | |
| main() | |