Restormer FGP 改进说明
本文档总结当前低光照去雨任务中,基于原始 Restormer baseline 所做的网络、训练、验证和日志保存改进。
1. 总体变化
原始版本:
Restormer
RGB input
-> Restormer Encoder-Decoder
-> RGB output
Loss: L1
Val: 每个 epoch 测试
Checkpoint: 每个 epoch 保存
当前改进版:
RestormerLowLightRain / FGP-Restormer
RGB input
-> illumination / structure / frequency prior extraction
-> multi-scale prior encoder
-> soft/hard mixture-of-frequency experts
-> global-local degradation router
-> Restormer Encoder-Decoder with multi-scale prior fusion
-> RGB output
Loss: Charbonnier + Edge + FFT
Val: 每 5 epoch 测试
Checkpoint: 每 5 epoch 保存,best 单独保存
Metrics: metrics.csv 逐 epoch 记录
对应核心文件:
net/restormer_lowlight_rain.py:网络结构改进。train_restormer.py:训练、loss、EMA、验证、日志保存。utils/inference_utils.py:验证阶段 padding、crop、PSNR、TTA。run_restormer_fgp.sh:当前 FGP 版本启动脚本。
2. 网络结构改进
2.1 从原始 Restormer 改到 FGP-Restormer
从:
Restormer()
改到:
RestormerLowLightRain()
训练入口中通过参数选择:
--model fgp_restormer
对应代码位置:
train_restormer.py
build_model()
原始 Restormer 只使用 RGB 图像本身作为输入特征,当前版本在 Restormer 主干之外增加了一个先验分支,用于显式建模低光照和雨纹退化。
3. 先验提取改进
3.1 从单纯 RGB 输入改到五类退化先验
从:
RGB input only
改到:
RGB input
-> luminance
-> darkness
-> high-frequency response
-> structure edge response
-> rain residual
对应模块:
net/restormer_lowlight_rain.py
PriorMapExtractor
具体五类先验:
| 先验 | 作用 |
|---|---|
luminance |
表示亮度分布,用于低光照恢复 |
darkness |
表示暗区域强度,引导曝光/照明增强 |
high |
Laplacian 高频响应,用于捕捉雨纹和细节 |
structure |
Sobel 边缘结构,用于保护物体轮廓 |
rain_residual |
luminance - low_frequency 的正残差,用于突出局部雨纹/亮 streak |
原始 Restormer 没有显式区分低频照明和高频雨纹;当前版本把低光和去雨拆成不同先验,更符合低光照去雨的退化特性。
4. 多尺度先验编码
4.1 从无先验分支改到 MultiPriorEncoder
从:
RGB -> patch_embed -> encoder
改到:
RGB -> PriorMapExtractor -> MultiPriorEncoder
-> prior_1 H x W
-> prior_2 H/2 x W/2
-> prior_3 H/4 x W/4
-> prior_4 H/8 x W/8
对应模块:
net/restormer_lowlight_rain.py
MultiPriorEncoder
这让先验可以在 Restormer 的 4 个 encoder scale 上逐层注入,而不是只在输入层简单拼接。
5. 频率专家 MoE 改进
5.1 从单一特征变换改到 Mixture-of-Frequency Experts
从:
main feature -> conv/attention -> output feature
改到:
main feature
-> low-frequency expert
-> high-frequency expert
-> structure expert
-> router selects / blends experts
-> expert feature
对应模块:
net/restormer_lowlight_rain.py
FrequencyExpertMixer
三个专家分别负责:
| Expert | 作用 |
|---|---|
| Low-frequency expert | 处理低光照、曝光、照明不均 |
| High-frequency expert | 处理雨纹、高频退化、细节恢复 |
| Structure expert | 处理边缘、轮廓、结构保持 |
5.2 从统一 soft 路由改到浅层 soft + 深层 hard 路由
从:
所有层使用同一种融合方式
改到:
浅层 encoder: soft expert routing
深层 encoder / latent: hard expert routing
对应代码:
RestormerLowLightRain.__init__()
fuse1 = FrequencyPriorFusion(..., hard_expert=False)
fuse2 = FrequencyPriorFusion(..., hard_expert=False)
fuse3 = FrequencyPriorFusion(..., hard_expert=True)
fuse4 = FrequencyPriorFusion(..., hard_expert=True)
设计动机:
- 浅层特征保留更多低级纹理和退化线索,适合 soft routing。
- 深层特征更接近语义和重建决策,适合 hard routing 强化专家分工。
这个设计参考了近期 all-in-one restoration / adverse weather restoration 中的 MoE routing 思路。
6. Global-Local Router 改进
6.1 从普通 concat fusion 改到全局-局部动态路由
从:
concat(main_feat, prior_feat) -> conv -> residual add
改到:
main_feat + prior_feat
-> global degradation router -> channel gate
-> local weather/texture router -> spatial mask
-> global-local gate
-> gated prior fusion
对应模块:
net/restormer_lowlight_rain.py
GlobalLocalRouter
FrequencyPriorFusion
其中:
global_router通过全局平均池化感知整图退化类型,例如整体低光、整体雨强。local_router通过卷积生成局部空间 mask,关注局部雨纹、边缘、暗区。
这样可以同时处理:
- 全局低光照问题。
- 局部雨纹/高频 streak。
- 结构边缘恢复。
7. 多尺度融合方式
7.1 从 Restormer 原始 encoder 改到 encoder 每层注入先验
从:
out_enc_level1 = encoder_level1(inp_enc_level1)
out_enc_level2 = encoder_level2(inp_enc_level2)
out_enc_level3 = encoder_level3(inp_enc_level3)
latent = latent(inp_enc_level4)
改到:
out_enc_level1 = encoder_level1(inp_enc_level1)
out_enc_level1 = fuse1(out_enc_level1, prior_1)
out_enc_level2 = encoder_level2(inp_enc_level2)
out_enc_level2 = fuse2(out_enc_level2, prior_2)
out_enc_level3 = encoder_level3(inp_enc_level3)
out_enc_level3 = fuse3(out_enc_level3, prior_3)
latent = latent(inp_enc_level4)
latent = fuse4(latent, prior_4)
这样主干仍然是 Restormer,但每个 scale 都受到低光/雨纹/结构先验引导。
8. 训练数据改进
8.1 从依赖旧 DataLoader 改到显式 paired dataset
从:
get_training_data(opt.TRAINING.TRAIN_DIR, ...)
改到:
PairedPatchDataset(
train_inp = ./dataset/train/syn+real/input,
train_tar = ./dataset/train/syn+real/target
)
对应代码:
train_restormer.py
PairedPatchDataset
当前训练集明确为:
/media/home/songmeixi_insta360.com/Low_light_rainy_new/dataset/train/syn+real/input
->
/media/home/songmeixi_insta360.com/Low_light_rainy_new/dataset/train/syn+real/target
不使用:
dataset/train/syn+real/target_smoke
这样可以避免 smoke 任务和低光照去雨任务混杂。
9. Loss 改进
9.1 从 L1 Loss 改到 Charbonnier + Edge + FFT Loss
从:
loss = L1(output, target)
改到:
loss = Charbonnier(output, target)
+ edge_weight * L1(edge(output), edge(target))
+ fft_weight * L1(FFT_amp(output), FFT_amp(target))
对应代码:
train_restormer.py
RestorationLoss
启动参数:
--loss_mode charbonnier_edge_fft
各项作用:
| Loss | 作用 |
|---|---|
| Charbonnier | 比 L1 更平滑,常用于图像复原 |
| Edge loss | 保持边缘和结构,减少去雨后的模糊 |
| FFT loss | 约束频率幅度,提升高频细节和雨纹去除一致性 |
10. EMA 改进
10.1 从直接验证当前权重改到 EMA 权重验证
从:
validate(model)
改到:
ema.update(model)
validate(ema_model)
对应代码:
train_restormer.py
ModelEma
启动参数:
--use_ema
EMA 可以降低训练震荡,让验证 PSNR 更稳定。
11. 验证协议改进
11.1 从直接整图推理改到 factor-8 padding + crop
从:
output = model(input)
psnr = PSNR(output, target)
改到:
padded_input = pad_to_factor(input, factor=8)
output = model(padded_input)
output = crop_to_original_size(output)
output = clamp(output, 0, 1)
psnr = RGB_PSNR(output, target)
对应代码:
utils/inference_utils.py
pad_to_factor
crop_to_size
run_model
batch_rgb_psnr
Restormer 有多次 downsample/upsample,factor-8 padding 可以避免尺寸不能被 8 整除时的潜在问题。
11.2 从每个 epoch 测试改到每 5 epoch 测试
从:
--val_every 1
改到:
--val_every 5
这样减少测试时间开销,适合长时间训练。
12. Checkpoint 保存改进
12.1 从每个 epoch 保存改到每 5 epoch 保存
从:
model_1.pth
model_2.pth
model_3.pth
...
改到:
model_5.pth
model_10.pth
model_15.pth
...
model_best.pth
对应参数:
--save_every 5
说明:
model_{epoch}.pth:每 5 个 epoch 保存一次。model_best.pth:只要验证 PSNR 创新高就保存。
13. 指标记录改进
13.1 从只打印终端改到保存 metrics.csv
从:
terminal only:
Epoch loss / Val PSNR
改到:
checkpoint_restormer_fgp/Deraining/models/RestormerFGP/metrics.csv
字段:
epoch,train_loss,lr,val_psnr,best_psnr,epoch_time,saved_checkpoint,is_best
对应代码:
train_restormer.py
append_metrics_csv
说明:
- 每个 epoch 都会写一行。
- 非测试 epoch 的
val_psnr为空。 - 保存普通 checkpoint 的 epoch,
saved_checkpoint=1。 - PSNR 刷新 best 的 epoch,
is_best=1。
14. 当前启动方式
当前使用:
bash run_restormer_fgp.sh
脚本内容等价于:
python train_restormer.py \
--config training.yml \
--model fgp_restormer \
--session RestormerFGP \
--save_dir ./checkpoint_restormer_fgp \
--loss_mode charbonnier_edge_fft \
--use_ema \
--val_every 5 \
--val_pad_factor 8 \
--save_every 5 \
--metrics_file metrics.csv \
--resume ./checkpoint_restormer_fgp/Deraining/models/RestormerFGP/model_5.pth
15. 当前版本的一句话概括
当前版本从原始 Restormer 改成了一个面向低光照去雨的结构感知频率专家 Restormer:
Plain Restormer
-> Structure-aware Mixture-of-Frequency Prior Restormer
-> Low-light / rain / structure prior guided multi-scale restoration
核心 novelty 可以概括为:
Structure-aware Mixture-of-Frequency Prior Guidance
for Low-Light Rainy Image Restoration