SeunghoEum commited on
Commit
69f6042
·
0 Parent(s):

Add Gradio deraining demo

Browse files
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Deraining Demo
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ ## Deraining 데모 (Gradio)
13
+
14
+ ### 사용 방법
15
+ - 이미지를 업로드하면 deraining 결과를 반환합니다.
16
+
17
+ ### 가중치(Weights) 준비
18
+ 이 Space는 기본적으로 `checkpoints/model_best.pth`를 찾습니다.
19
+
20
+ - **옵션 A (권장)**: Space repo에 `checkpoints/model_best.pth`를 **Git LFS**로 업로드
21
+ - **옵션 B**: Hub에서 다운로드
22
+ - Space 설정(Environment variables)에 아래를 추가
23
+ - `WEIGHTS_REPO`: 예) `your-username/your-deraining-weights`
24
+ - `WEIGHTS_FILENAME`: 예) `model_best.pth`
25
+
26
+ ### 모델 선택
27
+ 기본 모델은 `CGFWMSRNet` 입니다. 바꾸려면 환경변수로 설정하세요:
28
+ - `MODEL_NAME`: 예) `CGFWMSRNet`
29
+
30
+
__pycache__/app.cpython-311.pyc ADDED
Binary file (7.83 kB). View file
 
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from skimage import img_as_ubyte
7
+
8
+ from models import find_models_def
9
+
10
+ try:
11
+ # Optional: download weights from the Hub when not bundled in the Space repo.
12
+ from huggingface_hub import hf_hub_download # type: ignore
13
+ except Exception: # pragma: no cover
14
+ hf_hub_download = None # type: ignore
15
+
16
+ # -------------------------
17
+ # Global (load once)
18
+ # -------------------------
19
+ MODEL = None
20
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
22
+
23
+ def _resolve_weights_path() -> str:
24
+ """
25
+ Priority:
26
+ 1) Local bundled weights: ./checkpoints/model_best.pth
27
+ 2) Download from Hub if WEIGHTS_REPO + WEIGHTS_FILENAME are provided
28
+ """
29
+ local_path = os.path.join(BASE_DIR, "checkpoints", "model_best.pth")
30
+ if os.path.exists(local_path):
31
+ return local_path
32
+
33
+ repo_id = os.getenv("WEIGHTS_REPO", "").strip()
34
+ filename = os.getenv("WEIGHTS_FILENAME", "model_best.pth").strip()
35
+ if repo_id and hf_hub_download is not None:
36
+ return hf_hub_download(repo_id=repo_id, filename=filename)
37
+
38
+ raise FileNotFoundError(
39
+ "Weights not found. Put weights at 'checkpoints/model_best.pth' "
40
+ "or set env WEIGHTS_REPO and WEIGHTS_FILENAME to download from the Hub."
41
+ )
42
+
43
+
44
+ def _load_state_dict(model: torch.nn.Module, weights_path: str) -> None:
45
+ ckpt = torch.load(weights_path, map_location="cpu")
46
+ state_dict = ckpt.get("state_dict", ckpt) if isinstance(ckpt, dict) else ckpt
47
+
48
+ # Handle DataParallel 'module.' prefix
49
+ if isinstance(state_dict, dict) and any(k.startswith("module.") for k in state_dict.keys()):
50
+ state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
51
+
52
+ try:
53
+ model.load_state_dict(state_dict, strict=True)
54
+ except Exception as e:
55
+ raise RuntimeError(
56
+ f"Failed to load weights from '{weights_path}'. "
57
+ f"Check that MODEL_NAME matches the checkpoint architecture."
58
+ ) from e
59
+
60
+
61
+ def load_model():
62
+ global MODEL
63
+
64
+ if MODEL is not None:
65
+ return
66
+
67
+ # Model selection (default: CGFWMSRNet)
68
+ model_name = os.getenv("MODEL_NAME", "CGFWMSRNet").strip() or "CGFWMSRNet"
69
+ MODEL = find_models_def(model_name)()
70
+
71
+ # Load checkpoint (local first, Hub fallback)
72
+ weights_path = _resolve_weights_path()
73
+ _load_state_dict(MODEL, weights_path)
74
+
75
+ # Move to device
76
+ MODEL.to(DEVICE)
77
+ MODEL.eval()
78
+
79
+ def infer_image(inp_img: np.ndarray):
80
+ """
81
+ inp_img: HxWx3 uint8 (Gradio Image numpy)
82
+ return: restored HxWx3 uint8
83
+ """
84
+ load_model()
85
+
86
+ if inp_img is None:
87
+ raise gr.Error("이미지를 업로드해주세요.")
88
+
89
+ # 입력이 RGBA로 들어올 수도 있어 RGB로 맞춤
90
+ if inp_img.ndim == 3 and inp_img.shape[2] == 4:
91
+ inp_img = inp_img[:, :, :3]
92
+
93
+ # uint8 -> float32 [0,1]
94
+ img = inp_img.astype(np.float32) / 255.0
95
+
96
+ # HWC -> CHW, add batch
97
+ img_t = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) # (1,3,H,W)
98
+ img_t = img_t.to(DEVICE)
99
+
100
+ height, width = img_t.shape[2], img_t.shape[3]
101
+ img_multiple_of = 16
102
+ H = ((height + img_multiple_of) // img_multiple_of) * img_multiple_of
103
+ W = ((width + img_multiple_of) // img_multiple_of) * img_multiple_of
104
+ padh = H - height if height % img_multiple_of != 0 else 0
105
+ padw = W - width if width % img_multiple_of != 0 else 0
106
+
107
+ img_t = F.pad(img_t, (0, padw, 0, padh), mode="reflect")
108
+
109
+ with torch.no_grad():
110
+ # 기존 코드: restored = model_restoration(input_)
111
+ restored = MODEL(img_t)
112
+
113
+ # 기존 코드: restored[0] 클램프
114
+ # 모델 출력 형태가 (list/tuple) 인지, tensor인지에 따라 안전하게 처리
115
+ if isinstance(restored, (list, tuple)):
116
+ restored = restored[0]
117
+
118
+ restored = torch.clamp(restored, 0, 1)
119
+ restored = restored[:, :, :height, :width]
120
+
121
+ # BCHW -> HWC uint8
122
+ restored_np = restored.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
123
+ restored_img = img_as_ubyte(restored_np) # uint8
124
+
125
+ # GPU 메모리 정리(필수는 아니지만 데모에서 유용)
126
+ if DEVICE.startswith("cuda"):
127
+ torch.cuda.empty_cache()
128
+
129
+ return restored_img
130
+
131
+ demo = gr.Interface(
132
+ fn=infer_image,
133
+ inputs=gr.Image(type="numpy", label="Input (Rainy)"),
134
+ outputs=gr.Image(type="numpy", label="Restored"),
135
+ title="Image Deraining Demo (Gradio + Hugging Face Spaces)",
136
+ description=(
137
+ "이미지를 업로드하면 비 제거 결과를 반환합니다. "
138
+ "Space에 가중치(`checkpoints/model_best.pth`)가 포함되어 있거나, "
139
+ "환경변수 WEIGHTS_REPO/WEIGHTS_FILENAME로 Hub에서 내려받을 수 있어야 합니다."
140
+ ),
141
+ )
142
+
143
+ if __name__ == "__main__":
144
+ demo.queue().launch()
checkpoints/model_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f912468bebbea184d41aa14cdc3678b11323a83ba264c3b7e7e7fb11c70abba
3
+ size 73604311
models/CGFWMSRNet.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from models.modules import *
6
+
7
+
8
+ class Model(nn.Module):
9
+ def __init__(
10
+ self,
11
+ in_c: int = 3,
12
+ out_c: int = 3,
13
+ n_feat: int = 40,
14
+ scale_unetfeats: int = 20,
15
+ num_cab: int = 8,
16
+ kernel_size: int = 3,
17
+ reduction: int = 4,
18
+ bias: bool = False,
19
+ ):
20
+ super(Model, self).__init__()
21
+
22
+ act = nn.PReLU()
23
+ self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
24
+ self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
25
+ self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
26
+
27
+ # Cross Stage Feature Fusion (CSFF)
28
+ self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False)
29
+ self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
30
+
31
+ self.stage2_encoder_1 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
32
+ self.stage2_decoder_1 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
33
+ self.stage2_encoder_2 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
34
+ self.stage2_decoder_2 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
35
+
36
+ self.stage3_encoder_1 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
37
+ self.stage3_decoder_1 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
38
+ self.stage3_encoder_2 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
39
+ self.stage3_decoder_2 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
40
+ self.stage3_encoder_3 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
41
+ self.stage3_decoder_3 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
42
+
43
+ self.CGAM12 = CGAM(n_feat, kernel_size=1, bias=bias)
44
+ self.CGAM23 = CGAM(n_feat, kernel_size=1, bias=bias)
45
+
46
+ self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias)
47
+ self.concat23 = conv(n_feat * 2, n_feat, kernel_size, bias=bias)
48
+ self.tail = conv(n_feat, out_c, kernel_size, bias=bias)
49
+
50
+ def forward(self, x3_img: torch.Tensor):
51
+ # Stage 1: 1/4 resolution
52
+ x1_img = F.interpolate(x3_img, scale_factor=0.25, mode="bilinear", align_corners=False)
53
+ x1 = self.shallow_feat1(x1_img)
54
+ feat1 = self.stage1_encoder(x1)
55
+ res1 = self.stage1_decoder(feat1)
56
+
57
+ x2_CGAMfeats, stage1_img = self.CGAM12(res1[0], x1_img)
58
+
59
+ feat1 = [F.interpolate(f, scale_factor=2, mode="bilinear", align_corners=False) for f in feat1]
60
+ res1 = [F.interpolate(f, scale_factor=2, mode="bilinear", align_corners=False) for f in res1]
61
+ x2_CGAMfeats = F.interpolate(x2_CGAMfeats, scale_factor=2, mode="bilinear", align_corners=False)
62
+
63
+ # Stage 2: 1/2 resolution
64
+ x2_img = F.interpolate(x3_img, scale_factor=0.5, mode="bilinear", align_corners=False)
65
+ x2 = self.shallow_feat2(x2_img)
66
+ x2_cat = self.concat12(torch.cat([x2, x2_CGAMfeats], 1))
67
+
68
+ feat2_1 = self.stage2_encoder_1(x2_cat, feat1, res1)
69
+ res2 = self.stage2_decoder_1(feat2_1)
70
+ feat2_2 = self.stage2_encoder_2(res2[0], feat1, res1)
71
+ res2 = self.stage2_decoder_2(feat2_2)
72
+
73
+ x3_CGAMfeats, stage2_img = self.CGAM23(res2[0], x2_img)
74
+
75
+ feat2 = [F.interpolate(f, scale_factor=2, mode="bilinear", align_corners=False) for f in feat2_2]
76
+ res2 = [F.interpolate(f, scale_factor=2, mode="bilinear", align_corners=False) for f in res2]
77
+ x3_CGAMfeats = F.interpolate(x3_CGAMfeats, scale_factor=2, mode="bilinear", align_corners=False)
78
+
79
+ # Stage 3: full resolution
80
+ x3 = self.shallow_feat3(x3_img)
81
+ x3_cat = self.concat23(torch.cat([x3, x3_CGAMfeats], 1))
82
+
83
+ feat3_1 = self.stage3_encoder_1(x3_cat, feat2, res2)
84
+ res3 = self.stage3_decoder_1(feat3_1)
85
+ feat3_2 = self.stage3_encoder_2(res3[0], feat2, res2)
86
+ res3 = self.stage3_decoder_2(feat3_2)
87
+ feat3_3 = self.stage3_encoder_3(res3[0], feat2, res2)
88
+ res3 = self.stage3_decoder_3(feat3_3)
89
+
90
+ stage3_img = self.tail(res3[0])
91
+
92
+ return [stage3_img + x3_img, stage2_img, stage1_img]
93
+
94
+
models/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+
4
+ def find_models_def(model_name: str):
5
+ module_name = f"models.{model_name}"
6
+ module = importlib.import_module(module_name)
7
+ return getattr(module, "Model")
8
+
9
+
models/__pycache__/CGFWMSRNet.cpython-311.pyc ADDED
Binary file (7.2 kB). View file
 
