import os, gc, random, warnings import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from PIL import Image from torch.utils.data import Dataset, DataLoader import torchvision.transforms.functional as TF from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, DDIMScheduler from diffusers.training_utils import EMAModel from diffusers.optimization import get_cosine_schedule_with_warmup from peft import LoraConfig, get_peft_model from accelerate import Accelerator from accelerate.utils import set_seed from tqdm.auto import tqdm import wandb import pandas as pd import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split warnings.filterwarnings('ignore') DATA_ROOT = '/kaggle/input/datasets/shambac/augmented-sentinel-1-2' OUTPUT_DIR = '/kaggle/working/checkpoints' SD_MODEL_ID = 'runwayml/stable-diffusion-v1-5' IMG_SIZE = 256 SEASONS = ('spring', 'summer', 'fall', 'winter') MAX_PAIRS = 90000 VAL_FRACTION = 0.10 LORA_RANK = 32 LORA_ALPHA = 32 USE_DORA = True NUM_TRAIN_STEPS = 12_000 RESUME_STEP = 0 RESUME_FROM = None BATCH_SIZE = 8 GRAD_ACCUM = 2 LR = 5e-5 LR_WARMUP = 1_000 MIXED_PRECISION = 'fp16' GRAD_CKPT = True EMA_DECAY = 0.9999 NUM_TIMESTEPS = 1000 COLOR_LOSS_W = 0.5 PERCEPTUAL_W = 0.1 COLOR_LOSS_FREQ = 35 PERCEPTUAL_FREQ = 50 LOG_EVERY = 100 SAVE_EVERY = 2_000 VIS_EVERY = 3_000 SEED = 332 def collect_pairs(root, seasons, max_pairs=None): root_str = Path(root).as_posix() season_buckets = {s: [] for s in seasons} for season in seasons: csv_files = list((Path(root) / season).glob('*.csv')) if not csv_files: continue df = pd.concat([pd.read_csv(f) for f in csv_files], ignore_index=True) df['season'] = season df['region'] = df['region'].str.strip().str.lower() df['s1_fileName'] = (df['s1_fileName'] .str.replace('\\', '/', regex=False) .str.replace(r'(\w+?)1s2_', r'\1_s1_', regex=True)) df['s2_fileName'] = df['s2_fileName'].str.replace('\\', '/', regex=False) df['s1'] = root_str + '/' + df['s1_fileName'] df['s2'] = root_str + '/' + df['s2_fileName'] season_buckets[season] = df[['s1','s2','season','region']].to_dict('records') active = [s for s in seasons if season_buckets[s]] if max_pairs is None: pairs = [p for s in active for p in season_buckets[s]] else: per_season = max_pairs // len(active) pairs = [] for s in active: bucket = season_buckets[s].copy() random.shuffle(bucket) pairs.extend(bucket[:per_season]) random.shuffle(pairs) return pairs class SAROpticalDataset(Dataset): def __init__(self, pairs, img_size=256, augment=True): self.pairs = pairs self.img_size = img_size self.augment = augment def __len__(self): return len(self.pairs) def _load_sar(self, path): img = Image.open(path).convert('L') if img.size != (self.img_size, self.img_size): img = img.resize((self.img_size, self.img_size), Image.BILINEAR) arr = np.array(img, dtype=np.float32) / 255.0 arr = np.stack([arr, arr, arr], axis=2) return torch.from_numpy(arr).permute(2, 0, 1) * 2.0 - 1.0 def _load_optical(self, path): img = Image.open(path).convert('RGB') if img.size != (self.img_size, self.img_size): img = img.resize((self.img_size, self.img_size), Image.BILINEAR) arr = np.array(img, dtype=np.float32) / 255.0 return torch.from_numpy(arr).permute(2, 0, 1) * 2.0 - 1.0 def __getitem__(self, idx): pair = self.pairs[idx] sar = self._load_sar(pair['s1']) opt = self._load_optical(pair['s2']) if self.augment: if random.random() > 0.5: sar = TF.hflip(sar); opt = TF.hflip(opt) if random.random() > 0.5: sar = TF.vflip(sar); opt = TF.vflip(opt) k = random.randint(0, 3) if k > 0: sar = torch.rot90(sar, k, [1, 2]) opt = torch.rot90(opt, k, [1, 2]) return {'sar': sar, 'optical': opt, 'season': pair['season'], 'region': pair['region']} class VGGPerceptualLoss(nn.Module): def __init__(self): super().__init__() import torchvision.models as models vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) self.features = nn.Sequential(*list(vgg.features)[:9]).eval() for p in self.parameters(): p.requires_grad_(False) self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) def forward(self, pred, target): pred = (pred * 0.5 + 0.5 - self.mean) / self.std target = (target * 0.5 + 0.5 - self.mean) / self.std return F.l1_loss(self.features(pred), self.features(target)) def color_supervision_loss(noise_pred, noise, noisy_latents, timesteps, scheduler, vae, opt_latents): low_noise_mask = timesteps < 500 if low_noise_mask.sum() == 0: return torch.tensor(0.0, device=noise_pred.device) vae_mod = vae.module if hasattr(vae, 'module') else vae alphas_cumprod = scheduler.alphas_cumprod.to(noise_pred.device) sqrt_alpha = alphas_cumprod[timesteps[low_noise_mask]].sqrt().view(-1,1,1,1) sqrt_one_minus = (1 - alphas_cumprod[timesteps[low_noise_mask]]).sqrt().view(-1,1,1,1) noisy_sub = noisy_latents[low_noise_mask, :4] noise_sub = noise_pred[low_noise_mask] x0_pred = (noisy_sub - sqrt_one_minus * noise_sub) / (sqrt_alpha + 1e-8) x0_pred = x0_pred.clamp(-4, 4) with torch.no_grad(): pred_img = vae_mod.decode((x0_pred / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample gt_img = vae_mod.decode((opt_latents[low_noise_mask] / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample return F.l1_loss(pred_img.float(), gt_img.float()) @torch.no_grad() def translate_single(sar_tensor, unet_m, vae_m, null_embed, device, num_steps=50): ddim = DDIMScheduler(num_train_timesteps=NUM_TIMESTEPS, beta_schedule='scaled_linear', prediction_type='epsilon') sar = sar_tensor.unsqueeze(0).to(device, dtype=torch.float16) vae_mod = vae_m.module if hasattr(vae_m, 'module') else vae_m sar_latent = vae_mod.encode(sar).latent_dist.mean * vae_mod.config.scaling_factor latents = torch.randn_like(sar_latent) * ddim.init_noise_sigma embed = null_embed.to(device, dtype=torch.float16).expand(1, -1, -1) ddim.set_timesteps(num_steps) for t in ddim.timesteps: model_in = torch.cat([latents, sar_latent], dim=1) noise_pred = unet_m(model_in.float(), t.unsqueeze(0).to(device), encoder_hidden_states=embed.float()).sample latents = ddim.step(noise_pred.to(torch.float16), t, latents).prev_sample image = vae_mod.decode(latents / vae_mod.config.scaling_factor).sample return image.squeeze(0).clamp(-1, 1).float().cpu() def denorm(t): return ((t.clamp(-1, 1) + 1) / 2).permute(1, 2, 0).numpy() def save_validation_grid(val_pairs, unet_m, vae_m, null_embed, device, step, out_dir): selected = [] for s in SEASONS: pool = [p for p in val_pairs if p['season'] == s] if pool: selected.append(random.choice(pool)) selected = selected[:8] fig, axes = plt.subplots(len(selected), 3, figsize=(12, 3.5 * len(selected))) for i, pair in enumerate(selected): ds = SAROpticalDataset([pair], img_size=IMG_SIZE, augment=False) item = ds[0] pred = translate_single(item['sar'], unet_m, vae_m, null_embed, device) axes[i,0].imshow(denorm(item['sar'])[:,:,0], cmap='gray') axes[i,0].set_title(f"SAR | {pair['season']}", fontsize=8) axes[i,0].axis('off') axes[i,1].imshow(denorm(pred)) axes[i,1].set_title(f'Predicted (step {step})', fontsize=8) axes[i,1].axis('off') axes[i,2].imshow(denorm(item['optical'])) axes[i,2].set_title('Ground Truth', fontsize=8) axes[i,2].axis('off') plt.tight_layout() path = os.path.join(out_dir, f'val_step_{step:06d}.png') plt.savefig(path, dpi=100, bbox_inches='tight') plt.close() return path def main(): accelerator = Accelerator( mixed_precision=MIXED_PRECISION, gradient_accumulation_steps=GRAD_ACCUM, log_with='wandb', project_dir=OUTPUT_DIR, ) set_seed(SEED + accelerator.process_index) is_main = accelerator.is_main_process if is_main: os.makedirs(OUTPUT_DIR, exist_ok=True) os.makedirs(os.path.join(OUTPUT_DIR, 'val_grids'), exist_ok=True) wandb.login(key=os.environ['WANDB_API_KEY'], relogin=True) accelerator.init_trackers( project_name='sar-optical-diffusion', config={'lr': LR, 'batch_size': BATCH_SIZE, 'lora_rank': LORA_RANK, 'steps': NUM_TRAIN_STEPS, 'resume_step': RESUME_STEP, 'color_loss_w': COLOR_LOSS_W, 'perceptual_w': PERCEPTUAL_W, 'effective_batch': BATCH_SIZE * GRAD_ACCUM * accelerator.num_processes}, init_kwargs={'wandb': { 'name': f'dora-r{LORA_RANK}-step{RESUME_STEP}-to-{RESUME_STEP+NUM_TRAIN_STEPS}', 'tags': ['sentinel', 'SAR', 'diffusion', 'DoRA', 'color-supervision'], 'resume': 'allow', }} ) all_pairs = collect_pairs(DATA_ROOT, SEASONS, MAX_PAIRS) train_pairs, val_pairs = train_test_split( all_pairs, test_size=VAL_FRACTION, stratify=[f"{p['season']}_{p['region']}" for p in all_pairs], random_state=SEED, ) train_ds = SAROpticalDataset(train_pairs, img_size=IMG_SIZE, augment=True) train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True, persistent_workers=False) vae = AutoencoderKL.from_pretrained(SD_MODEL_ID, subfolder='vae', torch_dtype=torch.float16) vae.requires_grad_(False) unet = UNet2DConditionModel.from_pretrained(SD_MODEL_ID, subfolder='unet', torch_dtype=torch.float32) old_conv = unet.conv_in new_conv = nn.Conv2d(8, old_conv.out_channels, old_conv.kernel_size, old_conv.stride, old_conv.padding, bias=(old_conv.bias is not None)) with torch.no_grad(): new_conv.weight[:, :4].copy_(old_conv.weight) new_conv.weight[:, 4:].zero_() if old_conv.bias is not None: new_conv.bias.copy_(old_conv.bias) unet.conv_in = new_conv unet.config.in_channels = 8 text_embed_dim = unet.config.cross_attention_dim null_text_embed = nn.Parameter(torch.randn(1, 77, text_embed_dim) * 0.01) lora_cfg = LoraConfig( r=LORA_RANK, lora_alpha=LORA_ALPHA, use_dora=USE_DORA, init_lora_weights='gaussian', target_modules=['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','ff.net.0.proj','ff.net.2'], lora_dropout=0.0, bias='none', ) unet = get_peft_model(unet, lora_cfg) unet.conv_in.weight.requires_grad_(True) if unet.conv_in.bias is not None: unet.conv_in.bias.requires_grad_(True) if GRAD_CKPT: unet.enable_gradient_checkpointing() perceptual_loss_fn = VGGPerceptualLoss() noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TIMESTEPS, beta_schedule='scaled_linear', prediction_type='epsilon') ema_unet = EMAModel(unet.parameters(), decay=EMA_DECAY) trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) + [null_text_embed] total_steps_across_sessions = RESUME_STEP + NUM_TRAIN_STEPS optimizer = torch.optim.AdamW(trainable_params, lr=LR, betas=(0.9,0.999), weight_decay=1e-2, eps=1e-8) lr_scheduler = get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=LR_WARMUP * GRAD_ACCUM, num_training_steps=total_steps_across_sessions * GRAD_ACCUM, ) unet, vae, optimizer, train_dl, lr_scheduler = accelerator.prepare( unet, vae, optimizer, train_dl, lr_scheduler ) perceptual_loss_fn = perceptual_loss_fn.to(accelerator.device) null_text_embed = null_text_embed.to(accelerator.device) ema_unet.to(accelerator.device) if RESUME_FROM is not None: adapter_path = os.path.join(RESUME_FROM, 'unet_adapter') embed_path = os.path.join(RESUME_FROM, 'null_embed.pt') if os.path.exists(adapter_path): accelerator.unwrap_model(unet).load_adapter(adapter_path, adapter_name='default') tqdm.write(f'Loaded adapter from {adapter_path}') if os.path.exists(embed_path): null_text_embed.data = torch.load(embed_path, map_location=accelerator.device).data tqdm.write(f'Loaded null_embed from {embed_path}') for _ in range(RESUME_STEP * GRAD_ACCUM): lr_scheduler.step() tqdm.write(f'LR scheduler fast-forwarded to step {RESUME_STEP}') @torch.no_grad() def encode_to_latent(images): images = images.to(dtype=vae.dtype) vae_mod = vae.module if hasattr(vae, 'module') else vae return vae_mod.encode(images).latent_dist.sample() * vae_mod.config.scaling_factor steps_per_epoch = len(train_dl) // GRAD_ACCUM total_epochs = (NUM_TRAIN_STEPS + steps_per_epoch - 1) // steps_per_epoch global_step = RESUME_STEP avg_loss = 0.0 overall_bar = tqdm(total=NUM_TRAIN_STEPS, desc=f'Training (step {RESUME_STEP} -> {RESUME_STEP+NUM_TRAIN_STEPS})', disable=not is_main, position=0, dynamic_ncols=True, leave=True) unet.train() for epoch in range(total_epochs): accum_loss = 0.0 batches_seen = 0 for batch in train_dl: with accelerator.accumulate(unet): sar_imgs = batch['sar'].to(accelerator.device) opt_imgs = batch['optical'].to(accelerator.device) opt_latents = encode_to_latent(opt_imgs) sar_latents = encode_to_latent(sar_imgs) noise = torch.randn_like(opt_latents) B = opt_latents.shape[0] timesteps = torch.randint(0, NUM_TIMESTEPS, (B,), device=accelerator.device, dtype=torch.long) noisy_latents = noise_scheduler.add_noise(opt_latents, noise, timesteps) model_input = torch.cat([noisy_latents, sar_latents], dim=1) enc_hidden = null_text_embed.to(opt_latents.dtype).expand(B, -1, -1) noise_pred = unet(model_input, timesteps, encoder_hidden_states=enc_hidden).sample loss_diffusion = F.mse_loss(noise_pred.float(), noise.float()) loss_color = torch.tensor(0.0, device=accelerator.device) loss_perceptual = torch.tensor(0.0, device=accelerator.device) if global_step % COLOR_LOSS_FREQ == 0: loss_color = color_supervision_loss( noise_pred, noise, model_input, timesteps, noise_scheduler, vae, opt_latents) if global_step % PERCEPTUAL_FREQ == 0: low_mask = timesteps < 300 if low_mask.sum() > 0: vae_mod = vae.module if hasattr(vae, 'module') else vae alphas_cp = noise_scheduler.alphas_cumprod.to(accelerator.device) sqrt_a = alphas_cp[timesteps[low_mask]].sqrt().view(-1,1,1,1) sqrt_1ma = (1 - alphas_cp[timesteps[low_mask]]).sqrt().view(-1,1,1,1) x0_pred = (noisy_latents[low_mask] - sqrt_1ma * noise_pred[low_mask]) / (sqrt_a + 1e-8) x0_pred = x0_pred.clamp(-4, 4) with torch.no_grad(): pred_img = vae_mod.decode((x0_pred / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample gt_img = vae_mod.decode((opt_latents[low_mask] / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample loss_perceptual = perceptual_loss_fn(pred_img.float().clamp(-1,1), gt_img.float().clamp(-1,1)) loss = loss_diffusion + COLOR_LOSS_W * loss_color + PERCEPTUAL_W * loss_perceptual accum_loss += loss.item() batches_seen += 1 accelerator.backward(loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(accelerator.unwrap_model(unet).parameters(), 1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if accelerator.sync_gradients: ema_unet.step(accelerator.unwrap_model(unet).parameters()) global_step += 1 avg_loss = accum_loss / batches_seen lr_now = lr_scheduler.get_last_lr()[0] accum_loss = 0.0 batches_seen = 0 if is_main: overall_bar.update(1) overall_bar.set_postfix({'loss': f'{avg_loss:.4f}', 'lr': f'{lr_now:.1e}', 'step': global_step}) if global_step % LOG_EVERY == 0 and is_main: accelerator.log({'train/loss': avg_loss, 'train/loss_color': loss_color.item(), 'train/lr': lr_now}, step=global_step) if global_step % VIS_EVERY == 0 and is_main: unet.eval() ema_unwrapped = accelerator.unwrap_model(unet) ema_unet.copy_to(ema_unwrapped.parameters()) grid_path = save_validation_grid( val_pairs, ema_unwrapped, vae.module if hasattr(vae, 'module') else vae, null_text_embed, accelerator.device, global_step, os.path.join(OUTPUT_DIR, 'val_grids'), ) tqdm.write(f' Val grid -> {grid_path}') unet.train() if global_step % SAVE_EVERY == 0 and is_main: ckpt = os.path.join(OUTPUT_DIR, f'step_{global_step}') os.makedirs(ckpt, exist_ok=True) unwrapped = accelerator.unwrap_model(unet) unwrapped.save_pretrained(os.path.join(ckpt, 'unet_adapter')) torch.save(null_text_embed, os.path.join(ckpt, 'null_embed.pt')) tqdm.write(f'step {global_step} | adapter saved -> {ckpt}') if global_step >= RESUME_STEP + NUM_TRAIN_STEPS: break gc.collect() torch.cuda.empty_cache() if is_main: tqdm.write(f'Epoch {epoch+1}/{total_epochs} | step {global_step} | loss {avg_loss:.4f}') if global_step >= RESUME_STEP + NUM_TRAIN_STEPS: break overall_bar.close() if is_main: final_dir = os.path.join(OUTPUT_DIR, f'session_final_step{global_step}') os.makedirs(final_dir, exist_ok=True) unwrapped = accelerator.unwrap_model(unet) ema_unet.copy_to(unwrapped.parameters()) unwrapped.save_pretrained(os.path.join(final_dir, 'unet_adapter')) merged = unwrapped.merge_and_unload() merged.save_pretrained(os.path.join(final_dir, 'unet_full')) torch.save(null_text_embed, os.path.join(final_dir, 'null_embed.pt')) vae_mod = vae.module if hasattr(vae, 'module') else vae vae_mod.save_pretrained(os.path.join(final_dir, 'vae')) tqdm.write(f'Session complete — step {global_step}. Saved -> {final_dir}') tqdm.write(f'To resume next session set:') tqdm.write(f' RESUME_STEP = {global_step}') tqdm.write(f' RESUME_FROM = "{final_dir}"') accelerator.end_training() if __name__ == '__main__': main()