| |
| |
|
|
| import os, sys, math, random |
|
|
| import cv2 |
| import numpy as np |
| from pathlib import Path |
| from loguru import logger |
| from omegaconf import OmegaConf |
|
|
| from utils import util_net |
| from utils import util_image |
| from utils import util_common |
| from utils import util_color_fix |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from datapipe.datasets import create_dataset |
| from diffusers import StableDiffusionInvEnhancePipeline |
|
|
| _positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, meticulous detailing' |
| _negative= 'Low quality, blurring, jpeg artifacts, deformed, noisy' |
|
|
| def get_torch_dtype(torch_dtype: str): |
| |
| return torch.float32 |
|
|
|
|
| class BaseSampler: |
| def __init__(self, configs): |
| self.configs = configs |
|
|
| |
| self.device = torch.device("cpu") |
|
|
| self.setup_seed() |
| self.build_model() |
|
|
| def setup_seed(self, seed=None): |
| seed = self.configs.seed if seed is None else seed |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
|
|
| def write_log(self, log_str): |
| print(log_str, flush=True) |
|
|
| def build_model(self): |
| params = dict(self.configs.sd_pipe.params) |
| params['torch_dtype'] = torch.float32 |
|
|
| base_pipe = util_common.get_obj_from_str( |
| self.configs.sd_pipe.target |
| ).from_pretrained(**params) |
|
|
| if self.configs.get('scheduler', None) is not None: |
| base_pipe.scheduler = util_common.get_obj_from_str( |
| self.configs.scheduler.target |
| ).from_config(base_pipe.scheduler.config) |
|
|
| if self.configs.base_model in ['sd-turbo', 'sd2base']: |
| sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe) |
| else: |
| raise ValueError(f"Unsupported base model: {self.configs.base_model}!") |
|
|
| |
| sd_pipe.to(self.device) |
|
|
| model_configs = self.configs.model_start |
| params = model_configs.get('params', dict) |
|
|
| model_start = util_common.get_obj_from_str(model_configs.target)(**params) |
| model_start.to(self.device) |
|
|
| ckpt_path = model_configs.get('ckpt_path') |
| self.write_log(f"Loading model from {ckpt_path}...") |
|
|
| state = torch.load(ckpt_path, map_location=self.device) |
|
|
| if 'state_dict' in state: |
| state = state['state_dict'] |
|
|
| util_net.reload_model(model_start, state) |
|
|
| model_start.eval() |
| setattr(sd_pipe, 'start_noise_predictor', model_start) |
|
|
| self.sd_pipe = sd_pipe |
|
|
|
|
| class InvSamplerSR(BaseSampler): |
| @torch.no_grad() |
| def sample_func(self, im_cond): |
|
|
| im_cond = im_cond.to(self.device) |
|
|
| negative_prompt = [_negative]*im_cond.shape[0] if self.configs.cfg_scale > 1.0 else None |
|
|
| idle_pch_size = self.configs.basesr.chopping.pch_size |
|
|
| if min(im_cond.shape[-2:]) >= idle_pch_size: |
| pad_h_up = pad_w_left = 0 |
| else: |
| pad_h_up = pad_w_left = 0 |
|
|
| target_size = ( |
| im_cond.shape[-2] * self.configs.basesr.sf, |
| im_cond.shape[-1] * self.configs.basesr.sf |
| ) |
|
|
| res_sr = self.sd_pipe( |
| image=im_cond.float(), |
| prompt=[_positive]*im_cond.shape[0], |
| negative_prompt=negative_prompt, |
| target_size=target_size, |
| timesteps=self.configs.timesteps, |
| guidance_scale=self.configs.cfg_scale, |
| output_type="pt", |
| ).images |
|
|
| res_sr = res_sr.clamp(0.0, 1.0).cpu().permute(0,2,3,1).numpy() |
|
|
| return res_sr |
|
|
|
|
| def inference(self, in_path, out_path, bs=1): |
|
|
| in_path = Path(in_path) |
| out_path = Path(out_path) |
| out_path.mkdir(parents=True, exist_ok=True) |
|
|
| if in_path.is_dir(): |
| dataset = create_dataset({ |
| 'type': 'base', |
| 'params': {'dir_path': str(in_path)} |
| }) |
|
|
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs) |
|
|
| for data in dataloader: |
| res = self.sample_func(data['lq']) |
|
|
| for jj in range(res.shape[0]): |
| save_path = str(out_path / f"{jj}.png") |
| util_image.imwrite(res[jj], save_path, dtype_in='float32') |
|
|
| else: |
| im_cond = util_image.imread(in_path, chn='rgb', dtype='float32') |
| im_cond = util_image.img2tensor(im_cond).to(self.device) |
|
|
| image = self.sample_func(im_cond).squeeze(0) |
|
|
| save_path = str(out_path / f"{in_path.stem}.png") |
| util_image.imwrite(image, save_path, dtype_in='float32') |
|
|
| self.write_log(f"Done → {out_path}") |
|
|
| if __name__ == '__main__': |
| pass |
|
|
|
|