BiliSakura commited on
Commit
814c03f
·
verified ·
1 Parent(s): 61b666e

Update all files for SkySensepp

Browse files
Files changed (1) hide show
  1. s2/pipeline_skysensepp.py +100 -0
s2/pipeline_skysensepp.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HuggingFace Pipeline for SkySense++ (diffusers-style)."""
2
+
3
+ import numpy as np
4
+ import torch
5
+ from transformers import AutoModel, Pipeline
6
+
7
+
8
+ class SkySensePPPipeline(Pipeline):
9
+ """Pipeline for representation extraction (primary) and optional segmentation.
10
+
11
+ **Diffusers-style loading:** ::
12
+ pipe = SkySensePPPipeline.from_pretrained("path/to/SkySensepp")
13
+ result = pipe({"hr_img": hr_array}, extract=True)
14
+
15
+ **Primary use: representation extraction.** Get backbone and fusion features
16
+ for downstream tasks. Segmentation output (``extract=False``) requires fine-tuned head.
17
+ """
18
+
19
+ model_cpu_offload_seq = "model"
20
+
21
+ @classmethod
22
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
23
+ """Load pipeline (diffusers-style). Loads model + VAE from modality_vae/ subfolder."""
24
+ model = AutoModel.from_pretrained(
25
+ pretrained_model_name_or_path,
26
+ trust_remote_code=True,
27
+ **kwargs,
28
+ )
29
+ return cls(model=model)
30
+
31
+ def _sanitize_parameters(self, extract=None, **kwargs):
32
+ preprocess_kwargs = {}
33
+ forward_kwargs = {"return_features": extract if extract is not None else True}
34
+ postprocess_kwargs = {"extract": extract if extract is not None else True}
35
+ if "sources" in kwargs:
36
+ preprocess_kwargs["sources"] = kwargs["sources"]
37
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
38
+
39
+ # ------------------------------------------------------------------
40
+ # Pre-process
41
+ # ------------------------------------------------------------------
42
+
43
+ def preprocess(self, inputs, sources=None):
44
+ """Convert raw inputs into model-ready tensors.
45
+
46
+ Args:
47
+ inputs: A dict with optional keys ``hr_img``, ``s2_img``,
48
+ ``s1_img`` (numpy arrays or tensors).
49
+ sources: Optional list restricting which modalities to forward.
50
+
51
+ Returns:
52
+ dict of tensors placed on the model device.
53
+ """
54
+ if not isinstance(inputs, dict):
55
+ raise ValueError(
56
+ "SkySensePPPipeline expects a dict with image tensors, "
57
+ f"got {type(inputs)}"
58
+ )
59
+
60
+ active = sources or list(inputs.keys())
61
+ active_modalities = {s.replace("_img", "") for s in active}
62
+ model_inputs = {}
63
+ for key in ("hr_img", "s2_img", "s1_img"):
64
+ if key in inputs and key.replace("_img", "") in active_modalities:
65
+ tensor = inputs[key]
66
+ if isinstance(tensor, np.ndarray):
67
+ tensor = torch.from_numpy(tensor).float()
68
+ if tensor.dim() == 3:
69
+ tensor = tensor.unsqueeze(0)
70
+ model_inputs[key] = tensor
71
+
72
+ return model_inputs
73
+
74
+ # ------------------------------------------------------------------
75
+ # Forward
76
+ # ------------------------------------------------------------------
77
+
78
+ def _forward(self, model_inputs, return_features=True):
79
+ """Run the model forward pass."""
80
+ with torch.no_grad():
81
+ outputs = self.model(**model_inputs, return_features=return_features)
82
+ return outputs
83
+
84
+ # ------------------------------------------------------------------
85
+ # Post-process
86
+ # ------------------------------------------------------------------
87
+
88
+ def postprocess(self, model_outputs, extract=True):
89
+ """Return representations (extract=True) or segmentation map (extract=False)."""
90
+ if extract:
91
+ out = {}
92
+ for k in ("features_hr", "features_s2", "features_s1", "features_fusion"):
93
+ if k in model_outputs and model_outputs[k] is not None:
94
+ v = model_outputs[k]
95
+ out[k] = v.cpu() if isinstance(v, torch.Tensor) else v
96
+ return out
97
+ logits = model_outputs.get("logits_hr")
98
+ if logits is None:
99
+ return {"segmentation_map": None}
100
+ return {"segmentation_map": logits.argmax(dim=1).cpu().numpy()}