File size: 4,779 Bytes
bfa59ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efed4da
bfa59ab
efed4da
 
bfa59ab
2868b95
efed4da
 
 
2868b95
bfa59ab
 
 
 
efed4da
 
bfa59ab
efed4da
bfa59ab
 
 
 
 
 
 
 
 
 
 
 
 
efed4da
 
 
 
 
 
bfa59ab
efed4da
 
 
 
 
bfa59ab
 
 
efed4da
 
 
bfa59ab
 
 
efed4da
bfa59ab
efed4da
 
bfa59ab
efed4da
 
 
 
bfa59ab
 
efed4da
bfa59ab
efed4da
bfa59ab
 
 
 
 
efed4da
bfa59ab
 
 
efed4da
 
 
 
 
bfa59ab
 
 
 
 
efed4da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfa59ab
 
 
 
efed4da
bfa59ab
efed4da
 
 
bfa59ab
 
efed4da
 
 
 
 
 
 
bfa59ab
efed4da
bfa59ab
 
efed4da
bfa59ab
efed4da
bfa59ab
efed4da
 
bfa59ab
 
 
 
 
 
efed4da
bfa59ab
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#!/usr/bin/env python
# -*- coding:utf-8 -*-

import os, sys, math, random

import cv2
import numpy as np
from pathlib import Path
from loguru import logger
from omegaconf import OmegaConf

from utils import util_net
from utils import util_image
from utils import util_common
from utils import util_color_fix

import torch
import torch.nn.functional as F

from datapipe.datasets import create_dataset
from diffusers import StableDiffusionInvEnhancePipeline

_positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, meticulous detailing'
_negative= 'Low quality, blurring, jpeg artifacts, deformed, noisy'

def get_torch_dtype(torch_dtype: str):
    # 🔥 Force float32 for CPU
    return torch.float32


class BaseSampler:
    def __init__(self, configs):
        self.configs = configs

        # ✅ CPU device
        self.device = torch.device("cpu")

        self.setup_seed()
        self.build_model()

    def setup_seed(self, seed=None):
        seed = self.configs.seed if seed is None else seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

    def write_log(self, log_str):
        print(log_str, flush=True)

    def build_model(self):
        params = dict(self.configs.sd_pipe.params)
        params['torch_dtype'] = torch.float32  # CPU safe

        base_pipe = util_common.get_obj_from_str(
            self.configs.sd_pipe.target
        ).from_pretrained(**params)

        if self.configs.get('scheduler', None) is not None:
            base_pipe.scheduler = util_common.get_obj_from_str(
                self.configs.scheduler.target
            ).from_config(base_pipe.scheduler.config)

        if self.configs.base_model in ['sd-turbo', 'sd2base']:
            sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe)
        else:
            raise ValueError(f"Unsupported base model: {self.configs.base_model}!")

        # ✅ move to CPU
        sd_pipe.to(self.device)

        model_configs = self.configs.model_start
        params = model_configs.get('params', dict)

        model_start = util_common.get_obj_from_str(model_configs.target)(**params)
        model_start.to(self.device)

        ckpt_path = model_configs.get('ckpt_path')
        self.write_log(f"Loading model from {ckpt_path}...")

        state = torch.load(ckpt_path, map_location=self.device)

        if 'state_dict' in state:
            state = state['state_dict']

        util_net.reload_model(model_start, state)

        model_start.eval()
        setattr(sd_pipe, 'start_noise_predictor', model_start)

        self.sd_pipe = sd_pipe


class InvSamplerSR(BaseSampler):
    @torch.no_grad()
    def sample_func(self, im_cond):

        im_cond = im_cond.to(self.device)

        negative_prompt = [_negative]*im_cond.shape[0] if self.configs.cfg_scale > 1.0 else None

        idle_pch_size = self.configs.basesr.chopping.pch_size

        if min(im_cond.shape[-2:]) >= idle_pch_size:
            pad_h_up = pad_w_left = 0
        else:
            pad_h_up = pad_w_left = 0

        target_size = (
            im_cond.shape[-2] * self.configs.basesr.sf,
            im_cond.shape[-1] * self.configs.basesr.sf
        )

        res_sr = self.sd_pipe(
            image=im_cond.float(),  # ✅ float32
            prompt=[_positive]*im_cond.shape[0],
            negative_prompt=negative_prompt,
            target_size=target_size,
            timesteps=self.configs.timesteps,
            guidance_scale=self.configs.cfg_scale,
            output_type="pt",
        ).images

        res_sr = res_sr.clamp(0.0, 1.0).cpu().permute(0,2,3,1).numpy()

        return res_sr


    def inference(self, in_path, out_path, bs=1):

        in_path = Path(in_path)
        out_path = Path(out_path)
        out_path.mkdir(parents=True, exist_ok=True)

        if in_path.is_dir():
            dataset = create_dataset({
                'type': 'base',
                'params': {'dir_path': str(in_path)}
            })

            dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs)

            for data in dataloader:
                res = self.sample_func(data['lq'])

                for jj in range(res.shape[0]):
                    save_path = str(out_path / f"{jj}.png")
                    util_image.imwrite(res[jj], save_path, dtype_in='float32')

        else:
            im_cond = util_image.imread(in_path, chn='rgb', dtype='float32')
            im_cond = util_image.img2tensor(im_cond).to(self.device)

            image = self.sample_func(im_cond).squeeze(0)

            save_path = str(out_path / f"{in_path.stem}.png")
            util_image.imwrite(image, save_path, dtype_in='float32')

        self.write_log(f"Done → {out_path}")

if __name__ == '__main__':
    pass