SkySensepp / pipeline_skysensepp.py
BiliSakura's picture
Update all files for SkySensepp
6c2a548 verified
"""HuggingFace Pipeline for SkySense++ representation extraction and segmentation.
Diffusers-style API: SkySensePPPipeline.from_pretrained(model_id) loads model + VAE.
"""
import os
import numpy as np
import torch
from transformers import AutoModel, Pipeline
class SkySensePPPipeline(Pipeline):
"""Pipeline for representation extraction and optional segmentation (diffusers-style).
**Primary use: representation extraction.** Extract backbone and fusion features
for downstream tasks. The segmentation output is optional (head not pretrained).
Diffusers-style loading::
pipe = SkySensePPPipeline.from_pretrained("path/to/SkySensepp")
result = pipe(hr_img=hr_array, extract=True)
Or via transformers pipeline()::
pipe = pipeline(..., model="path/to/SkySensepp", pipeline_class=SkySensePPPipeline, trust_remote_code=True)
result = pipe({"hr_img": hr_array}, extract=True)
"""
model_cpu_offload_seq = "model"
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
"""Load pipeline from a pretrained model (diffusers-style).
Loads model (with VAE from subfolder modality_vae/ if present).
"""
model = AutoModel.from_pretrained(
pretrained_model_name_or_path,
trust_remote_code=True,
**kwargs,
)
return cls(model=model)
def _sanitize_parameters(self, extract=None, **kwargs):
preprocess_kwargs = {}
forward_kwargs = {"return_features": extract if extract is not None else True}
postprocess_kwargs = {"extract": extract if extract is not None else True}
if "sources" in kwargs:
preprocess_kwargs["sources"] = kwargs["sources"]
return preprocess_kwargs, forward_kwargs, postprocess_kwargs
# ------------------------------------------------------------------
# Pre-process
# ------------------------------------------------------------------
def preprocess(self, inputs, sources=None):
"""Convert raw inputs into model-ready tensors.
Args:
inputs: A dict with optional keys ``hr_img``, ``s2_img``,
``s1_img`` (numpy arrays or tensors).
sources: Optional list restricting which modalities to forward.
Returns:
dict of tensors placed on the model device.
"""
if not isinstance(inputs, dict):
raise ValueError(
"SkySensePPPipeline expects a dict with image tensors, "
f"got {type(inputs)}"
)
active = sources or list(inputs.keys())
active_modalities = {s.replace("_img", "") for s in active}
model_inputs = {}
for key in ("hr_img", "s2_img", "s1_img"):
if key in inputs and key.replace("_img", "") in active_modalities:
tensor = inputs[key]
if isinstance(tensor, np.ndarray):
tensor = torch.from_numpy(tensor).float()
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)
model_inputs[key] = tensor
return model_inputs
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def _forward(self, model_inputs, return_features=True):
"""Run the model forward pass."""
with torch.no_grad():
outputs = self.model(**model_inputs, return_features=return_features)
return outputs
# ------------------------------------------------------------------
# Post-process
# ------------------------------------------------------------------
def postprocess(self, model_outputs, extract=True):
"""Return representations or segmentation map."""
if extract:
# Representation extraction: return features (tensors)
out = {}
for k in ("features_hr", "features_s2", "features_s1", "features_fusion"):
if k in model_outputs and model_outputs[k] is not None:
v = model_outputs[k]
out[k] = v.cpu() if isinstance(v, torch.Tensor) else v
return out
# Segmentation (head not pretrained)
logits = model_outputs.get("logits_hr")
if logits is None:
return {"segmentation_map": None}
return {"segmentation_map": logits.argmax(dim=1).cpu().numpy()}