import argparse import csv import os import random import sys import time from pathlib import Path import numpy as np from PIL import Image from tqdm import tqdm from config import Config def parse_args(): parser = argparse.ArgumentParser(description="Train the original Restormer model.") parser.add_argument("--config", type=str, default="training.yml", help="Path to config yaml.") parser.add_argument("--model", type=str, default="restormer", choices=["restormer", "fgp_restormer", "fgp_restormer_v2"], help="Model variant.") parser.add_argument("--session", type=str, default="Restormer", help="Checkpoint session name.") parser.add_argument("--save_dir", type=str, default="./checkpoint_restormer", help="Checkpoint root.") parser.add_argument("--epochs", type=int, default=None, help="Override OPTIM.NUM_EPOCHS.") parser.add_argument("--batch_size", type=int, default=None, help="Override OPTIM.BATCH_SIZE.") parser.add_argument("--num_workers", type=int, default=8, help="Dataloader workers per process.") parser.add_argument("--val_every", type=int, default=1, help="Validate every N epochs.") parser.add_argument("--train_inp", type=str, default="./dataset/train/syn+real/input", help="Training low-light rainy input folder.") parser.add_argument("--train_tar", type=str, default="./dataset/train/syn+real/target", help="Training clean target folder.") parser.add_argument("--test_inp", type=str, default="./dataset/test/input", help="Test input folder for epoch validation.") parser.add_argument("--test_tar", type=str, default="./dataset/test/target", help="Test target folder for epoch validation.") parser.add_argument("--val_max_images", type=int, default=0, help="Limit validation images; <=0 means full test set.") parser.add_argument("--val_pad_factor", type=int, default=8, help="Pad validation inputs to this factor before inference.") parser.add_argument("--tta", action="store_true", help="Use flip/rotation self-ensemble during validation.") parser.add_argument("--use_ema", action="store_true", help="Maintain EMA weights and validate/save with EMA model.") parser.add_argument("--ema_decay", type=float, default=0.999, help="EMA decay.") parser.add_argument("--loss_mode", type=str, default="l1", choices=["l1", "charbonnier", "charbonnier_edge", "charbonnier_edge_fft"], help="Training objective.") parser.add_argument("--edge_weight", type=float, default=0.05, help="Edge loss weight.") parser.add_argument("--fft_weight", type=float, default=0.01, help="FFT amplitude loss weight.") parser.add_argument("--resume", type=str, default="", help="Path to a Restormer checkpoint.") parser.add_argument("--save_every", type=int, default=5, help="Save epoch checkpoint every N epochs; model_best is always saved.") parser.add_argument("--metrics_file", type=str, default="metrics.csv", help="Per-epoch metrics CSV filename under model_dir.") parser.add_argument("--use_rain_structure_prior", action="store_true", help="V2: use 7-channel rain/structure disentanglement prior.") parser.add_argument("--use_adaptive_temperature", action="store_true", help="V2: use learnable layer-wise routing temperature.") parser.add_argument("--use_conflict_aware_routing", action="store_true", help="V2: enable conflict-aware expert routing.") parser.add_argument("--expert_decorr_weight", type=float, default=0.0, help="V2: weight for expert decorrelation auxiliary loss.") parser.add_argument("--max_train_steps", type=int, default=0, help="Debug only: stop each epoch after N steps.") return parser.parse_args() args = parse_args() opt = Config(args.config) os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in opt.GPU) if "RANK" not in os.environ and "LOCAL_RANK" not in os.environ and len(opt.GPU) > 1: import subprocess env = os.environ.copy() cmd = [sys.executable, "-m", "torch.distributed.run", "--nproc_per_node", str(len(opt.GPU)), sys.argv[0]] + sys.argv[1:] sys.exit(subprocess.run(cmd, env=env).returncode) import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torchvision.transforms import functional as TF from restormer import Restormer from utils.inference_utils import batch_rgb_psnr, run_model IMAGE_EXTS = {".jpeg", ".jpg", ".png", ".bmp", ".tif", ".tiff", ".gif"} class PairedImageDataset(torch.utils.data.Dataset): def __init__(self, input_dir, target_dir): self.input_dir = Path(input_dir) self.target_dir = Path(target_dir) if not self.input_dir.exists(): raise FileNotFoundError(f"Test input folder not found: {self.input_dir}") if not self.target_dir.exists(): raise FileNotFoundError(f"Test target folder not found: {self.target_dir}") input_names = sorted(p.name for p in self.input_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS) target_names = {p.name for p in self.target_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS} self.names = [name for name in input_names if name in target_names] missing_targets = sorted(set(input_names) - target_names) if missing_targets: raise RuntimeError(f"{len(missing_targets)} test input images have no matching target, e.g. {missing_targets[:5]}") if not self.names: raise RuntimeError(f"No paired test images found in {self.input_dir} and {self.target_dir}") def __len__(self): return len(self.names) def __getitem__(self, index): name = self.names[index] input_img = Image.open(self.input_dir / name).convert("RGB") target_img = Image.open(self.target_dir / name).convert("RGB") return TF.to_tensor(target_img), TF.to_tensor(input_img), name class PairedPatchDataset(torch.utils.data.Dataset): def __init__(self, input_dir, target_dir, patch_size): self.input_dir = Path(input_dir) self.target_dir = Path(target_dir) self.patch_size = patch_size if not self.input_dir.exists(): raise FileNotFoundError(f"Training input folder not found: {self.input_dir}") if not self.target_dir.exists(): raise FileNotFoundError(f"Training target folder not found: {self.target_dir}") input_names = sorted(p.name for p in self.input_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS) target_names = {p.name for p in self.target_dir.iterdir() if p.is_file() and p.suffix.lower() in IMAGE_EXTS} self.names = [name for name in input_names if name in target_names] missing_targets = sorted(set(input_names) - target_names) if missing_targets: raise RuntimeError(f"{len(missing_targets)} training input images have no matching target, e.g. {missing_targets[:5]}") if not self.names: raise RuntimeError(f"No paired training images found in {self.input_dir} and {self.target_dir}") def __len__(self): return len(self.names) def __getitem__(self, index): name = self.names[index] input_img = Image.open(self.input_dir / name).convert("RGB") target_img = Image.open(self.target_dir / name).convert("RGB") ps = self.patch_size w, h = target_img.size pad_w = max(ps - w, 0) pad_h = max(ps - h, 0) if pad_w > 0 or pad_h > 0: input_img = TF.pad(input_img, (0, 0, pad_w, pad_h), padding_mode="reflect") target_img = TF.pad(target_img, (0, 0, pad_w, pad_h), padding_mode="reflect") input_img = TF.to_tensor(input_img) target_img = TF.to_tensor(target_img) _, h, w = target_img.shape top = random.randint(0, h - ps) left = random.randint(0, w - ps) input_img = input_img[:, top : top + ps, left : left + ps] target_img = target_img[:, top : top + ps, left : left + ps] aug = random.randint(0, 7) if aug == 1: input_img, target_img = input_img.flip(1), target_img.flip(1) elif aug == 2: input_img, target_img = input_img.flip(2), target_img.flip(2) elif aug == 3: input_img, target_img = torch.rot90(input_img, dims=(1, 2)), torch.rot90(target_img, dims=(1, 2)) elif aug == 4: input_img, target_img = torch.rot90(input_img, k=2, dims=(1, 2)), torch.rot90(target_img, k=2, dims=(1, 2)) elif aug == 5: input_img, target_img = torch.rot90(input_img, k=3, dims=(1, 2)), torch.rot90(target_img, k=3, dims=(1, 2)) elif aug == 6: input_img, target_img = torch.rot90(input_img.flip(1), dims=(1, 2)), torch.rot90(target_img.flip(1), dims=(1, 2)) elif aug == 7: input_img, target_img = torch.rot90(input_img.flip(2), dims=(1, 2)), torch.rot90(target_img.flip(2), dims=(1, 2)) return target_img, input_img, name def seed_everything(seed=1234): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def unwrap_model(model): return model.module if hasattr(model, "module") else model def build_model(model_name): if model_name == "restormer": return Restormer() if model_name == "fgp_restormer": from net.restormer_lowlight_rain import RestormerLowLightRain return RestormerLowLightRain() if model_name == "fgp_restormer_v2": from net.restormer_lowlight_rain_v2 import RestormerLowLightRainV2 return RestormerLowLightRainV2( use_rain_structure_prior=args.use_rain_structure_prior, use_adaptive_temperature=args.use_adaptive_temperature, use_conflict_aware_routing=args.use_conflict_aware_routing, ) raise ValueError(f"Unsupported model: {model_name}") def unpack_model_output(output): if isinstance(output, tuple): return output return output, None def expert_decorrelation_loss(aux): if not aux or "fusion" not in aux: return None losses = [item["decorrelation"] for item in aux["fusion"] if item and "decorrelation" in item] if not losses: return None return torch.stack(losses).mean() def load_restormer_checkpoint(model, optimizer, scheduler, ckpt_path, device): checkpoint = torch.load(ckpt_path, map_location=device) state_dict = checkpoint.get("state_dict_G1", checkpoint.get("state_dict", checkpoint)) if "state_dict_ema" in checkpoint: state_dict = checkpoint["state_dict_ema"] cleaned = {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()} unwrap_model(model).load_state_dict(cleaned, strict=True) if optimizer is not None and "optimizer_G1" in checkpoint: optimizer.load_state_dict(checkpoint["optimizer_G1"]) if scheduler is not None and "scheduler_G1" in checkpoint: scheduler.load_state_dict(checkpoint["scheduler_G1"]) return int(checkpoint.get("epoch", 0)) + 1 def append_metrics_csv(path, row): path = Path(path) need_header = not path.exists() with path.open("a", newline="") as f: writer = csv.DictWriter( f, fieldnames=[ "epoch", "train_loss", "lr", "val_psnr", "best_psnr", "epoch_time", "saved_checkpoint", "is_best", ], ) if need_header: writer.writeheader() writer.writerow(row) class ModelEma: def __init__(self, model, decay=0.999): import copy self.module = copy.deepcopy(unwrap_model(model)).eval() self.decay = decay for param in self.module.parameters(): param.requires_grad_(False) @torch.no_grad() def update(self, model): source = unwrap_model(model).state_dict() target = self.module.state_dict() for key, value in target.items(): if value.dtype.is_floating_point: value.mul_(self.decay).add_(source[key].detach(), alpha=1.0 - self.decay) else: value.copy_(source[key]) class CharbonnierLoss(nn.Module): def __init__(self, eps=1e-3): super().__init__() self.eps = eps def forward(self, pred, target): return torch.mean(torch.sqrt((pred - target) ** 2 + self.eps ** 2)) class RestorationLoss(nn.Module): def __init__(self, mode="l1", edge_weight=0.05, fft_weight=0.01): super().__init__() self.mode = mode self.edge_weight = edge_weight self.fft_weight = fft_weight self.base_l1 = nn.L1Loss() self.base_charb = CharbonnierLoss() kernel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32).view(1, 1, 3, 3) kernel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32).view(1, 1, 3, 3) self.register_buffer("kernel_x", kernel_x.repeat(3, 1, 1, 1)) self.register_buffer("kernel_y", kernel_y.repeat(3, 1, 1, 1)) def _edge(self, x): gx = torch.nn.functional.conv2d(x, self.kernel_x, padding=1, groups=3) gy = torch.nn.functional.conv2d(x, self.kernel_y, padding=1, groups=3) return torch.sqrt(gx * gx + gy * gy + 1e-6) def _fft_amp(self, x): return torch.log1p(torch.abs(torch.fft.rfft2(x, norm="ortho"))) def forward(self, pred, target): if self.mode == "l1": return self.base_l1(pred, target) loss = self.base_charb(pred, target) if "edge" in self.mode: loss = loss + self.edge_weight * self.base_l1(self._edge(pred), self._edge(target)) if "fft" in self.mode: loss = loss + self.fft_weight * self.base_l1(self._fft_amp(pred), self._fft_amp(target)) return loss @torch.no_grad() def validate(model, val_loader, device, max_images, pad_factor=8, tta=False): model.eval() psnr_sum = 0.0 count = 0 for data in tqdm(val_loader, desc="Val", leave=False): target = data[0].to(device, non_blocking=True) input_img = data[1].to(device, non_blocking=True) output = run_model(model, input_img, pad_factor=pad_factor, tta=tta).clamp(0, 1) psnr_sum += batch_rgb_psnr(output, target) * output.size(0) count += output.size(0) if max_images > 0 and count >= max_images: return psnr_sum / count return psnr_sum / max(count, 1) def main(): torch.backends.cudnn.benchmark = True seed_everything() use_ddp = "RANK" in os.environ or "LOCAL_RANK" in os.environ if use_ddp: torch.distributed.init_process_group(backend="nccl") local_rank = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(local_rank) rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() device = torch.device("cuda", local_rank) else: local_rank = 0 rank = 0 world_size = 1 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") epochs = args.epochs or opt.OPTIM.NUM_EPOCHS batch_size = args.batch_size or opt.OPTIM.BATCH_SIZE model_dir = os.path.join(args.save_dir, "Deraining", "models", args.session) os.makedirs(model_dir, exist_ok=True) metrics_path = os.path.join(model_dir, args.metrics_file) model = build_model(args.model).to(device) if use_ddp: model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) elif torch.cuda.device_count() > 1: model = nn.DataParallel(model) optimizer = optim.Adam(model.parameters(), lr=opt.OPTIM.LR_INITIAL, betas=(0.9, 0.999), eps=1e-8) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=opt.OPTIM.LR_MIN) criterion = RestorationLoss(args.loss_mode, edge_weight=args.edge_weight, fft_weight=args.fft_weight).to(device) ema = ModelEma(model, decay=args.ema_decay) if args.use_ema and rank == 0 else None start_epoch = 1 if args.resume: start_epoch = load_restormer_checkpoint(model, optimizer, scheduler, args.resume, device) if ema is not None: ema.module.load_state_dict(unwrap_model(model).state_dict()) train_dataset = PairedPatchDataset(args.train_inp, args.train_tar, opt.TRAINING.TRAIN_PS) val_dataset = PairedImageDataset(args.test_inp, args.test_tar) if use_ddp: train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler, num_workers=max(1, args.num_workers // world_size), pin_memory=True, drop_last=False, ) else: train_sampler = None train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, drop_last=False, ) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=max(1, args.num_workers // 2), pin_memory=True) if rank == 0: print(f"==> Model: {args.model}") print(f"==> Train on: {args.train_inp} -> {args.train_tar}") print(f"==> Train images: {len(train_dataset)}, Test images for val: {len(val_dataset)}") print(f"==> Val every epoch on: {args.test_inp} -> {args.test_tar}") print(f"==> Loss: {args.loss_mode}, EMA: {args.use_ema}, TTA: {args.tta}") print(f"==> Epochs: {epochs}, Batch size per process: {batch_size}, World size: {world_size}") print(f"==> Checkpoints: {model_dir}") print(f"==> Metrics CSV: {metrics_path}") print(f"==> Save checkpoint every {args.save_every} epoch(s)") best_psnr = 0.0 for epoch in range(start_epoch, epochs + 1): if train_sampler is not None: train_sampler.set_epoch(epoch) model.train() epoch_loss = 0.0 train_steps = 0 epoch_start = time.time() progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", disable=(rank != 0)) for step, data in enumerate(progress, start=1): train_steps = step target = data[0].to(device, non_blocking=True) input_img = data[1].to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) output, aux = unpack_model_output(model(input_img)) loss = criterion(output, target) if args.expert_decorr_weight > 0: decorr = expert_decorrelation_loss(aux) if decorr is not None: loss = loss + args.expert_decorr_weight * decorr loss.backward() optimizer.step() if ema is not None: ema.update(model) epoch_loss += loss.item() progress.set_postfix(loss=f"{epoch_loss / step:.4f}") if args.max_train_steps > 0 and step >= args.max_train_steps: break scheduler.step() if rank == 0: avg_loss = epoch_loss / max(train_steps, 1) cur_lr = scheduler.get_last_lr()[0] epoch_time = time.time() - epoch_start print(f"Epoch {epoch}: loss={avg_loss:.4f}, lr={cur_lr:.2e}, time={epoch_time:.1f}s") ckpt = { "epoch": epoch, "state_dict_G1": unwrap_model(model).state_dict(), "optimizer_G1": optimizer.state_dict(), "scheduler_G1": scheduler.state_dict(), "model": args.model, "loss_mode": args.loss_mode, } if ema is not None: ckpt["state_dict_ema"] = ema.module.state_dict() val_psnr = "" is_best = False if args.val_every > 0 and epoch % args.val_every == 0: eval_model = ema.module if ema is not None else unwrap_model(model) psnr = validate(eval_model, val_loader, device, args.val_max_images, pad_factor=args.val_pad_factor, tta=args.tta) val_psnr = psnr print(f"Val PSNR: {psnr:.4f}") if psnr > best_psnr: best_psnr = psnr is_best = True torch.save(ckpt, os.path.join(model_dir, "model_best.pth")) print(f"Saved best checkpoint: PSNR={best_psnr:.4f}") saved_checkpoint = False if args.save_every > 0 and epoch % args.save_every == 0: torch.save(ckpt, os.path.join(model_dir, f"model_{epoch}.pth")) saved_checkpoint = True print(f"Saved epoch checkpoint: model_{epoch}.pth") append_metrics_csv( metrics_path, { "epoch": epoch, "train_loss": f"{avg_loss:.6f}", "lr": f"{cur_lr:.8e}", "val_psnr": f"{val_psnr:.6f}" if val_psnr != "" else "", "best_psnr": f"{best_psnr:.6f}", "epoch_time": f"{epoch_time:.2f}", "saved_checkpoint": int(saved_checkpoint), "is_best": int(is_best), }, ) if use_ddp: torch.distributed.barrier() if __name__ == "__main__": main()