zilong123321 commited on
Commit
3bfd811
·
1 Parent(s): 454f5ab
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/*.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/*.jpg filter=lfs diff=lfs merge=lfs -text
38
+ examples/*.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ examples/*.bmp filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import io
4
+ import cv2
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ import spaces
9
+ from PIL import Image
10
+ from functools import lru_cache
11
+ from huggingface_hub import hf_hub_download, snapshot_download
12
+ from torchvision.transforms.functional import normalize
13
+ import glob
14
+
15
+
16
+ from restormerRFR_arch import RestormerRFR
17
+ from dino_feature_extractor import DinoFeatureModule
18
+
19
+ WEIGHT_REPO_ID = "233zzl/RAM_plus_plus"
20
+ WEIGHT_FILENAME = "7task/RestormerRFR.pth"
21
+ MODEL_NAME = "RestormerRFR"
22
+
23
+ def get_device():
24
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
+
27
+ def warmup():
28
+
29
+ hf_hub_download(
30
+ repo_id=WEIGHT_REPO_ID,
31
+ filename=WEIGHT_FILENAME,
32
+ repo_type="model",
33
+ revision="main"
34
+ )
35
+ snapshot_download(
36
+ repo_id="facebook/dinov2-giant",
37
+ repo_type="model",
38
+ revision="main"
39
+ )
40
+
41
+
42
+ def build_model():
43
+ model = RestormerRFR(
44
+ inp_channels=3,
45
+ out_channels=3,
46
+ dim=48,
47
+ num_blocks=[4, 6, 6, 8],
48
+ num_refinement_blocks=4,
49
+ heads=[1, 2, 4, 8],
50
+ ffn_expansion_factor=2.66,
51
+ bias=False,
52
+ LayerNorm_type="WithBias",
53
+ finetune_type=None,
54
+ img_size=128,
55
+ )
56
+ return model
57
+
58
+ @lru_cache(maxsize=1)
59
+ def get_dino_extractor(device):
60
+ extractor = DinoFeatureModule().to(device).eval()
61
+ return extractor
62
+
63
+ @lru_cache(maxsize=1)
64
+ def get_model_and_device():
65
+ device = get_device()
66
+ model = build_model()
67
+
68
+
69
+ weight_path = hf_hub_download(
70
+ repo_id=WEIGHT_REPO_ID,
71
+ filename=WEIGHT_FILENAME,
72
+
73
+ )
74
+
75
+ ckpt = torch.load(weight_path, map_location="cpu")
76
+ keyname = "params" if "params" in ckpt else None
77
+ if keyname is not None:
78
+ model.load_state_dict(ckpt[keyname], strict=False)
79
+ else:
80
+ model.load_state_dict(ckpt, strict=False)
81
+
82
+ model.eval().to(device)
83
+ return model, device
84
+
85
+
86
+ @spaces.GPU(duration=120)
87
+ def restore_image(pil_img: Image.Image) -> Image.Image:
88
+ """
89
+ 输入一张图片,输出复原后的图片(与 RAM++ RestormerRFR + DINO 特征推理一致)
90
+ """
91
+ model, device = get_model_and_device()
92
+ dino_extractor = get_dino_extractor(device)
93
+
94
+
95
+ img_bgr = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR).astype(np.float32) / 255.0
96
+ img = torch.from_numpy(np.transpose(img_bgr[:, :, [2, 1, 0]], (2, 0, 1))).float() # (3,H,W), RGB
97
+ img = img.unsqueeze(0).to(device) # (1,3,H,W)
98
+
99
+
100
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
101
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
102
+ normalize(img, mean, std, inplace=True)
103
+
104
+ with torch.no_grad():
105
+ dino_features = dino_extractor(img)
106
+ output = model(img, dino_features)
107
+
108
+
109
+ output = normalize(output, -1 * mean / std, 1 / std)
110
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() # (3,H,W)
111
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # (H,W,RGB)
112
+ output = (output * 255.0).round().astype(np.uint8)
113
+ out_pil = Image.fromarray(output, mode="RGB")
114
+ return out_pil
115
+
116
+
117
+ DESCRIPTION = """
118
+ # RAM++ Demo
119
+ """
120
+
121
+ with gr.Blocks(title="RAM++ ZeroGPU Demo") as demo:
122
+ gr.Markdown(DESCRIPTION)
123
+
124
+ with gr.Row():
125
+ with gr.Column():
126
+ inp = gr.Image(type="pil", label="load picture(JPEG/PNG)")
127
+ btn = gr.Button("Run (ZeroGPU)")
128
+ with gr.Column():
129
+ out = gr.Image(type="pil", label="output")
130
+
131
+ ex_files = []
132
+ for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp"):
133
+ ex_files.extend(glob.glob(os.path.join("examples", ext)))
134
+ ex_files = sorted(ex_files)
135
+ if ex_files:
136
+ gr.Examples(examples=ex_files, inputs=inp, label="exampls)")
137
+
138
+ btn.click(restore_image, inputs=inp, outputs=out, api_name="run")
139
+
140
+ gr.Markdown("""
141
+ **Tips**
142
+ - If the queue is long or you hit the quota, please try again later, or upgrade to Pro for a higher ZeroGPU quota and priority.
143
+ """)
144
+
145
+ demo.load(fn=warmup, inputs=None, outputs=None)
146
+
147
+
148
+ if __name__ == "__main__":
149
+
150
+ demo.launch()
dino_feature_extractor.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numbers
5
+ import numpy as np
6
+ import os
7
+ from transformers import AutoImageProcessor, AutoModel
8
+ import math
9
+ class DinoFeatureModule(nn.Module):
10
+ def __init__(self, model_id: str = "facebook/dinov2-giant"):
11
+ super(DinoFeatureModule, self).__init__()
12
+ dtype = torch.float32
13
+ self.dino = AutoModel.from_pretrained(
14
+ model_id,
15
+ torch_dtype=dtype
16
+ )
17
+
18
+
19
+ self.dino.eval()
20
+ for param in self.dino.parameters():
21
+ param.requires_grad = False
22
+
23
+
24
+ frozen = all(not p.requires_grad for p in self.dino.parameters())
25
+ assert frozen, "DINOv2 model parameters are not completely frozen!"
26
+
27
+
28
+ self.shallow_dim = 1536
29
+ self.mid_dim = 1536
30
+ self.deep_dim = 1536
31
+
32
+ def get_dino_features(self, x):
33
+ with torch.no_grad():
34
+ outputs = self.dino(x, output_hidden_states=True)
35
+ hidden_states = outputs.hidden_states
36
+
37
+ _, _, H, W = x.shape
38
+ aspect_ratio = W / H
39
+
40
+ shallow_feat1 = hidden_states[7]
41
+ shallow_feat2 = hidden_states[15]
42
+ mid_feat1 = hidden_states[20]
43
+ mid_feat2 = hidden_states[22]
44
+ deep_feat1 = hidden_states[33]
45
+ deep_feat2 = hidden_states[39]
46
+
47
+ def reshape_features(feat):
48
+ feat = feat[:, 1:, :]
49
+ B, N, C = feat.shape
50
+
51
+ h = int(math.sqrt(N / aspect_ratio))
52
+ w = int(N / h)
53
+
54
+
55
+ if(aspect_ratio > 1):
56
+ if h * w > N:
57
+ h -= 1
58
+ w = N // h
59
+ if h * w < N:
60
+ h += 1
61
+ w = N // h
62
+ else:
63
+ if h * w > N:
64
+ w -= 1
65
+ h = N // w
66
+ if h * w < N:
67
+ w += 1
68
+ h = N // w
69
+
70
+
71
+ assert h * w == N, f"Dimensions mismatch: {h}*{w} != {N}"
72
+
73
+
74
+ feat = feat.reshape(B, h, w, C).permute(0, 3, 1, 2)
75
+ return feat
76
+
77
+
78
+ shallow_feat1 = reshape_features(shallow_feat1).float()
79
+ mid_feat1 = reshape_features(mid_feat1).float()
80
+ deep_feat1 = reshape_features(deep_feat1).float()
81
+ shallow_feat2 = reshape_features(shallow_feat2).float()
82
+ mid_feat2 = reshape_features(mid_feat2).float()
83
+ deep_feat2 = reshape_features(deep_feat2).float()
84
+
85
+ return shallow_feat1, mid_feat1, deep_feat1, shallow_feat2, mid_feat2, deep_feat2
86
+
87
+ def check_image_size(self, x):
88
+ _, _, h, w = x.size()
89
+ pad_size = 16
90
+ mod_pad_h = (pad_size - h % pad_size) % pad_size
91
+ mod_pad_w = (pad_size - w % pad_size) % pad_size
92
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
93
+ return x
94
+
95
+ def forward(self, inp_img):
96
+
97
+ device = inp_img.device
98
+
99
+ mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
100
+ std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
101
+
102
+
103
+ denormalized_img = inp_img * std + mean
104
+ denormalized_img = self.check_image_size(denormalized_img)
105
+ h_denormalized, w_denormalized = denormalized_img.shape[2], denormalized_img.shape[3]
106
+ # To ensure minimal changes and maintain code generality, the image size is directly scaled here to guarantee spatial alignment.
107
+
108
+ target_h = (h_denormalized // 8) * 14
109
+ target_w = (w_denormalized // 8) * 14
110
+
111
+ shortest_edge = min(target_h, target_w)
112
+ processor = AutoImageProcessor.from_pretrained(
113
+ model_id,
114
+ local_files_only=False,
115
+ do_rescale=False,
116
+ do_center_crop=False,
117
+ use_fast=True,
118
+ size={"shortest_edge": shortest_edge}
119
+ )
120
+
121
+ inputs = processor(
122
+ images=denormalized_img,
123
+ return_tensors="pt"
124
+ ).to(device)
125
+
126
+
127
+ shallow_feat1, mid_feat1, deep_feat1, shallow_feat2, mid_feat2, deep_feat2 = self.get_dino_features(inputs['pixel_values'])
128
+
129
+ dino_features = {
130
+ 'shallow_feat1': shallow_feat1,
131
+ 'mid_feat1': mid_feat1,
132
+ 'deep_feat1': deep_feat1,
133
+ 'shallow_feat2': shallow_feat2,
134
+ 'mid_feat2': mid_feat2,
135
+ 'deep_feat2': deep_feat2
136
+ }
137
+
138
+ return dino_features
examples/BSD_0038.png ADDED

Git LFS Details

  • SHA256: 80366849b1874226e5c2fb5ef85579a393e2584bd9f8e097ee5ee61fb1c263cd
  • Pointer size: 131 Bytes
  • Size of remote file: 454 kB
examples/BSD_0047.png ADDED

Git LFS Details

  • SHA256: 0e7fd6ef10041def5387720aaca2efd0533e802963721baabec1cd7a8a051fa2
  • Pointer size: 131 Bytes
  • Size of remote file: 463 kB
examples/Rain100H_15.png ADDED

Git LFS Details

  • SHA256: fb6c98a260852143e49cfe5c25c3d2f308eaccf9996bdc5d0c563617718168b1
  • Pointer size: 131 Bytes
  • Size of remote file: 227 kB
examples/Rain100L_79.png ADDED

Git LFS Details

  • SHA256: 69f4c95b1f9620e7e91ab7518435a636a902512565b73f9fbae2a823547117e4
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
examples/SOTS_0271_0.85_0.12.jpg ADDED

Git LFS Details

  • SHA256: 7d5122f955c324485af0246302ab3d80744f144b85137edb55d2e0466541ab7e
  • Pointer size: 130 Bytes
  • Size of remote file: 94.9 kB
examples/SOTS_1977_0.8_0.08.jpg ADDED

Git LFS Details

  • SHA256: 61ffbdb89c9f18881c0b554c65232749c89043e7022d6c19f894defeba345e44
  • Pointer size: 131 Bytes
  • Size of remote file: 125 kB
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ spaces>=0.28.3
3
+ huggingface_hub>=0.23.0
4
+ transformers>=4.41.0
5
+ safetensors>=0.4.3
6
+ numpy>=1.26.0
7
+ Pillow>=10.0.0
8
+ opencv-python-headless>=4.8.0.76
9
+ einops>=0.7.0
10
+ torch>=2.1.0
11
+ torchvision>=0.16.0
12
+ timm>=0.9.10
restormerRFR_arch.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RAM++: Robust Representation Learning via Adaptive Mask for All-in-One Image Restoration
2
+ # Zilong Zhang, Chujie Qin, Chunle Guo, Yong Zhang, Chao Xue, Ming-Ming Cheng and Chongyi Li
3
+ # https://arxiv.org/abs/2509.12039
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numbers
9
+
10
+ from einops import rearrange
11
+
12
+
13
+
14
+ def to_3d(x):
15
+ return rearrange(x, 'b c h w -> b (h w) c')
16
+
17
+ def to_4d(x,h,w):
18
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
19
+
20
+ class BiasFree_LayerNorm(nn.Module):
21
+ def __init__(self, normalized_shape):
22
+ super(BiasFree_LayerNorm, self).__init__()
23
+ if isinstance(normalized_shape, numbers.Integral):
24
+ normalized_shape = (normalized_shape,)
25
+ normalized_shape = torch.Size(normalized_shape)
26
+
27
+ assert len(normalized_shape) == 1
28
+
29
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
30
+ self.normalized_shape = normalized_shape
31
+
32
+ def forward(self, x):
33
+ sigma = x.var(-1, keepdim=True, unbiased=False)
34
+ return x / torch.sqrt(sigma+1e-5) * self.weight
35
+
36
+ class WithBias_LayerNorm(nn.Module):
37
+ def __init__(self, normalized_shape):
38
+ super(WithBias_LayerNorm, self).__init__()
39
+ if isinstance(normalized_shape, numbers.Integral):
40
+ normalized_shape = (normalized_shape,)
41
+ normalized_shape = torch.Size(normalized_shape)
42
+
43
+ assert len(normalized_shape) == 1
44
+
45
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
46
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
47
+ self.normalized_shape = normalized_shape
48
+
49
+ def forward(self, x):
50
+ mu = x.mean(-1, keepdim=True)
51
+ sigma = x.var(-1, keepdim=True, unbiased=False)
52
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
53
+
54
+
55
+ class LayerNorm(nn.Module):
56
+ def __init__(self, dim, LayerNorm_type):
57
+ super(LayerNorm, self).__init__()
58
+ if LayerNorm_type =='BiasFree':
59
+ self.body = BiasFree_LayerNorm(dim)
60
+ else:
61
+ self.body = WithBias_LayerNorm(dim)
62
+
63
+ def forward(self, x):
64
+ h, w = x.shape[-2:]
65
+ return to_4d(self.body(to_3d(x)), h, w)
66
+
67
+
68
+
69
+ ##########################################################################
70
+ ## Gated-Dconv Feed-Forward Network (GDFN)
71
+ class FeedForward(nn.Module):
72
+ def __init__(self, dim, ffn_expansion_factor, bias,finetune_type=None):
73
+ super(FeedForward, self).__init__()
74
+
75
+ hidden_features = int(dim*ffn_expansion_factor)
76
+
77
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
78
+
79
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
80
+
81
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
82
+
83
+ def forward(self, x):
84
+ x = self.project_in(x)
85
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
86
+ x = F.gelu(x1) * x2
87
+ x = self.project_out(x)
88
+
89
+ return x
90
+
91
+
92
+
93
+ ##########################################################################
94
+ ## Multi-DConv Head Transposed Self-Attention (MDTA)
95
+ class Attention(nn.Module):
96
+ def __init__(self, dim, num_heads, bias):
97
+ super(Attention, self).__init__()
98
+ self.num_heads = num_heads
99
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
100
+
101
+ self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias)
102
+ self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias)
103
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
104
+
105
+
106
+ def forward(self, x):
107
+ b,c,h,w = x.shape
108
+
109
+ qkv = self.qkv_dwconv(self.qkv(x))
110
+ q,k,v = qkv.chunk(3, dim=1)
111
+
112
+ q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
113
+ k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
114
+ v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
115
+
116
+ q = torch.nn.functional.normalize(q, dim=-1)
117
+ k = torch.nn.functional.normalize(k, dim=-1)
118
+
119
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
120
+ attn = attn.softmax(dim=-1)
121
+
122
+ out = (attn @ v)
123
+
124
+ out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
125
+
126
+ out = self.project_out(out)
127
+ return out
128
+
129
+
130
+
131
+ class resblock(nn.Module):
132
+ def __init__(self, dim):
133
+
134
+ super(resblock, self).__init__()
135
+ # self.norm = LayerNorm(dim, LayerNorm_type='BiasFree')
136
+
137
+ self.body = nn.Sequential(nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False),
138
+ nn.PReLU(dim),
139
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False))
140
+
141
+ def forward(self, x):
142
+ res = self.body((x))
143
+ res += x
144
+ return res
145
+
146
+
147
+ ##########################################################################
148
+ ## Resizing modules
149
+ class Downsample(nn.Module):
150
+ def __init__(self, n_feat):
151
+ super(Downsample, self).__init__()
152
+
153
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False),
154
+ nn.PixelUnshuffle(2))
155
+
156
+ def forward(self, x):
157
+ return self.body(x)
158
+
159
+ class Upsample(nn.Module):
160
+ def __init__(self, n_feat):
161
+ super(Upsample, self).__init__()
162
+
163
+ self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
164
+ nn.PixelShuffle(2))
165
+
166
+ def forward(self, x):
167
+ return self.body(x)
168
+
169
+
170
+ ##########################################################################
171
+ ## Transformer Block
172
+ class TransformerBlock(nn.Module):
173
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type,finetune_type=None):
174
+ super(TransformerBlock, self).__init__()
175
+
176
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
177
+ self.attn = Attention(dim, num_heads, bias)
178
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
179
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias,finetune_type)
180
+
181
+ def forward(self, x):
182
+ x = x + self.attn(self.norm1(x))
183
+ x = x + self.ffn(self.norm2(x))
184
+
185
+ return x
186
+
187
+
188
+
189
+ ##########################################################################
190
+ ## Overlapped image patch embedding with 3x3 Conv
191
+ class OverlapPatchEmbed(nn.Module):
192
+ def __init__(self, in_c=3, embed_dim=48, bias=False):
193
+ super(OverlapPatchEmbed, self).__init__()
194
+
195
+ self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
196
+
197
+ def forward(self, x):
198
+ x = self.proj(x)
199
+
200
+ return x
201
+
202
+ class TemperatureSoftmax(nn.Module):
203
+ def __init__(self, temperature):
204
+ super().__init__()
205
+ self.temperature = temperature
206
+
207
+ def forward(self, x):
208
+ return F.softmax(x / torch.clamp(self.temperature, min=1e-8), dim=1)
209
+
210
+ class DinoFeatureFusion(nn.Module):
211
+ def __init__(self, dino_dim=1536):
212
+ super(DinoFeatureFusion, self).__init__()
213
+
214
+ self.global_pool = nn.AdaptiveAvgPool2d(1)
215
+ self.temperature = nn.Parameter(torch.ones(1) * 1.0)
216
+
217
+ self.gate_network = nn.Sequential(
218
+ nn.Linear(dino_dim * 2, dino_dim),
219
+ nn.PReLU(dino_dim),
220
+ nn.Linear(dino_dim, 512),
221
+ nn.PReLU(512),
222
+ nn.Linear(512, 2),
223
+ TemperatureSoftmax(self.temperature)
224
+ )
225
+
226
+ def forward(self, dino_feat1, dino_feat2):
227
+ pooled_feat1 = self.global_pool(dino_feat1).squeeze(-1).squeeze(-1)
228
+ pooled_feat2 = self.global_pool(dino_feat2).squeeze(-1).squeeze(-1)
229
+ pooled_features = torch.cat([pooled_feat1, pooled_feat2], dim=1)
230
+
231
+ weights = self.gate_network(pooled_features)
232
+ weighted_feat1 = dino_feat1 * weights[:, 0:1].view(-1, 1, 1, 1)
233
+ weighted_feat2 = dino_feat2 * weights[:, 1:2].view(-1, 1, 1, 1)
234
+
235
+ fused_feat = weighted_feat1 + weighted_feat2
236
+ return fused_feat
237
+
238
+
239
+
240
+
241
+
242
+
243
+ class DRAdaptation(nn.Module):
244
+ def __init__(self, dino_dim=1536, restore_dim=48, scale_factor=14, size=128):
245
+ super(DRAdaptation, self).__init__()
246
+ self.size = size
247
+ self.restore_dim = restore_dim
248
+ self.adaptation = nn.Sequential(
249
+ nn.Conv2d(dino_dim, restore_dim*16, kernel_size=3, padding=1), #768
250
+ nn.PReLU(restore_dim*16),
251
+ nn.Conv2d(restore_dim*16, restore_dim*8, kernel_size=1),#384
252
+ )
253
+
254
+ def forward(self, dino_feat, restore_feat):
255
+ B, C, H, W = restore_feat.shape
256
+
257
+ adapted_dino = self.adaptation(dino_feat)
258
+
259
+ return adapted_dino
260
+
261
+
262
+
263
+ ##########################################################################
264
+ ##---------- D-R Fusion -----------------------
265
+ class DinoRestoreFeatureFusion(nn.Module):
266
+ def __init__(self, dim, num_heads, bias):
267
+ super(DinoRestoreFeatureFusion, self).__init__()
268
+ self.reduce_chan = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
269
+ def forward(self, dino_feat, restore_feat):
270
+ x_fusion = self.reduce_chan(torch.cat([dino_feat, restore_feat], dim=1))
271
+ res = x_fusion + restore_feat
272
+ return res
273
+
274
+
275
+ ##---------- restormerRFR -----------------------
276
+ class RestormerRFR(nn.Module):
277
+ def __init__(self,
278
+ inp_channels=3,
279
+ out_channels=3,
280
+ dim = 48,
281
+ num_blocks = [4,6,6,8],
282
+ num_refinement_blocks = 4,
283
+ heads = [1,2,4,8],
284
+ ffn_expansion_factor = 2.66,
285
+ bias = False,
286
+ LayerNorm_type = 'WithBias',
287
+ finetune_type = None,
288
+ img_size = 128
289
+ ):
290
+
291
+ super(RestormerRFR, self).__init__()
292
+
293
+
294
+ self.patch_embed = OverlapPatchEmbed(inp_channels, dim)
295
+
296
+ self.mask_token = torch.zeros(1, 3, img_size, img_size)
297
+
298
+ self.dr_adaptation1 = DRAdaptation(dino_dim=1536, restore_dim=48, scale_factor=14, size=128)
299
+ self.dr_adaptation2 = DRAdaptation(dino_dim=1536, restore_dim=48, scale_factor=14, size=128)
300
+ self.dr_adaptation3 = DRAdaptation(dino_dim=1536, restore_dim=48, scale_factor=14, size=128)
301
+ self.dr_fusion1 = DinoRestoreFeatureFusion(dim=int(dim*2**3), num_heads=heads[3], bias=bias)
302
+ self.dr_fusion2 = DinoRestoreFeatureFusion(dim=int(dim*2**2), num_heads=heads[2], bias=bias)
303
+ self.dr_fusion3 = DinoRestoreFeatureFusion(dim=int(dim*2**1), num_heads=heads[1], bias=bias)
304
+ self.up_4_3_dino1 = Upsample(int(dim*2**3))
305
+ self.up_4_3_dino2 = Upsample(int(dim*2**3))
306
+ self.up_3_2_dino = Upsample(int(dim*2**2))
307
+ self.dino_fusion_shallow = DinoFeatureFusion(dino_dim=1536)
308
+ self.dino_fusion_mid = DinoFeatureFusion(dino_dim=1536)
309
+ self.dino_fusion_deep = DinoFeatureFusion(dino_dim=1536)
310
+
311
+
312
+
313
+ self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,finetune_type=finetune_type if i==num_blocks[0]-1 else None) for i in range(num_blocks[0])])
314
+ self.down1_2 = Downsample(dim)
315
+ self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,finetune_type=finetune_type if i==num_blocks[1]-1 else None) for i in range(num_blocks[1])])
316
+ self.down2_3 = Downsample(int(dim*2**1))
317
+ self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,finetune_type=finetune_type if i==num_blocks[2]-1 else None) for i in range(num_blocks[2])])
318
+ self.down3_4 = Downsample(int(dim*2**2))
319
+ self.latent = nn.Sequential(*[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,finetune_type=finetune_type if i==num_blocks[3]-1 else None) for i in range(num_blocks[3])])
320
+
321
+ self.up4_3 = Upsample(int(dim*2**3))
322
+ self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
323
+ self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,finetune_type=finetune_type if i==num_blocks[2]-1 else None) for i in range(num_blocks[2])])
324
+
325
+
326
+ self.up3_2 = Upsample(int(dim*2**2))
327
+ self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
328
+ self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,finetune_type=finetune_type if i==num_blocks[1]-1 else None) for i in range(num_blocks[1])])
329
+
330
+ self.up2_1 = Upsample(int(dim*2**1))
331
+ self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,finetune_type=finetune_type if i==num_blocks[0]-1 else None) for i in range(num_blocks[0])])
332
+ self.refinement = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,finetune_type=finetune_type if i==num_refinement_blocks-1 else None) for i in range(num_refinement_blocks)])
333
+ self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)
334
+
335
+
336
+ def check_image_size(self, x):
337
+ _, _, h, w = x.size()
338
+ pad_size = 16
339
+ mod_pad_h = (pad_size - h % pad_size) % pad_size
340
+ mod_pad_w = (pad_size - w % pad_size) % pad_size
341
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
342
+ return x
343
+ def forward(self, inp_img, dino_features =None ):
344
+ b,c,h,w = inp_img.shape
345
+
346
+ shallow_feat1, mid_feat1, deep_feat1, shallow_feat2, mid_feat2, deep_feat2 = dino_features.values()
347
+ inp_img = self.check_image_size(inp_img)
348
+
349
+ inp_enc_level1 = self.patch_embed(inp_img)
350
+
351
+ out_enc_level1 = self.encoder_level1(inp_enc_level1)
352
+
353
+ inp_enc_level2 = self.down1_2(out_enc_level1)
354
+
355
+ out_enc_level2 = self.encoder_level2(inp_enc_level2)
356
+
357
+ inp_enc_level3 = self.down2_3(out_enc_level2)
358
+
359
+ out_enc_level3 = self.encoder_level3(inp_enc_level3)
360
+
361
+ inp_enc_level4 = self.down3_4(out_enc_level3)
362
+
363
+ latent = self.latent(inp_enc_level4)
364
+
365
+
366
+
367
+
368
+ shallow_feat = self.dino_fusion_shallow(shallow_feat1, shallow_feat2)
369
+ mid_feat = self.dino_fusion_mid(mid_feat1, mid_feat2)
370
+ deep_feat = self.dino_fusion_deep(deep_feat1, deep_feat2)
371
+
372
+ shallow_feat = self.dr_adaptation1(shallow_feat, latent)
373
+ mid_feat = self.dr_adaptation2(mid_feat, latent)
374
+ deep_feat = self.dr_adaptation3(deep_feat, latent)
375
+
376
+ latent = self.dr_fusion1(dino_feat=deep_feat, restore_feat=latent)
377
+ shallow_feat = self.up_4_3_dino1(shallow_feat)
378
+ mid_feat = self.up_4_3_dino2(mid_feat)
379
+
380
+ inp_dec_level3 = self.up4_3(latent)
381
+ inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
382
+ inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
383
+
384
+ out_dec_level3 = self.decoder_level3(inp_dec_level3)
385
+
386
+ out_dec_level3 = self.dr_fusion2(dino_feat=mid_feat, restore_feat=out_dec_level3)
387
+ shallow_feat = self.up_3_2_dino(shallow_feat)
388
+ inp_dec_level2 = self.up3_2(out_dec_level3)
389
+
390
+ inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
391
+ inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
392
+
393
+ out_dec_level2 = self.decoder_level2(inp_dec_level2)
394
+
395
+ out_dec_level2 = self.dr_fusion3(dino_feat=shallow_feat, restore_feat=out_dec_level2)
396
+
397
+
398
+ inp_dec_level1 = self.up2_1(out_dec_level2)
399
+ inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
400
+
401
+ out_dec_level1 = self.decoder_level1(inp_dec_level1)
402
+
403
+ out_dec_level1 = self.refinement(out_dec_level1)
404
+
405
+ out_dec_level1 = self.output(out_dec_level1)
406
+
407
+
408
+ return out_dec_level1[:,:,:h,:w]