Upload folder using huggingface_hub
Browse files- RetinexFormer_FiveK.yml +129 -0
- inference.py +118 -0
- pretrained_weights/pytorch_model.pth +3 -0
- push.py +19 -0
RetinexFormer_FiveK.yml
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# general settings
|
| 2 |
+
name: Enhancement_RetinexFormer_FiveK
|
| 3 |
+
model_type: ImageCleanModel
|
| 4 |
+
scale: 1
|
| 5 |
+
num_gpu: 1 # set num_gpu: 0 for cpu mode
|
| 6 |
+
manual_seed: 100
|
| 7 |
+
|
| 8 |
+
# dataset and data loader settings
|
| 9 |
+
datasets:
|
| 10 |
+
train:
|
| 11 |
+
name: TrainSet
|
| 12 |
+
type: Dataset_PairedImage
|
| 13 |
+
dataroot_gt: data/FiveK/train/target
|
| 14 |
+
dataroot_lq: data/FiveK/train/target
|
| 15 |
+
geometric_augs: true
|
| 16 |
+
cache_data: true
|
| 17 |
+
filename_tmpl: '{}'
|
| 18 |
+
io_backend:
|
| 19 |
+
type: disk
|
| 20 |
+
|
| 21 |
+
# data loader
|
| 22 |
+
use_shuffle: true
|
| 23 |
+
num_worker_per_gpu: 8
|
| 24 |
+
batch_size_per_gpu: 8
|
| 25 |
+
|
| 26 |
+
### -------------Progressive training--------------------------
|
| 27 |
+
# mini_batch_sizes: [8,5,4,2,1,1] # Batch size per gpu
|
| 28 |
+
# iters: [92000,64000,48000,36000,36000,24000]
|
| 29 |
+
# gt_size: 384 # Max patch size for progressive training
|
| 30 |
+
# gt_sizes: [128,160,192,256,320,384] # Patch sizes for progressive training.
|
| 31 |
+
### ------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
### ------- Training on single fixed-patch size 128x128---------
|
| 34 |
+
mini_batch_sizes: [8]
|
| 35 |
+
iters: [300000]
|
| 36 |
+
gt_size: 280
|
| 37 |
+
gt_sizes: [280]
|
| 38 |
+
### ------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
dataset_enlarge_ratio: 1
|
| 41 |
+
prefetch_mode: ~
|
| 42 |
+
|
| 43 |
+
val:
|
| 44 |
+
name: ValSet
|
| 45 |
+
type: Dataset_PairedImage
|
| 46 |
+
# dataroot_gt: data/FiveK/test/target
|
| 47 |
+
# dataroot_lq: data/FiveK/test/input
|
| 48 |
+
dataroot_gt: data/Faded/Low
|
| 49 |
+
dataroot_lq: data/Faded/Normal
|
| 50 |
+
io_backend:
|
| 51 |
+
type: disk
|
| 52 |
+
|
| 53 |
+
# network structures
|
| 54 |
+
network_g:
|
| 55 |
+
type: RetinexFormer
|
| 56 |
+
in_channels: 3
|
| 57 |
+
out_channels: 3
|
| 58 |
+
n_feat: 40
|
| 59 |
+
stage: 1
|
| 60 |
+
num_blocks: [1,2,2]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# path
|
| 65 |
+
path:
|
| 66 |
+
pretrain_network_g: ~
|
| 67 |
+
strict_load_g: true
|
| 68 |
+
resume_state: ~
|
| 69 |
+
|
| 70 |
+
# training settings
|
| 71 |
+
train:
|
| 72 |
+
total_iter: 300000
|
| 73 |
+
warmup_iter: -1 # no warm up
|
| 74 |
+
use_grad_clip: true
|
| 75 |
+
|
| 76 |
+
# Split 300k iterations into two cycles.
|
| 77 |
+
# 1st cycle: fixed 3e-4 LR for 92k iters.
|
| 78 |
+
# 2nd cycle: cosine annealing (3e-4 to 1e-6) for 208k iters.
|
| 79 |
+
scheduler:
|
| 80 |
+
type: CosineAnnealingRestartCyclicLR
|
| 81 |
+
periods: [92000, 208000]
|
| 82 |
+
restart_weights: [1,1]
|
| 83 |
+
eta_mins: [0.0003,0.000001]
|
| 84 |
+
|
| 85 |
+
mixing_augs:
|
| 86 |
+
mixup: true
|
| 87 |
+
mixup_beta: 1.2
|
| 88 |
+
use_identity: true
|
| 89 |
+
|
| 90 |
+
optim_g:
|
| 91 |
+
type: Adam
|
| 92 |
+
lr: !!float 2e-4
|
| 93 |
+
# weight_decay: !!float 1e-4
|
| 94 |
+
betas: [0.9, 0.999]
|
| 95 |
+
|
| 96 |
+
# losses
|
| 97 |
+
pixel_opt:
|
| 98 |
+
type: L1Loss
|
| 99 |
+
loss_weight: 1
|
| 100 |
+
reduction: mean
|
| 101 |
+
|
| 102 |
+
# validation settings
|
| 103 |
+
val:
|
| 104 |
+
window_size: 4
|
| 105 |
+
val_freq: !!float 4e3
|
| 106 |
+
save_img: false
|
| 107 |
+
rgb2bgr: true
|
| 108 |
+
use_image: false
|
| 109 |
+
max_minibatch: 8
|
| 110 |
+
|
| 111 |
+
metrics:
|
| 112 |
+
psnr: # metric name, can be arbitrary
|
| 113 |
+
type: calculate_psnr
|
| 114 |
+
crop_border: 0
|
| 115 |
+
test_y_channel: false
|
| 116 |
+
|
| 117 |
+
# logging settings
|
| 118 |
+
logger:
|
| 119 |
+
print_freq: 1000
|
| 120 |
+
save_checkpoint_freq: !!float 4e3
|
| 121 |
+
use_tb_logger: true
|
| 122 |
+
wandb:
|
| 123 |
+
project: low_light
|
| 124 |
+
resume_id: ~
|
| 125 |
+
|
| 126 |
+
# dist training settings
|
| 127 |
+
dist_params:
|
| 128 |
+
backend: nccl
|
| 129 |
+
port: 29500
|
inference.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import cv2
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from basicsr.models import create_model
|
| 11 |
+
from basicsr.utils.options import parse
|
| 12 |
+
from skimage import img_as_ubyte
|
| 13 |
+
from torch.cuda.amp import autocast, GradScaler
|
| 14 |
+
def load_img(filepath):
|
| 15 |
+
return cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
|
| 16 |
+
def save_img(filepath, img):
|
| 17 |
+
cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
| 18 |
+
def self_ensemble(x, model):
|
| 19 |
+
def forward_transformed(x, hflip, vflip, rotate, model):
|
| 20 |
+
if hflip:
|
| 21 |
+
x = torch.flip(x, (-2,))
|
| 22 |
+
if vflip:
|
| 23 |
+
x = torch.flip(x, (-1,))
|
| 24 |
+
if rotate:
|
| 25 |
+
x = torch.rot90(x, dims=(-2, -1))
|
| 26 |
+
x = model(x)
|
| 27 |
+
if rotate:
|
| 28 |
+
x = torch.rot90(x, dims=(-2, -1), k=3)
|
| 29 |
+
if vflip:
|
| 30 |
+
x = torch.flip(x, (-1,))
|
| 31 |
+
if hflip:
|
| 32 |
+
x = torch.flip(x, (-2,))
|
| 33 |
+
return x
|
| 34 |
+
t = []
|
| 35 |
+
for hflip in [False, True]:
|
| 36 |
+
for vflip in [False, True]:
|
| 37 |
+
for rot in [False, True]:
|
| 38 |
+
t.append(forward_transformed(x, hflip, vflip, rot, model))
|
| 39 |
+
t = torch.stack(t)
|
| 40 |
+
return torch.mean(t, dim=0)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Set GPU
|
| 44 |
+
gpu_list = ','.join(str(x) for x in '0')
|
| 45 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
| 46 |
+
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
| 47 |
+
|
| 48 |
+
# Load YAML configuration
|
| 49 |
+
opt = parse("RetinexFormer_FiveK.yml", is_train=False)
|
| 50 |
+
opt['dist'] = False
|
| 51 |
+
print(opt)
|
| 52 |
+
|
| 53 |
+
# Load model
|
| 54 |
+
model_restoration = create_model(opt).net_g
|
| 55 |
+
checkpoint = torch.load("pretrained_weights/FiveK.pth")
|
| 56 |
+
model_restoration.load_state_dict(checkpoint['params'])
|
| 57 |
+
print("===>Testing using weights: ", "pretrained_weights/FiveK.pth")
|
| 58 |
+
model_restoration.cuda()
|
| 59 |
+
model_restoration = nn.DataParallel(model_restoration)
|
| 60 |
+
model_restoration.eval()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def process_image(inp_path, model_restoration, out_dir, factor=4):
|
| 64 |
+
torch.cuda.ipc_collect()
|
| 65 |
+
torch.cuda.empty_cache()
|
| 66 |
+
|
| 67 |
+
img = np.float32(load_img(inp_path)) / 255.
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Resize image to have height 1024px while maintaining aspect ratio
|
| 71 |
+
max_height = 1024
|
| 72 |
+
aspect_ratio = img.shape[1] / img.shape[0]
|
| 73 |
+
new_width = int(max_height * aspect_ratio)
|
| 74 |
+
img = cv2.resize(img, (new_width, max_height), interpolation=cv2.INTER_AREA)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
img = torch.from_numpy(img).permute(2, 0, 1)
|
| 78 |
+
input_ = img.unsqueeze(0).cuda()
|
| 79 |
+
|
| 80 |
+
# Padding in case images are not multiples of 4
|
| 81 |
+
b, c, h, w = input_.shape
|
| 82 |
+
H, W = ((h + factor) // factor) * factor, ((w + factor) // factor) * factor
|
| 83 |
+
padh = H - h if h % factor != 0 else 0
|
| 84 |
+
padw = W - w if w % factor != 0 else 0
|
| 85 |
+
input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
|
| 86 |
+
|
| 87 |
+
scaler = GradScaler() # for mixed precision
|
| 88 |
+
|
| 89 |
+
with autocast(): # enable mixed precision
|
| 90 |
+
if h < 3000 and w < 3000:
|
| 91 |
+
if 1 == 0:
|
| 92 |
+
restored = self_ensemble(input_, model_restoration)
|
| 93 |
+
else:
|
| 94 |
+
restored = model_restoration(input_)
|
| 95 |
+
else:
|
| 96 |
+
# split and test
|
| 97 |
+
input_1 = input_[:, :, :, 1::2]
|
| 98 |
+
input_2 = input_[:, :, :, 0::2]
|
| 99 |
+
if 1 == 0:
|
| 100 |
+
restored_1 = self_ensemble(input_1, model_restoration)
|
| 101 |
+
restored_2 = self_ensemble(input_2, model_restoration)
|
| 102 |
+
else:
|
| 103 |
+
restored_1 = model_restoration(input_1)
|
| 104 |
+
restored_2 = model_restoration(input_2)
|
| 105 |
+
restored = torch.zeros_like(input_)
|
| 106 |
+
restored[:, :, :, 1::2] = restored_1
|
| 107 |
+
restored[:, :, :, 0::2] = restored_2
|
| 108 |
+
|
| 109 |
+
# Unpad images to original dimensions
|
| 110 |
+
restored = restored[:, :, :h, :w]
|
| 111 |
+
|
| 112 |
+
restored = torch.clamp(restored, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if True:
|
| 116 |
+
save_img(os.path.join(out_dir, os.path.splitext(os.path.split(inp_path)[-1])[0] + '.png'), img_as_ubyte(restored))
|
| 117 |
+
|
| 118 |
+
|
pretrained_weights/pytorch_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:800f6a9281fe8d95daca3108f2b826d5a2adead09031e0e998d30a615286d9c1
|
| 3 |
+
size 6478393
|
push.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import Repository, HfApi
|
| 2 |
+
|
| 3 |
+
# # Initialize repository
|
| 4 |
+
# repo = Repository(
|
| 5 |
+
# local_dir="/home/duyht/ShotXRetouch",
|
| 6 |
+
# clone_from="DuySota/retouch"
|
| 7 |
+
# )
|
| 8 |
+
|
| 9 |
+
# # Add, commit, and push files
|
| 10 |
+
# repo.git_add(auto_lfs_track=True)
|
| 11 |
+
# repo.git_commit("Initial commit")
|
| 12 |
+
# repo.git_push()
|
| 13 |
+
|
| 14 |
+
# Alternatively, you can use the HfApi for more granular control
|
| 15 |
+
api = HfApi()
|
| 16 |
+
api.upload_folder(
|
| 17 |
+
repo_id="DuySota/retouch",
|
| 18 |
+
folder_path="/home/duyht/ShotXRetouch"
|
| 19 |
+
)
|