Spaces:
Sleeping
Sleeping
Commit ·
69f6042
0
Parent(s):
Add Gradio deraining demo
Browse files- .gitattributes +1 -0
- README.md +30 -0
- __pycache__/app.cpython-311.pyc +0 -0
- app.py +144 -0
- checkpoints/model_best.pth +3 -0
- models/CGFWMSRNet.py +94 -0
- models/__init__.py +9 -0
- models/__pycache__/CGFWMSRNet.cpython-311.pyc +0 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/modules.cpython-311.pyc +0 -0
- models/modules.py +177 -0
- requirements.txt +6 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Deraining Demo
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## Deraining 데모 (Gradio)
|
| 13 |
+
|
| 14 |
+
### 사용 방법
|
| 15 |
+
- 이미지를 업로드하면 deraining 결과를 반환합니다.
|
| 16 |
+
|
| 17 |
+
### 가중치(Weights) 준비
|
| 18 |
+
이 Space는 기본적으로 `checkpoints/model_best.pth`를 찾습니다.
|
| 19 |
+
|
| 20 |
+
- **옵션 A (권장)**: Space repo에 `checkpoints/model_best.pth`를 **Git LFS**로 업로드
|
| 21 |
+
- **옵션 B**: Hub에서 다운로드
|
| 22 |
+
- Space 설정(Environment variables)에 아래를 추가
|
| 23 |
+
- `WEIGHTS_REPO`: 예) `your-username/your-deraining-weights`
|
| 24 |
+
- `WEIGHTS_FILENAME`: 예) `model_best.pth`
|
| 25 |
+
|
| 26 |
+
### 모델 선택
|
| 27 |
+
기본 모델은 `CGFWMSRNet` 입니다. 바꾸려면 환경변수로 설정하세요:
|
| 28 |
+
- `MODEL_NAME`: 예) `CGFWMSRNet`
|
| 29 |
+
|
| 30 |
+
|
__pycache__/app.cpython-311.pyc
ADDED
|
Binary file (7.83 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from skimage import img_as_ubyte
|
| 7 |
+
|
| 8 |
+
from models import find_models_def
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
# Optional: download weights from the Hub when not bundled in the Space repo.
|
| 12 |
+
from huggingface_hub import hf_hub_download # type: ignore
|
| 13 |
+
except Exception: # pragma: no cover
|
| 14 |
+
hf_hub_download = None # type: ignore
|
| 15 |
+
|
| 16 |
+
# -------------------------
|
| 17 |
+
# Global (load once)
|
| 18 |
+
# -------------------------
|
| 19 |
+
MODEL = None
|
| 20 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 22 |
+
|
| 23 |
+
def _resolve_weights_path() -> str:
|
| 24 |
+
"""
|
| 25 |
+
Priority:
|
| 26 |
+
1) Local bundled weights: ./checkpoints/model_best.pth
|
| 27 |
+
2) Download from Hub if WEIGHTS_REPO + WEIGHTS_FILENAME are provided
|
| 28 |
+
"""
|
| 29 |
+
local_path = os.path.join(BASE_DIR, "checkpoints", "model_best.pth")
|
| 30 |
+
if os.path.exists(local_path):
|
| 31 |
+
return local_path
|
| 32 |
+
|
| 33 |
+
repo_id = os.getenv("WEIGHTS_REPO", "").strip()
|
| 34 |
+
filename = os.getenv("WEIGHTS_FILENAME", "model_best.pth").strip()
|
| 35 |
+
if repo_id and hf_hub_download is not None:
|
| 36 |
+
return hf_hub_download(repo_id=repo_id, filename=filename)
|
| 37 |
+
|
| 38 |
+
raise FileNotFoundError(
|
| 39 |
+
"Weights not found. Put weights at 'checkpoints/model_best.pth' "
|
| 40 |
+
"or set env WEIGHTS_REPO and WEIGHTS_FILENAME to download from the Hub."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _load_state_dict(model: torch.nn.Module, weights_path: str) -> None:
|
| 45 |
+
ckpt = torch.load(weights_path, map_location="cpu")
|
| 46 |
+
state_dict = ckpt.get("state_dict", ckpt) if isinstance(ckpt, dict) else ckpt
|
| 47 |
+
|
| 48 |
+
# Handle DataParallel 'module.' prefix
|
| 49 |
+
if isinstance(state_dict, dict) and any(k.startswith("module.") for k in state_dict.keys()):
|
| 50 |
+
state_dict = {k.replace("module.", "", 1): v for k, v in state_dict.items()}
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
model.load_state_dict(state_dict, strict=True)
|
| 54 |
+
except Exception as e:
|
| 55 |
+
raise RuntimeError(
|
| 56 |
+
f"Failed to load weights from '{weights_path}'. "
|
| 57 |
+
f"Check that MODEL_NAME matches the checkpoint architecture."
|
| 58 |
+
) from e
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_model():
|
| 62 |
+
global MODEL
|
| 63 |
+
|
| 64 |
+
if MODEL is not None:
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
# Model selection (default: CGFWMSRNet)
|
| 68 |
+
model_name = os.getenv("MODEL_NAME", "CGFWMSRNet").strip() or "CGFWMSRNet"
|
| 69 |
+
MODEL = find_models_def(model_name)()
|
| 70 |
+
|
| 71 |
+
# Load checkpoint (local first, Hub fallback)
|
| 72 |
+
weights_path = _resolve_weights_path()
|
| 73 |
+
_load_state_dict(MODEL, weights_path)
|
| 74 |
+
|
| 75 |
+
# Move to device
|
| 76 |
+
MODEL.to(DEVICE)
|
| 77 |
+
MODEL.eval()
|
| 78 |
+
|
| 79 |
+
def infer_image(inp_img: np.ndarray):
|
| 80 |
+
"""
|
| 81 |
+
inp_img: HxWx3 uint8 (Gradio Image numpy)
|
| 82 |
+
return: restored HxWx3 uint8
|
| 83 |
+
"""
|
| 84 |
+
load_model()
|
| 85 |
+
|
| 86 |
+
if inp_img is None:
|
| 87 |
+
raise gr.Error("이미지를 업로드해주세요.")
|
| 88 |
+
|
| 89 |
+
# 입력이 RGBA로 들어올 수도 있어 RGB로 맞춤
|
| 90 |
+
if inp_img.ndim == 3 and inp_img.shape[2] == 4:
|
| 91 |
+
inp_img = inp_img[:, :, :3]
|
| 92 |
+
|
| 93 |
+
# uint8 -> float32 [0,1]
|
| 94 |
+
img = inp_img.astype(np.float32) / 255.0
|
| 95 |
+
|
| 96 |
+
# HWC -> CHW, add batch
|
| 97 |
+
img_t = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) # (1,3,H,W)
|
| 98 |
+
img_t = img_t.to(DEVICE)
|
| 99 |
+
|
| 100 |
+
height, width = img_t.shape[2], img_t.shape[3]
|
| 101 |
+
img_multiple_of = 16
|
| 102 |
+
H = ((height + img_multiple_of) // img_multiple_of) * img_multiple_of
|
| 103 |
+
W = ((width + img_multiple_of) // img_multiple_of) * img_multiple_of
|
| 104 |
+
padh = H - height if height % img_multiple_of != 0 else 0
|
| 105 |
+
padw = W - width if width % img_multiple_of != 0 else 0
|
| 106 |
+
|
| 107 |
+
img_t = F.pad(img_t, (0, padw, 0, padh), mode="reflect")
|
| 108 |
+
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
# 기존 코드: restored = model_restoration(input_)
|
| 111 |
+
restored = MODEL(img_t)
|
| 112 |
+
|
| 113 |
+
# 기존 코드: restored[0] 클램프
|
| 114 |
+
# 모델 출력 형태가 (list/tuple) 인지, tensor인지에 따라 안전하게 처리
|
| 115 |
+
if isinstance(restored, (list, tuple)):
|
| 116 |
+
restored = restored[0]
|
| 117 |
+
|
| 118 |
+
restored = torch.clamp(restored, 0, 1)
|
| 119 |
+
restored = restored[:, :, :height, :width]
|
| 120 |
+
|
| 121 |
+
# BCHW -> HWC uint8
|
| 122 |
+
restored_np = restored.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
|
| 123 |
+
restored_img = img_as_ubyte(restored_np) # uint8
|
| 124 |
+
|
| 125 |
+
# GPU 메모리 정리(필수는 아니지만 데모에서 유용)
|
| 126 |
+
if DEVICE.startswith("cuda"):
|
| 127 |
+
torch.cuda.empty_cache()
|
| 128 |
+
|
| 129 |
+
return restored_img
|
| 130 |
+
|
| 131 |
+
demo = gr.Interface(
|
| 132 |
+
fn=infer_image,
|
| 133 |
+
inputs=gr.Image(type="numpy", label="Input (Rainy)"),
|
| 134 |
+
outputs=gr.Image(type="numpy", label="Restored"),
|
| 135 |
+
title="Image Deraining Demo (Gradio + Hugging Face Spaces)",
|
| 136 |
+
description=(
|
| 137 |
+
"이미지를 업로드하면 비 제거 결과를 반환합니다. "
|
| 138 |
+
"Space에 가중치(`checkpoints/model_best.pth`)가 포함되어 있거나, "
|
| 139 |
+
"환경변수 WEIGHTS_REPO/WEIGHTS_FILENAME로 Hub에서 내려받을 수 있어야 합니다."
|
| 140 |
+
),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
demo.queue().launch()
|
checkpoints/model_best.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f912468bebbea184d41aa14cdc3678b11323a83ba264c3b7e7e7fb11c70abba
|
| 3 |
+
size 73604311
|
models/CGFWMSRNet.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
from models.modules import *
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Model(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
in_c: int = 3,
|
| 12 |
+
out_c: int = 3,
|
| 13 |
+
n_feat: int = 40,
|
| 14 |
+
scale_unetfeats: int = 20,
|
| 15 |
+
num_cab: int = 8,
|
| 16 |
+
kernel_size: int = 3,
|
| 17 |
+
reduction: int = 4,
|
| 18 |
+
bias: bool = False,
|
| 19 |
+
):
|
| 20 |
+
super(Model, self).__init__()
|
| 21 |
+
|
| 22 |
+
act = nn.PReLU()
|
| 23 |
+
self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
|
| 24 |
+
self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
|
| 25 |
+
self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat, kernel_size, reduction, bias=bias, act=act))
|
| 26 |
+
|
| 27 |
+
# Cross Stage Feature Fusion (CSFF)
|
| 28 |
+
self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False)
|
| 29 |
+
self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
|
| 30 |
+
|
| 31 |
+
self.stage2_encoder_1 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
|
| 32 |
+
self.stage2_decoder_1 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
|
| 33 |
+
self.stage2_encoder_2 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
|
| 34 |
+
self.stage2_decoder_2 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
|
| 35 |
+
|
| 36 |
+
self.stage3_encoder_1 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
|
| 37 |
+
self.stage3_decoder_1 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
|
| 38 |
+
self.stage3_encoder_2 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
|
| 39 |
+
self.stage3_decoder_2 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
|
| 40 |
+
self.stage3_encoder_3 = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True)
|
| 41 |
+
self.stage3_decoder_3 = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats)
|
| 42 |
+
|
| 43 |
+
self.CGAM12 = CGAM(n_feat, kernel_size=1, bias=bias)
|
| 44 |
+
self.CGAM23 = CGAM(n_feat, kernel_size=1, bias=bias)
|
| 45 |
+
|
| 46 |
+
self.concat12 = conv(n_feat * 2, n_feat, kernel_size, bias=bias)
|
| 47 |
+
self.concat23 = conv(n_feat * 2, n_feat, kernel_size, bias=bias)
|
| 48 |
+
self.tail = conv(n_feat, out_c, kernel_size, bias=bias)
|
| 49 |
+
|
| 50 |
+
def forward(self, x3_img: torch.Tensor):
|
| 51 |
+
# Stage 1: 1/4 resolution
|
| 52 |
+
x1_img = F.interpolate(x3_img, scale_factor=0.25, mode="bilinear", align_corners=False)
|
| 53 |
+
x1 = self.shallow_feat1(x1_img)
|
| 54 |
+
feat1 = self.stage1_encoder(x1)
|
| 55 |
+
res1 = self.stage1_decoder(feat1)
|
| 56 |
+
|
| 57 |
+
x2_CGAMfeats, stage1_img = self.CGAM12(res1[0], x1_img)
|
| 58 |
+
|
| 59 |
+
feat1 = [F.interpolate(f, scale_factor=2, mode="bilinear", align_corners=False) for f in feat1]
|
| 60 |
+
res1 = [F.interpolate(f, scale_factor=2, mode="bilinear", align_corners=False) for f in res1]
|
| 61 |
+
x2_CGAMfeats = F.interpolate(x2_CGAMfeats, scale_factor=2, mode="bilinear", align_corners=False)
|
| 62 |
+
|
| 63 |
+
# Stage 2: 1/2 resolution
|
| 64 |
+
x2_img = F.interpolate(x3_img, scale_factor=0.5, mode="bilinear", align_corners=False)
|
| 65 |
+
x2 = self.shallow_feat2(x2_img)
|
| 66 |
+
x2_cat = self.concat12(torch.cat([x2, x2_CGAMfeats], 1))
|
| 67 |
+
|
| 68 |
+
feat2_1 = self.stage2_encoder_1(x2_cat, feat1, res1)
|
| 69 |
+
res2 = self.stage2_decoder_1(feat2_1)
|
| 70 |
+
feat2_2 = self.stage2_encoder_2(res2[0], feat1, res1)
|
| 71 |
+
res2 = self.stage2_decoder_2(feat2_2)
|
| 72 |
+
|
| 73 |
+
x3_CGAMfeats, stage2_img = self.CGAM23(res2[0], x2_img)
|
| 74 |
+
|
| 75 |
+
feat2 = [F.interpolate(f, scale_factor=2, mode="bilinear", align_corners=False) for f in feat2_2]
|
| 76 |
+
res2 = [F.interpolate(f, scale_factor=2, mode="bilinear", align_corners=False) for f in res2]
|
| 77 |
+
x3_CGAMfeats = F.interpolate(x3_CGAMfeats, scale_factor=2, mode="bilinear", align_corners=False)
|
| 78 |
+
|
| 79 |
+
# Stage 3: full resolution
|
| 80 |
+
x3 = self.shallow_feat3(x3_img)
|
| 81 |
+
x3_cat = self.concat23(torch.cat([x3, x3_CGAMfeats], 1))
|
| 82 |
+
|
| 83 |
+
feat3_1 = self.stage3_encoder_1(x3_cat, feat2, res2)
|
| 84 |
+
res3 = self.stage3_decoder_1(feat3_1)
|
| 85 |
+
feat3_2 = self.stage3_encoder_2(res3[0], feat2, res2)
|
| 86 |
+
res3 = self.stage3_decoder_2(feat3_2)
|
| 87 |
+
feat3_3 = self.stage3_encoder_3(res3[0], feat2, res2)
|
| 88 |
+
res3 = self.stage3_decoder_3(feat3_3)
|
| 89 |
+
|
| 90 |
+
stage3_img = self.tail(res3[0])
|
| 91 |
+
|
| 92 |
+
return [stage3_img + x3_img, stage2_img, stage1_img]
|
| 93 |
+
|
| 94 |
+
|
models/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def find_models_def(model_name: str):
|
| 5 |
+
module_name = f"models.{model_name}"
|
| 6 |
+
module = importlib.import_module(module_name)
|
| 7 |
+
return getattr(module, "Model")
|
| 8 |
+
|
| 9 |
+
|
models/__pycache__/CGFWMSRNet.cpython-311.pyc
ADDED
|
Binary file (7.2 kB). View file
|
|
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (539 Bytes). View file
|
|
|
models/__pycache__/modules.cpython-311.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
models/modules.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
|
| 7 |
+
return nn.Conv2d(
|
| 8 |
+
in_channels,
|
| 9 |
+
out_channels,
|
| 10 |
+
kernel_size,
|
| 11 |
+
padding=(kernel_size // 2),
|
| 12 |
+
bias=bias,
|
| 13 |
+
stride=stride,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CALayer(nn.Module):
|
| 18 |
+
def __init__(self, channel, reduction=16, bias=False):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 21 |
+
self.conv_du = nn.Sequential(
|
| 22 |
+
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
|
| 23 |
+
nn.ReLU(inplace=True),
|
| 24 |
+
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
|
| 25 |
+
nn.Sigmoid(),
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
def forward(self, x):
|
| 29 |
+
y = self.avg_pool(x)
|
| 30 |
+
y = self.conv_du(y)
|
| 31 |
+
return x * y
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CAB(nn.Module):
|
| 35 |
+
def __init__(self, n_feat, kernel_size, reduction, bias, act):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.body = nn.Sequential(
|
| 38 |
+
conv(n_feat, n_feat, kernel_size, bias=bias),
|
| 39 |
+
act,
|
| 40 |
+
conv(n_feat, n_feat, kernel_size, bias=bias),
|
| 41 |
+
)
|
| 42 |
+
self.CA = CALayer(n_feat, reduction, bias=bias)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
res = self.body(x)
|
| 46 |
+
res = self.CA(res)
|
| 47 |
+
res += x
|
| 48 |
+
return res
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class CG(nn.Module):
|
| 52 |
+
def __init__(self):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.sobel_x = torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]).unsqueeze(0).unsqueeze(0)
|
| 55 |
+
self.sobel_y = torch.tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]]).unsqueeze(0).unsqueeze(0)
|
| 56 |
+
|
| 57 |
+
def forward(self, image):
|
| 58 |
+
device = image.device
|
| 59 |
+
sobel_x = self.sobel_x.to(device)
|
| 60 |
+
sobel_y = self.sobel_y.to(device)
|
| 61 |
+
|
| 62 |
+
gradients = []
|
| 63 |
+
for c in range(image.shape[1]):
|
| 64 |
+
grad_x = F.conv2d(image[:, c : c + 1, :, :], sobel_x, padding=1)
|
| 65 |
+
grad_y = F.conv2d(image[:, c : c + 1, :, :], sobel_y, padding=1)
|
| 66 |
+
gradient = torch.sqrt(torch.clamp(grad_x**2 + grad_y**2, min=1e-6))
|
| 67 |
+
gradients.append(gradient)
|
| 68 |
+
return torch.cat(gradients, dim=1)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class CGAM(nn.Module):
|
| 72 |
+
def __init__(self, n_feat, kernel_size, bias):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias)
|
| 75 |
+
self.conv2 = conv(n_feat, 3, kernel_size, bias=bias)
|
| 76 |
+
self.conv3 = conv(3, n_feat, kernel_size, bias=bias)
|
| 77 |
+
self.gradfilter = CG()
|
| 78 |
+
|
| 79 |
+
def forward(self, x, x_img):
|
| 80 |
+
x1 = self.conv1(x)
|
| 81 |
+
img = self.conv2(x) + x_img
|
| 82 |
+
rain_grad = self.gradfilter(x_img)
|
| 83 |
+
clean_grad = self.gradfilter(img)
|
| 84 |
+
x2 = self.conv3(torch.abs(rain_grad - clean_grad))
|
| 85 |
+
grad_att = torch.sigmoid(x1 * x2)
|
| 86 |
+
x1 = x1 + x1 * grad_att
|
| 87 |
+
return x1, img
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class DownSample(nn.Module):
|
| 91 |
+
def __init__(self, in_channels, s_factor):
|
| 92 |
+
super().__init__()
|
| 93 |
+
self.down = nn.Sequential(
|
| 94 |
+
nn.Upsample(scale_factor=0.5, mode="bilinear", align_corners=False),
|
| 95 |
+
nn.Conv2d(in_channels, in_channels + s_factor, 1, stride=1, padding=0, bias=False),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def forward(self, x):
|
| 99 |
+
return self.down(x)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class SkipUpSample(nn.Module):
|
| 103 |
+
def __init__(self, in_channels, s_factor):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.up = nn.Sequential(
|
| 106 |
+
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
| 107 |
+
nn.Conv2d(in_channels + s_factor, in_channels, 1, stride=1, padding=0, bias=False),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def forward(self, x, y):
|
| 111 |
+
x = self.up(x)
|
| 112 |
+
return x + y
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Encoder(nn.Module):
|
| 116 |
+
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff):
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.encoder_level1 = nn.Sequential(*[CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
|
| 119 |
+
self.encoder_level2 = nn.Sequential(*[CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
|
| 120 |
+
self.encoder_level3 = nn.Sequential(*[CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
|
| 121 |
+
|
| 122 |
+
self.down12 = DownSample(n_feat, scale_unetfeats)
|
| 123 |
+
self.down23 = DownSample(n_feat + scale_unetfeats, scale_unetfeats)
|
| 124 |
+
|
| 125 |
+
if csff:
|
| 126 |
+
self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
|
| 127 |
+
self.csff_enc2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias)
|
| 128 |
+
self.csff_enc3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias)
|
| 129 |
+
|
| 130 |
+
self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)
|
| 131 |
+
self.csff_dec2 = nn.Conv2d(n_feat + scale_unetfeats, n_feat + scale_unetfeats, kernel_size=1, bias=bias)
|
| 132 |
+
self.csff_dec3 = nn.Conv2d(n_feat + (scale_unetfeats * 2), n_feat + (scale_unetfeats * 2), kernel_size=1, bias=bias)
|
| 133 |
+
|
| 134 |
+
def forward(self, x, encoder_outs=None, decoder_outs=None):
|
| 135 |
+
enc1 = self.encoder_level1(x)
|
| 136 |
+
if (encoder_outs is not None) and (decoder_outs is not None):
|
| 137 |
+
enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0])
|
| 138 |
+
|
| 139 |
+
x = self.down12(enc1)
|
| 140 |
+
enc2 = self.encoder_level2(x)
|
| 141 |
+
if (encoder_outs is not None) and (decoder_outs is not None):
|
| 142 |
+
enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1])
|
| 143 |
+
|
| 144 |
+
x = self.down23(enc2)
|
| 145 |
+
enc3 = self.encoder_level3(x)
|
| 146 |
+
if (encoder_outs is not None) and (decoder_outs is not None):
|
| 147 |
+
enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2])
|
| 148 |
+
|
| 149 |
+
return [enc1, enc2, enc3]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class Decoder(nn.Module):
|
| 153 |
+
def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.decoder_level1 = nn.Sequential(*[CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
|
| 156 |
+
self.decoder_level2 = nn.Sequential(*[CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
|
| 157 |
+
self.decoder_level3 = nn.Sequential(*[CAB(n_feat + (scale_unetfeats * 2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)])
|
| 158 |
+
|
| 159 |
+
self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act)
|
| 160 |
+
self.skip_attn2 = CAB(n_feat + scale_unetfeats, kernel_size, reduction, bias=bias, act=act)
|
| 161 |
+
|
| 162 |
+
self.up21 = SkipUpSample(n_feat, scale_unetfeats)
|
| 163 |
+
self.up32 = SkipUpSample(n_feat + scale_unetfeats, scale_unetfeats)
|
| 164 |
+
|
| 165 |
+
def forward(self, outs):
|
| 166 |
+
enc1, enc2, enc3 = outs
|
| 167 |
+
dec3 = self.decoder_level3(enc3)
|
| 168 |
+
|
| 169 |
+
x = self.up32(dec3, self.skip_attn2(enc2))
|
| 170 |
+
dec2 = self.decoder_level2(x)
|
| 171 |
+
|
| 172 |
+
x = self.up21(dec2, self.skip_attn1(enc1))
|
| 173 |
+
dec1 = self.decoder_level1(x)
|
| 174 |
+
|
| 175 |
+
return [dec1, dec2, dec3]
|
| 176 |
+
|
| 177 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0,<6.0
|
| 2 |
+
numpy
|
| 3 |
+
torch
|
| 4 |
+
scikit-image
|
| 5 |
+
huggingface_hub
|
| 6 |
+
|