llir / train_restormer.py
linxin02's picture
Upload portable Low_light_rainy_new code export
4336727 verified
import argparse
import csv
import os
import random
import sys
import time
from pathlib import Path
import numpy as np
from PIL import Image
from tqdm import tqdm
from config import Config
def parse_args():
parser = argparse.ArgumentParser(description="Train the original Restormer model.")
parser.add_argument("--config", type=str, default="training.yml", help="Path to config yaml.")
parser.add_argument("--model", type=str, default="restormer", choices=["restormer", "fgp_restormer", "fgp_restormer_v2"], help="Model variant.")
parser.add_argument("--session", type=str, default="Restormer", help="Checkpoint session name.")
parser.add_argument("--save_dir", type=str, default="./checkpoint_restormer", help="Checkpoint root.")
parser.add_argument("--epochs", type=int, default=None, help="Override OPTIM.NUM_EPOCHS.")
parser.add_argument("--batch_size", type=int, default=None, help="Override OPTIM.BATCH_SIZE.")
parser.add_argument("--num_workers", type=int, default=8, help="Dataloader workers per process.")
parser.add_argument("--val_every", type=int, default=1, help="Validate every N epochs.")
parser.add_argument("--train_inp", type=str, default="./dataset/train/syn+real/input", help="Training low-light rainy input folder.")
parser.add_argument("--train_tar", type=str, default="./dataset/train/syn+real/target", help="Training clean target folder.")
parser.add_argument("--test_inp", type=str, default="./dataset/test/input", help="Test input folder for epoch validation.")
parser.add_argument("--test_tar", type=str, default="./dataset/test/target", help="Test target folder for epoch validation.")
parser.add_argument("--val_max_images", type=int, default=0, help="Limit validation images; <=0 means full test set.")
parser.add_argument("--val_pad_factor", type=int, default=8, help="Pad validation inputs to this factor before inference.")
parser.add_argument("--tta", action="store_true", help="Use flip/rotation self-ensemble during validation.")
parser.add_argument("--use_ema", action="store_true", help="Maintain EMA weights and validate/save with EMA model.")
parser.add_argument("--ema_decay", type=float, default=0.999, help="EMA decay.")
parser.add_argument("--loss_mode", type=str, default="l1", choices=["l1", "charbonnier", "charbonnier_edge", "charbonnier_edge_fft"], help="Training objective.")
parser.add_argument("--edge_weight", type=float, default=0.05, help="Edge loss weight.")
parser.add_argument("--fft_weight", type=float, default=0.01, help="FFT amplitude loss weight.")
parser.add_argument("--resume", type=str, default="", help="Path to a Restormer checkpoint.")
parser.add_argument("--save_every", type=int, default=5, help="Save epoch checkpoint every N epochs; model_best is always saved.")
parser.add_argument("--metrics_file", type=str, default="metrics.csv", help="Per-epoch metrics CSV filename under model_dir.")
parser.add_argument("--use_rain_structure_prior", action="store_true", help="V2: use 7-channel rain/structure disentanglement prior.")
parser.add_argument("--use_adaptive_temperature", action="store_true", help="V2: use learnable layer-wise routing temperature.")
parser.add_argument("--use_conflict_aware_routing", action="store_true", help="V2: enable conflict-aware expert routing.")
parser.add_argument("--expert_decorr_weight", type=float, default=0.0, help="V2: weight for expert decorrelation auxiliary loss.")
parser.add_argument("--max_train_steps", type=int, default=0, help="Debug only: stop each epoch after N steps.")
return parser.parse_args()
args = parse_args()
opt = Config(args.config)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in opt.GPU)
if "RANK" not in os.environ and "LOCAL_RANK" not in os.environ and len(opt.GPU) > 1:
import subprocess
env = os.environ.copy()
cmd = [sys.executable, "-m", "torch.distributed.run", "--nproc_per_node", str(len(opt.GPU)), sys.argv[0]] + sys.argv[1:]
sys.exit(subprocess.run(cmd, env=env).returncode)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.transforms import functional as TF
from restormer import Restormer
from utils.inference_utils import batch_rgb_psnr, run_model
IMAGE_EXTS = {".jpeg", ".jpg", ".png", ".bmp", ".tif", ".tiff", ".gif"}
class PairedImageDataset(torch.utils.data.Dataset):
def __init__(self, input_dir, target_dir):
self.input_dir = Path(input_dir)
self.target_dir = Path(target_dir)
if not self.input_dir.exists():
raise FileNotFoundError(f"Test input folder not found: {self.input_dir}")
if not self.target_dir.exists():
raise FileNotFoundError(f"Test target folder not found: {self.target_dir}")
input_names = sorted(p.name for p in self.input_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS)
target_names = {p.name for p in self.target_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS}
self.names = [name for name in input_names if name in target_names]
missing_targets = sorted(set(input_names) - target_names)
if missing_targets:
raise RuntimeError(f"{len(missing_targets)} test input images have no matching target, e.g. {missing_targets[:5]}")
if not self.names:
raise RuntimeError(f"No paired test images found in {self.input_dir} and {self.target_dir}")
def __len__(self):
return len(self.names)
def __getitem__(self, index):
name = self.names[index]
input_img = Image.open(self.input_dir / name).convert("RGB")
target_img = Image.open(self.target_dir / name).convert("RGB")
return TF.to_tensor(target_img), TF.to_tensor(input_img), name
class PairedPatchDataset(torch.utils.data.Dataset):
def __init__(self, input_dir, target_dir, patch_size):
self.input_dir = Path(input_dir)
self.target_dir = Path(target_dir)
self.patch_size = patch_size
if not self.input_dir.exists():
raise FileNotFoundError(f"Training input folder not found: {self.input_dir}")
if not self.target_dir.exists():
raise FileNotFoundError(f"Training target folder not found: {self.target_dir}")
input_names = sorted(p.name for p in self.input_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS)
target_names = {p.name for p in self.target_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS}
self.names = [name for name in input_names if name in target_names]
missing_targets = sorted(set(input_names) - target_names)
if missing_targets:
raise RuntimeError(f"{len(missing_targets)} training input images have no matching target, e.g. {missing_targets[:5]}")
if not self.names:
raise RuntimeError(f"No paired training images found in {self.input_dir} and {self.target_dir}")
def __len__(self):
return len(self.names)
def __getitem__(self, index):
name = self.names[index]
input_img = Image.open(self.input_dir / name).convert("RGB")
target_img = Image.open(self.target_dir / name).convert("RGB")
ps = self.patch_size
w, h = target_img.size
pad_w = max(ps - w, 0)
pad_h = max(ps - h, 0)
if pad_w > 0 or pad_h > 0:
input_img = TF.pad(input_img, (0, 0, pad_w, pad_h), padding_mode="reflect")
target_img = TF.pad(target_img, (0, 0, pad_w, pad_h), padding_mode="reflect")
input_img = TF.to_tensor(input_img)
target_img = TF.to_tensor(target_img)
_, h, w = target_img.shape
top = random.randint(0, h - ps)
left = random.randint(0, w - ps)
input_img = input_img[:, top : top + ps, left : left + ps]
target_img = target_img[:, top : top + ps, left : left + ps]
aug = random.randint(0, 7)
if aug == 1:
input_img, target_img = input_img.flip(1), target_img.flip(1)
elif aug == 2:
input_img, target_img = input_img.flip(2), target_img.flip(2)
elif aug == 3:
input_img, target_img = torch.rot90(input_img, dims=(1, 2)), torch.rot90(target_img, dims=(1, 2))
elif aug == 4:
input_img, target_img = torch.rot90(input_img, k=2, dims=(1, 2)), torch.rot90(target_img, k=2, dims=(1, 2))
elif aug == 5:
input_img, target_img = torch.rot90(input_img, k=3, dims=(1, 2)), torch.rot90(target_img, k=3, dims=(1, 2))
elif aug == 6:
input_img, target_img = torch.rot90(input_img.flip(1), dims=(1, 2)), torch.rot90(target_img.flip(1), dims=(1, 2))
elif aug == 7:
input_img, target_img = torch.rot90(input_img.flip(2), dims=(1, 2)), torch.rot90(target_img.flip(2), dims=(1, 2))
return target_img, input_img, name
def seed_everything(seed=1234):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def unwrap_model(model):
return model.module if hasattr(model, "module") else model
def build_model(model_name):
if model_name == "restormer":
return Restormer()
if model_name == "fgp_restormer":
from net.restormer_lowlight_rain import RestormerLowLightRain
return RestormerLowLightRain()
if model_name == "fgp_restormer_v2":
from net.restormer_lowlight_rain_v2 import RestormerLowLightRainV2
return RestormerLowLightRainV2(
use_rain_structure_prior=args.use_rain_structure_prior,
use_adaptive_temperature=args.use_adaptive_temperature,
use_conflict_aware_routing=args.use_conflict_aware_routing,
)
raise ValueError(f"Unsupported model: {model_name}")
def unpack_model_output(output):
if isinstance(output, tuple):
return output
return output, None
def expert_decorrelation_loss(aux):
if not aux or "fusion" not in aux:
return None
losses = [item["decorrelation"] for item in aux["fusion"] if item and "decorrelation" in item]
if not losses:
return None
return torch.stack(losses).mean()
def load_restormer_checkpoint(model, optimizer, scheduler, ckpt_path, device):
checkpoint = torch.load(ckpt_path, map_location=device)
state_dict = checkpoint.get("state_dict_G1", checkpoint.get("state_dict", checkpoint))
if "state_dict_ema" in checkpoint:
state_dict = checkpoint["state_dict_ema"]
cleaned = {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()}
unwrap_model(model).load_state_dict(cleaned, strict=True)
if optimizer is not None and "optimizer_G1" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer_G1"])
if scheduler is not None and "scheduler_G1" in checkpoint:
scheduler.load_state_dict(checkpoint["scheduler_G1"])
return int(checkpoint.get("epoch", 0)) + 1
def append_metrics_csv(path, row):
path = Path(path)
need_header = not path.exists()
with path.open("a", newline="") as f:
writer = csv.DictWriter(
f,
fieldnames=[
"epoch",
"train_loss",
"lr",
"val_psnr",
"best_psnr",
"epoch_time",
"saved_checkpoint",
"is_best",
],
)
if need_header:
writer.writeheader()
writer.writerow(row)
class ModelEma:
def __init__(self, model, decay=0.999):
import copy
self.module = copy.deepcopy(unwrap_model(model)).eval()
self.decay = decay
for param in self.module.parameters():
param.requires_grad_(False)
@torch.no_grad()
def update(self, model):
source = unwrap_model(model).state_dict()
target = self.module.state_dict()
for key, value in target.items():
if value.dtype.is_floating_point:
value.mul_(self.decay).add_(source[key].detach(), alpha=1.0 - self.decay)
else:
value.copy_(source[key])
class CharbonnierLoss(nn.Module):
def __init__(self, eps=1e-3):
super().__init__()
self.eps = eps
def forward(self, pred, target):
return torch.mean(torch.sqrt((pred - target) ** 2 + self.eps ** 2))
class RestorationLoss(nn.Module):
def __init__(self, mode="l1", edge_weight=0.05, fft_weight=0.01):
super().__init__()
self.mode = mode
self.edge_weight = edge_weight
self.fft_weight = fft_weight
self.base_l1 = nn.L1Loss()
self.base_charb = CharbonnierLoss()
kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3)
kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3)
self.register_buffer("kernel_x", kernel_x.repeat(3, 1, 1, 1))
self.register_buffer("kernel_y", kernel_y.repeat(3, 1, 1, 1))
def _edge(self, x):
gx = torch.nn.functional.conv2d(x, self.kernel_x, padding=1, groups=3)
gy = torch.nn.functional.conv2d(x, self.kernel_y, padding=1, groups=3)
return torch.sqrt(gx * gx + gy * gy + 1e-6)
def _fft_amp(self, x):
return torch.log1p(torch.abs(torch.fft.rfft2(x, norm="ortho")))
def forward(self, pred, target):
if self.mode == "l1":
return self.base_l1(pred, target)
loss = self.base_charb(pred, target)
if "edge" in self.mode:
loss = loss + self.edge_weight * self.base_l1(self._edge(pred), self._edge(target))
if "fft" in self.mode:
loss = loss + self.fft_weight * self.base_l1(self._fft_amp(pred), self._fft_amp(target))
return loss
@torch.no_grad()
def validate(model, val_loader, device, max_images, pad_factor=8, tta=False):
model.eval()
psnr_sum = 0.0
count = 0
for data in tqdm(val_loader, desc="Val", leave=False):
target = data[0].to(device, non_blocking=True)
input_img = data[1].to(device, non_blocking=True)
output = run_model(model, input_img, pad_factor=pad_factor, tta=tta).clamp(0, 1)
psnr_sum += batch_rgb_psnr(output, target) * output.size(0)
count += output.size(0)
if max_images > 0 and count >= max_images:
return psnr_sum / count
return psnr_sum / max(count, 1)
def main():
torch.backends.cudnn.benchmark = True
seed_everything()
use_ddp = "RANK" in os.environ or "LOCAL_RANK" in os.environ
if use_ddp:
torch.distributed.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
device = torch.device("cuda", local_rank)
else:
local_rank = 0
rank = 0
world_size = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = args.epochs or opt.OPTIM.NUM_EPOCHS
batch_size = args.batch_size or opt.OPTIM.BATCH_SIZE
model_dir = os.path.join(args.save_dir, "Deraining", "models", args.session)
os.makedirs(model_dir, exist_ok=True)
metrics_path = os.path.join(model_dir, args.metrics_file)
model = build_model(args.model).to(device)
if use_ddp:
model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
elif torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
optimizer = optim.Adam(model.parameters(), lr=opt.OPTIM.LR_INITIAL, betas=(0.9, 0.999), eps=1e-8)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=opt.OPTIM.LR_MIN)
criterion = RestorationLoss(args.loss_mode, edge_weight=args.edge_weight, fft_weight=args.fft_weight).to(device)
ema = ModelEma(model, decay=args.ema_decay) if args.use_ema and rank == 0 else None
start_epoch = 1
if args.resume:
start_epoch = load_restormer_checkpoint(model, optimizer, scheduler, args.resume, device)
if ema is not None:
ema.module.load_state_dict(unwrap_model(model).state_dict())
train_dataset = PairedPatchDataset(args.train_inp, args.train_tar, opt.TRAINING.TRAIN_PS)
val_dataset = PairedImageDataset(args.test_inp, args.test_tar)
if use_ddp:
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=False,
sampler=train_sampler,
num_workers=max(1, args.num_workers // world_size),
pin_memory=True,
drop_last=False,
)
else:
train_sampler = None
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=max(1, args.num_workers // 2), pin_memory=True)
if rank == 0:
print(f"==> Model: {args.model}")
print(f"==> Train on: {args.train_inp} -> {args.train_tar}")
print(f"==> Train images: {len(train_dataset)}, Test images for val: {len(val_dataset)}")
print(f"==> Val every epoch on: {args.test_inp} -> {args.test_tar}")
print(f"==> Loss: {args.loss_mode}, EMA: {args.use_ema}, TTA: {args.tta}")
print(f"==> Epochs: {epochs}, Batch size per process: {batch_size}, World size: {world_size}")
print(f"==> Checkpoints: {model_dir}")
print(f"==> Metrics CSV: {metrics_path}")
print(f"==> Save checkpoint every {args.save_every} epoch(s)")
best_psnr = 0.0
for epoch in range(start_epoch, epochs + 1):
if train_sampler is not None:
train_sampler.set_epoch(epoch)
model.train()
epoch_loss = 0.0
train_steps = 0
epoch_start = time.time()
progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", disable=(rank != 0))
for step, data in enumerate(progress, start=1):
train_steps = step
target = data[0].to(device, non_blocking=True)
input_img = data[1].to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
output, aux = unpack_model_output(model(input_img))
loss = criterion(output, target)
if args.expert_decorr_weight > 0:
decorr = expert_decorrelation_loss(aux)
if decorr is not None:
loss = loss + args.expert_decorr_weight * decorr
loss.backward()
optimizer.step()
if ema is not None:
ema.update(model)
epoch_loss += loss.item()
progress.set_postfix(loss=f"{epoch_loss / step:.4f}")
if args.max_train_steps > 0 and step >= args.max_train_steps:
break
scheduler.step()
if rank == 0:
avg_loss = epoch_loss / max(train_steps, 1)
cur_lr = scheduler.get_last_lr()[0]
epoch_time = time.time() - epoch_start
print(f"Epoch {epoch}: loss={avg_loss:.4f}, lr={cur_lr:.2e}, time={epoch_time:.1f}s")
ckpt = {
"epoch": epoch,
"state_dict_G1": unwrap_model(model).state_dict(),
"optimizer_G1": optimizer.state_dict(),
"scheduler_G1": scheduler.state_dict(),
"model": args.model,
"loss_mode": args.loss_mode,
}
if ema is not None:
ckpt["state_dict_ema"] = ema.module.state_dict()
val_psnr = ""
is_best = False
if args.val_every > 0 and epoch % args.val_every == 0:
eval_model = ema.module if ema is not None else unwrap_model(model)
psnr = validate(eval_model, val_loader, device, args.val_max_images, pad_factor=args.val_pad_factor, tta=args.tta)
val_psnr = psnr
print(f"Val PSNR: {psnr:.4f}")
if psnr > best_psnr:
best_psnr = psnr
is_best = True
torch.save(ckpt, os.path.join(model_dir, "model_best.pth"))
print(f"Saved best checkpoint: PSNR={best_psnr:.4f}")
saved_checkpoint = False
if args.save_every > 0 and epoch % args.save_every == 0:
torch.save(ckpt, os.path.join(model_dir, f"model_{epoch}.pth"))
saved_checkpoint = True
print(f"Saved epoch checkpoint: model_{epoch}.pth")
append_metrics_csv(
metrics_path,
{
"epoch": epoch,
"train_loss": f"{avg_loss:.6f}",
"lr": f"{cur_lr:.8e}",
"val_psnr": f"{val_psnr:.6f}" if val_psnr != "" else "",
"best_psnr": f"{best_psnr:.6f}",
"epoch_time": f"{epoch_time:.2f}",
"saved_checkpoint": int(saved_checkpoint),
"is_best": int(is_best),
},
)
if use_ddp:
torch.distributed.barrier()
if __name__ == "__main__":
main()