moebius / infer /utils.py
Mike0021's picture
Implement Moebius Gradio Space
166ab04 verified
Raw
History Blame Contribute Delete
4.4 kB
from functools import partial
import os
from typing import List
from pathlib import Path
import math
from tqdm import tqdm
import numpy as np
import torch
from PIL import Image
def get_batch_infer_args(parser=None):
if parser is None:
import argparse
parser = argparse.ArgumentParser()
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
# model argument
parser.add_argument("--model-config", type=str, required=False, default=None)
parser.add_argument("--model-weight", type=str, required=False, default=None)
# sampling argument
parser.add_argument("--num-step", type=int, required=False, default=20)
parser.add_argument("--cfg", type=float, required=False, default=2.5)
parser.add_argument("--pst", type=str2bool, required=False, default=True)
parser.add_argument("--cps", type=str2bool, required=False, default=False)
parser.add_argument("--noise-offset", type=float, required=False, default=0.0357)
parser.add_argument("--seed", type=int, default=0, required=False)
# data argument
parser.add_argument("--real-dir", type=Path, required=True)
parser.add_argument("--mask-dir", type=Path, required=False)
parser.add_argument("--resolution", type=int, default=512, required=False)
# runtime argument
parser.add_argument("--device", type=str, required=False, default="cuda")
parser.add_argument("--batch-size", type=int, required=False, default=32)
parser.add_argument("--num-workers", type=int, required=False, default=64)
# save argument
parser.add_argument("--save-dir", type=str, required=True)
parser.add_argument("--visualize-latent", action="store_true", default=False)
return parser.parse_args()
def build_pipeline(args):
from diffusers import DDIMScheduler
from removal.v1_2.pipeline import RemovalSDXLPipeline_BatchMode as Removal_Pipeline
from removal.v1_2 import build_removal_model, load_cfg, load_removal_model
from utils_train import build_vae
model_cfg = load_cfg(args.model_config)
removal_model = build_removal_model(model_cfg, 20).to(args.device)
print(load_removal_model(removal_model, args.model_weight,args.device))
vae = build_vae(model_cfg).to(args.device)
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
num_train_timesteps=1000, clip_sample=False)
pipe = Removal_Pipeline(
removal_model=removal_model,
vae=vae,
scheduler=scheduler,
device=args.device,
dtype=torch.float)
return pipe
class SAVER:
@staticmethod
def save_image(img, name, path):
img.save(path / name)
return name
@staticmethod
def save_images(images:List[Image.Image], names:List[str], save_root:str):
assert len(images) == len(names), \
f"images and names are not equal: {len(images)}!={len(names)}"
pbar_save = tqdm(zip(images, names), total=len(names))
cache_names = os.listdir(save_root)
for image, name in pbar_save:
if name not in cache_names:
SAVER.save_image(image, name, save_root)
@staticmethod
def save_images_mt(images:List[Image.Image], names:List[str], save_root:str, num_workers=8):
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [
executor.submit(SAVER.save_image, image, name, save_root) for image, name in zip(images, names)]
for future in tqdm(futures):
future.result()
@staticmethod
def save_images_mp(images:List[Image.Image], names:List[str], save_root:str, num_workers=8):
from concurrent.futures import ProcessPoolExecutor
with ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = [
executor.submit(SAVER.save_image, image, name, save_root) for image, name in zip(images, names)]
for future in tqdm(futures):
future.result()