File size: 4,779 Bytes
bfa59ab efed4da bfa59ab efed4da bfa59ab 2868b95 efed4da 2868b95 bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab efed4da bfa59ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | #!/usr/bin/env python
# -*- coding:utf-8 -*-
import os, sys, math, random
import cv2
import numpy as np
from pathlib import Path
from loguru import logger
from omegaconf import OmegaConf
from utils import util_net
from utils import util_image
from utils import util_common
from utils import util_color_fix
import torch
import torch.nn.functional as F
from datapipe.datasets import create_dataset
from diffusers import StableDiffusionInvEnhancePipeline
_positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, meticulous detailing'
_negative= 'Low quality, blurring, jpeg artifacts, deformed, noisy'
def get_torch_dtype(torch_dtype: str):
# 🔥 Force float32 for CPU
return torch.float32
class BaseSampler:
def __init__(self, configs):
self.configs = configs
# ✅ CPU device
self.device = torch.device("cpu")
self.setup_seed()
self.build_model()
def setup_seed(self, seed=None):
seed = self.configs.seed if seed is None else seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def write_log(self, log_str):
print(log_str, flush=True)
def build_model(self):
params = dict(self.configs.sd_pipe.params)
params['torch_dtype'] = torch.float32 # CPU safe
base_pipe = util_common.get_obj_from_str(
self.configs.sd_pipe.target
).from_pretrained(**params)
if self.configs.get('scheduler', None) is not None:
base_pipe.scheduler = util_common.get_obj_from_str(
self.configs.scheduler.target
).from_config(base_pipe.scheduler.config)
if self.configs.base_model in ['sd-turbo', 'sd2base']:
sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe)
else:
raise ValueError(f"Unsupported base model: {self.configs.base_model}!")
# ✅ move to CPU
sd_pipe.to(self.device)
model_configs = self.configs.model_start
params = model_configs.get('params', dict)
model_start = util_common.get_obj_from_str(model_configs.target)(**params)
model_start.to(self.device)
ckpt_path = model_configs.get('ckpt_path')
self.write_log(f"Loading model from {ckpt_path}...")
state = torch.load(ckpt_path, map_location=self.device)
if 'state_dict' in state:
state = state['state_dict']
util_net.reload_model(model_start, state)
model_start.eval()
setattr(sd_pipe, 'start_noise_predictor', model_start)
self.sd_pipe = sd_pipe
class InvSamplerSR(BaseSampler):
@torch.no_grad()
def sample_func(self, im_cond):
im_cond = im_cond.to(self.device)
negative_prompt = [_negative]*im_cond.shape[0] if self.configs.cfg_scale > 1.0 else None
idle_pch_size = self.configs.basesr.chopping.pch_size
if min(im_cond.shape[-2:]) >= idle_pch_size:
pad_h_up = pad_w_left = 0
else:
pad_h_up = pad_w_left = 0
target_size = (
im_cond.shape[-2] * self.configs.basesr.sf,
im_cond.shape[-1] * self.configs.basesr.sf
)
res_sr = self.sd_pipe(
image=im_cond.float(), # ✅ float32
prompt=[_positive]*im_cond.shape[0],
negative_prompt=negative_prompt,
target_size=target_size,
timesteps=self.configs.timesteps,
guidance_scale=self.configs.cfg_scale,
output_type="pt",
).images
res_sr = res_sr.clamp(0.0, 1.0).cpu().permute(0,2,3,1).numpy()
return res_sr
def inference(self, in_path, out_path, bs=1):
in_path = Path(in_path)
out_path = Path(out_path)
out_path.mkdir(parents=True, exist_ok=True)
if in_path.is_dir():
dataset = create_dataset({
'type': 'base',
'params': {'dir_path': str(in_path)}
})
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs)
for data in dataloader:
res = self.sample_func(data['lq'])
for jj in range(res.shape[0]):
save_path = str(out_path / f"{jj}.png")
util_image.imwrite(res[jj], save_path, dtype_in='float32')
else:
im_cond = util_image.imread(in_path, chn='rgb', dtype='float32')
im_cond = util_image.img2tensor(im_cond).to(self.device)
image = self.sample_func(im_cond).squeeze(0)
save_path = str(out_path / f"{in_path.stem}.png")
util_image.imwrite(image, save_path, dtype_in='float32')
self.write_log(f"Done → {out_path}")
if __name__ == '__main__':
pass
|