| import argparse |
| import csv |
| import os |
| import random |
| import sys |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| from PIL import Image |
| from tqdm import tqdm |
|
|
| from config import Config |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train the original Restormer model.") |
| parser.add_argument("--config", type=str, default="training.yml", help="Path to config yaml.") |
| parser.add_argument("--model", type=str, default="restormer", choices=["restormer", "fgp_restormer", "fgp_restormer_v2"], help="Model variant.") |
| parser.add_argument("--session", type=str, default="Restormer", help="Checkpoint session name.") |
| parser.add_argument("--save_dir", type=str, default="./checkpoint_restormer", help="Checkpoint root.") |
| parser.add_argument("--epochs", type=int, default=None, help="Override OPTIM.NUM_EPOCHS.") |
| parser.add_argument("--batch_size", type=int, default=None, help="Override OPTIM.BATCH_SIZE.") |
| parser.add_argument("--num_workers", type=int, default=8, help="Dataloader workers per process.") |
| parser.add_argument("--val_every", type=int, default=1, help="Validate every N epochs.") |
| parser.add_argument("--train_inp", type=str, default="./dataset/train/syn+real/input", help="Training low-light rainy input folder.") |
| parser.add_argument("--train_tar", type=str, default="./dataset/train/syn+real/target", help="Training clean target folder.") |
| parser.add_argument("--test_inp", type=str, default="./dataset/test/input", help="Test input folder for epoch validation.") |
| parser.add_argument("--test_tar", type=str, default="./dataset/test/target", help="Test target folder for epoch validation.") |
| parser.add_argument("--val_max_images", type=int, default=0, help="Limit validation images; <=0 means full test set.") |
| parser.add_argument("--val_pad_factor", type=int, default=8, help="Pad validation inputs to this factor before inference.") |
| parser.add_argument("--tta", action="store_true", help="Use flip/rotation self-ensemble during validation.") |
| parser.add_argument("--use_ema", action="store_true", help="Maintain EMA weights and validate/save with EMA model.") |
| parser.add_argument("--ema_decay", type=float, default=0.999, help="EMA decay.") |
| parser.add_argument("--loss_mode", type=str, default="l1", choices=["l1", "charbonnier", "charbonnier_edge", "charbonnier_edge_fft"], help="Training objective.") |
| parser.add_argument("--edge_weight", type=float, default=0.05, help="Edge loss weight.") |
| parser.add_argument("--fft_weight", type=float, default=0.01, help="FFT amplitude loss weight.") |
| parser.add_argument("--resume", type=str, default="", help="Path to a Restormer checkpoint.") |
| parser.add_argument("--save_every", type=int, default=5, help="Save epoch checkpoint every N epochs; model_best is always saved.") |
| parser.add_argument("--metrics_file", type=str, default="metrics.csv", help="Per-epoch metrics CSV filename under model_dir.") |
| parser.add_argument("--use_rain_structure_prior", action="store_true", help="V2: use 7-channel rain/structure disentanglement prior.") |
| parser.add_argument("--use_adaptive_temperature", action="store_true", help="V2: use learnable layer-wise routing temperature.") |
| parser.add_argument("--use_conflict_aware_routing", action="store_true", help="V2: enable conflict-aware expert routing.") |
| parser.add_argument("--expert_decorr_weight", type=float, default=0.0, help="V2: weight for expert decorrelation auxiliary loss.") |
| parser.add_argument("--max_train_steps", type=int, default=0, help="Debug only: stop each epoch after N steps.") |
| return parser.parse_args() |
|
|
|
|
| args = parse_args() |
| opt = Config(args.config) |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in opt.GPU) |
|
|
| if "RANK" not in os.environ and "LOCAL_RANK" not in os.environ and len(opt.GPU) > 1: |
| import subprocess |
|
|
| env = os.environ.copy() |
| cmd = [sys.executable, "-m", "torch.distributed.run", "--nproc_per_node", str(len(opt.GPU)), sys.argv[0]] + sys.argv[1:] |
| sys.exit(subprocess.run(cmd, env=env).returncode) |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| from torchvision.transforms import functional as TF |
|
|
| from restormer import Restormer |
| from utils.inference_utils import batch_rgb_psnr, run_model |
|
|
|
|
| IMAGE_EXTS = {".jpeg", ".jpg", ".png", ".bmp", ".tif", ".tiff", ".gif"} |
|
|
|
|
| class PairedImageDataset(torch.utils.data.Dataset): |
| def __init__(self, input_dir, target_dir): |
| self.input_dir = Path(input_dir) |
| self.target_dir = Path(target_dir) |
| if not self.input_dir.exists(): |
| raise FileNotFoundError(f"Test input folder not found: {self.input_dir}") |
| if not self.target_dir.exists(): |
| raise FileNotFoundError(f"Test target folder not found: {self.target_dir}") |
|
|
| input_names = sorted(p.name for p in self.input_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS) |
| target_names = {p.name for p in self.target_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS} |
| self.names = [name for name in input_names if name in target_names] |
|
|
| missing_targets = sorted(set(input_names) - target_names) |
| if missing_targets: |
| raise RuntimeError(f"{len(missing_targets)} test input images have no matching target, e.g. {missing_targets[:5]}") |
| if not self.names: |
| raise RuntimeError(f"No paired test images found in {self.input_dir} and {self.target_dir}") |
|
|
| def __len__(self): |
| return len(self.names) |
|
|
| def __getitem__(self, index): |
| name = self.names[index] |
| input_img = Image.open(self.input_dir / name).convert("RGB") |
| target_img = Image.open(self.target_dir / name).convert("RGB") |
| return TF.to_tensor(target_img), TF.to_tensor(input_img), name |
|
|
|
|
| class PairedPatchDataset(torch.utils.data.Dataset): |
| def __init__(self, input_dir, target_dir, patch_size): |
| self.input_dir = Path(input_dir) |
| self.target_dir = Path(target_dir) |
| self.patch_size = patch_size |
| if not self.input_dir.exists(): |
| raise FileNotFoundError(f"Training input folder not found: {self.input_dir}") |
| if not self.target_dir.exists(): |
| raise FileNotFoundError(f"Training target folder not found: {self.target_dir}") |
|
|
| input_names = sorted(p.name for p in self.input_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS) |
| target_names = {p.name for p in self.target_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS} |
| self.names = [name for name in input_names if name in target_names] |
| missing_targets = sorted(set(input_names) - target_names) |
| if missing_targets: |
| raise RuntimeError(f"{len(missing_targets)} training input images have no matching target, e.g. {missing_targets[:5]}") |
| if not self.names: |
| raise RuntimeError(f"No paired training images found in {self.input_dir} and {self.target_dir}") |
|
|
| def __len__(self): |
| return len(self.names) |
|
|
| def __getitem__(self, index): |
| name = self.names[index] |
| input_img = Image.open(self.input_dir / name).convert("RGB") |
| target_img = Image.open(self.target_dir / name).convert("RGB") |
| ps = self.patch_size |
| w, h = target_img.size |
| pad_w = max(ps - w, 0) |
| pad_h = max(ps - h, 0) |
| if pad_w > 0 or pad_h > 0: |
| input_img = TF.pad(input_img, (0, 0, pad_w, pad_h), padding_mode="reflect") |
| target_img = TF.pad(target_img, (0, 0, pad_w, pad_h), padding_mode="reflect") |
|
|
| input_img = TF.to_tensor(input_img) |
| target_img = TF.to_tensor(target_img) |
| _, h, w = target_img.shape |
| top = random.randint(0, h - ps) |
| left = random.randint(0, w - ps) |
| input_img = input_img[:, top : top + ps, left : left + ps] |
| target_img = target_img[:, top : top + ps, left : left + ps] |
|
|
| aug = random.randint(0, 7) |
| if aug == 1: |
| input_img, target_img = input_img.flip(1), target_img.flip(1) |
| elif aug == 2: |
| input_img, target_img = input_img.flip(2), target_img.flip(2) |
| elif aug == 3: |
| input_img, target_img = torch.rot90(input_img, dims=(1, 2)), torch.rot90(target_img, dims=(1, 2)) |
| elif aug == 4: |
| input_img, target_img = torch.rot90(input_img, k=2, dims=(1, 2)), torch.rot90(target_img, k=2, dims=(1, 2)) |
| elif aug == 5: |
| input_img, target_img = torch.rot90(input_img, k=3, dims=(1, 2)), torch.rot90(target_img, k=3, dims=(1, 2)) |
| elif aug == 6: |
| input_img, target_img = torch.rot90(input_img.flip(1), dims=(1, 2)), torch.rot90(target_img.flip(1), dims=(1, 2)) |
| elif aug == 7: |
| input_img, target_img = torch.rot90(input_img.flip(2), dims=(1, 2)), torch.rot90(target_img.flip(2), dims=(1, 2)) |
|
|
| return target_img, input_img, name |
|
|
|
|
| def seed_everything(seed=1234): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def unwrap_model(model): |
| return model.module if hasattr(model, "module") else model |
|
|
|
|
| def build_model(model_name): |
| if model_name == "restormer": |
| return Restormer() |
| if model_name == "fgp_restormer": |
| from net.restormer_lowlight_rain import RestormerLowLightRain |
|
|
| return RestormerLowLightRain() |
| if model_name == "fgp_restormer_v2": |
| from net.restormer_lowlight_rain_v2 import RestormerLowLightRainV2 |
|
|
| return RestormerLowLightRainV2( |
| use_rain_structure_prior=args.use_rain_structure_prior, |
| use_adaptive_temperature=args.use_adaptive_temperature, |
| use_conflict_aware_routing=args.use_conflict_aware_routing, |
| ) |
| raise ValueError(f"Unsupported model: {model_name}") |
|
|
|
|
| def unpack_model_output(output): |
| if isinstance(output, tuple): |
| return output |
| return output, None |
|
|
|
|
| def expert_decorrelation_loss(aux): |
| if not aux or "fusion" not in aux: |
| return None |
| losses = [item["decorrelation"] for item in aux["fusion"] if item and "decorrelation" in item] |
| if not losses: |
| return None |
| return torch.stack(losses).mean() |
|
|
|
|
| def load_restormer_checkpoint(model, optimizer, scheduler, ckpt_path, device): |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| state_dict = checkpoint.get("state_dict_G1", checkpoint.get("state_dict", checkpoint)) |
| if "state_dict_ema" in checkpoint: |
| state_dict = checkpoint["state_dict_ema"] |
| cleaned = {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()} |
| unwrap_model(model).load_state_dict(cleaned, strict=True) |
|
|
| if optimizer is not None and "optimizer_G1" in checkpoint: |
| optimizer.load_state_dict(checkpoint["optimizer_G1"]) |
| if scheduler is not None and "scheduler_G1" in checkpoint: |
| scheduler.load_state_dict(checkpoint["scheduler_G1"]) |
|
|
| return int(checkpoint.get("epoch", 0)) + 1 |
|
|
|
|
| def append_metrics_csv(path, row): |
| path = Path(path) |
| need_header = not path.exists() |
| with path.open("a", newline="") as f: |
| writer = csv.DictWriter( |
| f, |
| fieldnames=[ |
| "epoch", |
| "train_loss", |
| "lr", |
| "val_psnr", |
| "best_psnr", |
| "epoch_time", |
| "saved_checkpoint", |
| "is_best", |
| ], |
| ) |
| if need_header: |
| writer.writeheader() |
| writer.writerow(row) |
|
|
|
|
| class ModelEma: |
| def __init__(self, model, decay=0.999): |
| import copy |
|
|
| self.module = copy.deepcopy(unwrap_model(model)).eval() |
| self.decay = decay |
| for param in self.module.parameters(): |
| param.requires_grad_(False) |
|
|
| @torch.no_grad() |
| def update(self, model): |
| source = unwrap_model(model).state_dict() |
| target = self.module.state_dict() |
| for key, value in target.items(): |
| if value.dtype.is_floating_point: |
| value.mul_(self.decay).add_(source[key].detach(), alpha=1.0 - self.decay) |
| else: |
| value.copy_(source[key]) |
|
|
|
|
| class CharbonnierLoss(nn.Module): |
| def __init__(self, eps=1e-3): |
| super().__init__() |
| self.eps = eps |
|
|
| def forward(self, pred, target): |
| return torch.mean(torch.sqrt((pred - target) ** 2 + self.eps ** 2)) |
|
|
|
|
| class RestorationLoss(nn.Module): |
| def __init__(self, mode="l1", edge_weight=0.05, fft_weight=0.01): |
| super().__init__() |
| self.mode = mode |
| self.edge_weight = edge_weight |
| self.fft_weight = fft_weight |
| self.base_l1 = nn.L1Loss() |
| self.base_charb = CharbonnierLoss() |
| kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3) |
| kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3) |
| self.register_buffer("kernel_x", kernel_x.repeat(3, 1, 1, 1)) |
| self.register_buffer("kernel_y", kernel_y.repeat(3, 1, 1, 1)) |
|
|
| def _edge(self, x): |
| gx = torch.nn.functional.conv2d(x, self.kernel_x, padding=1, groups=3) |
| gy = torch.nn.functional.conv2d(x, self.kernel_y, padding=1, groups=3) |
| return torch.sqrt(gx * gx + gy * gy + 1e-6) |
|
|
| def _fft_amp(self, x): |
| return torch.log1p(torch.abs(torch.fft.rfft2(x, norm="ortho"))) |
|
|
| def forward(self, pred, target): |
| if self.mode == "l1": |
| return self.base_l1(pred, target) |
|
|
| loss = self.base_charb(pred, target) |
| if "edge" in self.mode: |
| loss = loss + self.edge_weight * self.base_l1(self._edge(pred), self._edge(target)) |
| if "fft" in self.mode: |
| loss = loss + self.fft_weight * self.base_l1(self._fft_amp(pred), self._fft_amp(target)) |
| return loss |
|
|
|
|
| @torch.no_grad() |
| def validate(model, val_loader, device, max_images, pad_factor=8, tta=False): |
| model.eval() |
| psnr_sum = 0.0 |
| count = 0 |
|
|
| for data in tqdm(val_loader, desc="Val", leave=False): |
| target = data[0].to(device, non_blocking=True) |
| input_img = data[1].to(device, non_blocking=True) |
| output = run_model(model, input_img, pad_factor=pad_factor, tta=tta).clamp(0, 1) |
|
|
| psnr_sum += batch_rgb_psnr(output, target) * output.size(0) |
| count += output.size(0) |
| if max_images > 0 and count >= max_images: |
| return psnr_sum / count |
|
|
| return psnr_sum / max(count, 1) |
|
|
|
|
| def main(): |
| torch.backends.cudnn.benchmark = True |
| seed_everything() |
|
|
| use_ddp = "RANK" in os.environ or "LOCAL_RANK" in os.environ |
| if use_ddp: |
| torch.distributed.init_process_group(backend="nccl") |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| torch.cuda.set_device(local_rank) |
| rank = torch.distributed.get_rank() |
| world_size = torch.distributed.get_world_size() |
| device = torch.device("cuda", local_rank) |
| else: |
| local_rank = 0 |
| rank = 0 |
| world_size = 1 |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| epochs = args.epochs or opt.OPTIM.NUM_EPOCHS |
| batch_size = args.batch_size or opt.OPTIM.BATCH_SIZE |
|
|
| model_dir = os.path.join(args.save_dir, "Deraining", "models", args.session) |
| os.makedirs(model_dir, exist_ok=True) |
| metrics_path = os.path.join(model_dir, args.metrics_file) |
|
|
| model = build_model(args.model).to(device) |
| if use_ddp: |
| model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) |
| elif torch.cuda.device_count() > 1: |
| model = nn.DataParallel(model) |
|
|
| optimizer = optim.Adam(model.parameters(), lr=opt.OPTIM.LR_INITIAL, betas=(0.9, 0.999), eps=1e-8) |
| scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=opt.OPTIM.LR_MIN) |
| criterion = RestorationLoss(args.loss_mode, edge_weight=args.edge_weight, fft_weight=args.fft_weight).to(device) |
| ema = ModelEma(model, decay=args.ema_decay) if args.use_ema and rank == 0 else None |
|
|
| start_epoch = 1 |
| if args.resume: |
| start_epoch = load_restormer_checkpoint(model, optimizer, scheduler, args.resume, device) |
| if ema is not None: |
| ema.module.load_state_dict(unwrap_model(model).state_dict()) |
|
|
| train_dataset = PairedPatchDataset(args.train_inp, args.train_tar, opt.TRAINING.TRAIN_PS) |
| val_dataset = PairedImageDataset(args.test_inp, args.test_tar) |
|
|
| if use_ddp: |
| train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| sampler=train_sampler, |
| num_workers=max(1, args.num_workers // world_size), |
| pin_memory=True, |
| drop_last=False, |
| ) |
| else: |
| train_sampler = None |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| drop_last=False, |
| ) |
|
|
| val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=max(1, args.num_workers // 2), pin_memory=True) |
|
|
| if rank == 0: |
| print(f"==> Model: {args.model}") |
| print(f"==> Train on: {args.train_inp} -> {args.train_tar}") |
| print(f"==> Train images: {len(train_dataset)}, Test images for val: {len(val_dataset)}") |
| print(f"==> Val every epoch on: {args.test_inp} -> {args.test_tar}") |
| print(f"==> Loss: {args.loss_mode}, EMA: {args.use_ema}, TTA: {args.tta}") |
| print(f"==> Epochs: {epochs}, Batch size per process: {batch_size}, World size: {world_size}") |
| print(f"==> Checkpoints: {model_dir}") |
| print(f"==> Metrics CSV: {metrics_path}") |
| print(f"==> Save checkpoint every {args.save_every} epoch(s)") |
|
|
| best_psnr = 0.0 |
| for epoch in range(start_epoch, epochs + 1): |
| if train_sampler is not None: |
| train_sampler.set_epoch(epoch) |
|
|
| model.train() |
| epoch_loss = 0.0 |
| train_steps = 0 |
| epoch_start = time.time() |
| progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", disable=(rank != 0)) |
|
|
| for step, data in enumerate(progress, start=1): |
| train_steps = step |
| target = data[0].to(device, non_blocking=True) |
| input_img = data[1].to(device, non_blocking=True) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| output, aux = unpack_model_output(model(input_img)) |
| loss = criterion(output, target) |
| if args.expert_decorr_weight > 0: |
| decorr = expert_decorrelation_loss(aux) |
| if decorr is not None: |
| loss = loss + args.expert_decorr_weight * decorr |
| loss.backward() |
| optimizer.step() |
| if ema is not None: |
| ema.update(model) |
|
|
| epoch_loss += loss.item() |
| progress.set_postfix(loss=f"{epoch_loss / step:.4f}") |
|
|
| if args.max_train_steps > 0 and step >= args.max_train_steps: |
| break |
|
|
| scheduler.step() |
|
|
| if rank == 0: |
| avg_loss = epoch_loss / max(train_steps, 1) |
| cur_lr = scheduler.get_last_lr()[0] |
| epoch_time = time.time() - epoch_start |
| print(f"Epoch {epoch}: loss={avg_loss:.4f}, lr={cur_lr:.2e}, time={epoch_time:.1f}s") |
|
|
| ckpt = { |
| "epoch": epoch, |
| "state_dict_G1": unwrap_model(model).state_dict(), |
| "optimizer_G1": optimizer.state_dict(), |
| "scheduler_G1": scheduler.state_dict(), |
| "model": args.model, |
| "loss_mode": args.loss_mode, |
| } |
| if ema is not None: |
| ckpt["state_dict_ema"] = ema.module.state_dict() |
|
|
| val_psnr = "" |
| is_best = False |
| if args.val_every > 0 and epoch % args.val_every == 0: |
| eval_model = ema.module if ema is not None else unwrap_model(model) |
| psnr = validate(eval_model, val_loader, device, args.val_max_images, pad_factor=args.val_pad_factor, tta=args.tta) |
| val_psnr = psnr |
| print(f"Val PSNR: {psnr:.4f}") |
| if psnr > best_psnr: |
| best_psnr = psnr |
| is_best = True |
| torch.save(ckpt, os.path.join(model_dir, "model_best.pth")) |
| print(f"Saved best checkpoint: PSNR={best_psnr:.4f}") |
|
|
| saved_checkpoint = False |
| if args.save_every > 0 and epoch % args.save_every == 0: |
| torch.save(ckpt, os.path.join(model_dir, f"model_{epoch}.pth")) |
| saved_checkpoint = True |
| print(f"Saved epoch checkpoint: model_{epoch}.pth") |
|
|
| append_metrics_csv( |
| metrics_path, |
| { |
| "epoch": epoch, |
| "train_loss": f"{avg_loss:.6f}", |
| "lr": f"{cur_lr:.8e}", |
| "val_psnr": f"{val_psnr:.6f}" if val_psnr != "" else "", |
| "best_psnr": f"{best_psnr:.6f}", |
| "epoch_time": f"{epoch_time:.2f}", |
| "saved_checkpoint": int(saved_checkpoint), |
| "is_best": int(is_best), |
| }, |
| ) |
|
|
| if use_ddp: |
| torch.distributed.barrier() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|