NGain commited on
Commit
cef4cd8
·
verified ·
1 Parent(s): ce77205

Delete test_seesr_sam.py

Browse files
Files changed (1) hide show
  1. test_seesr_sam.py +0 -328
test_seesr_sam.py DELETED
@@ -1,328 +0,0 @@
1
- '''
2
- * SeeSR: Towards Semantics-Aware Real-World Image Super-Resolution
3
- * Modified from diffusers by Rongyuan Wu
4
- * 24/12/2023
5
- '''
6
- import os
7
- import sys
8
- sys.path.append(os.getcwd())
9
- import cv2
10
- import glob
11
- import argparse
12
- import numpy as np
13
- from PIL import Image
14
-
15
- import torch
16
- import torch.utils.checkpoint
17
-
18
- from accelerate import Accelerator
19
- from accelerate.logging import get_logger
20
- from accelerate.utils import set_seed
21
- from diffusers import AutoencoderKL, DDPMScheduler
22
- from diffusers.utils import check_min_version
23
- from diffusers.utils.import_utils import is_xformers_available
24
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
25
-
26
- from pipelines.pipeline_seesr import StableDiffusionControlNetPipeline
27
- from utils.misc import load_dreambooth_lora
28
- from utils.wavelet_color_fix import wavelet_color_fix, adain_color_fix
29
-
30
- from ram.models.ram_lora import ram
31
- from ram import inference_ram as inference
32
- from ram import get_transform
33
-
34
- from typing import Mapping, Any
35
- from torchvision import transforms
36
- import torch.nn as nn
37
- import torch.nn.functional as F
38
- from torchvision import transforms
39
-
40
- sys.path.insert(0, '/media/ssd8T/wyw/code/SeeSR/sam2')
41
- from sam2.build_sam import build_sam2
42
- from sam2.sam2_image_predictor import SAM2ImagePredictor
43
-
44
- logger = get_logger(__name__, log_level="INFO")
45
-
46
-
47
- tensor_transforms = transforms.Compose([
48
- transforms.ToTensor(),
49
- ])
50
-
51
- ram_transforms = transforms.Compose([
52
- transforms.Resize((384, 384)),
53
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54
- ])
55
-
56
- sam_mean = [0.485, 0.456, 0.406]
57
- sam_std = [0.229, 0.224, 0.225]
58
- sam_transforms = transforms.Compose([
59
- transforms.Resize((1024, 1024)),
60
- transforms.Normalize(mean=sam_mean, std=sam_std)
61
- ])
62
-
63
-
64
- def load_state_dict_diffbirSwinIR(model: nn.Module, state_dict: Mapping[str, Any], strict: bool=False) -> None:
65
- state_dict = state_dict.get("state_dict", state_dict)
66
-
67
- is_model_key_starts_with_module = list(model.state_dict().keys())[0].startswith("module.")
68
- is_state_dict_key_starts_with_module = list(state_dict.keys())[0].startswith("module.")
69
-
70
- if (
71
- is_model_key_starts_with_module and
72
- (not is_state_dict_key_starts_with_module)
73
- ):
74
- state_dict = {f"module.{key}": value for key, value in state_dict.items()}
75
- if (
76
- (not is_model_key_starts_with_module) and
77
- is_state_dict_key_starts_with_module
78
- ):
79
- state_dict = {key[len("module."):]: value for key, value in state_dict.items()}
80
-
81
- model.load_state_dict(state_dict, strict=strict)
82
-
83
-
84
- def load_seesr_pipeline(args, accelerator, enable_xformers_memory_efficient_attention):
85
-
86
- from models.controlnet import ControlNetModel
87
- from models.unet_2d_condition import UNet2DConditionModel
88
-
89
- # Load scheduler, tokenizer and models.
90
-
91
- scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, subfolder="scheduler")
92
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
93
- tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
94
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
95
- feature_extractor = CLIPImageProcessor.from_pretrained(f"{args.pretrained_model_path}/feature_extractor")
96
- unet = UNet2DConditionModel.from_pretrained(args.seesr_model_path, subfolder="unet")
97
- controlnet = ControlNetModel.from_pretrained(args.seesr_model_path, subfolder="controlnet")
98
-
99
- # Freeze vae and text_encoder
100
- vae.requires_grad_(False)
101
- text_encoder.requires_grad_(False)
102
- unet.requires_grad_(False)
103
- controlnet.requires_grad_(False)
104
-
105
- if enable_xformers_memory_efficient_attention:
106
- if is_xformers_available():
107
- unet.enable_xformers_memory_efficient_attention()
108
- controlnet.enable_xformers_memory_efficient_attention()
109
- else:
110
- raise ValueError("xformers is not available. Make sure it is installed correctly")
111
-
112
- # Get the validation pipeline
113
- validation_pipeline = StableDiffusionControlNetPipeline(
114
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor,
115
- unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
116
- )
117
- # def count_parameters(model):
118
- # return sum(p.numel() for p in model.parameters())
119
-
120
- # # 计算各子模块的参数量
121
- # unet_params = count_parameters(unet)
122
- # controlnet_params = count_parameters(controlnet)
123
- # vae_params = count_parameters(vae)
124
- # text_encoder_params = count_parameters(text_encoder)
125
-
126
- # # 总参数量
127
- # total_params = unet_params + controlnet_params + vae_params + text_encoder_params
128
- # print(f"UNet 参数量: {unet_params/ 1e9:.2f}B")
129
- # print(f"ControlNet 参数量: {controlnet_params/ 1e9:.2f}B")
130
- # print(f"VAE 参数��: {vae_params/ 1e9:.2f}B")
131
- # print(f"Text Encoder 参数量: {text_encoder_params/ 1e9:.2f}B")
132
- # print(f"总参数量: {total_params/ 1e9:.2f}B")
133
- # exit()
134
-
135
-
136
- validation_pipeline._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size)
137
-
138
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
139
- # as these models are only used for inference, keeping weights in full precision is not required.
140
- weight_dtype = torch.float32
141
- if accelerator.mixed_precision == "fp16":
142
- weight_dtype = torch.float16
143
- elif accelerator.mixed_precision == "bf16":
144
- weight_dtype = torch.bfloat16
145
-
146
- # Move text_encode and vae to gpu and cast to weight_dtype
147
- text_encoder.to(accelerator.device, dtype=weight_dtype)
148
- vae.to(accelerator.device, dtype=weight_dtype)
149
- unet.to(accelerator.device, dtype=weight_dtype)
150
- controlnet.to(accelerator.device, dtype=weight_dtype)
151
-
152
- return validation_pipeline
153
-
154
- def load_tag_model(args, device='cuda'):# 0.47B
155
-
156
- model = ram(pretrained='/media/ssd8T/ly/SeeSR/preset/models/ram_swin_large_14m.pth',
157
- pretrained_condition=args.ram_ft_path,
158
- image_size=384,
159
- vit='swin_l')
160
- model.eval()
161
- model.to(device)
162
- # def count_parameters(model):
163
- # return sum(p.numel() for p in model.parameters())
164
- # total_params = count_parameters(model)
165
- # print(f"总参数量: {total_params/ 1e9:.2f}B")
166
- # exit()
167
- return model
168
-
169
- def get_validation_prompt(args, image, model, device='cuda'):
170
- validation_prompt = ""
171
-
172
- lq = tensor_transforms(image).unsqueeze(0).to(device)
173
- lq = ram_transforms(lq)
174
- res = inference(lq, model)
175
- ram_encoder_hidden_states = model.generate_image_embeds(lq)
176
-
177
- validation_prompt = f"{res[0]}, {args.prompt},"
178
-
179
- return validation_prompt, ram_encoder_hidden_states
180
-
181
- def load_sam_model(device='cuda'):#0.03B
182
-
183
- sam2_checkpoint = "./preset/models/sam2.1_hiera_tiny.pt"
184
- model_cfg = "configs/sam2.1/sam2.1_hiera_t.yaml"
185
- sam2 = build_sam2(model_cfg, sam2_checkpoint, mode="eval", device='cuda', apply_postprocessing=False)
186
- sam2.to(device)
187
- SAM = SAM2ImagePredictor(sam2)
188
-
189
- return SAM
190
-
191
- def get_sam_embedding(image, model, device='cuda'):
192
-
193
- lq = tensor_transforms(image).unsqueeze(0).to(device)
194
- lq = sam_transforms(lq)
195
- sam_encoder_hidden_states = model.generate_image_embedding(lq)
196
-
197
- return sam_encoder_hidden_states
198
-
199
- def main(args, enable_xformers_memory_efficient_attention=True,):
200
- txt_path = os.path.join(args.output_dir, 'txt')
201
- os.makedirs(txt_path, exist_ok=True)
202
-
203
- accelerator = Accelerator(
204
- mixed_precision=args.mixed_precision,
205
- )
206
-
207
- # If passed along, set the training seed now.
208
- if args.seed is not None:
209
- set_seed(args.seed)
210
-
211
- # Handle the output folder creation
212
- if accelerator.is_main_process:
213
- os.makedirs(args.output_dir, exist_ok=True)
214
-
215
- # We need to initialize the trackers we use, and also store our configuration.
216
- # The trackers initializes automatically on the main process.
217
- if accelerator.is_main_process:
218
- accelerator.init_trackers("SeeSR")
219
-
220
- pipeline = load_seesr_pipeline(args, accelerator, enable_xformers_memory_efficient_attention)
221
- model = load_tag_model(args, accelerator.device)
222
- sam_model = load_sam_model(accelerator.device)
223
-
224
- if accelerator.is_main_process:
225
- generator = torch.Generator(device=accelerator.device)
226
- if args.seed is not None:
227
- generator.manual_seed(args.seed)
228
-
229
- if os.path.isdir(args.image_path):
230
- image_names = sorted(glob.glob(f'{args.image_path}/*.*'))
231
- else:
232
- image_names = [args.image_path]
233
- # image_names = image_names[:10]
234
-
235
- for image_idx, image_name in enumerate(image_names[:]):
236
- print(f'================== process {image_idx} imgs... ===================')
237
- validation_image = Image.open(image_name).convert("RGB")
238
-
239
- validation_prompt, _ = get_validation_prompt(args, validation_image, model)
240
- validation_prompt += args.added_prompt # clean, extremely detailed, best quality, sharp, clean
241
- negative_prompt = args.negative_prompt #dirty, messy, low quality, frames, deformed,
242
-
243
- sam_encoder_hidden_states = get_sam_embedding(validation_image,sam_model)
244
-
245
- if args.save_prompts:
246
- txt_save_path = f"{txt_path}/{os.path.basename(image_name).split('.')[0]}.txt"
247
- file = open(txt_save_path, "w")
248
- file.write(validation_prompt)
249
- file.close()
250
- print(f'{validation_prompt}')
251
-
252
- ori_width, ori_height = validation_image.size
253
- resize_flag = False
254
- rscale = args.upscale
255
- if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale:
256
- scale = (args.process_size//rscale)/min(ori_width, ori_height)
257
- tmp_image = validation_image.resize((int(scale*ori_width), int(scale*ori_height)))
258
-
259
- validation_image = tmp_image
260
- resize_flag = True
261
-
262
- validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale))
263
- validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))
264
- width, height = validation_image.size
265
- resize_flag = True #
266
-
267
- print(f'input size: {height}x{width}')
268
-
269
- for sample_idx in range(args.sample_times):
270
- os.makedirs(f'{args.output_dir}/sample{str(sample_idx).zfill(2)}/', exist_ok=True)
271
-
272
- for sample_idx in range(args.sample_times):
273
- with torch.autocast("cuda"):
274
- image = pipeline(
275
- validation_prompt, validation_image, num_inference_steps=args.num_inference_steps, generator=generator, height=height, width=width,
276
- guidance_scale=args.guidance_scale, negative_prompt=negative_prompt, conditioning_scale=args.conditioning_scale,
277
- start_point=args.start_point, ram_encoder_hidden_states=sam_encoder_hidden_states,
278
- latent_tiled_size=args.latent_tiled_size, latent_tiled_overlap=args.latent_tiled_overlap,
279
- args=args,
280
- ).images[0]
281
-
282
- if args.align_method == 'nofix':
283
- image = image
284
- else:
285
- if args.align_method == 'wavelet':
286
- image = wavelet_color_fix(image, validation_image)
287
- elif args.align_method == 'adain':
288
- image = adain_color_fix(image, validation_image)
289
-
290
- if resize_flag:
291
- image = image.resize((ori_width*rscale, ori_height*rscale))
292
-
293
- name, ext = os.path.splitext(os.path.basename(image_name))
294
-
295
- image.save(f'{args.output_dir}/sample{str(sample_idx).zfill(2)}/{name}.png')
296
-
297
- if __name__ == "__main__":
298
- parser = argparse.ArgumentParser()
299
- parser.add_argument("--seesr_model_path", type=str, default=None)
300
- parser.add_argument("--ram_ft_path", type=str, default=None)
301
- parser.add_argument("--pretrained_model_path", type=str, default=None)
302
- parser.add_argument("--prompt", type=str, default="") # user can add self-prompt to improve the results
303
- parser.add_argument("--added_prompt", type=str, default="clean, high-resolution, 8k")
304
- parser.add_argument("--negative_prompt", type=str, default="dotted, noise, blur, lowres, smooth")
305
- parser.add_argument("--image_path", type=str, default=None)
306
- parser.add_argument("--output_dir", type=str, default=None)
307
- parser.add_argument("--mixed_precision", type=str, default="fp16") # no/fp16/bf16
308
- parser.add_argument("--guidance_scale", type=float, default=5.5)
309
- parser.add_argument("--conditioning_scale", type=float, default=1.0)
310
- parser.add_argument("--blending_alpha", type=float, default=1.0)
311
- parser.add_argument("--num_inference_steps", type=int, default=50)
312
- parser.add_argument("--process_size", type=int, default=512)
313
- parser.add_argument("--vae_decoder_tiled_size", type=int, default=224) # latent size, for 24G
314
- parser.add_argument("--vae_encoder_tiled_size", type=int, default=1024) # image size, for 13G
315
- parser.add_argument("--latent_tiled_size", type=int, default=96)
316
- parser.add_argument("--latent_tiled_overlap", type=int, default=32)
317
- parser.add_argument("--upscale", type=int, default=4)
318
- parser.add_argument("--seed", type=int, default=None)
319
- parser.add_argument("--sample_times", type=int, default=1)
320
- parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain')
321
- parser.add_argument("--start_steps", type=int, default=999) # defaults set to 999.
322
- parser.add_argument("--start_point", type=str, choices=['lr', 'noise'], default='lr') # LR Embedding Strategy, choose 'lr latent + 999 steps noise' as diffusion start point.
323
- parser.add_argument("--save_prompts", action='store_true')
324
- args = parser.parse_args()
325
- main(args)
326
-
327
-
328
-