models/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (539 Bytes). View file
 
models/__pycache__/modules.cpython-311.pyc ADDED
Binary file (15.4 kB). View file
 
models/modules.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
7
+ return nn.Conv2d(
8
+ in_channels,
9
+ out_channels,
10
+ kernel_size,
11
+ padding=(kernel_size // 2),
12
+ bias=bias,
13
+ stride=stride,
14
+ )
15
+
16
+
17
+ class CALayer(nn.Module):
18
+ def __init__(self, channel, reduction=16, bias=False):
19
+ super().__init__()
20
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
21
+ self.conv_du = nn.Sequential(
22
+ nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
25
+ nn.Sigmoid(),
26
+ )
27
+
28
+ def forward(self, x):
29
+ y = self.avg_pool(x)
30
+ y = self.conv_du(y)
31
+ return x * y
32
+
33
+
34
+ class CAB(nn.Module):
35
+ def __init__(self, n_feat, kernel_size, reduction, bias, act):
36
+ super().__init__()
37
+ self.body = nn.Sequential(
38
+ conv(n_feat, n_feat, kernel_size, bias=bias),
39
+ act,
40
+ conv(n_feat, n_feat, kernel_size, bias=bias),
41
+ )
42
+ self.CA = CALayer(n_feat, reduction, bias=bias)
43
+
44
+ def forward(self, x):
45
+ res = self.body(x)
46
+ res = self.CA(res)
47
+ res += x
48
+ return res
49
+
50
+
51
+ class CG(nn.Module):
52
+ def __init__(self):
53
+ super().__init__()
54
+ self.sobel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]).unsqueeze(0).unsqueeze(0)
55
+ self.sobel_y = torch.tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]).unsqueeze(0).unsqueeze(0)
56
+
57
+ def forward(self, image):
58
+ device = image.device
59
+ sobel_x = self.sobel_x.to(device)
60
+ sobel_y = self.sobel_y.to(device)
61
+
62
+ gradients = []
63
+ for c in range(image.shape[1]):
64
+ grad_x = F.conv2d(image[:, c : c + 1, :, :], sobel_x, padding=1)
65
+ grad_y = F.conv2d(image[:, c : c + 1, :, :], sobel_y, padding=1)
66
+ gradient = torch.sqrt(torch.clamp(grad_x**2 + grad_y**2, min=1e-6))
67
+ gradients.append(gradient)
68
+ return torch.cat(gradients, dim=1)
69
+
70
+
71
+ class CGAM(nn.Module):
72
+ def __init__(self, n_feat, kernel_size, bias):
73
+ super().__init__()
74
+ self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
75
+ self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
76
+ self.conv3 = conv(3, n_feat, kernel_size, bias=bias)
77
+ self.gradfilter = CG()
78
+
79
+ def forward(self, x, x_img):
80
+ x1 = self.conv1(x)
81
+ img = self.conv2(x) + x_img
82
+ rain_grad = self.gradfilter(x_img)
83
+ clean_grad = self.gradfilter(img)
84
+ x2 = self.conv3(torch.abs(rain_grad - clean_grad))
85
+ grad_att = torch.sigmoid(x1 * x2)
86
+ x1 = x1 + x1 * grad_att
87
+ return x1, img
88
+
89
+
90
+ class DownSample(nn.Module):
91
+ def __init__(self, in_channels, s_factor):
92
+ super().__init__()
93
+ self.down = nn.Sequential(
94
+ nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=False),
95
+ nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False),
96
+ )
97
+
98
+ def forward(self, x):
99
+ return self.down(x)
100
+
101
+
102
+ class SkipUpSample(nn.Module):
103
+ def __init__(self, in_channels, s_factor):
104
+ super().__init__()
105
+ self.up = nn.Sequential(
106
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
107
+ nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False),
108
+ )
109
+
110
+ def forward(self, x, y):
111
+ x = self.up(x)
112
+ return x + y
113
+
114
+
115
+ class Encoder(nn.Module):
116
+ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff):
117
+ super().__init__()
118
+ self.encoder_level1 = nn.Sequential(*[CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
119
+ self.encoder_level2 = nn.Sequential(*[CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
120
+ self.encoder_level3 = nn.Sequential(*[CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
121
+
122
+ self.down12 = DownSample(n_feat, scale_unetfeats)
123
+ self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats)
124
+
125
+ if csff:
126
+ self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
127
+ self.csff_enc2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias)
128
+ self.csff_enc3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias)
129
+
130
+ self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
131
+ self.csff_dec2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias)
132
+ self.csff_dec3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias)
133
+
134
+ def forward(self, x, encoder_outs=None, decoder_outs=None):
135
+ enc1 = self.encoder_level1(x)
136
+ if (encoder_outs is not None) and (decoder_outs is not None):
137
+ enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])
138
+
139
+ x = self.down12(enc1)
140
+ enc2 = self.encoder_level2(x)
141
+ if (encoder_outs is not None) and (decoder_outs is not None):
142
+ enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])
143
+
144
+ x = self.down23(enc2)
145
+ enc3 = self.encoder_level3(x)
146
+ if (encoder_outs is not None) and (decoder_outs is not None):
147
+ enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
148
+
149
+ return [enc1, enc2, enc3]
150
+
151
+
152
+ class Decoder(nn.Module):
153
+ def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats):
154
+ super().__init__()
155
+ self.decoder_level1 = nn.Sequential(*[CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
156
+ self.decoder_level2 = nn.Sequential(*[CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
157
+ self.decoder_level3 = nn.Sequential(*[CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
158
+
159
+ self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
160
+ self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
161
+
162
+ self.up21 = SkipUpSample(n_feat, scale_unetfeats)
163
+ self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats)
164
+
165
+ def forward(self, outs):
166
+ enc1, enc2, enc3 = outs
167
+ dec3 = self.decoder_level3(enc3)
168
+
169
+ x = self.up32(dec3, self.skip_attn2(enc2))
170
+ dec2 = self.decoder_level2(x)
171
+
172
+ x = self.up21(dec2, self.skip_attn1(enc1))
173
+ dec1 = self.decoder_level1(x)
174
+
175
+ return [dec1, dec2, dec3]
176
+
177
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=4.0,<6.0
2
+ numpy
3
+ torch
4
+ scikit-image
5
+ huggingface_hub
6
+