| from functools import partial |
| import os |
| from typing import List |
| from pathlib import Path |
| import math |
|
|
| from tqdm import tqdm |
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| def get_batch_infer_args(parser=None): |
| |
| if parser is None: |
| import argparse |
| parser = argparse.ArgumentParser() |
|
|
| def str2bool(v): |
| if isinstance(v, bool): |
| return v |
| if v.lower() in ('yes', 'true', 't', 'y', '1'): |
| return True |
| elif v.lower() in ('no', 'false', 'f', 'n', '0'): |
| return False |
| else: |
| raise argparse.ArgumentTypeError('Boolean value expected.') |
| |
|
|
| |
| parser.add_argument("--model-config", type=str, required=False, default=None) |
| parser.add_argument("--model-weight", type=str, required=False, default=None) |
|
|
| |
| parser.add_argument("--num-step", type=int, required=False, default=20) |
| parser.add_argument("--cfg", type=float, required=False, default=2.5) |
| parser.add_argument("--pst", type=str2bool, required=False, default=True) |
| parser.add_argument("--cps", type=str2bool, required=False, default=False) |
| parser.add_argument("--noise-offset", type=float, required=False, default=0.0357) |
| parser.add_argument("--seed", type=int, default=0, required=False) |
|
|
|
|
| |
| parser.add_argument("--real-dir", type=Path, required=True) |
| parser.add_argument("--mask-dir", type=Path, required=False) |
| parser.add_argument("--resolution", type=int, default=512, required=False) |
|
|
| |
| parser.add_argument("--device", type=str, required=False, default="cuda") |
| parser.add_argument("--batch-size", type=int, required=False, default=32) |
| parser.add_argument("--num-workers", type=int, required=False, default=64) |
|
|
| |
| parser.add_argument("--save-dir", type=str, required=True) |
| parser.add_argument("--visualize-latent", action="store_true", default=False) |
|
|
| return parser.parse_args() |
|
|
| def build_pipeline(args): |
| from diffusers import DDIMScheduler |
| from removal.v1_2.pipeline import RemovalSDXLPipeline_BatchMode as Removal_Pipeline |
| from removal.v1_2 import build_removal_model, load_cfg, load_removal_model |
| from utils_train import build_vae |
|
|
|
|
| model_cfg = load_cfg(args.model_config) |
|
|
| removal_model = build_removal_model(model_cfg, 20).to(args.device) |
| print(load_removal_model(removal_model, args.model_weight,args.device)) |
|
|
| vae = build_vae(model_cfg).to(args.device) |
| scheduler = DDIMScheduler( |
| beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", |
| num_train_timesteps=1000, clip_sample=False) |
|
|
| pipe = Removal_Pipeline( |
| removal_model=removal_model, |
| vae=vae, |
| scheduler=scheduler, |
| device=args.device, |
| dtype=torch.float) |
|
|
| return pipe |
|
|
| class SAVER: |
| @staticmethod |
| def save_image(img, name, path): |
| img.save(path / name) |
| return name |
|
|
| @staticmethod |
| def save_images(images:List[Image.Image], names:List[str], save_root:str): |
| assert len(images) == len(names), \ |
| f"images and names are not equal: {len(images)}!={len(names)}" |
| |
| pbar_save = tqdm(zip(images, names), total=len(names)) |
|
|
| cache_names = os.listdir(save_root) |
| for image, name in pbar_save: |
| if name not in cache_names: |
| SAVER.save_image(image, name, save_root) |
|
|
| @staticmethod |
| def save_images_mt(images:List[Image.Image], names:List[str], save_root:str, num_workers=8): |
| from concurrent.futures import ThreadPoolExecutor |
| with ThreadPoolExecutor(max_workers=num_workers) as executor: |
| futures = [ |
| executor.submit(SAVER.save_image, image, name, save_root) for image, name in zip(images, names)] |
| |
| for future in tqdm(futures): |
| future.result() |
| |
| @staticmethod |
| def save_images_mp(images:List[Image.Image], names:List[str], save_root:str, num_workers=8): |
| from concurrent.futures import ProcessPoolExecutor |
| with ProcessPoolExecutor(max_workers=num_workers) as executor: |
| futures = [ |
| executor.submit(SAVER.save_image, image, name, save_root) for image, name in zip(images, names)] |
| |
| for future in tqdm(futures): |
| future.result() |
|
|
|
|
|
|
|
|