SkySensepp / s2 /modeling_skysensepp.py
BiliSakura's picture
Update all files for SkySensepp
457788f verified
"""HuggingFace PreTrainedModel wrapper for the SkySense++ model."""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from .configuration_skysensepp import SkySensePPConfig
from .modality_vae import ModalityCompletionVAE
from lib.models.backbones import SwinTransformerV2MSL, VisionTransformerMSL
from lib.models.necks import TransformerEncoder
from lib.models.heads import UPerHead, UPHead
class SkySensePPModel(PreTrainedModel):
"""HuggingFace wrapper for the SkySense++ multi-modal segmentation model.
The model fuses high-resolution optical imagery (HR), Sentinel-2 (S2) and
Sentinel-1 SAR (S1) features through independent backbones, an optional
modality-completion VAE, a shared transformer fusion encoder, and a UPer
decode head.
Args:
config (:class:`SkySensePPConfig`): Model configuration.
"""
config_class = SkySensePPConfig
# The underlying backbones call tensor.item() during __init__, which is
# incompatible with the meta-tensor fast-init path in transformers.
_supports_param_buffer_assignment = False
# Spatial downsampling factor between HR input and annotation mask grid.
_BLOCK_SIZE = 32
@classmethod
def get_init_context(cls, dtype, is_quantized, _is_ds_init_called):
"""Override to avoid meta-device init (backbones use .item())."""
import contextlib
from transformers.modeling_utils import local_torch_dtype
try:
from transformers.modeling_utils.init import no_tie_weights
except (ImportError, AttributeError):
no_tie_weights = contextlib.nullcontext
return [local_torch_dtype(dtype, cls.__name__), no_tie_weights()]
def __init__(self, config: SkySensePPConfig):
super().__init__(config)
self.sources = config.sources
# --- Backbone HR (SwinTransformerV2MSL) ---
if "hr" in self.sources:
self.backbone_hr = SwinTransformerV2MSL(
arch=config.hr_arch,
img_size=config.hr_img_size,
patch_size=config.hr_patch_size,
in_channels=config.hr_in_channels,
vocabulary_size=config.vocabulary_size,
window_size=config.hr_window_size,
drop_path_rate=config.hr_drop_path_rate,
out_indices=config.hr_out_indices,
use_abs_pos_embed=config.hr_use_abs_pos_embed,
with_cp=config.hr_with_cp,
pad_small_map=config.hr_pad_small_map,
)
# --- Backbone S2 (VisionTransformerMSL) ---
if "s2" in self.sources:
self.backbone_s2 = VisionTransformerMSL(
img_size=config.s2_img_size,
patch_size=config.s2_patch_size,
in_channels=config.s2_in_channels,
embed_dims=config.s2_embed_dims,
num_layers=config.s2_num_layers,
num_heads=config.s2_num_heads,
mlp_ratio=config.s2_mlp_ratio,
out_indices=config.s2_out_indices,
drop_path_rate=config.s2_drop_path_rate,
vocabulary_size=config.vocabulary_size,
)
self.head_s2 = UPHead(
in_dim=config.s2_embed_dims,
out_dim=config.s2_embed_dims,
up_scale=1,
)
# --- Fusion Encoder (always created) ---
self.fusion = TransformerEncoder(
input_dims=config.fusion_input_dims,
embed_dims=config.fusion_embed_dims,
num_layers=config.fusion_num_layers,
num_heads=config.fusion_num_heads,
with_cls_token=config.fusion_with_cls_token,
output_cls_token=config.fusion_output_cls_token,
)
# --- Backbone S1 (VisionTransformerMSL) ---
if "s1" in self.sources:
self.backbone_s1 = VisionTransformerMSL(
img_size=config.s1_img_size,
patch_size=config.s1_patch_size,
in_channels=config.s1_in_channels,
embed_dims=config.s1_embed_dims,
num_layers=config.s1_num_layers,
num_heads=config.s1_num_heads,
vocabulary_size=config.vocabulary_size,
)
self.head_s1 = UPHead(
in_dim=config.s1_embed_dims,
out_dim=config.s1_embed_dims,
up_scale=1,
)
# --- Modality VAE (diffusers-style loadable component) ---
if config.use_modal_vae:
self.modality_vae = ModalityCompletionVAE(
input_shape_hr=(config.fusion_input_dims, 16, 16),
input_shape_s2=(config.fusion_input_dims, 16, 16),
input_shape_s1=(config.fusion_input_dims, 16, 16),
)
# --- Decode Head (UPerHead for HR reconstruction) ---
self.head_rec_hr = UPerHead(
in_channels=config.decode_in_channels,
channels=config.decode_channels,
num_classes=config.decode_num_classes,
in_index=[0, 1, 2, 3, 4],
align_corners=True,
)
self.post_init()
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
hr_img=None,
s2_img=None,
s1_img=None,
anno_img=None,
anno_mask=None,
s2_ct=None,
s2_ct2=None,
modality_flag_hr=None,
modality_flag_s2=None,
modality_flag_s1=None,
return_features=False,
):
"""Run multi-modal forward pass.
Args:
hr_img (Tensor, optional): High-resolution image ``(B, C, H, W)``.
s2_img (Tensor, optional): Sentinel-2 image ``(B, C, S, H, W)``.
s1_img (Tensor, optional): Sentinel-1 SAR image ``(B, C, S, H, W)``.
anno_img (Tensor, optional): Annotation image ``(B, H, W)``.
anno_mask (Tensor, optional): Annotation mask ``(B, H_m, W_m)``.
s2_ct (Tensor, optional): Calendar-time index for S2 pass 1.
s2_ct2 (Tensor, optional): Calendar-time index for S2 pass 2.
modality_flag_hr (Tensor, optional): Per-sample HR availability flag.
modality_flag_s2 (Tensor, optional): Per-sample S2 availability flag.
modality_flag_s1 (Tensor, optional): Per-sample S1 availability flag.
return_features (bool): If True, include backbone and fusion
representations in output (for representation extraction).
Default False.
Returns:
dict: ``logits_hr`` (when HR present), and optionally
``features_hr``, ``features_s2``, ``features_s1``, ``features_fusion``
when return_features=True.
"""
output = {}
# Determine batch size from any available input
B = None
for img in (hr_img, s2_img, s1_img):
if img is not None:
B = img.shape[0]
_ref_device = img.device
break
if B is None:
return output
# Build modality flags ------------------------------------------
if modality_flag_hr is None:
modality_flag_hr = torch.ones(B, dtype=torch.bool, device=_ref_device) if hr_img is not None else torch.zeros(B, dtype=torch.bool, device=_ref_device)
if modality_flag_s2 is None:
modality_flag_s2 = torch.ones(B, dtype=torch.bool, device=_ref_device) if s2_img is not None else torch.zeros(B, dtype=torch.bool, device=_ref_device)
if modality_flag_s1 is None:
modality_flag_s1 = torch.ones(B, dtype=torch.bool, device=_ref_device) if s1_img is not None else torch.zeros(B, dtype=torch.bool, device=_ref_device)
modalities = torch.stack(
[modality_flag_hr, modality_flag_s2, modality_flag_s1], dim=-1
)
# 1. Backbone feature extraction --------------------------------
hr_features = None
s2_features = None
s1_features = None
if "hr" in self.sources and hr_img is not None:
if anno_img is not None and anno_mask is not None:
B_M, H_M, W_M = anno_mask.shape
_, _, H_img, W_img = hr_img.shape
# Derive block size from anno_mask and image dims
block_h = H_img // H_M
block_w = W_img // W_M
anno_mask_hr = (
anno_mask.unsqueeze(-1)
.unsqueeze(-1)
.repeat(1, 1, 1, block_h, block_w)
)
anno_mask_hr = (
anno_mask_hr.permute(0, 1, 3, 2, 4)
.reshape(B_M, H_M * block_h, W_M * block_w)
.contiguous()
)
hr_features = self.backbone_hr(hr_img, anno_img, anno_mask_hr)
else:
# Inference without annotation: pass dummy zeros
B_hr, _, H, W = hr_img.shape
patch_size = self.config.hr_patch_size
dummy_anno = torch.zeros(B_hr, H, W, dtype=torch.long, device=hr_img.device)
dummy_mask = torch.zeros(B_hr, H // patch_size, W // patch_size, dtype=torch.bool, device=hr_img.device)
hr_features = self.backbone_hr(hr_img, dummy_anno, dummy_mask)
if "s2" in self.sources and s2_img is not None:
B, C_S2, S_S2, H_S2, W_S2 = s2_img.shape
s2_flat = s2_img.permute(0, 2, 1, 3, 4).reshape(
B * S_S2, C_S2, H_S2, W_S2
).contiguous()
if anno_img is not None and anno_mask is not None:
# Subsample annotation to S2 spatial resolution
step = max(1, anno_img.shape[1] // H_S2)
offset = step // 2
anno_s2 = anno_img[:, offset::step, offset::step][:, :H_S2, :W_S2]
s2_features = self.backbone_s2(s2_flat, anno_s2, anno_mask)
else:
patch_s2 = self.config.s2_patch_size
h_p, w_p = H_S2 // patch_s2, W_S2 // patch_s2
dummy_anno = torch.zeros(B * S_S2, H_S2, W_S2, dtype=torch.long, device=s2_img.device)
dummy_mask = torch.zeros(B * S_S2, h_p, w_p, dtype=torch.bool, device=s2_img.device)
s2_features = self.backbone_s2(s2_flat, dummy_anno, dummy_mask)
s2_features = [self.head_s2(s2_features[-1])]
if "s1" in self.sources and s1_img is not None:
B, C_S1, S_S1, H_S1, W_S1 = s1_img.shape
s1_flat = s1_img.permute(0, 2, 1, 3, 4).reshape(
B * S_S1, C_S1, H_S1, W_S1
).contiguous()
if anno_img is not None and anno_mask is not None:
step = max(1, anno_img.shape[1] // H_S1)
offset = step // 2
anno_s1 = anno_img[:, offset::step, offset::step][:, :H_S1, :W_S1]
s1_features = self.backbone_s1(s1_flat, anno_s1, anno_mask)
else:
patch_s1 = self.config.s1_patch_size
h_p, w_p = H_S1 // patch_s1, W_S1 // patch_s1
dummy_anno = torch.zeros(B * S_S1, H_S1, W_S1, dtype=torch.long, device=s1_img.device)
dummy_mask = torch.zeros(B * S_S1, h_p, w_p, dtype=torch.bool, device=s1_img.device)
s1_features = self.backbone_s1(s1_flat, dummy_anno, dummy_mask)
s1_features = [self.head_s1(s1_features[-1])]
# 2. Modality VAE -----------------------------------------------
hr_stage3 = hr_features[-1] if hr_features is not None else None
s2_stage3 = s2_features[-1] if s2_features is not None else None
s1_stage3 = s1_features[-1] if s1_features is not None else None
if (
self.config.use_modal_vae
and hr_stage3 is not None
and s2_stage3 is not None
and s1_stage3 is not None
):
modalities_dev = modalities.to(hr_stage3.device)
vae_out = self.modality_vae(
hr_stage3, s2_stage3, s1_stage3, modalities_dev
)
hr_stage3 = vae_out["hr_out"]
s2_stage3 = vae_out["s2_out"]
s1_stage3 = vae_out["s1_out"]
output["vae_out"] = vae_out
# 3. Fusion ------------------------------------------------------
# Collect per-modality tokens for fusion
# H3, W3 define the spatial grid for cls_token reshape (use HR if available, else S2/S1)
feature_parts = []
H3, W3 = None, None
if hr_stage3 is not None:
B, C3, H3, W3 = hr_stage3.shape
hr_tok = hr_stage3.permute(0, 2, 3, 1).reshape(
B * H3 * W3, C3
).unsqueeze(1).contiguous()
feature_parts.append(hr_tok)
if s2_stage3 is not None:
_, C3_S2, H3_S2, W3_S2 = s2_stage3.shape
if H3 is None:
H3, W3 = H3_S2, W3_S2
s2_tok = (
s2_stage3.reshape(B, S_S2, C3_S2, H3_S2, W3_S2)
.permute(0, 3, 4, 1, 2)
.reshape(B * H3_S2 * W3_S2, S_S2, C3_S2)
.contiguous()
)
feature_parts.append(s2_tok)
if s1_stage3 is not None:
_, C3_S1, H3_S1, W3_S1 = s1_stage3.shape
if H3 is None:
H3, W3 = H3_S1, W3_S1
s1_tok = (
s1_stage3.reshape(B, S_S1, C3_S1, H3_S1, W3_S1)
.permute(0, 3, 4, 1, 2)
.reshape(B * H3_S1 * W3_S1, S_S1, C3_S1)
.contiguous()
)
feature_parts.append(s1_tok)
features_stage3 = torch.cat(feature_parts, dim=1)
if self.config.fusion_output_cls_token:
cls_token = self.fusion(features_stage3)
_, C3_cls = cls_token.shape
cls_token = (
cls_token.reshape(B, H3, W3, C3_cls)
.contiguous()
.permute(0, 3, 1, 2)
.contiguous()
)
else:
features_stage3 = self.fusion(features_stage3)
# Representation extraction outputs (when requested)
if return_features:
output["features_hr"] = hr_features
output["features_s2"] = s2_features[0] if s2_features else None
output["features_s1"] = s1_features[0] if s1_features else None
output["features_fusion"] = cls_token if self.config.fusion_output_cls_token else features_stage3
# 4. Decode -------------------------------------------------------
if hr_features is not None:
hr_rec_inputs = list(hr_features)
feat_stage1 = hr_rec_inputs[0]
if feat_stage1.shape[-1] == feat_stage1.shape[-2]:
left, right = torch.split(
feat_stage1, feat_stage1.shape[-1] // 2, dim=-1
)
hr_rec_inputs[0] = torch.cat([left, right], dim=1)
rec_feats = [*hr_rec_inputs, cls_token]
logits_hr = self.head_rec_hr(rec_feats)
logits_hr = logits_hr.to(torch.float32)
logits_hr = F.interpolate(
logits_hr, scale_factor=4, mode="bilinear", align_corners=True
)
output["logits_hr"] = logits_hr
return output
# ------------------------------------------------------------------
# Loading (diffusers-style: VAE in subfolder modality_vae/)
# ------------------------------------------------------------------
def load_vae(
self,
pretrained_model_name_or_path=None,
subfolder="modality_vae",
**kwargs,
):
"""Load modality VAE from a pretrained repo (diffusers-style).
Uses ModalityCompletionVAE.from_pretrained(..., subfolder="modality_vae").
Layout: {path}/modality_vae/diffusion_pytorch_model.safetensors
Fallback: {path}/modality_vae.safetensors (legacy)
Args:
pretrained_model_name_or_path: Model path. If None, uses config path.
subfolder: Subfolder name, default "modality_vae".
**kwargs: Passed to from_pretrained.
Returns:
self (for chaining).
"""
if not getattr(self.config, "use_modal_vae", False) or not hasattr(
self, "modality_vae"
):
raise ValueError(
"Model has no modality_vae (use_modal_vae=False or single-modality variant)"
)
path = pretrained_model_name_or_path
if path is None and hasattr(self.config, "config_file") and self.config.config_file:
path = os.path.dirname(self.config.config_file)
if path is None:
raise ValueError(
"pretrained_model_name_or_path required when config has no config_file"
)
loaded = ModalityCompletionVAE.from_pretrained(
path, subfolder=subfolder, **kwargs
)
self.modality_vae.load_state_dict(loaded.state_dict(), strict=False)
return self
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
if getattr(model.config, "use_modal_vae", False) and hasattr(
model, "modality_vae"
):
subfolder = getattr(model.config, "vae_subfolder", "modality_vae")
try:
model.load_vae(
pretrained_model_name_or_path=pretrained_model_name_or_path,
subfolder=subfolder,
)
except FileNotFoundError:
pass # VAE optional if not present
return model
# ------------------------------------------------------------------
# Checkpoint conversion helper
# ------------------------------------------------------------------
@classmethod
def from_original_checkpoint(cls, config, checkpoint_path):
"""Load an original SkySensePP checkpoint into this HuggingFace model.
Args:
config (:class:`SkySensePPConfig`): Model configuration.
checkpoint_path (str): Path to the original ``.pth`` checkpoint.
Returns:
:class:`SkySensePPModel`: Model with loaded weights.
"""
model = cls(config)
ckpt = torch.load(checkpoint_path, map_location="cpu")
if "model" in ckpt:
state_dict = ckpt["model"]
else:
state_dict = ckpt
# Strip leading ``model.`` prefix that the antmmf framework adds.
cleaned = {}
for k, v in state_dict.items():
new_key = k[len("model.") :] if k.startswith("model.") else k
cleaned[new_key] = v
missing, unexpected = model.load_state_dict(cleaned, strict=False)
if missing:
print(f"Missing keys ({len(missing)}): {missing[:10]}...")
if unexpected:
print(f"Unexpected keys ({len(unexpected)}): {unexpected[:10]}...")
return model