#!/usr/bin/env python # -*- coding:utf-8 -*- import os, sys, math, random import cv2 import numpy as np from pathlib import Path from loguru import logger from omegaconf import OmegaConf from utils import util_net from utils import util_image from utils import util_common from utils import util_color_fix import torch import torch.nn.functional as F from datapipe.datasets import create_dataset from diffusers import StableDiffusionInvEnhancePipeline _positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, meticulous detailing' _negative= 'Low quality, blurring, jpeg artifacts, deformed, noisy' def get_torch_dtype(torch_dtype: str): # 🔥 Force float32 for CPU return torch.float32 class BaseSampler: def __init__(self, configs): self.configs = configs # ✅ CPU device self.device = torch.device("cpu") self.setup_seed() self.build_model() def setup_seed(self, seed=None): seed = self.configs.seed if seed is None else seed random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) def write_log(self, log_str): print(log_str, flush=True) def build_model(self): params = dict(self.configs.sd_pipe.params) params['torch_dtype'] = torch.float32 # CPU safe base_pipe = util_common.get_obj_from_str( self.configs.sd_pipe.target ).from_pretrained(**params) if self.configs.get('scheduler', None) is not None: base_pipe.scheduler = util_common.get_obj_from_str( self.configs.scheduler.target ).from_config(base_pipe.scheduler.config) if self.configs.base_model in ['sd-turbo', 'sd2base']: sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe) else: raise ValueError(f"Unsupported base model: {self.configs.base_model}!") # ✅ move to CPU sd_pipe.to(self.device) model_configs = self.configs.model_start params = model_configs.get('params', dict) model_start = util_common.get_obj_from_str(model_configs.target)(**params) model_start.to(self.device) ckpt_path = model_configs.get('ckpt_path') self.write_log(f"Loading model from {ckpt_path}...") state = torch.load(ckpt_path, map_location=self.device) if 'state_dict' in state: state = state['state_dict'] util_net.reload_model(model_start, state) model_start.eval() setattr(sd_pipe, 'start_noise_predictor', model_start) self.sd_pipe = sd_pipe class InvSamplerSR(BaseSampler): @torch.no_grad() def sample_func(self, im_cond): im_cond = im_cond.to(self.device) negative_prompt = [_negative]*im_cond.shape[0] if self.configs.cfg_scale > 1.0 else None idle_pch_size = self.configs.basesr.chopping.pch_size if min(im_cond.shape[-2:]) >= idle_pch_size: pad_h_up = pad_w_left = 0 else: pad_h_up = pad_w_left = 0 target_size = ( im_cond.shape[-2] * self.configs.basesr.sf, im_cond.shape[-1] * self.configs.basesr.sf ) res_sr = self.sd_pipe( image=im_cond.float(), # ✅ float32 prompt=[_positive]*im_cond.shape[0], negative_prompt=negative_prompt, target_size=target_size, timesteps=self.configs.timesteps, guidance_scale=self.configs.cfg_scale, output_type="pt", ).images res_sr = res_sr.clamp(0.0, 1.0).cpu().permute(0,2,3,1).numpy() return res_sr def inference(self, in_path, out_path, bs=1): in_path = Path(in_path) out_path = Path(out_path) out_path.mkdir(parents=True, exist_ok=True) if in_path.is_dir(): dataset = create_dataset({ 'type': 'base', 'params': {'dir_path': str(in_path)} }) dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs) for data in dataloader: res = self.sample_func(data['lq']) for jj in range(res.shape[0]): save_path = str(out_path / f"{jj}.png") util_image.imwrite(res[jj], save_path, dtype_in='float32') else: im_cond = util_image.imread(in_path, chn='rgb', dtype='float32') im_cond = util_image.img2tensor(im_cond).to(self.device) image = self.sample_func(im_cond).squeeze(0) save_path = str(out_path / f"{in_path.stem}.png") util_image.imwrite(image, save_path, dtype_in='float32') self.write_log(f"Done → {out_path}") if __name__ == '__main__': pass