Instructions to use AliMusaRizvi/sar-to-optical-diffusion with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use AliMusaRizvi/sar-to-optical-diffusion with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("AliMusaRizvi/sar-to-optical-diffusion", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import os, gc, random, warnings | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from pathlib import Path | |
| from PIL import Image | |
| from torch.utils.data import Dataset, DataLoader | |
| import torchvision.transforms.functional as TF | |
| from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, DDIMScheduler | |
| from diffusers.training_utils import EMAModel | |
| from diffusers.optimization import get_cosine_schedule_with_warmup | |
| from peft import LoraConfig, get_peft_model | |
| from accelerate import Accelerator | |
| from accelerate.utils import set_seed | |
| from tqdm.auto import tqdm | |
| import wandb | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from sklearn.model_selection import train_test_split | |
| warnings.filterwarnings('ignore') | |
| DATA_ROOT = '/kaggle/input/datasets/shambac/augmented-sentinel-1-2' | |
| OUTPUT_DIR = '/kaggle/working/checkpoints' | |
| SD_MODEL_ID = 'runwayml/stable-diffusion-v1-5' | |
| IMG_SIZE = 256 | |
| SEASONS = ('spring', 'summer', 'fall', 'winter') | |
| MAX_PAIRS = 90000 | |
| VAL_FRACTION = 0.10 | |
| LORA_RANK = 32 | |
| LORA_ALPHA = 32 | |
| USE_DORA = True | |
| NUM_TRAIN_STEPS = 12_000 | |
| RESUME_STEP = 0 | |
| RESUME_FROM = None | |
| BATCH_SIZE = 8 | |
| GRAD_ACCUM = 2 | |
| LR = 5e-5 | |
| LR_WARMUP = 1_000 | |
| MIXED_PRECISION = 'fp16' | |
| GRAD_CKPT = True | |
| EMA_DECAY = 0.9999 | |
| NUM_TIMESTEPS = 1000 | |
| COLOR_LOSS_W = 0.5 | |
| PERCEPTUAL_W = 0.1 | |
| COLOR_LOSS_FREQ = 35 | |
| PERCEPTUAL_FREQ = 50 | |
| LOG_EVERY = 100 | |
| SAVE_EVERY = 2_000 | |
| VIS_EVERY = 3_000 | |
| SEED = 332 | |
| def collect_pairs(root, seasons, max_pairs=None): | |
| root_str = Path(root).as_posix() | |
| season_buckets = {s: [] for s in seasons} | |
| for season in seasons: | |
| csv_files = list((Path(root) / season).glob('*.csv')) | |
| if not csv_files: | |
| continue | |
| df = pd.concat([pd.read_csv(f) for f in csv_files], ignore_index=True) | |
| df['season'] = season | |
| df['region'] = df['region'].str.strip().str.lower() | |
| df['s1_fileName'] = (df['s1_fileName'] | |
| .str.replace('\\', '/', regex=False) | |
| .str.replace(r'(\w+?)1s2_', r'\1_s1_', regex=True)) | |
| df['s2_fileName'] = df['s2_fileName'].str.replace('\\', '/', regex=False) | |
| df['s1'] = root_str + '/' + df['s1_fileName'] | |
| df['s2'] = root_str + '/' + df['s2_fileName'] | |
| season_buckets[season] = df[['s1','s2','season','region']].to_dict('records') | |
| active = [s for s in seasons if season_buckets[s]] | |
| if max_pairs is None: | |
| pairs = [p for s in active for p in season_buckets[s]] | |
| else: | |
| per_season = max_pairs // len(active) | |
| pairs = [] | |
| for s in active: | |
| bucket = season_buckets[s].copy() | |
| random.shuffle(bucket) | |
| pairs.extend(bucket[:per_season]) | |
| random.shuffle(pairs) | |
| return pairs | |
| class SAROpticalDataset(Dataset): | |
| def __init__(self, pairs, img_size=256, augment=True): | |
| self.pairs = pairs | |
| self.img_size = img_size | |
| self.augment = augment | |
| def __len__(self): | |
| return len(self.pairs) | |
| def _load_sar(self, path): | |
| img = Image.open(path).convert('L') | |
| if img.size != (self.img_size, self.img_size): | |
| img = img.resize((self.img_size, self.img_size), Image.BILINEAR) | |
| arr = np.array(img, dtype=np.float32) / 255.0 | |
| arr = np.stack([arr, arr, arr], axis=2) | |
| return torch.from_numpy(arr).permute(2, 0, 1) * 2.0 - 1.0 | |
| def _load_optical(self, path): | |
| img = Image.open(path).convert('RGB') | |
| if img.size != (self.img_size, self.img_size): | |
| img = img.resize((self.img_size, self.img_size), Image.BILINEAR) | |
| arr = np.array(img, dtype=np.float32) / 255.0 | |
| return torch.from_numpy(arr).permute(2, 0, 1) * 2.0 - 1.0 | |
| def __getitem__(self, idx): | |
| pair = self.pairs[idx] | |
| sar = self._load_sar(pair['s1']) | |
| opt = self._load_optical(pair['s2']) | |
| if self.augment: | |
| if random.random() > 0.5: | |
| sar = TF.hflip(sar); opt = TF.hflip(opt) | |
| if random.random() > 0.5: | |
| sar = TF.vflip(sar); opt = TF.vflip(opt) | |
| k = random.randint(0, 3) | |
| if k > 0: | |
| sar = torch.rot90(sar, k, [1, 2]) | |
| opt = torch.rot90(opt, k, [1, 2]) | |
| return {'sar': sar, 'optical': opt, 'season': pair['season'], 'region': pair['region']} | |
| class VGGPerceptualLoss(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| import torchvision.models as models | |
| vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) | |
| self.features = nn.Sequential(*list(vgg.features)[:9]).eval() | |
| for p in self.parameters(): | |
| p.requires_grad_(False) | |
| self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1)) | |
| self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1)) | |
| def forward(self, pred, target): | |
| pred = (pred * 0.5 + 0.5 - self.mean) / self.std | |
| target = (target * 0.5 + 0.5 - self.mean) / self.std | |
| return F.l1_loss(self.features(pred), self.features(target)) | |
| def color_supervision_loss(noise_pred, noise, noisy_latents, timesteps, scheduler, vae, opt_latents): | |
| low_noise_mask = timesteps < 500 | |
| if low_noise_mask.sum() == 0: | |
| return torch.tensor(0.0, device=noise_pred.device) | |
| vae_mod = vae.module if hasattr(vae, 'module') else vae | |
| alphas_cumprod = scheduler.alphas_cumprod.to(noise_pred.device) | |
| sqrt_alpha = alphas_cumprod[timesteps[low_noise_mask]].sqrt().view(-1,1,1,1) | |
| sqrt_one_minus = (1 - alphas_cumprod[timesteps[low_noise_mask]]).sqrt().view(-1,1,1,1) | |
| noisy_sub = noisy_latents[low_noise_mask, :4] | |
| noise_sub = noise_pred[low_noise_mask] | |
| x0_pred = (noisy_sub - sqrt_one_minus * noise_sub) / (sqrt_alpha + 1e-8) | |
| x0_pred = x0_pred.clamp(-4, 4) | |
| with torch.no_grad(): | |
| pred_img = vae_mod.decode((x0_pred / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample | |
| gt_img = vae_mod.decode((opt_latents[low_noise_mask] / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample | |
| return F.l1_loss(pred_img.float(), gt_img.float()) | |
| def translate_single(sar_tensor, unet_m, vae_m, null_embed, device, num_steps=50): | |
| ddim = DDIMScheduler(num_train_timesteps=NUM_TIMESTEPS, beta_schedule='scaled_linear', prediction_type='epsilon') | |
| sar = sar_tensor.unsqueeze(0).to(device, dtype=torch.float16) | |
| vae_mod = vae_m.module if hasattr(vae_m, 'module') else vae_m | |
| sar_latent = vae_mod.encode(sar).latent_dist.mean * vae_mod.config.scaling_factor | |
| latents = torch.randn_like(sar_latent) * ddim.init_noise_sigma | |
| embed = null_embed.to(device, dtype=torch.float16).expand(1, -1, -1) | |
| ddim.set_timesteps(num_steps) | |
| for t in ddim.timesteps: | |
| model_in = torch.cat([latents, sar_latent], dim=1) | |
| noise_pred = unet_m(model_in.float(), t.unsqueeze(0).to(device), encoder_hidden_states=embed.float()).sample | |
| latents = ddim.step(noise_pred.to(torch.float16), t, latents).prev_sample | |
| image = vae_mod.decode(latents / vae_mod.config.scaling_factor).sample | |
| return image.squeeze(0).clamp(-1, 1).float().cpu() | |
| def denorm(t): | |
| return ((t.clamp(-1, 1) + 1) / 2).permute(1, 2, 0).numpy() | |
| def save_validation_grid(val_pairs, unet_m, vae_m, null_embed, device, step, out_dir): | |
| selected = [] | |
| for s in SEASONS: | |
| pool = [p for p in val_pairs if p['season'] == s] | |
| if pool: selected.append(random.choice(pool)) | |
| selected = selected[:8] | |
| fig, axes = plt.subplots(len(selected), 3, figsize=(12, 3.5 * len(selected))) | |
| for i, pair in enumerate(selected): | |
| ds = SAROpticalDataset([pair], img_size=IMG_SIZE, augment=False) | |
| item = ds[0] | |
| pred = translate_single(item['sar'], unet_m, vae_m, null_embed, device) | |
| axes[i,0].imshow(denorm(item['sar'])[:,:,0], cmap='gray') | |
| axes[i,0].set_title(f"SAR | {pair['season']}", fontsize=8) | |
| axes[i,0].axis('off') | |
| axes[i,1].imshow(denorm(pred)) | |
| axes[i,1].set_title(f'Predicted (step {step})', fontsize=8) | |
| axes[i,1].axis('off') | |
| axes[i,2].imshow(denorm(item['optical'])) | |
| axes[i,2].set_title('Ground Truth', fontsize=8) | |
| axes[i,2].axis('off') | |
| plt.tight_layout() | |
| path = os.path.join(out_dir, f'val_step_{step:06d}.png') | |
| plt.savefig(path, dpi=100, bbox_inches='tight') | |
| plt.close() | |
| return path | |
| def main(): | |
| accelerator = Accelerator( | |
| mixed_precision=MIXED_PRECISION, | |
| gradient_accumulation_steps=GRAD_ACCUM, | |
| log_with='wandb', | |
| project_dir=OUTPUT_DIR, | |
| ) | |
| set_seed(SEED + accelerator.process_index) | |
| is_main = accelerator.is_main_process | |
| if is_main: | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| os.makedirs(os.path.join(OUTPUT_DIR, 'val_grids'), exist_ok=True) | |
| wandb.login(key=os.environ['WANDB_API_KEY'], relogin=True) | |
| accelerator.init_trackers( | |
| project_name='sar-optical-diffusion', | |
| config={'lr': LR, 'batch_size': BATCH_SIZE, 'lora_rank': LORA_RANK, | |
| 'steps': NUM_TRAIN_STEPS, 'resume_step': RESUME_STEP, | |
| 'color_loss_w': COLOR_LOSS_W, 'perceptual_w': PERCEPTUAL_W, | |
| 'effective_batch': BATCH_SIZE * GRAD_ACCUM * accelerator.num_processes}, | |
| init_kwargs={'wandb': { | |
| 'name': f'dora-r{LORA_RANK}-step{RESUME_STEP}-to-{RESUME_STEP+NUM_TRAIN_STEPS}', | |
| 'tags': ['sentinel', 'SAR', 'diffusion', 'DoRA', 'color-supervision'], | |
| 'resume': 'allow', | |
| }} | |
| ) | |
| all_pairs = collect_pairs(DATA_ROOT, SEASONS, MAX_PAIRS) | |
| train_pairs, val_pairs = train_test_split( | |
| all_pairs, test_size=VAL_FRACTION, | |
| stratify=[f"{p['season']}_{p['region']}" for p in all_pairs], | |
| random_state=SEED, | |
| ) | |
| train_ds = SAROpticalDataset(train_pairs, img_size=IMG_SIZE, augment=True) | |
| train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, | |
| num_workers=2, pin_memory=True, drop_last=True, persistent_workers=False) | |
| vae = AutoencoderKL.from_pretrained(SD_MODEL_ID, subfolder='vae', torch_dtype=torch.float16) | |
| vae.requires_grad_(False) | |
| unet = UNet2DConditionModel.from_pretrained(SD_MODEL_ID, subfolder='unet', torch_dtype=torch.float32) | |
| old_conv = unet.conv_in | |
| new_conv = nn.Conv2d(8, old_conv.out_channels, old_conv.kernel_size, | |
| old_conv.stride, old_conv.padding, bias=(old_conv.bias is not None)) | |
| with torch.no_grad(): | |
| new_conv.weight[:, :4].copy_(old_conv.weight) | |
| new_conv.weight[:, 4:].zero_() | |
| if old_conv.bias is not None: new_conv.bias.copy_(old_conv.bias) | |
| unet.conv_in = new_conv | |
| unet.config.in_channels = 8 | |
| text_embed_dim = unet.config.cross_attention_dim | |
| null_text_embed = nn.Parameter(torch.randn(1, 77, text_embed_dim) * 0.01) | |
| lora_cfg = LoraConfig( | |
| r=LORA_RANK, lora_alpha=LORA_ALPHA, use_dora=USE_DORA, | |
| init_lora_weights='gaussian', | |
| target_modules=['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','ff.net.0.proj','ff.net.2'], | |
| lora_dropout=0.0, bias='none', | |
| ) | |
| unet = get_peft_model(unet, lora_cfg) | |
| unet.conv_in.weight.requires_grad_(True) | |
| if unet.conv_in.bias is not None: unet.conv_in.bias.requires_grad_(True) | |
| if GRAD_CKPT: unet.enable_gradient_checkpointing() | |
| perceptual_loss_fn = VGGPerceptualLoss() | |
| noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TIMESTEPS, | |
| beta_schedule='scaled_linear', prediction_type='epsilon') | |
| ema_unet = EMAModel(unet.parameters(), decay=EMA_DECAY) | |
| trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) + [null_text_embed] | |
| total_steps_across_sessions = RESUME_STEP + NUM_TRAIN_STEPS | |
| optimizer = torch.optim.AdamW(trainable_params, lr=LR, betas=(0.9,0.999), weight_decay=1e-2, eps=1e-8) | |
| lr_scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=LR_WARMUP * GRAD_ACCUM, | |
| num_training_steps=total_steps_across_sessions * GRAD_ACCUM, | |
| ) | |
| unet, vae, optimizer, train_dl, lr_scheduler = accelerator.prepare( | |
| unet, vae, optimizer, train_dl, lr_scheduler | |
| ) | |
| perceptual_loss_fn = perceptual_loss_fn.to(accelerator.device) | |
| null_text_embed = null_text_embed.to(accelerator.device) | |
| ema_unet.to(accelerator.device) | |
| if RESUME_FROM is not None: | |
| adapter_path = os.path.join(RESUME_FROM, 'unet_adapter') | |
| embed_path = os.path.join(RESUME_FROM, 'null_embed.pt') | |
| if os.path.exists(adapter_path): | |
| accelerator.unwrap_model(unet).load_adapter(adapter_path, adapter_name='default') | |
| tqdm.write(f'Loaded adapter from {adapter_path}') | |
| if os.path.exists(embed_path): | |
| null_text_embed.data = torch.load(embed_path, map_location=accelerator.device).data | |
| tqdm.write(f'Loaded null_embed from {embed_path}') | |
| for _ in range(RESUME_STEP * GRAD_ACCUM): | |
| lr_scheduler.step() | |
| tqdm.write(f'LR scheduler fast-forwarded to step {RESUME_STEP}') | |
| def encode_to_latent(images): | |
| images = images.to(dtype=vae.dtype) | |
| vae_mod = vae.module if hasattr(vae, 'module') else vae | |
| return vae_mod.encode(images).latent_dist.sample() * vae_mod.config.scaling_factor | |
| steps_per_epoch = len(train_dl) // GRAD_ACCUM | |
| total_epochs = (NUM_TRAIN_STEPS + steps_per_epoch - 1) // steps_per_epoch | |
| global_step = RESUME_STEP | |
| avg_loss = 0.0 | |
| overall_bar = tqdm(total=NUM_TRAIN_STEPS, desc=f'Training (step {RESUME_STEP} -> {RESUME_STEP+NUM_TRAIN_STEPS})', | |
| disable=not is_main, position=0, dynamic_ncols=True, leave=True) | |
| unet.train() | |
| for epoch in range(total_epochs): | |
| accum_loss = 0.0 | |
| batches_seen = 0 | |
| for batch in train_dl: | |
| with accelerator.accumulate(unet): | |
| sar_imgs = batch['sar'].to(accelerator.device) | |
| opt_imgs = batch['optical'].to(accelerator.device) | |
| opt_latents = encode_to_latent(opt_imgs) | |
| sar_latents = encode_to_latent(sar_imgs) | |
| noise = torch.randn_like(opt_latents) | |
| B = opt_latents.shape[0] | |
| timesteps = torch.randint(0, NUM_TIMESTEPS, (B,), device=accelerator.device, dtype=torch.long) | |
| noisy_latents = noise_scheduler.add_noise(opt_latents, noise, timesteps) | |
| model_input = torch.cat([noisy_latents, sar_latents], dim=1) | |
| enc_hidden = null_text_embed.to(opt_latents.dtype).expand(B, -1, -1) | |
| noise_pred = unet(model_input, timesteps, encoder_hidden_states=enc_hidden).sample | |
| loss_diffusion = F.mse_loss(noise_pred.float(), noise.float()) | |
| loss_color = torch.tensor(0.0, device=accelerator.device) | |
| loss_perceptual = torch.tensor(0.0, device=accelerator.device) | |
| if global_step % COLOR_LOSS_FREQ == 0: | |
| loss_color = color_supervision_loss( | |
| noise_pred, noise, model_input, timesteps, noise_scheduler, vae, opt_latents) | |
| if global_step % PERCEPTUAL_FREQ == 0: | |
| low_mask = timesteps < 300 | |
| if low_mask.sum() > 0: | |
| vae_mod = vae.module if hasattr(vae, 'module') else vae | |
| alphas_cp = noise_scheduler.alphas_cumprod.to(accelerator.device) | |
| sqrt_a = alphas_cp[timesteps[low_mask]].sqrt().view(-1,1,1,1) | |
| sqrt_1ma = (1 - alphas_cp[timesteps[low_mask]]).sqrt().view(-1,1,1,1) | |
| x0_pred = (noisy_latents[low_mask] - sqrt_1ma * noise_pred[low_mask]) / (sqrt_a + 1e-8) | |
| x0_pred = x0_pred.clamp(-4, 4) | |
| with torch.no_grad(): | |
| pred_img = vae_mod.decode((x0_pred / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample | |
| gt_img = vae_mod.decode((opt_latents[low_mask] / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample | |
| loss_perceptual = perceptual_loss_fn(pred_img.float().clamp(-1,1), gt_img.float().clamp(-1,1)) | |
| loss = loss_diffusion + COLOR_LOSS_W * loss_color + PERCEPTUAL_W * loss_perceptual | |
| accum_loss += loss.item() | |
| batches_seen += 1 | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_(accelerator.unwrap_model(unet).parameters(), 1.0) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| if accelerator.sync_gradients: | |
| ema_unet.step(accelerator.unwrap_model(unet).parameters()) | |
| global_step += 1 | |
| avg_loss = accum_loss / batches_seen | |
| lr_now = lr_scheduler.get_last_lr()[0] | |
| accum_loss = 0.0 | |
| batches_seen = 0 | |
| if is_main: | |
| overall_bar.update(1) | |
| overall_bar.set_postfix({'loss': f'{avg_loss:.4f}', 'lr': f'{lr_now:.1e}', 'step': global_step}) | |
| if global_step % LOG_EVERY == 0 and is_main: | |
| accelerator.log({'train/loss': avg_loss, 'train/loss_color': loss_color.item(), 'train/lr': lr_now}, step=global_step) | |
| if global_step % VIS_EVERY == 0 and is_main: | |
| unet.eval() | |
| ema_unwrapped = accelerator.unwrap_model(unet) | |
| ema_unet.copy_to(ema_unwrapped.parameters()) | |
| grid_path = save_validation_grid( | |
| val_pairs, ema_unwrapped, | |
| vae.module if hasattr(vae, 'module') else vae, | |
| null_text_embed, accelerator.device, global_step, | |
| os.path.join(OUTPUT_DIR, 'val_grids'), | |
| ) | |
| tqdm.write(f' Val grid -> {grid_path}') | |
| unet.train() | |
| if global_step % SAVE_EVERY == 0 and is_main: | |
| ckpt = os.path.join(OUTPUT_DIR, f'step_{global_step}') | |
| os.makedirs(ckpt, exist_ok=True) | |
| unwrapped = accelerator.unwrap_model(unet) | |
| unwrapped.save_pretrained(os.path.join(ckpt, 'unet_adapter')) | |
| torch.save(null_text_embed, os.path.join(ckpt, 'null_embed.pt')) | |
| tqdm.write(f'step {global_step} | adapter saved -> {ckpt}') | |
| if global_step >= RESUME_STEP + NUM_TRAIN_STEPS: break | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if is_main: tqdm.write(f'Epoch {epoch+1}/{total_epochs} | step {global_step} | loss {avg_loss:.4f}') | |
| if global_step >= RESUME_STEP + NUM_TRAIN_STEPS: break | |
| overall_bar.close() | |
| if is_main: | |
| final_dir = os.path.join(OUTPUT_DIR, f'session_final_step{global_step}') | |
| os.makedirs(final_dir, exist_ok=True) | |
| unwrapped = accelerator.unwrap_model(unet) | |
| ema_unet.copy_to(unwrapped.parameters()) | |
| unwrapped.save_pretrained(os.path.join(final_dir, 'unet_adapter')) | |
| merged = unwrapped.merge_and_unload() | |
| merged.save_pretrained(os.path.join(final_dir, 'unet_full')) | |
| torch.save(null_text_embed, os.path.join(final_dir, 'null_embed.pt')) | |
| vae_mod = vae.module if hasattr(vae, 'module') else vae | |
| vae_mod.save_pretrained(os.path.join(final_dir, 'vae')) | |
| tqdm.write(f'Session complete — step {global_step}. Saved -> {final_dir}') | |
| tqdm.write(f'To resume next session set:') | |
| tqdm.write(f' RESUME_STEP = {global_step}') | |
| tqdm.write(f' RESUME_FROM = "{final_dir}"') | |
| accelerator.end_training() | |
| if __name__ == '__main__': | |
| main() |