| """HuggingFace PreTrainedModel wrapper for the SkySense++ model.""" |
|
|
| import os |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
|
|
| from .configuration_skysensepp import SkySensePPConfig |
| from .modality_vae import ModalityCompletionVAE |
| from lib.models.backbones import SwinTransformerV2MSL, VisionTransformerMSL |
| from lib.models.necks import TransformerEncoder |
| from lib.models.heads import UPerHead, UPHead |
|
|
|
|
| class SkySensePPModel(PreTrainedModel): |
| """HuggingFace wrapper for the SkySense++ multi-modal segmentation model. |
| |
| The model fuses high-resolution optical imagery (HR), Sentinel-2 (S2) and |
| Sentinel-1 SAR (S1) features through independent backbones, an optional |
| modality-completion VAE, a shared transformer fusion encoder, and a UPer |
| decode head. |
| |
| Args: |
| config (:class:`SkySensePPConfig`): Model configuration. |
| """ |
|
|
| config_class = SkySensePPConfig |
| |
| |
| _supports_param_buffer_assignment = False |
|
|
| |
| _BLOCK_SIZE = 32 |
|
|
| @classmethod |
| def get_init_context(cls, dtype, is_quantized, _is_ds_init_called): |
| """Override to avoid meta-device init (backbones use .item()).""" |
| import contextlib |
| from transformers.modeling_utils import local_torch_dtype |
| try: |
| from transformers.modeling_utils.init import no_tie_weights |
| except (ImportError, AttributeError): |
| no_tie_weights = contextlib.nullcontext |
| return [local_torch_dtype(dtype, cls.__name__), no_tie_weights()] |
|
|
| def __init__(self, config: SkySensePPConfig): |
| super().__init__(config) |
| self.sources = config.sources |
|
|
| |
| if "hr" in self.sources: |
| self.backbone_hr = SwinTransformerV2MSL( |
| arch=config.hr_arch, |
| img_size=config.hr_img_size, |
| patch_size=config.hr_patch_size, |
| in_channels=config.hr_in_channels, |
| vocabulary_size=config.vocabulary_size, |
| window_size=config.hr_window_size, |
| drop_path_rate=config.hr_drop_path_rate, |
| out_indices=config.hr_out_indices, |
| use_abs_pos_embed=config.hr_use_abs_pos_embed, |
| with_cp=config.hr_with_cp, |
| pad_small_map=config.hr_pad_small_map, |
| ) |
|
|
| |
| if "s2" in self.sources: |
| self.backbone_s2 = VisionTransformerMSL( |
| img_size=config.s2_img_size, |
| patch_size=config.s2_patch_size, |
| in_channels=config.s2_in_channels, |
| embed_dims=config.s2_embed_dims, |
| num_layers=config.s2_num_layers, |
| num_heads=config.s2_num_heads, |
| mlp_ratio=config.s2_mlp_ratio, |
| out_indices=config.s2_out_indices, |
| drop_path_rate=config.s2_drop_path_rate, |
| vocabulary_size=config.vocabulary_size, |
| ) |
| self.head_s2 = UPHead( |
| in_dim=config.s2_embed_dims, |
| out_dim=config.s2_embed_dims, |
| up_scale=1, |
| ) |
|
|
| |
| self.fusion = TransformerEncoder( |
| input_dims=config.fusion_input_dims, |
| embed_dims=config.fusion_embed_dims, |
| num_layers=config.fusion_num_layers, |
| num_heads=config.fusion_num_heads, |
| with_cls_token=config.fusion_with_cls_token, |
| output_cls_token=config.fusion_output_cls_token, |
| ) |
|
|
| |
| if "s1" in self.sources: |
| self.backbone_s1 = VisionTransformerMSL( |
| img_size=config.s1_img_size, |
| patch_size=config.s1_patch_size, |
| in_channels=config.s1_in_channels, |
| embed_dims=config.s1_embed_dims, |
| num_layers=config.s1_num_layers, |
| num_heads=config.s1_num_heads, |
| vocabulary_size=config.vocabulary_size, |
| ) |
| self.head_s1 = UPHead( |
| in_dim=config.s1_embed_dims, |
| out_dim=config.s1_embed_dims, |
| up_scale=1, |
| ) |
|
|
| |
| if config.use_modal_vae: |
| self.modality_vae = ModalityCompletionVAE( |
| input_shape_hr=(config.fusion_input_dims, 16, 16), |
| input_shape_s2=(config.fusion_input_dims, 16, 16), |
| input_shape_s1=(config.fusion_input_dims, 16, 16), |
| ) |
|
|
| |
| self.head_rec_hr = UPerHead( |
| in_channels=config.decode_in_channels, |
| channels=config.decode_channels, |
| num_classes=config.decode_num_classes, |
| in_index=[0, 1, 2, 3, 4], |
| align_corners=True, |
| ) |
|
|
| self.post_init() |
|
|
| |
| |
| |
|
|
| def forward( |
| self, |
| hr_img=None, |
| s2_img=None, |
| s1_img=None, |
| anno_img=None, |
| anno_mask=None, |
| s2_ct=None, |
| s2_ct2=None, |
| modality_flag_hr=None, |
| modality_flag_s2=None, |
| modality_flag_s1=None, |
| return_features=False, |
| ): |
| """Run multi-modal forward pass. |
| |
| Args: |
| hr_img (Tensor, optional): High-resolution image ``(B, C, H, W)``. |
| s2_img (Tensor, optional): Sentinel-2 image ``(B, C, S, H, W)``. |
| s1_img (Tensor, optional): Sentinel-1 SAR image ``(B, C, S, H, W)``. |
| anno_img (Tensor, optional): Annotation image ``(B, H, W)``. |
| anno_mask (Tensor, optional): Annotation mask ``(B, H_m, W_m)``. |
| s2_ct (Tensor, optional): Calendar-time index for S2 pass 1. |
| s2_ct2 (Tensor, optional): Calendar-time index for S2 pass 2. |
| modality_flag_hr (Tensor, optional): Per-sample HR availability flag. |
| modality_flag_s2 (Tensor, optional): Per-sample S2 availability flag. |
| modality_flag_s1 (Tensor, optional): Per-sample S1 availability flag. |
| return_features (bool): If True, include backbone and fusion |
| representations in output (for representation extraction). |
| Default False. |
| |
| Returns: |
| dict: ``logits_hr`` (when HR present), and optionally |
| ``features_hr``, ``features_s2``, ``features_s1``, ``features_fusion`` |
| when return_features=True. |
| """ |
| output = {} |
|
|
| |
| B = None |
| for img in (hr_img, s2_img, s1_img): |
| if img is not None: |
| B = img.shape[0] |
| _ref_device = img.device |
| break |
| if B is None: |
| return output |
|
|
| |
| if modality_flag_hr is None: |
| modality_flag_hr = torch.ones(B, dtype=torch.bool, device=_ref_device) if hr_img is not None else torch.zeros(B, dtype=torch.bool, device=_ref_device) |
| if modality_flag_s2 is None: |
| modality_flag_s2 = torch.ones(B, dtype=torch.bool, device=_ref_device) if s2_img is not None else torch.zeros(B, dtype=torch.bool, device=_ref_device) |
| if modality_flag_s1 is None: |
| modality_flag_s1 = torch.ones(B, dtype=torch.bool, device=_ref_device) if s1_img is not None else torch.zeros(B, dtype=torch.bool, device=_ref_device) |
|
|
| modalities = torch.stack( |
| [modality_flag_hr, modality_flag_s2, modality_flag_s1], dim=-1 |
| ) |
|
|
| |
| hr_features = None |
| s2_features = None |
| s1_features = None |
|
|
| if "hr" in self.sources and hr_img is not None: |
| if anno_img is not None and anno_mask is not None: |
| B_M, H_M, W_M = anno_mask.shape |
| _, _, H_img, W_img = hr_img.shape |
| |
| block_h = H_img // H_M |
| block_w = W_img // W_M |
| anno_mask_hr = ( |
| anno_mask.unsqueeze(-1) |
| .unsqueeze(-1) |
| .repeat(1, 1, 1, block_h, block_w) |
| ) |
| anno_mask_hr = ( |
| anno_mask_hr.permute(0, 1, 3, 2, 4) |
| .reshape(B_M, H_M * block_h, W_M * block_w) |
| .contiguous() |
| ) |
| hr_features = self.backbone_hr(hr_img, anno_img, anno_mask_hr) |
| else: |
| |
| B_hr, _, H, W = hr_img.shape |
| patch_size = self.config.hr_patch_size |
| dummy_anno = torch.zeros(B_hr, H, W, dtype=torch.long, device=hr_img.device) |
| dummy_mask = torch.zeros(B_hr, H // patch_size, W // patch_size, dtype=torch.bool, device=hr_img.device) |
| hr_features = self.backbone_hr(hr_img, dummy_anno, dummy_mask) |
|
|
| if "s2" in self.sources and s2_img is not None: |
| B, C_S2, S_S2, H_S2, W_S2 = s2_img.shape |
| s2_flat = s2_img.permute(0, 2, 1, 3, 4).reshape( |
| B * S_S2, C_S2, H_S2, W_S2 |
| ).contiguous() |
| if anno_img is not None and anno_mask is not None: |
| |
| step = max(1, anno_img.shape[1] // H_S2) |
| offset = step // 2 |
| anno_s2 = anno_img[:, offset::step, offset::step][:, :H_S2, :W_S2] |
| s2_features = self.backbone_s2(s2_flat, anno_s2, anno_mask) |
| else: |
| patch_s2 = self.config.s2_patch_size |
| h_p, w_p = H_S2 // patch_s2, W_S2 // patch_s2 |
| dummy_anno = torch.zeros(B * S_S2, H_S2, W_S2, dtype=torch.long, device=s2_img.device) |
| dummy_mask = torch.zeros(B * S_S2, h_p, w_p, dtype=torch.bool, device=s2_img.device) |
| s2_features = self.backbone_s2(s2_flat, dummy_anno, dummy_mask) |
| s2_features = [self.head_s2(s2_features[-1])] |
|
|
| if "s1" in self.sources and s1_img is not None: |
| B, C_S1, S_S1, H_S1, W_S1 = s1_img.shape |
| s1_flat = s1_img.permute(0, 2, 1, 3, 4).reshape( |
| B * S_S1, C_S1, H_S1, W_S1 |
| ).contiguous() |
| if anno_img is not None and anno_mask is not None: |
| step = max(1, anno_img.shape[1] // H_S1) |
| offset = step // 2 |
| anno_s1 = anno_img[:, offset::step, offset::step][:, :H_S1, :W_S1] |
| s1_features = self.backbone_s1(s1_flat, anno_s1, anno_mask) |
| else: |
| patch_s1 = self.config.s1_patch_size |
| h_p, w_p = H_S1 // patch_s1, W_S1 // patch_s1 |
| dummy_anno = torch.zeros(B * S_S1, H_S1, W_S1, dtype=torch.long, device=s1_img.device) |
| dummy_mask = torch.zeros(B * S_S1, h_p, w_p, dtype=torch.bool, device=s1_img.device) |
| s1_features = self.backbone_s1(s1_flat, dummy_anno, dummy_mask) |
| s1_features = [self.head_s1(s1_features[-1])] |
|
|
| |
| hr_stage3 = hr_features[-1] if hr_features is not None else None |
| s2_stage3 = s2_features[-1] if s2_features is not None else None |
| s1_stage3 = s1_features[-1] if s1_features is not None else None |
|
|
| if ( |
| self.config.use_modal_vae |
| and hr_stage3 is not None |
| and s2_stage3 is not None |
| and s1_stage3 is not None |
| ): |
| modalities_dev = modalities.to(hr_stage3.device) |
| vae_out = self.modality_vae( |
| hr_stage3, s2_stage3, s1_stage3, modalities_dev |
| ) |
| hr_stage3 = vae_out["hr_out"] |
| s2_stage3 = vae_out["s2_out"] |
| s1_stage3 = vae_out["s1_out"] |
| output["vae_out"] = vae_out |
|
|
| |
| |
| |
| feature_parts = [] |
| H3, W3 = None, None |
|
|
| if hr_stage3 is not None: |
| B, C3, H3, W3 = hr_stage3.shape |
| hr_tok = hr_stage3.permute(0, 2, 3, 1).reshape( |
| B * H3 * W3, C3 |
| ).unsqueeze(1).contiguous() |
| feature_parts.append(hr_tok) |
|
|
| if s2_stage3 is not None: |
| _, C3_S2, H3_S2, W3_S2 = s2_stage3.shape |
| if H3 is None: |
| H3, W3 = H3_S2, W3_S2 |
| s2_tok = ( |
| s2_stage3.reshape(B, S_S2, C3_S2, H3_S2, W3_S2) |
| .permute(0, 3, 4, 1, 2) |
| .reshape(B * H3_S2 * W3_S2, S_S2, C3_S2) |
| .contiguous() |
| ) |
| feature_parts.append(s2_tok) |
|
|
| if s1_stage3 is not None: |
| _, C3_S1, H3_S1, W3_S1 = s1_stage3.shape |
| if H3 is None: |
| H3, W3 = H3_S1, W3_S1 |
| s1_tok = ( |
| s1_stage3.reshape(B, S_S1, C3_S1, H3_S1, W3_S1) |
| .permute(0, 3, 4, 1, 2) |
| .reshape(B * H3_S1 * W3_S1, S_S1, C3_S1) |
| .contiguous() |
| ) |
| feature_parts.append(s1_tok) |
|
|
| features_stage3 = torch.cat(feature_parts, dim=1) |
|
|
| if self.config.fusion_output_cls_token: |
| cls_token = self.fusion(features_stage3) |
| _, C3_cls = cls_token.shape |
| cls_token = ( |
| cls_token.reshape(B, H3, W3, C3_cls) |
| .contiguous() |
| .permute(0, 3, 1, 2) |
| .contiguous() |
| ) |
| else: |
| features_stage3 = self.fusion(features_stage3) |
|
|
| |
| if return_features: |
| output["features_hr"] = hr_features |
| output["features_s2"] = s2_features[0] if s2_features else None |
| output["features_s1"] = s1_features[0] if s1_features else None |
| output["features_fusion"] = cls_token if self.config.fusion_output_cls_token else features_stage3 |
|
|
| |
| if hr_features is not None: |
| hr_rec_inputs = list(hr_features) |
| feat_stage1 = hr_rec_inputs[0] |
| if feat_stage1.shape[-1] == feat_stage1.shape[-2]: |
| left, right = torch.split( |
| feat_stage1, feat_stage1.shape[-1] // 2, dim=-1 |
| ) |
| hr_rec_inputs[0] = torch.cat([left, right], dim=1) |
|
|
| rec_feats = [*hr_rec_inputs, cls_token] |
| logits_hr = self.head_rec_hr(rec_feats) |
| logits_hr = logits_hr.to(torch.float32) |
| logits_hr = F.interpolate( |
| logits_hr, scale_factor=4, mode="bilinear", align_corners=True |
| ) |
| output["logits_hr"] = logits_hr |
|
|
| return output |
|
|
| |
| |
| |
|
|
| def load_vae( |
| self, |
| pretrained_model_name_or_path=None, |
| subfolder="modality_vae", |
| **kwargs, |
| ): |
| """Load modality VAE from a pretrained repo (diffusers-style). |
| |
| Uses ModalityCompletionVAE.from_pretrained(..., subfolder="modality_vae"). |
| Layout: {path}/modality_vae/diffusion_pytorch_model.safetensors |
| Fallback: {path}/modality_vae.safetensors (legacy) |
| |
| Args: |
| pretrained_model_name_or_path: Model path. If None, uses config path. |
| subfolder: Subfolder name, default "modality_vae". |
| **kwargs: Passed to from_pretrained. |
| |
| Returns: |
| self (for chaining). |
| """ |
| if not getattr(self.config, "use_modal_vae", False) or not hasattr( |
| self, "modality_vae" |
| ): |
| raise ValueError( |
| "Model has no modality_vae (use_modal_vae=False or single-modality variant)" |
| ) |
|
|
| path = pretrained_model_name_or_path |
| if path is None and hasattr(self.config, "config_file") and self.config.config_file: |
| path = os.path.dirname(self.config.config_file) |
| if path is None: |
| raise ValueError( |
| "pretrained_model_name_or_path required when config has no config_file" |
| ) |
|
|
| loaded = ModalityCompletionVAE.from_pretrained( |
| path, subfolder=subfolder, **kwargs |
| ) |
| self.modality_vae.load_state_dict(loaded.state_dict(), strict=False) |
| return self |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| model = super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
| if getattr(model.config, "use_modal_vae", False) and hasattr( |
| model, "modality_vae" |
| ): |
| subfolder = getattr(model.config, "vae_subfolder", "modality_vae") |
| try: |
| model.load_vae( |
| pretrained_model_name_or_path=pretrained_model_name_or_path, |
| subfolder=subfolder, |
| ) |
| except FileNotFoundError: |
| pass |
| return model |
|
|
| |
| |
| |
|
|
| @classmethod |
| def from_original_checkpoint(cls, config, checkpoint_path): |
| """Load an original SkySensePP checkpoint into this HuggingFace model. |
| |
| Args: |
| config (:class:`SkySensePPConfig`): Model configuration. |
| checkpoint_path (str): Path to the original ``.pth`` checkpoint. |
| |
| Returns: |
| :class:`SkySensePPModel`: Model with loaded weights. |
| """ |
| model = cls(config) |
| ckpt = torch.load(checkpoint_path, map_location="cpu") |
| if "model" in ckpt: |
| state_dict = ckpt["model"] |
| else: |
| state_dict = ckpt |
|
|
| |
| cleaned = {} |
| for k, v in state_dict.items(): |
| new_key = k[len("model.") :] if k.startswith("model.") else k |
| cleaned[new_key] = v |
|
|
| missing, unexpected = model.load_state_dict(cleaned, strict=False) |
| if missing: |
| print(f"Missing keys ({len(missing)}): {missing[:10]}...") |
| if unexpected: |
| print(f"Unexpected keys ({len(unexpected)}): {unexpected[:10]}...") |
| return model |
|
|