import torch import torch.nn as nn import os import numpy as np import torch.nn.functional as F import pytorch_lightning as pl from datasets import OSMDataset from torch.utils.data import DataLoader import random from typing import Optional, List, Tuple, Literal from image_encoder import SatlasPretrainEncoder from text_encoder import TextEncoder from orthogonal_adamw import OrthogonalAdamW from configs.config_e5 import config from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from lightning.pytorch.loggers import WandbLogger from utils import generate_tag_poly_pairs import matplotlib.pyplot as plt import io import wandb from PIL import Image # This performs a typical InfoNCE loss def contrastive_loss(image_feats: torch.Tensor, text_feats: torch.Tensor, logit_scale: torch.Tensor) -> torch.Tensor: logits = torch.matmul(image_feats, text_feats.t()) * logit_scale labels = torch.arange(logits.size(0), device=logits.device) return F.cross_entropy(logits, labels), logits class OSMBind(pl.LightningModule): def __init__(self, train_dataset=None, val_dataset=None, **kwargs): super().__init__() self.train_dataset = train_dataset self.val_dataset = val_dataset self.image_encoder = SatlasPretrainEncoder(fpn=True, model_name="Aerial_SwinB_SI", out_dim=768, num_extra_fpn_layers=4) taglist_vocab = torch.load(kwargs.get("taglist_path"), weights_only = True) tag_vocab_inverted = torch.load(kwargs.get("tagvocab_path"), weights_only = True) # str -> int tag_vocab = {v: k for k, v in tag_vocab_inverted.items()} # int -> str self.text_encoder = TextEncoder(taglist_vocab, tag_vocab, model_name=kwargs.get("text_backbone")) # for param in self.text_encoder.parameters(): # param.requires_grad = False self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # softer scale for misaligned encoders self.batch_size = kwargs.get("batch_size") self.num_workers = kwargs.get("num_workers") self.lr = kwargs.get("lr", 1e-4) self.num_samples = kwargs.get("num_samples") # number of OSM classes sampled self.ort_grad = kwargs.get("ort_grad") def forward(self, sat_img: torch.Tensor, pixel_tensor: torch.Tensor): full_image_feats = self.image_encoder(sat_img) # [B, D, H', W'] sampled_tag_tensor, image_poly_feats = generate_tag_poly_pairs(pixel_tensor, full_image_feats, K=self.num_samples) # [K], [K, D] text_sampled_feats = self.text_encoder(sampled_tag_tensor) # [K, D] return image_poly_feats, text_sampled_feats # [K, D], [K, D] def shared_step(self, batch): sat_img, pixel_tensor = batch image_poly_feats, text_sampled_feats = self(sat_img, pixel_tensor) # [K, D], [K, D] # contrastive loss for whole batch image_feats_norm = F.normalize(image_poly_feats, dim=1) text_feats_norm = F.normalize(text_sampled_feats, dim=1) logit_scale = self.logit_scale.exp() loss, logits = contrastive_loss(image_feats_norm, text_feats_norm, logit_scale=logit_scale) return loss, logits def log_similarity_matrix(self, logits): mat = logits.detach().cpu().numpy() fig, ax = plt.subplots(figsize=(6,6)) cax = ax.matshow(mat, cmap="viridis") fig.colorbar(cax) ax.set_xlabel("Text samples") ax.set_ylabel("Image samples") ax.set_title("Similarity Matrix") buf = io.BytesIO() plt.savefig(buf, format='png') buf.seek(0) plt.close(fig) # ✅ Fix: Convert buffer to PIL Image image = Image.open(buf) if isinstance(self.logger, WandbLogger): self.logger.experiment.log({ "similarity_matrix": wandb.Image(image), "global_step": self.global_step }) def training_step(self, batch, batch_idx): loss, logits = self.shared_step(batch) self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size) self.log('temperature', self.logit_scale.exp().item(), prog_bar=True, on_epoch=True) if self.global_step % 500 == 0: self.log_similarity_matrix(logits) # Log histogram of similarity scores every step if self.logger and hasattr(self.logger.experiment, "log"): self.logger.experiment.log({"logits_hist": wandb.Histogram(logits.detach().cpu().numpy())}) # Optionally log mean and max of logits for monitoring self.log("logits_mean", logits.mean(), on_step=True, on_epoch=False, prog_bar=True) self.log("logits_max", logits.max(), on_step=True, on_epoch=False, prog_bar=True) return loss def on_train_batch_end(self, outputs, batch, batch_idx): min_log_scale = np.log(1 / 1.0) max_log_scale = np.log(1 / 0.01) self.logit_scale.data.clamp_(min_log_scale, max_log_scale) def on_after_backward(self): if self.global_rank == 0 and self.current_epoch == 0: for name, param in self.named_parameters(): if param.requires_grad and param.grad is None: print(f"⚠️ Unused parameter: {name}") def validation_step(self, batch, batch_idx): loss, _ = self.shared_step(batch) self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size) return loss def train_dataloader(self): if self.train_dataset is None: raise ValueError("This model was initialized without a training dataset.") return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True, persistent_workers=False) def val_dataloader(self): if self.val_dataset is None: raise ValueError("This model was initialized without a validation dataset.") return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, persistent_workers=False) def configure_optimizers(self): params = self.parameters() if self.ort_grad: self.optim = OrthogonalAdamW( params, lr=self.lr, betas=(0.9, 0.98), beta_ort=0.9, eps=1e-6, weight_decay=0.01 ) else: self.optim = torch.optim.AdamW( params, lr=self.lr, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.01 ) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer=self.optim, T_0=20 ) return [self.optim], [self.scheduler] def sim_map_inf(self, sat_image: torch.Tensor, raw_text: str) -> torch.Tensor: """ Args: sat_image: [1, 3, 512, 512] tensor (already normalized) raw_text: str, e.g., "building" Returns: sim_map: [512, 512] similarity map between image and text embedding """ assert sat_image.dim() == 4 and sat_image.size(0) == 1, "Expected input of shape [1, 3, H, W]" # Step 1: Extract spatial features with torch.no_grad(): # image features feat_map = self.image_encoder(sat_image) # [1, D, H', W'] feat_map = feat_map.squeeze(0) # [D, H', W'] feat_map_upsampled = F.interpolate(feat_map.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False).squeeze(0) # [D, 512, 512] feat_map_upsampled = F.normalize(feat_map_upsampled, dim=0) # [D, 512, 512] # text features text_feat = self.text_encoder.encode_raw_text(raw_text) # cosine sim text_feat = F.normalize(text_feat, dim=0) feat_map_upsampled = F.normalize(feat_map_upsampled, dim=0) sim_map = torch.einsum('chw,c->hw', feat_map_upsampled, text_feat) # [512, 512] return sim_map def encode_text(self, text: str) -> torch.Tensor: with torch.no_grad(): return self.text_encoder.encode_raw_text(text) def encode_image(self, image: torch.Tensor) -> torch.Tensor: with torch.no_grad(): return self.image_encoder(image) def seed_everything(seed=42): """ seed: int """ torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False os.environ["PYTHONHASHSEED"] = str(seed) if __name__=='__main__': import warnings warnings.filterwarnings("ignore") torch.set_warn_always(False) seed_everything() train_dataset = OSMDataset(metadata_path = config.train_csv, image_dir=config.sat_img_dir, pixel_tensor_dir=config.pixel_tensors_dir, mode='train') val_dataset = OSMDataset(metadata_path = config.val_csv, image_dir=config.sat_img_dir, pixel_tensor_dir=config.pixel_tensors_dir, mode='val') # from torch.utils.data import Subset # train_dataset = Subset(train_dataset, range(1000)) # val_dataset = Subset(val_dataset, range(200)) kwargs = { 'batch_size':config.batch_size, 'num_workers': config.num_workers, 'num_samples': config.num_contrastive_samples, 'ort_grad': config.ort_grad, 'lr': config.lr, 'taglist_vocab_path': config.taglist_vocab_path, 'tag_vocab_path': config.tag_vocab_path, 'text_backbone': config.text_backbone } model = OSMBind(train_dataset, val_dataset, **kwargs) torch.cuda.empty_cache() checkpoint_path = '/data/b.j.wei/rendersynth/osm_clip/checkpoints/osmclip_e5/osmclip_config_e5-epoch=39-val_loss=3.23.ckpt' if checkpoint_path: ckpt = torch.load(checkpoint_path, map_location='cpu') model.load_state_dict(ckpt['state_dict']) checkpoint = ModelCheckpoint( monitor='val_loss', dirpath=config.save_dir, filename=config.filename, mode='min', save_top_k=1, every_n_epochs=1 ) early_stop_callback = EarlyStopping( monitor='val_loss', patience=15, mode='min' ) logger = WandbLogger(project="osmclip", name=f"{config.experiment_name}") trainer = pl.Trainer( accelerator='gpu', devices=config.devices, strategy='ddp', max_epochs=config.max_epochs, num_nodes=1, callbacks=[checkpoint, early_stop_callback], accumulate_grad_batches=config.accumulate_grad_batches, log_every_n_steps=5, logger = logger #wandb logger ) trainer.fit(model)