AliMusaRizvi's picture
Upload train.py with huggingface_hub
0c04e69 verified
import os, gc, random, warnings
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, DDIMScheduler
from diffusers.training_utils import EMAModel
from diffusers.optimization import get_cosine_schedule_with_warmup
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator
from accelerate.utils import set_seed
from tqdm.auto import tqdm
import wandb
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
warnings.filterwarnings('ignore')
DATA_ROOT = '/kaggle/input/datasets/shambac/augmented-sentinel-1-2'
OUTPUT_DIR = '/kaggle/working/checkpoints'
SD_MODEL_ID = 'runwayml/stable-diffusion-v1-5'
IMG_SIZE = 256
SEASONS = ('spring', 'summer', 'fall', 'winter')
MAX_PAIRS = 90000
VAL_FRACTION = 0.10
LORA_RANK = 32
LORA_ALPHA = 32
USE_DORA = True
NUM_TRAIN_STEPS = 12_000
RESUME_STEP = 0
RESUME_FROM = None
BATCH_SIZE = 8
GRAD_ACCUM = 2
LR = 5e-5
LR_WARMUP = 1_000
MIXED_PRECISION = 'fp16'
GRAD_CKPT = True
EMA_DECAY = 0.9999
NUM_TIMESTEPS = 1000
COLOR_LOSS_W = 0.5
PERCEPTUAL_W = 0.1
COLOR_LOSS_FREQ = 35
PERCEPTUAL_FREQ = 50
LOG_EVERY = 100
SAVE_EVERY = 2_000
VIS_EVERY = 3_000
SEED = 332
def collect_pairs(root, seasons, max_pairs=None):
root_str = Path(root).as_posix()
season_buckets = {s: [] for s in seasons}
for season in seasons:
csv_files = list((Path(root) / season).glob('*.csv'))
if not csv_files:
continue
df = pd.concat([pd.read_csv(f) for f in csv_files], ignore_index=True)
df['season'] = season
df['region'] = df['region'].str.strip().str.lower()
df['s1_fileName'] = (df['s1_fileName']
.str.replace('\\', '/', regex=False)
.str.replace(r'(\w+?)1s2_', r'\1_s1_', regex=True))
df['s2_fileName'] = df['s2_fileName'].str.replace('\\', '/', regex=False)
df['s1'] = root_str + '/' + df['s1_fileName']
df['s2'] = root_str + '/' + df['s2_fileName']
season_buckets[season] = df[['s1','s2','season','region']].to_dict('records')
active = [s for s in seasons if season_buckets[s]]
if max_pairs is None:
pairs = [p for s in active for p in season_buckets[s]]
else:
per_season = max_pairs // len(active)
pairs = []
for s in active:
bucket = season_buckets[s].copy()
random.shuffle(bucket)
pairs.extend(bucket[:per_season])
random.shuffle(pairs)
return pairs
class SAROpticalDataset(Dataset):
def __init__(self, pairs, img_size=256, augment=True):
self.pairs = pairs
self.img_size = img_size
self.augment = augment
def __len__(self):
return len(self.pairs)
def _load_sar(self, path):
img = Image.open(path).convert('L')
if img.size != (self.img_size, self.img_size):
img = img.resize((self.img_size, self.img_size), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
arr = np.stack([arr, arr, arr], axis=2)
return torch.from_numpy(arr).permute(2, 0, 1) * 2.0 - 1.0
def _load_optical(self, path):
img = Image.open(path).convert('RGB')
if img.size != (self.img_size, self.img_size):
img = img.resize((self.img_size, self.img_size), Image.BILINEAR)
arr = np.array(img, dtype=np.float32) / 255.0
return torch.from_numpy(arr).permute(2, 0, 1) * 2.0 - 1.0
def __getitem__(self, idx):
pair = self.pairs[idx]
sar = self._load_sar(pair['s1'])
opt = self._load_optical(pair['s2'])
if self.augment:
if random.random() > 0.5:
sar = TF.hflip(sar); opt = TF.hflip(opt)
if random.random() > 0.5:
sar = TF.vflip(sar); opt = TF.vflip(opt)
k = random.randint(0, 3)
if k > 0:
sar = torch.rot90(sar, k, [1, 2])
opt = torch.rot90(opt, k, [1, 2])
return {'sar': sar, 'optical': opt, 'season': pair['season'], 'region': pair['region']}
class VGGPerceptualLoss(nn.Module):
def __init__(self):
super().__init__()
import torchvision.models as models
vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
self.features = nn.Sequential(*list(vgg.features)[:9]).eval()
for p in self.parameters():
p.requires_grad_(False)
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
def forward(self, pred, target):
pred = (pred * 0.5 + 0.5 - self.mean) / self.std
target = (target * 0.5 + 0.5 - self.mean) / self.std
return F.l1_loss(self.features(pred), self.features(target))
def color_supervision_loss(noise_pred, noise, noisy_latents, timesteps, scheduler, vae, opt_latents):
low_noise_mask = timesteps < 500
if low_noise_mask.sum() == 0:
return torch.tensor(0.0, device=noise_pred.device)
vae_mod = vae.module if hasattr(vae, 'module') else vae
alphas_cumprod = scheduler.alphas_cumprod.to(noise_pred.device)
sqrt_alpha = alphas_cumprod[timesteps[low_noise_mask]].sqrt().view(-1,1,1,1)
sqrt_one_minus = (1 - alphas_cumprod[timesteps[low_noise_mask]]).sqrt().view(-1,1,1,1)
noisy_sub = noisy_latents[low_noise_mask, :4]
noise_sub = noise_pred[low_noise_mask]
x0_pred = (noisy_sub - sqrt_one_minus * noise_sub) / (sqrt_alpha + 1e-8)
x0_pred = x0_pred.clamp(-4, 4)
with torch.no_grad():
pred_img = vae_mod.decode((x0_pred / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample
gt_img = vae_mod.decode((opt_latents[low_noise_mask] / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample
return F.l1_loss(pred_img.float(), gt_img.float())
@torch.no_grad()
def translate_single(sar_tensor, unet_m, vae_m, null_embed, device, num_steps=50):
ddim = DDIMScheduler(num_train_timesteps=NUM_TIMESTEPS, beta_schedule='scaled_linear', prediction_type='epsilon')
sar = sar_tensor.unsqueeze(0).to(device, dtype=torch.float16)
vae_mod = vae_m.module if hasattr(vae_m, 'module') else vae_m
sar_latent = vae_mod.encode(sar).latent_dist.mean * vae_mod.config.scaling_factor
latents = torch.randn_like(sar_latent) * ddim.init_noise_sigma
embed = null_embed.to(device, dtype=torch.float16).expand(1, -1, -1)
ddim.set_timesteps(num_steps)
for t in ddim.timesteps:
model_in = torch.cat([latents, sar_latent], dim=1)
noise_pred = unet_m(model_in.float(), t.unsqueeze(0).to(device), encoder_hidden_states=embed.float()).sample
latents = ddim.step(noise_pred.to(torch.float16), t, latents).prev_sample
image = vae_mod.decode(latents / vae_mod.config.scaling_factor).sample
return image.squeeze(0).clamp(-1, 1).float().cpu()
def denorm(t):
return ((t.clamp(-1, 1) + 1) / 2).permute(1, 2, 0).numpy()
def save_validation_grid(val_pairs, unet_m, vae_m, null_embed, device, step, out_dir):
selected = []
for s in SEASONS:
pool = [p for p in val_pairs if p['season'] == s]
if pool: selected.append(random.choice(pool))
selected = selected[:8]
fig, axes = plt.subplots(len(selected), 3, figsize=(12, 3.5 * len(selected)))
for i, pair in enumerate(selected):
ds = SAROpticalDataset([pair], img_size=IMG_SIZE, augment=False)
item = ds[0]
pred = translate_single(item['sar'], unet_m, vae_m, null_embed, device)
axes[i,0].imshow(denorm(item['sar'])[:,:,0], cmap='gray')
axes[i,0].set_title(f"SAR | {pair['season']}", fontsize=8)
axes[i,0].axis('off')
axes[i,1].imshow(denorm(pred))
axes[i,1].set_title(f'Predicted (step {step})', fontsize=8)
axes[i,1].axis('off')
axes[i,2].imshow(denorm(item['optical']))
axes[i,2].set_title('Ground Truth', fontsize=8)
axes[i,2].axis('off')
plt.tight_layout()
path = os.path.join(out_dir, f'val_step_{step:06d}.png')
plt.savefig(path, dpi=100, bbox_inches='tight')
plt.close()
return path
def main():
accelerator = Accelerator(
mixed_precision=MIXED_PRECISION,
gradient_accumulation_steps=GRAD_ACCUM,
log_with='wandb',
project_dir=OUTPUT_DIR,
)
set_seed(SEED + accelerator.process_index)
is_main = accelerator.is_main_process
if is_main:
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, 'val_grids'), exist_ok=True)
wandb.login(key=os.environ['WANDB_API_KEY'], relogin=True)
accelerator.init_trackers(
project_name='sar-optical-diffusion',
config={'lr': LR, 'batch_size': BATCH_SIZE, 'lora_rank': LORA_RANK,
'steps': NUM_TRAIN_STEPS, 'resume_step': RESUME_STEP,
'color_loss_w': COLOR_LOSS_W, 'perceptual_w': PERCEPTUAL_W,
'effective_batch': BATCH_SIZE * GRAD_ACCUM * accelerator.num_processes},
init_kwargs={'wandb': {
'name': f'dora-r{LORA_RANK}-step{RESUME_STEP}-to-{RESUME_STEP+NUM_TRAIN_STEPS}',
'tags': ['sentinel', 'SAR', 'diffusion', 'DoRA', 'color-supervision'],
'resume': 'allow',
}}
)
all_pairs = collect_pairs(DATA_ROOT, SEASONS, MAX_PAIRS)
train_pairs, val_pairs = train_test_split(
all_pairs, test_size=VAL_FRACTION,
stratify=[f"{p['season']}_{p['region']}" for p in all_pairs],
random_state=SEED,
)
train_ds = SAROpticalDataset(train_pairs, img_size=IMG_SIZE, augment=True)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
num_workers=2, pin_memory=True, drop_last=True, persistent_workers=False)
vae = AutoencoderKL.from_pretrained(SD_MODEL_ID, subfolder='vae', torch_dtype=torch.float16)
vae.requires_grad_(False)
unet = UNet2DConditionModel.from_pretrained(SD_MODEL_ID, subfolder='unet', torch_dtype=torch.float32)
old_conv = unet.conv_in
new_conv = nn.Conv2d(8, old_conv.out_channels, old_conv.kernel_size,
old_conv.stride, old_conv.padding, bias=(old_conv.bias is not None))
with torch.no_grad():
new_conv.weight[:, :4].copy_(old_conv.weight)
new_conv.weight[:, 4:].zero_()
if old_conv.bias is not None: new_conv.bias.copy_(old_conv.bias)
unet.conv_in = new_conv
unet.config.in_channels = 8
text_embed_dim = unet.config.cross_attention_dim
null_text_embed = nn.Parameter(torch.randn(1, 77, text_embed_dim) * 0.01)
lora_cfg = LoraConfig(
r=LORA_RANK, lora_alpha=LORA_ALPHA, use_dora=USE_DORA,
init_lora_weights='gaussian',
target_modules=['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','ff.net.0.proj','ff.net.2'],
lora_dropout=0.0, bias='none',
)
unet = get_peft_model(unet, lora_cfg)
unet.conv_in.weight.requires_grad_(True)
if unet.conv_in.bias is not None: unet.conv_in.bias.requires_grad_(True)
if GRAD_CKPT: unet.enable_gradient_checkpointing()
perceptual_loss_fn = VGGPerceptualLoss()
noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TIMESTEPS,
beta_schedule='scaled_linear', prediction_type='epsilon')
ema_unet = EMAModel(unet.parameters(), decay=EMA_DECAY)
trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) + [null_text_embed]
total_steps_across_sessions = RESUME_STEP + NUM_TRAIN_STEPS
optimizer = torch.optim.AdamW(trainable_params, lr=LR, betas=(0.9,0.999), weight_decay=1e-2, eps=1e-8)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=LR_WARMUP * GRAD_ACCUM,
num_training_steps=total_steps_across_sessions * GRAD_ACCUM,
)
unet, vae, optimizer, train_dl, lr_scheduler = accelerator.prepare(
unet, vae, optimizer, train_dl, lr_scheduler
)
perceptual_loss_fn = perceptual_loss_fn.to(accelerator.device)
null_text_embed = null_text_embed.to(accelerator.device)
ema_unet.to(accelerator.device)
if RESUME_FROM is not None:
adapter_path = os.path.join(RESUME_FROM, 'unet_adapter')
embed_path = os.path.join(RESUME_FROM, 'null_embed.pt')
if os.path.exists(adapter_path):
accelerator.unwrap_model(unet).load_adapter(adapter_path, adapter_name='default')
tqdm.write(f'Loaded adapter from {adapter_path}')
if os.path.exists(embed_path):
null_text_embed.data = torch.load(embed_path, map_location=accelerator.device).data
tqdm.write(f'Loaded null_embed from {embed_path}')
for _ in range(RESUME_STEP * GRAD_ACCUM):
lr_scheduler.step()
tqdm.write(f'LR scheduler fast-forwarded to step {RESUME_STEP}')
@torch.no_grad()
def encode_to_latent(images):
images = images.to(dtype=vae.dtype)
vae_mod = vae.module if hasattr(vae, 'module') else vae
return vae_mod.encode(images).latent_dist.sample() * vae_mod.config.scaling_factor
steps_per_epoch = len(train_dl) // GRAD_ACCUM
total_epochs = (NUM_TRAIN_STEPS + steps_per_epoch - 1) // steps_per_epoch
global_step = RESUME_STEP
avg_loss = 0.0
overall_bar = tqdm(total=NUM_TRAIN_STEPS, desc=f'Training (step {RESUME_STEP} -> {RESUME_STEP+NUM_TRAIN_STEPS})',
disable=not is_main, position=0, dynamic_ncols=True, leave=True)
unet.train()
for epoch in range(total_epochs):
accum_loss = 0.0
batches_seen = 0
for batch in train_dl:
with accelerator.accumulate(unet):
sar_imgs = batch['sar'].to(accelerator.device)
opt_imgs = batch['optical'].to(accelerator.device)
opt_latents = encode_to_latent(opt_imgs)
sar_latents = encode_to_latent(sar_imgs)
noise = torch.randn_like(opt_latents)
B = opt_latents.shape[0]
timesteps = torch.randint(0, NUM_TIMESTEPS, (B,), device=accelerator.device, dtype=torch.long)
noisy_latents = noise_scheduler.add_noise(opt_latents, noise, timesteps)
model_input = torch.cat([noisy_latents, sar_latents], dim=1)
enc_hidden = null_text_embed.to(opt_latents.dtype).expand(B, -1, -1)
noise_pred = unet(model_input, timesteps, encoder_hidden_states=enc_hidden).sample
loss_diffusion = F.mse_loss(noise_pred.float(), noise.float())
loss_color = torch.tensor(0.0, device=accelerator.device)
loss_perceptual = torch.tensor(0.0, device=accelerator.device)
if global_step % COLOR_LOSS_FREQ == 0:
loss_color = color_supervision_loss(
noise_pred, noise, model_input, timesteps, noise_scheduler, vae, opt_latents)
if global_step % PERCEPTUAL_FREQ == 0:
low_mask = timesteps < 300
if low_mask.sum() > 0:
vae_mod = vae.module if hasattr(vae, 'module') else vae
alphas_cp = noise_scheduler.alphas_cumprod.to(accelerator.device)
sqrt_a = alphas_cp[timesteps[low_mask]].sqrt().view(-1,1,1,1)
sqrt_1ma = (1 - alphas_cp[timesteps[low_mask]]).sqrt().view(-1,1,1,1)
x0_pred = (noisy_latents[low_mask] - sqrt_1ma * noise_pred[low_mask]) / (sqrt_a + 1e-8)
x0_pred = x0_pred.clamp(-4, 4)
with torch.no_grad():
pred_img = vae_mod.decode((x0_pred / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample
gt_img = vae_mod.decode((opt_latents[low_mask] / vae_mod.config.scaling_factor).to(vae_mod.dtype)).sample
loss_perceptual = perceptual_loss_fn(pred_img.float().clamp(-1,1), gt_img.float().clamp(-1,1))
loss = loss_diffusion + COLOR_LOSS_W * loss_color + PERCEPTUAL_W * loss_perceptual
accum_loss += loss.item()
batches_seen += 1
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(accelerator.unwrap_model(unet).parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
ema_unet.step(accelerator.unwrap_model(unet).parameters())
global_step += 1
avg_loss = accum_loss / batches_seen
lr_now = lr_scheduler.get_last_lr()[0]
accum_loss = 0.0
batches_seen = 0
if is_main:
overall_bar.update(1)
overall_bar.set_postfix({'loss': f'{avg_loss:.4f}', 'lr': f'{lr_now:.1e}', 'step': global_step})
if global_step % LOG_EVERY == 0 and is_main:
accelerator.log({'train/loss': avg_loss, 'train/loss_color': loss_color.item(), 'train/lr': lr_now}, step=global_step)
if global_step % VIS_EVERY == 0 and is_main:
unet.eval()
ema_unwrapped = accelerator.unwrap_model(unet)
ema_unet.copy_to(ema_unwrapped.parameters())
grid_path = save_validation_grid(
val_pairs, ema_unwrapped,
vae.module if hasattr(vae, 'module') else vae,
null_text_embed, accelerator.device, global_step,
os.path.join(OUTPUT_DIR, 'val_grids'),
)
tqdm.write(f' Val grid -> {grid_path}')
unet.train()
if global_step % SAVE_EVERY == 0 and is_main:
ckpt = os.path.join(OUTPUT_DIR, f'step_{global_step}')
os.makedirs(ckpt, exist_ok=True)
unwrapped = accelerator.unwrap_model(unet)
unwrapped.save_pretrained(os.path.join(ckpt, 'unet_adapter'))
torch.save(null_text_embed, os.path.join(ckpt, 'null_embed.pt'))
tqdm.write(f'step {global_step} | adapter saved -> {ckpt}')
if global_step >= RESUME_STEP + NUM_TRAIN_STEPS: break
gc.collect()
torch.cuda.empty_cache()
if is_main: tqdm.write(f'Epoch {epoch+1}/{total_epochs} | step {global_step} | loss {avg_loss:.4f}')
if global_step >= RESUME_STEP + NUM_TRAIN_STEPS: break
overall_bar.close()
if is_main:
final_dir = os.path.join(OUTPUT_DIR, f'session_final_step{global_step}')
os.makedirs(final_dir, exist_ok=True)
unwrapped = accelerator.unwrap_model(unet)
ema_unet.copy_to(unwrapped.parameters())
unwrapped.save_pretrained(os.path.join(final_dir, 'unet_adapter'))
merged = unwrapped.merge_and_unload()
merged.save_pretrained(os.path.join(final_dir, 'unet_full'))
torch.save(null_text_embed, os.path.join(final_dir, 'null_embed.pt'))
vae_mod = vae.module if hasattr(vae, 'module') else vae
vae_mod.save_pretrained(os.path.join(final_dir, 'vae'))
tqdm.write(f'Session complete — step {global_step}. Saved -> {final_dir}')
tqdm.write(f'To resume next session set:')
tqdm.write(f' RESUME_STEP = {global_step}')
tqdm.write(f' RESUME_FROM = "{final_dir}"')
accelerator.end_training()
if __name__ == '__main__':
main()