DuySota commited on
Commit
defb4d0
·
verified ·
1 Parent(s): 793728e

Upload folder using huggingface_hub

Browse files
RetinexFormer_FiveK.yml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # general settings
2
+ name: Enhancement_RetinexFormer_FiveK
3
+ model_type: ImageCleanModel
4
+ scale: 1
5
+ num_gpu: 1 # set num_gpu: 0 for cpu mode
6
+ manual_seed: 100
7
+
8
+ # dataset and data loader settings
9
+ datasets:
10
+ train:
11
+ name: TrainSet
12
+ type: Dataset_PairedImage
13
+ dataroot_gt: data/FiveK/train/target
14
+ dataroot_lq: data/FiveK/train/target
15
+ geometric_augs: true
16
+ cache_data: true
17
+ filename_tmpl: '{}'
18
+ io_backend:
19
+ type: disk
20
+
21
+ # data loader
22
+ use_shuffle: true
23
+ num_worker_per_gpu: 8
24
+ batch_size_per_gpu: 8
25
+
26
+ ### -------------Progressive training--------------------------
27
+ # mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
28
+ # iters: [92000,64000,48000,36000,36000,24000]
29
+ # gt_size: 384 # Max patch size for progressive training
30
+ # gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
31
+ ### ------------------------------------------------------------
32
+
33
+ ### ------- Training on single fixed-patch size 128x128---------
34
+ mini_batch_sizes: [8]
35
+ iters: [300000]
36
+ gt_size: 280
37
+ gt_sizes: [280]
38
+ ### ------------------------------------------------------------
39
+
40
+ dataset_enlarge_ratio: 1
41
+ prefetch_mode: ~
42
+
43
+ val:
44
+ name: ValSet
45
+ type: Dataset_PairedImage
46
+ # dataroot_gt: data/FiveK/test/target
47
+ # dataroot_lq: data/FiveK/test/input
48
+ dataroot_gt: data/Faded/Low
49
+ dataroot_lq: data/Faded/Normal
50
+ io_backend:
51
+ type: disk
52
+
53
+ # network structures
54
+ network_g:
55
+ type: RetinexFormer
56
+ in_channels: 3
57
+ out_channels: 3
58
+ n_feat: 40
59
+ stage: 1
60
+ num_blocks: [1,2,2]
61
+
62
+
63
+
64
+ # path
65
+ path:
66
+ pretrain_network_g: ~
67
+ strict_load_g: true
68
+ resume_state: ~
69
+
70
+ # training settings
71
+ train:
72
+ total_iter: 300000
73
+ warmup_iter: -1 # no warm up
74
+ use_grad_clip: true
75
+
76
+ # Split 300k iterations into two cycles.
77
+ # 1st cycle: fixed 3e-4 LR for 92k iters.
78
+ # 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
79
+ scheduler:
80
+ type: CosineAnnealingRestartCyclicLR
81
+ periods: [92000, 208000]
82
+ restart_weights: [1,1]
83
+ eta_mins: [0.0003,0.000001]
84
+
85
+ mixing_augs:
86
+ mixup: true
87
+ mixup_beta: 1.2
88
+ use_identity: true
89
+
90
+ optim_g:
91
+ type: Adam
92
+ lr: !!float 2e-4
93
+ # weight_decay: !!float 1e-4
94
+ betas: [0.9, 0.999]
95
+
96
+ # losses
97
+ pixel_opt:
98
+ type: L1Loss
99
+ loss_weight: 1
100
+ reduction: mean
101
+
102
+ # validation settings
103
+ val:
104
+ window_size: 4
105
+ val_freq: !!float 4e3
106
+ save_img: false
107
+ rgb2bgr: true
108
+ use_image: false
109
+ max_minibatch: 8
110
+
111
+ metrics:
112
+ psnr: # metric name, can be arbitrary
113
+ type: calculate_psnr
114
+ crop_border: 0
115
+ test_y_channel: false
116
+
117
+ # logging settings
118
+ logger:
119
+ print_freq: 1000
120
+ save_checkpoint_freq: !!float 4e3
121
+ use_tb_logger: true
122
+ wandb:
123
+ project: low_light
124
+ resume_id: ~
125
+
126
+ # dist training settings
127
+ dist_params:
128
+ backend: nccl
129
+ port: 29500
inference.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import argparse
4
+ import cv2
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from basicsr.models import create_model
11
+ from basicsr.utils.options import parse
12
+ from skimage import img_as_ubyte
13
+ from torch.cuda.amp import autocast, GradScaler
14
+ def load_img(filepath):
15
+ return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
16
+ def save_img(filepath, img):
17
+ cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
18
+ def self_ensemble(x, model):
19
+ def forward_transformed(x, hflip, vflip, rotate, model):
20
+ if hflip:
21
+ x = torch.flip(x, (-2,))
22
+ if vflip:
23
+ x = torch.flip(x, (-1,))
24
+ if rotate:
25
+ x = torch.rot90(x, dims=(-2, -1))
26
+ x = model(x)
27
+ if rotate:
28
+ x = torch.rot90(x, dims=(-2, -1), k=3)
29
+ if vflip:
30
+ x = torch.flip(x, (-1,))
31
+ if hflip:
32
+ x = torch.flip(x, (-2,))
33
+ return x
34
+ t = []
35
+ for hflip in [False, True]:
36
+ for vflip in [False, True]:
37
+ for rot in [False, True]:
38
+ t.append(forward_transformed(x, hflip, vflip, rot, model))
39
+ t = torch.stack(t)
40
+ return torch.mean(t, dim=0)
41
+
42
+
43
+ # Set GPU
44
+ gpu_list = ','.join(str(x) for x in '0')
45
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
46
+ print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
47
+
48
+ # Load YAML configuration
49
+ opt = parse("RetinexFormer_FiveK.yml", is_train=False)
50
+ opt['dist'] = False
51
+ print(opt)
52
+
53
+ # Load model
54
+ model_restoration = create_model(opt).net_g
55
+ checkpoint = torch.load("pretrained_weights/FiveK.pth")
56
+ model_restoration.load_state_dict(checkpoint['params'])
57
+ print("===>Testing using weights: ", "pretrained_weights/FiveK.pth")
58
+ model_restoration.cuda()
59
+ model_restoration = nn.DataParallel(model_restoration)
60
+ model_restoration.eval()
61
+
62
+
63
+ def process_image(inp_path, model_restoration, out_dir, factor=4):
64
+ torch.cuda.ipc_collect()
65
+ torch.cuda.empty_cache()
66
+
67
+ img = np.float32(load_img(inp_path)) / 255.
68
+
69
+
70
+ # Resize image to have height 1024px while maintaining aspect ratio
71
+ max_height = 1024
72
+ aspect_ratio = img.shape[1] / img.shape[0]
73
+ new_width = int(max_height * aspect_ratio)
74
+ img = cv2.resize(img, (new_width, max_height), interpolation=cv2.INTER_AREA)
75
+
76
+
77
+ img = torch.from_numpy(img).permute(2, 0, 1)
78
+ input_ = img.unsqueeze(0).cuda()
79
+
80
+ # Padding in case images are not multiples of 4
81
+ b, c, h, w = input_.shape
82
+ H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
83
+ padh = H - h if h % factor != 0 else 0
84
+ padw = W - w if w % factor != 0 else 0
85
+ input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
86
+
87
+ scaler = GradScaler() # for mixed precision
88
+
89
+ with autocast(): # enable mixed precision
90
+ if h < 3000 and w < 3000:
91
+ if 1 == 0:
92
+ restored = self_ensemble(input_, model_restoration)
93
+ else:
94
+ restored = model_restoration(input_)
95
+ else:
96
+ # split and test
97
+ input_1 = input_[:, :, :, 1::2]
98
+ input_2 = input_[:, :, :, 0::2]
99
+ if 1 == 0:
100
+ restored_1 = self_ensemble(input_1, model_restoration)
101
+ restored_2 = self_ensemble(input_2, model_restoration)
102
+ else:
103
+ restored_1 = model_restoration(input_1)
104
+ restored_2 = model_restoration(input_2)
105
+ restored = torch.zeros_like(input_)
106
+ restored[:, :, :, 1::2] = restored_1
107
+ restored[:, :, :, 0::2] = restored_2
108
+
109
+ # Unpad images to original dimensions
110
+ restored = restored[:, :, :h, :w]
111
+
112
+ restored = torch.clamp(restored, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
113
+
114
+
115
+ if True:
116
+ save_img(os.path.join(out_dir, os.path.splitext(os.path.split(inp_path)[-1])[0] + '.png'), img_as_ubyte(restored))
117
+
118
+
pretrained_weights/pytorch_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:800f6a9281fe8d95daca3108f2b826d5a2adead09031e0e998d30a615286d9c1
3
+ size 6478393
push.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import Repository, HfApi
2
+
3
+ # # Initialize repository
4
+ # repo = Repository(
5
+ # local_dir="/home/duyht/ShotXRetouch",
6
+ # clone_from="DuySota/retouch"
7
+ # )
8
+
9
+ # # Add, commit, and push files
10
+ # repo.git_add(auto_lfs_track=True)
11
+ # repo.git_commit("Initial commit")
12
+ # repo.git_push()
13
+
14
+ # Alternatively, you can use the HfApi for more granular control
15
+ api = HfApi()
16
+ api.upload_folder(
17
+ repo_id="DuySota/retouch",
18
+ folder_path="/home/duyht/ShotXRetouch"
19
+ )