shenxiaochen commited on
Commit
8360541
·
verified ·
1 Parent(s): b295d47

Add files using upload-large-folder tool

Browse files
README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ pipeline_tag: feature-extraction
4
+ base_model: google/medsiglip-448
5
+ tags:
6
+ - medical-imaging
7
+ - mri
8
+ - brain-mri
9
+ - siglip
10
+ - vision-language
11
+ - contrastive-learning
12
+ - feature-extraction
13
+ - custom-code
14
+ - pytorch
15
+ ---
16
+
17
+ # Brain MRI SigLIP
18
+
19
+ Brain MRI SigLIP is a 3D MRI vision-language representation model trained with a SigLIP-style image-text contrastive objective. This repository publishes the final saved `stage2_joint_finetune` checkpoint from the `brain_mri_siglip_run_0509` experiment.
20
+
21
+ This checkpoint is intended as a research visual encoder for brain MRI downstream tasks and as a warm-start encoder for building a medical VLM. It is not a clinical diagnostic device.
22
+
23
+ ## Model Summary
24
+
25
+ - Base text tower: `google/medsiglip-448`
26
+ - Model class: `BrainMRISiglipModel`
27
+ - Vision input: single-channel 3D MRI volumes
28
+ - Expected volume shape: `[1, 128, 192, 192]`
29
+ - Projection dimension: `1152`
30
+ - Patch size: `[8, 16, 16]`
31
+ - Training precision: `bf16`
32
+ - Training input format: preprocessed `.pt` tensors, `float16`, value range `[-1, 1]`
33
+
34
+ ## Training Context
35
+
36
+ This model was initialized from the `brain_mri_siglip_run_0509/stage1_freeze_text` checkpoint and then jointly fine-tuned with both vision and text towers trainable.
37
+
38
+ Training summary:
39
+
40
+ - Training samples: `950,720`
41
+ - Validation samples: `67,450`
42
+ - Validation samples with `metadata_text`: `32,278`
43
+ - Stage 1: frozen text tower, vision-heavy training
44
+ - Stage 2: joint vision-text fine-tuning
45
+ - Stage 2 epochs configured: `8`
46
+ - World size: `5`
47
+ - Stage 2 per-device batch size: `160`
48
+ - Stage 2 contrastive forward batch: `800`
49
+ - Gradient checkpointing: text and vision enabled
50
+
51
+ Training-time retrieval evaluation used capped validation subsets and should be treated as monitoring rather than a final benchmark.
52
+
53
+ ## Loading
54
+
55
+ This model uses custom Transformers code. Load it with `trust_remote_code=True`.
56
+
57
+ ```python
58
+ import torch
59
+ from transformers import AutoModel, AutoProcessor
60
+
61
+ repo_id = "shenxiaochen/brain-mri-siglip"
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+
64
+ model = AutoModel.from_pretrained(
65
+ repo_id,
66
+ trust_remote_code=True,
67
+ dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
68
+ ).to(device).eval()
69
+
70
+ processor = AutoProcessor.from_pretrained(
71
+ repo_id,
72
+ trust_remote_code=True,
73
+ )
74
+ ```
75
+
76
+ ## NIfTI Preprocessing
77
+
78
+ For reproducible inference from NIfTI files, pass paths directly to the saved processor. This repository includes the offline-aligned preprocessing implementation used to match the training tensor distribution.
79
+
80
+ ```python
81
+ nifti_path = "/path/to/brain_mri.nii.gz"
82
+
83
+ inputs = processor(
84
+ volumes=nifti_path,
85
+ return_tensors="pt",
86
+ )
87
+ pixel_values = inputs["pixel_values"].to(device)
88
+
89
+ if torch.cuda.is_available():
90
+ pixel_values = pixel_values.to(dtype=torch.bfloat16)
91
+
92
+ with torch.inference_mode():
93
+ image_embeds = model.get_image_features(pixel_values=pixel_values)
94
+
95
+ print(pixel_values.shape) # [1, 1, 128, 192, 192]
96
+ print(image_embeds.shape) # [1, 1152]
97
+ ```
98
+
99
+ The saved path-based preprocessing recipe is:
100
+
101
+ - canonicalize image orientation to closest RAS
102
+ - build foreground mask with threshold `1e-3`
103
+ - keep the largest connected foreground component
104
+ - crop foreground with `5mm` margin
105
+ - normalize foreground intensities with `0.5/99.5` percentiles
106
+ - map intensities to `[-1, 1]`
107
+ - resample to spacing `(1.25, 1.0, 1.0)`
108
+ - downscale to fit `[128, 192, 192]`
109
+ - center-pad with background value `-1.0`
110
+
111
+ The exact settings are saved in `preprocessor_config.json` and `processor_config.json`.
112
+
113
+ ## Using Preprocessed `.pt` Inputs
114
+
115
+ If your data is already stored as the same offline preprocessed tensors used during training, you can load it directly:
116
+
117
+ ```python
118
+ payload = torch.load("/path/to/sample.pt", map_location="cpu")
119
+ pixel_values = payload["pixel_values"] if isinstance(payload, dict) else payload
120
+
121
+ if pixel_values.ndim == 4:
122
+ pixel_values = pixel_values.unsqueeze(0)
123
+
124
+ pixel_values = pixel_values.to(device=device, dtype=torch.bfloat16)
125
+
126
+ with torch.inference_mode():
127
+ image_embeds = model.get_image_features(pixel_values=pixel_values)
128
+ ```
129
+
130
+ Expected tensor format:
131
+
132
+ - shape `[1, 128, 192, 192]` for one volume, or `[B, 1, 128, 192, 192]` for a batch
133
+ - values in `[-1, 1]`
134
+ - padded background voxels near `-1.0`
135
+
136
+ ## VLM Integration Notes
137
+
138
+ For VLM construction, use the 3D vision tower as a visual backbone and add a projector, Q-Former, Perceiver resampler, or other token compressor before connecting to an LLM.
139
+
140
+ A practical downstream recipe is:
141
+
142
+ 1. Freeze this MRI encoder and train only the multimodal projector/resampler.
143
+ 2. Evaluate downstream classification, retrieval, report alignment, or instruction-following behavior.
144
+ 3. Optionally unfreeze the top vision layers with a much smaller learning rate.
145
+
146
+ ## Limitations
147
+
148
+ - This checkpoint was trained for representation learning, not diagnosis.
149
+ - Performance should be validated on task-specific subject-level or study-level splits.
150
+ - Scanner, protocol, site, and preprocessing differences can affect embeddings.
151
+ - External users should preserve the saved preprocessing pipeline for NIfTI inference.
152
+ - Retrieval monitoring during training is not a substitute for downstream clinical validation.
153
+
154
+ ## Citation
155
+
156
+ If you use this checkpoint, please cite this model repository and the upstream MedSigLIP model where appropriate.
__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration_brain_mri_siglip import BrainMRISiglipConfig
2
+ from .modeling_brain_mri_siglip import BrainMRISiglipModel
3
+ from .processing_brain_mri_siglip import BrainMRISiglipProcessor, BrainMRISiglipVolumeProcessor
4
+
5
+ __all__ = [
6
+ "BrainMRISiglipConfig",
7
+ "BrainMRISiglipModel",
8
+ "BrainMRISiglipProcessor",
9
+ "BrainMRISiglipVolumeProcessor",
10
+ ]
11
+
12
+ try:
13
+ BrainMRISiglipConfig.register_for_auto_class("AutoConfig")
14
+ BrainMRISiglipModel.register_for_auto_class("AutoModel")
15
+ BrainMRISiglipProcessor.register_for_auto_class("AutoProcessor")
16
+ except Exception:
17
+ # Registration is best-effort and not required for local imports.
18
+ pass
19
+
common.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Common utility helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import Iterable
8
+ from typing import Sequence, Tuple, Union
9
+
10
+
11
+ def to_3tuple(value: Union[int, Sequence[int]], name: str) -> Tuple[int, int, int]:
12
+ if isinstance(value, int):
13
+ return (value, value, value)
14
+ if len(value) != 3:
15
+ raise ValueError(f"`{name}` must be an int or length-3 sequence. Got: {value}")
16
+ return (int(value[0]), int(value[1]), int(value[2]))
17
+
18
+
19
+ REMOTE_CODE_FILES = (
20
+ "__init__.py",
21
+ "common.py",
22
+ "configuration_brain_mri_siglip.py",
23
+ "modeling_brain_mri_siglip.py",
24
+ "offline_aligned_preprocessing.py",
25
+ "processing_brain_mri_siglip.py",
26
+ )
27
+
28
+
29
+ def copy_remote_code_files(destination: Union[str, Path], file_names: Iterable[str] = REMOTE_CODE_FILES) -> None:
30
+ src_dir = Path(__file__).resolve().parent
31
+ dst_dir = Path(destination)
32
+ dst_dir.mkdir(parents=True, exist_ok=True)
33
+ for name in file_names:
34
+ src_file = src_dir / name
35
+ if src_file.exists():
36
+ shutil.copy2(src_file, dst_dir / name)
config.json ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BrainMRISiglipModel"
4
+ ],
5
+ "attn_implementation": null,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_brain_mri_siglip.BrainMRISiglipConfig",
8
+ "AutoModel": "modeling_brain_mri_siglip.BrainMRISiglipModel",
9
+ "AutoProcessor": "processing_brain_mri_siglip.BrainMRISiglipProcessor"
10
+ },
11
+ "dtype": "float32",
12
+ "initializer_range": 0.02,
13
+ "logit_bias_init_value": -10.0,
14
+ "logit_scale_init_value": 2.6592,
15
+ "logit_scale_max": 100.0,
16
+ "logit_scale_min": 0.001,
17
+ "max_text_length": 64,
18
+ "model_type": "brain-mri-siglip",
19
+ "num_channels": 1,
20
+ "patch_size": [
21
+ 8,
22
+ 16,
23
+ 16
24
+ ],
25
+ "projection_dim": 1152,
26
+ "text_config": {
27
+ "_name_or_path": "",
28
+ "add_cross_attention": false,
29
+ "architectures": null,
30
+ "attention_dropout": 0.0,
31
+ "bad_words_ids": null,
32
+ "begin_suppress_tokens": null,
33
+ "bos_token_id": 49406,
34
+ "chunk_size_feed_forward": 0,
35
+ "cross_attention_hidden_size": null,
36
+ "decoder_start_token_id": null,
37
+ "diversity_penalty": 0.0,
38
+ "do_sample": false,
39
+ "dtype": null,
40
+ "early_stopping": false,
41
+ "encoder_no_repeat_ngram_size": 0,
42
+ "eos_token_id": 49407,
43
+ "exponential_decay_length_penalty": null,
44
+ "finetuning_task": null,
45
+ "forced_bos_token_id": null,
46
+ "forced_eos_token_id": null,
47
+ "hidden_act": "gelu_pytorch_tanh",
48
+ "hidden_size": 1152,
49
+ "id2label": {
50
+ "0": "LABEL_0",
51
+ "1": "LABEL_1"
52
+ },
53
+ "intermediate_size": 4304,
54
+ "is_decoder": false,
55
+ "is_encoder_decoder": false,
56
+ "label2id": {
57
+ "LABEL_0": 0,
58
+ "LABEL_1": 1
59
+ },
60
+ "layer_norm_eps": 1e-06,
61
+ "length_penalty": 1.0,
62
+ "max_length": 20,
63
+ "max_position_embeddings": 64,
64
+ "min_length": 0,
65
+ "model_type": "siglip_text_model",
66
+ "no_repeat_ngram_size": 0,
67
+ "num_attention_heads": 16,
68
+ "num_beam_groups": 1,
69
+ "num_beams": 1,
70
+ "num_hidden_layers": 27,
71
+ "num_return_sequences": 1,
72
+ "output_attentions": false,
73
+ "output_hidden_states": false,
74
+ "output_scores": false,
75
+ "pad_token_id": 1,
76
+ "prefix": null,
77
+ "problem_type": null,
78
+ "projection_size": 1152,
79
+ "pruned_heads": {},
80
+ "remove_invalid_values": false,
81
+ "repetition_penalty": 1.0,
82
+ "return_dict": true,
83
+ "return_dict_in_generate": false,
84
+ "sep_token_id": null,
85
+ "suppress_tokens": null,
86
+ "task_specific_params": null,
87
+ "temperature": 1.0,
88
+ "tf_legacy_loss": false,
89
+ "tie_encoder_decoder": false,
90
+ "tie_word_embeddings": true,
91
+ "tokenizer_class": null,
92
+ "top_k": 50,
93
+ "top_p": 1.0,
94
+ "torchscript": false,
95
+ "transformers_version": "4.57.6",
96
+ "typical_p": 1.0,
97
+ "use_bfloat16": false,
98
+ "vocab_size": 32000
99
+ },
100
+ "text_model_name_or_path": "google/medsiglip-448",
101
+ "transformers_version": "4.57.6",
102
+ "vision_config": {
103
+ "_name_or_path": "",
104
+ "add_cross_attention": false,
105
+ "architectures": null,
106
+ "attention_dropout": 0.0,
107
+ "bad_words_ids": null,
108
+ "begin_suppress_tokens": null,
109
+ "bos_token_id": null,
110
+ "chunk_size_feed_forward": 0,
111
+ "cross_attention_hidden_size": null,
112
+ "decoder_start_token_id": null,
113
+ "diversity_penalty": 0.0,
114
+ "do_sample": false,
115
+ "dtype": null,
116
+ "early_stopping": false,
117
+ "encoder_no_repeat_ngram_size": 0,
118
+ "eos_token_id": null,
119
+ "exponential_decay_length_penalty": null,
120
+ "finetuning_task": null,
121
+ "forced_bos_token_id": null,
122
+ "forced_eos_token_id": null,
123
+ "hidden_act": "gelu_pytorch_tanh",
124
+ "hidden_size": 1152,
125
+ "id2label": {
126
+ "0": "LABEL_0",
127
+ "1": "LABEL_1"
128
+ },
129
+ "image_size": 448,
130
+ "intermediate_size": 4304,
131
+ "is_decoder": false,
132
+ "is_encoder_decoder": false,
133
+ "label2id": {
134
+ "LABEL_0": 0,
135
+ "LABEL_1": 1
136
+ },
137
+ "layer_norm_eps": 1e-06,
138
+ "length_penalty": 1.0,
139
+ "max_length": 20,
140
+ "min_length": 0,
141
+ "model_type": "siglip_vision_model",
142
+ "no_repeat_ngram_size": 0,
143
+ "num_attention_heads": 16,
144
+ "num_beam_groups": 1,
145
+ "num_beams": 1,
146
+ "num_channels": 1,
147
+ "num_hidden_layers": 27,
148
+ "num_return_sequences": 1,
149
+ "output_attentions": false,
150
+ "output_hidden_states": false,
151
+ "output_scores": false,
152
+ "pad_token_id": null,
153
+ "patch_size": 14,
154
+ "prefix": null,
155
+ "problem_type": null,
156
+ "pruned_heads": {},
157
+ "remove_invalid_values": false,
158
+ "repetition_penalty": 1.0,
159
+ "return_dict": true,
160
+ "return_dict_in_generate": false,
161
+ "sep_token_id": null,
162
+ "suppress_tokens": null,
163
+ "task_specific_params": null,
164
+ "temperature": 1.0,
165
+ "tf_legacy_loss": false,
166
+ "tie_encoder_decoder": false,
167
+ "tie_word_embeddings": true,
168
+ "tokenizer_class": null,
169
+ "top_k": 50,
170
+ "top_p": 1.0,
171
+ "torchscript": false,
172
+ "transformers_version": "4.57.6",
173
+ "typical_p": 1.0,
174
+ "use_bfloat16": false
175
+ },
176
+ "volume_size": [
177
+ 128,
178
+ 192,
179
+ 192
180
+ ]
181
+ }
configuration_brain_mri_siglip.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration for Brain MRI SigLIP."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Dict, Mapping, Optional, Sequence, Union
6
+
7
+ from transformers import PretrainedConfig, SiglipTextConfig, SiglipVisionConfig
8
+
9
+ from .common import to_3tuple
10
+
11
+
12
+ class BrainMRISiglipConfig(PretrainedConfig):
13
+ r"""Configuration class for :class:`BrainMRISiglipModel`."""
14
+
15
+ model_type = "brain-mri-siglip"
16
+
17
+ def __init__(
18
+ self,
19
+ text_config: Optional[Mapping[str, Any]] = None,
20
+ vision_config: Optional[Mapping[str, Any]] = None,
21
+ text_model_name_or_path: str = "google/medsiglip-448",
22
+ volume_size: Union[int, Sequence[int]] = (128, 192, 192),
23
+ patch_size: Union[int, Sequence[int]] = (8, 16, 16),
24
+ num_channels: int = 1,
25
+ projection_dim: Optional[int] = None,
26
+ logit_scale_init_value: float = 2.6592,
27
+ logit_scale_min: float = 1e-3,
28
+ logit_bias_init_value: float = -10.0,
29
+ logit_scale_max: float = 100.0,
30
+ attn_implementation: Optional[str] = None,
31
+ max_text_length: int = 64,
32
+ initializer_range: float = 0.02,
33
+ auto_map: Optional[Mapping[str, str]] = None,
34
+ **kwargs: Any,
35
+ ) -> None:
36
+ if text_config is None:
37
+ text_config_dict = SiglipTextConfig().to_dict()
38
+ else:
39
+ text_config_dict = dict(text_config)
40
+
41
+ if vision_config is None:
42
+ vision_config_dict = SiglipVisionConfig().to_dict()
43
+ else:
44
+ vision_config_dict = dict(vision_config)
45
+
46
+ resolved_volume_size = to_3tuple(volume_size, "volume_size")
47
+ resolved_patch_size = to_3tuple(patch_size, "patch_size")
48
+ if any(v <= 0 for v in resolved_volume_size):
49
+ raise ValueError(f"`volume_size` must contain positive integers. Got {resolved_volume_size}.")
50
+ if any(p <= 0 for p in resolved_patch_size):
51
+ raise ValueError(f"`patch_size` must contain positive integers. Got {resolved_patch_size}.")
52
+ if any(v % p != 0 for v, p in zip(resolved_volume_size, resolved_patch_size)):
53
+ raise ValueError(
54
+ f"`volume_size` must be divisible by `patch_size`. "
55
+ f"Got volume_size={resolved_volume_size}, patch_size={resolved_patch_size}."
56
+ )
57
+
58
+ vision_config_dict["num_channels"] = int(num_channels)
59
+
60
+ if projection_dim is None:
61
+ projection_dim = int(
62
+ text_config_dict.get(
63
+ "projection_size",
64
+ text_config_dict.get("hidden_size", vision_config_dict.get("hidden_size", 768)),
65
+ )
66
+ )
67
+
68
+ if auto_map is None:
69
+ # Keep module paths as `<module>.<Class>` for compatibility with HF dynamic loader.
70
+ auto_map = {
71
+ "AutoConfig": "configuration_brain_mri_siglip.BrainMRISiglipConfig",
72
+ "AutoModel": "modeling_brain_mri_siglip.BrainMRISiglipModel",
73
+ "AutoProcessor": "processing_brain_mri_siglip.BrainMRISiglipProcessor",
74
+ }
75
+
76
+ self.text_config = text_config_dict
77
+ self.vision_config = vision_config_dict
78
+ self.text_model_name_or_path = text_model_name_or_path
79
+ self.volume_size = list(resolved_volume_size)
80
+ self.patch_size = list(resolved_patch_size)
81
+ self.num_channels = int(num_channels)
82
+ self.projection_dim = int(projection_dim)
83
+ self.logit_scale_init_value = float(logit_scale_init_value)
84
+ self.logit_scale_min = float(logit_scale_min)
85
+ self.logit_bias_init_value = float(logit_bias_init_value)
86
+ self.logit_scale_max = float(logit_scale_max)
87
+ self.attn_implementation = attn_implementation
88
+ self.max_text_length = int(max_text_length)
89
+ self.initializer_range = float(initializer_range)
90
+ self.auto_map = dict(auto_map)
91
+
92
+ super().__init__(**kwargs)
93
+
94
+ def get_text_config(self, *args: Any, **kwargs: Any) -> SiglipTextConfig:
95
+ del args, kwargs
96
+ config = SiglipTextConfig(**self.text_config)
97
+ if self.attn_implementation:
98
+ config._attn_implementation = self.attn_implementation
99
+ elif getattr(config, "_attn_implementation", None) is None:
100
+ config._attn_implementation = "sdpa"
101
+ return config
102
+
103
+ def get_vision_config(self, *args: Any, **kwargs: Any) -> SiglipVisionConfig:
104
+ del args, kwargs
105
+ cfg_dict = dict(self.vision_config)
106
+ cfg_dict["num_channels"] = int(self.num_channels)
107
+ config = SiglipVisionConfig(**cfg_dict)
108
+ if self.attn_implementation:
109
+ config._attn_implementation = self.attn_implementation
110
+ elif getattr(config, "_attn_implementation", None) is None:
111
+ config._attn_implementation = "sdpa"
112
+ return config
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:656edd47b6a98dfa950593e10cb0d8214b30e7eca4a4f725c69129f22aabf055
3
+ size 3536557760
modeling_brain_mri_siglip.py ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modeling code for Brain MRI SigLIP."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from typing import Any, Mapping, Optional, Tuple
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.distributed.nn.functional import all_gather as all_gather_with_grad
13
+ from transformers import AutoConfig, AutoModel, PreTrainedModel
14
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
15
+ from transformers.models.siglip import SiglipTextConfig, SiglipVisionConfig
16
+ from transformers.models.siglip.modeling_siglip import (
17
+ SiglipAttention,
18
+ SiglipEncoder,
19
+ SiglipMLP,
20
+ SiglipMultiheadAttentionPoolingHead,
21
+ SiglipOutput,
22
+ SiglipTextModel,
23
+ default_flax_embed_init,
24
+ )
25
+
26
+ from .configuration_brain_mri_siglip import BrainMRISiglipConfig
27
+
28
+
29
+ def _siglip_sigmoid_loss(logits_per_text: torch.Tensor) -> torch.Tensor:
30
+ eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device, dtype=logits_per_text.dtype)
31
+ labels = -torch.ones_like(logits_per_text) + 2 * eye
32
+ loglik = F.logsigmoid(labels * logits_per_text)
33
+ nll = -torch.sum(loglik, dim=-1)
34
+ return nll.mean()
35
+
36
+
37
+ def _lecun_normal_(tensor: torch.Tensor) -> torch.Tensor:
38
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(tensor)
39
+ if fan_in <= 0:
40
+ return nn.init.normal_(tensor, mean=0.0, std=1.0)
41
+ return nn.init.normal_(tensor, mean=0.0, std=1.0 / math.sqrt(fan_in))
42
+
43
+
44
+ def _siglip_embedding_init_(tensor: torch.Tensor) -> torch.Tensor:
45
+ default_flax_embed_init(tensor)
46
+ return tensor
47
+
48
+
49
+ def _distributed_concat_with_grad(embeddings: torch.Tensor) -> torch.Tensor:
50
+ if not dist.is_available() or not dist.is_initialized():
51
+ return embeddings
52
+ world_size = dist.get_world_size()
53
+ local_batch = embeddings.shape[0]
54
+ local_batch_tensor = torch.tensor([local_batch], dtype=torch.long, device=embeddings.device)
55
+ batch_sizes = [torch.zeros_like(local_batch_tensor) for _ in range(world_size)]
56
+ dist.all_gather(batch_sizes, local_batch_tensor)
57
+ batch_sizes_int = [int(size.item()) for size in batch_sizes]
58
+ max_batch = max(batch_sizes_int)
59
+
60
+ if local_batch < max_batch:
61
+ pad_shape = (max_batch - local_batch, embeddings.shape[1])
62
+ padding = embeddings.new_zeros(pad_shape)
63
+ padded_embeddings = torch.cat([embeddings, padding], dim=0)
64
+ else:
65
+ padded_embeddings = embeddings
66
+
67
+ gathered = all_gather_with_grad(padded_embeddings)
68
+ if isinstance(gathered, torch.Tensor):
69
+ if gathered.ndim == 3 and gathered.shape[0] == world_size:
70
+ chunks = [gathered[rank] for rank in range(world_size)]
71
+ else:
72
+ chunks = list(torch.split(gathered, max_batch, dim=0))
73
+ else:
74
+ chunks = list(gathered)
75
+
76
+ trimmed = [chunk[: batch_sizes_int[rank]] for rank, chunk in enumerate(chunks) if batch_sizes_int[rank] > 0]
77
+ if not trimmed:
78
+ return embeddings.new_zeros((0, embeddings.shape[1]))
79
+ return torch.cat(trimmed, dim=0)
80
+
81
+
82
+ def _load_state_dict_with_flexible_prefix(
83
+ module: nn.Module,
84
+ source_state_dict: Mapping[str, torch.Tensor],
85
+ strict: bool = True,
86
+ ) -> Tuple[Any, Any]:
87
+ target_keys = list(module.state_dict().keys())
88
+ source_keys = list(source_state_dict.keys())
89
+
90
+ if not target_keys or not source_keys:
91
+ return module.load_state_dict(source_state_dict, strict=strict)
92
+
93
+ target_has_text_model_prefix = all(key.startswith("text_model.") for key in target_keys)
94
+ source_has_text_model_prefix = all(key.startswith("text_model.") for key in source_keys)
95
+
96
+ aligned_state_dict = dict(source_state_dict)
97
+ if target_has_text_model_prefix and not source_has_text_model_prefix:
98
+ aligned_state_dict = {f"text_model.{key}": value for key, value in source_state_dict.items()}
99
+ elif source_has_text_model_prefix and not target_has_text_model_prefix:
100
+ aligned_state_dict = {
101
+ key[len("text_model.") :]: value for key, value in source_state_dict.items() if key.startswith("text_model.")
102
+ }
103
+
104
+ return module.load_state_dict(aligned_state_dict, strict=strict)
105
+
106
+
107
+ class SiglipVisionEmbeddings3D(nn.Module):
108
+ """3D patch embeddings for MRI volumes."""
109
+
110
+ def __init__(
111
+ self,
112
+ vision_config: SiglipVisionConfig,
113
+ volume_size: Tuple[int, int, int],
114
+ patch_size: Tuple[int, int, int],
115
+ num_channels: int,
116
+ ) -> None:
117
+ super().__init__()
118
+ self.embed_dim = int(vision_config.hidden_size)
119
+ self.volume_size = tuple(int(v) for v in volume_size)
120
+ self.patch_size = tuple(int(v) for v in patch_size)
121
+
122
+ if any(v % p != 0 for v, p in zip(self.volume_size, self.patch_size)):
123
+ raise ValueError(
124
+ "Volume size must be divisible by patch size for all dimensions. "
125
+ f"Got volume_size={self.volume_size}, patch_size={self.patch_size}."
126
+ )
127
+
128
+ self.patch_embedding = nn.Conv3d(
129
+ in_channels=int(num_channels),
130
+ out_channels=self.embed_dim,
131
+ kernel_size=self.patch_size,
132
+ stride=self.patch_size,
133
+ padding=0,
134
+ )
135
+
136
+ patches_per_dim = tuple(v // p for v, p in zip(self.volume_size, self.patch_size))
137
+ self.grid_size = patches_per_dim
138
+ self.num_patches = int(patches_per_dim[0] * patches_per_dim[1] * patches_per_dim[2])
139
+ self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
140
+ self.register_buffer("position_ids", torch.arange(self.num_patches).expand((1, -1)), persistent=False)
141
+
142
+ def _interpolate_position_embeddings(
143
+ self,
144
+ grid_size: Tuple[int, int, int],
145
+ target_dtype: torch.dtype,
146
+ target_device: torch.device,
147
+ ) -> torch.Tensor:
148
+ base_grid_depth, base_grid_height, base_grid_width = self.grid_size
149
+ position_embeddings = self.position_embedding.weight.reshape(
150
+ base_grid_depth,
151
+ base_grid_height,
152
+ base_grid_width,
153
+ self.embed_dim,
154
+ )
155
+ position_embeddings = position_embeddings.permute(3, 0, 1, 2).unsqueeze(0)
156
+ position_embeddings = F.interpolate(
157
+ position_embeddings,
158
+ size=grid_size,
159
+ mode="trilinear",
160
+ align_corners=False,
161
+ )
162
+ position_embeddings = position_embeddings.squeeze(0).permute(1, 2, 3, 0).reshape(1, -1, self.embed_dim)
163
+ return position_embeddings.to(dtype=target_dtype, device=target_device)
164
+
165
+ def _get_position_embeddings(
166
+ self,
167
+ grid_size: Tuple[int, int, int],
168
+ target_dtype: torch.dtype,
169
+ target_device: torch.device,
170
+ interpolate_pos_encoding: bool,
171
+ ) -> torch.Tensor:
172
+ num_patches = int(grid_size[0] * grid_size[1] * grid_size[2])
173
+ if num_patches == self.num_patches:
174
+ return self.position_embedding(self.position_ids).to(dtype=target_dtype, device=target_device)
175
+ if not interpolate_pos_encoding:
176
+ raise ValueError(
177
+ f"Unexpected number of patches: {num_patches} vs expected {self.num_patches}. "
178
+ "Enable `interpolate_pos_encoding=True` for variable volume sizes."
179
+ )
180
+ return self._interpolate_position_embeddings(grid_size, target_dtype=target_dtype, target_device=target_device)
181
+
182
+ def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = True) -> torch.Tensor:
183
+ if pixel_values.ndim != 5:
184
+ raise ValueError(
185
+ "`pixel_values` must have shape [batch, channels, depth, height, width]. "
186
+ f"Got shape {tuple(pixel_values.shape)}"
187
+ )
188
+ spatial_shape = tuple(int(v) for v in pixel_values.shape[-3:])
189
+ if any(dim % patch != 0 for dim, patch in zip(spatial_shape, self.patch_size)):
190
+ raise ValueError(
191
+ f"Input spatial size {spatial_shape} must be divisible by patch_size {self.patch_size}."
192
+ )
193
+
194
+ target_dtype = self.patch_embedding.weight.dtype
195
+ embeddings = self.patch_embedding(pixel_values.to(dtype=target_dtype))
196
+ grid_size = tuple(int(v) for v in embeddings.shape[-3:])
197
+ embeddings = embeddings.flatten(2).transpose(1, 2)
198
+ position_embeddings = self._get_position_embeddings(
199
+ grid_size=grid_size,
200
+ target_dtype=embeddings.dtype,
201
+ target_device=embeddings.device,
202
+ interpolate_pos_encoding=interpolate_pos_encoding,
203
+ )
204
+ return embeddings + position_embeddings
205
+
206
+
207
+ class BrainMRISiglipVisionTransformer(nn.Module):
208
+ """SigLIP vision tower with 3D embeddings."""
209
+
210
+ def __init__(self, config: BrainMRISiglipConfig) -> None:
211
+ super().__init__()
212
+ vision_config = config.get_vision_config()
213
+ volume_size = tuple(int(v) for v in config.volume_size)
214
+ patch_size = tuple(int(v) for v in config.patch_size)
215
+
216
+ self.embeddings = SiglipVisionEmbeddings3D(
217
+ vision_config=vision_config,
218
+ volume_size=volume_size,
219
+ patch_size=patch_size,
220
+ num_channels=int(config.num_channels),
221
+ )
222
+ self.encoder = SiglipEncoder(vision_config)
223
+ self.post_layernorm = nn.LayerNorm(vision_config.hidden_size, eps=vision_config.layer_norm_eps)
224
+ self.head = SiglipMultiheadAttentionPoolingHead(vision_config)
225
+
226
+ def forward(
227
+ self,
228
+ pixel_values: torch.Tensor,
229
+ interpolate_pos_encoding: bool = True,
230
+ **kwargs: Any,
231
+ ) -> BaseModelOutputWithPooling:
232
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
233
+ encoder_outputs = self.encoder(inputs_embeds=hidden_states, **kwargs)
234
+ last_hidden_state = self.post_layernorm(encoder_outputs.last_hidden_state)
235
+ pooler_output = self.head(last_hidden_state)
236
+ return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state, pooler_output=pooler_output)
237
+
238
+
239
+ class BrainMRISiglipPreTrainedModel(PreTrainedModel):
240
+ config_class = BrainMRISiglipConfig
241
+ base_model_prefix = "brain_mri_siglip"
242
+ supports_gradient_checkpointing = True
243
+
244
+ def _init_weights(self, module: nn.Module) -> None:
245
+ if isinstance(module, SiglipVisionEmbeddings3D):
246
+ width = int(self.config.get_vision_config().hidden_size)
247
+ nn.init.normal_(module.position_embedding.weight, std=1.0 / math.sqrt(width))
248
+ _lecun_normal_(module.patch_embedding.weight)
249
+ if module.patch_embedding.bias is not None:
250
+ nn.init.zeros_(module.patch_embedding.bias)
251
+ return
252
+
253
+ if isinstance(module, nn.Embedding):
254
+ _siglip_embedding_init_(module.weight)
255
+ return
256
+
257
+ if isinstance(module, SiglipAttention):
258
+ nn.init.xavier_uniform_(module.q_proj.weight)
259
+ nn.init.xavier_uniform_(module.k_proj.weight)
260
+ nn.init.xavier_uniform_(module.v_proj.weight)
261
+ nn.init.xavier_uniform_(module.out_proj.weight)
262
+ if module.q_proj.bias is not None:
263
+ nn.init.zeros_(module.q_proj.bias)
264
+ if module.k_proj.bias is not None:
265
+ nn.init.zeros_(module.k_proj.bias)
266
+ if module.v_proj.bias is not None:
267
+ nn.init.zeros_(module.v_proj.bias)
268
+ if module.out_proj.bias is not None:
269
+ nn.init.zeros_(module.out_proj.bias)
270
+ return
271
+
272
+ if isinstance(module, SiglipMLP):
273
+ nn.init.xavier_uniform_(module.fc1.weight)
274
+ nn.init.xavier_uniform_(module.fc2.weight)
275
+ if module.fc1.bias is not None:
276
+ nn.init.normal_(module.fc1.bias, std=1e-6)
277
+ if module.fc2.bias is not None:
278
+ nn.init.normal_(module.fc2.bias, std=1e-6)
279
+ return
280
+
281
+ if isinstance(module, SiglipMultiheadAttentionPoolingHead):
282
+ nn.init.xavier_uniform_(module.probe)
283
+ nn.init.xavier_uniform_(module.attention.in_proj_weight)
284
+ if module.attention.in_proj_bias is not None:
285
+ nn.init.zeros_(module.attention.in_proj_bias)
286
+ return
287
+
288
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
289
+ _lecun_normal_(module.weight)
290
+ if module.bias is not None:
291
+ nn.init.zeros_(module.bias)
292
+ return
293
+
294
+ if isinstance(module, nn.LayerNorm):
295
+ module.bias.data.zero_()
296
+ module.weight.data.fill_(1.0)
297
+ return
298
+
299
+ class BrainMRISiglipModel(BrainMRISiglipPreTrainedModel):
300
+ """3D MRI + text dual-encoder model with SigLIP contrastive loss."""
301
+
302
+ def __init__(self, config: BrainMRISiglipConfig) -> None:
303
+ super().__init__(config)
304
+ self.text_config = config.get_text_config()
305
+ self.vision_config = config.get_vision_config()
306
+
307
+ self.text_model = SiglipTextModel(self.text_config)
308
+ self.vision_model = BrainMRISiglipVisionTransformer(config)
309
+
310
+ projection_dim = int(config.projection_dim)
311
+ self.visual_projection = nn.Linear(self.vision_config.hidden_size, projection_dim, bias=False)
312
+ self.text_projection = nn.Linear(self.text_config.hidden_size, projection_dim, bias=False)
313
+
314
+ self.logit_scale = nn.Parameter(torch.tensor(float(config.logit_scale_init_value)))
315
+ self.logit_bias = nn.Parameter(torch.tensor(float(config.logit_bias_init_value)))
316
+
317
+ self.post_init()
318
+
319
+ @classmethod
320
+ def from_medsiglip_pretrained(
321
+ cls,
322
+ text_model_name_or_path: str = "google/medsiglip-448",
323
+ trust_remote_code: bool = True,
324
+ local_files_only: bool = False,
325
+ **kwargs: Any,
326
+ ) -> "BrainMRISiglipModel":
327
+ base_config = AutoConfig.from_pretrained(
328
+ text_model_name_or_path,
329
+ trust_remote_code=trust_remote_code,
330
+ local_files_only=local_files_only,
331
+ )
332
+ if hasattr(base_config, "text_config"):
333
+ raw_text_config = base_config.text_config
334
+ text_config = raw_text_config.to_dict() if hasattr(raw_text_config, "to_dict") else dict(raw_text_config)
335
+ else:
336
+ text_config = SiglipTextConfig().to_dict()
337
+
338
+ if hasattr(base_config, "vision_config"):
339
+ raw_vision_config = base_config.vision_config
340
+ vision_config = (
341
+ raw_vision_config.to_dict()
342
+ if hasattr(raw_vision_config, "to_dict")
343
+ else dict(raw_vision_config)
344
+ )
345
+ else:
346
+ vision_config = SiglipVisionConfig().to_dict()
347
+ projection_dim = kwargs.pop(
348
+ "projection_dim",
349
+ int(getattr(base_config, "projection_dim", text_config.get("projection_size", text_config["hidden_size"]))),
350
+ )
351
+
352
+ config = BrainMRISiglipConfig(
353
+ text_config=text_config,
354
+ vision_config=vision_config,
355
+ projection_dim=projection_dim,
356
+ text_model_name_or_path=text_model_name_or_path,
357
+ **kwargs,
358
+ )
359
+ model = cls(config)
360
+ model.load_text_tower_from_pretrained(
361
+ text_model_name_or_path,
362
+ trust_remote_code=trust_remote_code,
363
+ local_files_only=local_files_only,
364
+ )
365
+ return model
366
+
367
+ def load_text_tower_from_pretrained(
368
+ self,
369
+ text_model_name_or_path: str,
370
+ trust_remote_code: bool = True,
371
+ local_files_only: bool = False,
372
+ strict: bool = True,
373
+ ) -> Tuple[Any, Any]:
374
+ source_model = None
375
+ try:
376
+ source_model = AutoModel.from_pretrained(
377
+ text_model_name_or_path,
378
+ trust_remote_code=trust_remote_code,
379
+ local_files_only=local_files_only,
380
+ )
381
+ if hasattr(source_model, "text_model"):
382
+ source_text_model = source_model.text_model
383
+ elif isinstance(source_model, SiglipTextModel):
384
+ source_text_model = source_model
385
+ else:
386
+ raise ValueError(
387
+ f"Could not find a SigLIP text tower in `{text_model_name_or_path}` "
388
+ f"({type(source_model).__name__})."
389
+ )
390
+
391
+ missing, unexpected = _load_state_dict_with_flexible_prefix(
392
+ self.text_model,
393
+ source_text_model.state_dict(),
394
+ strict=strict,
395
+ )
396
+
397
+ if hasattr(source_model, "text_projection") and isinstance(source_model.text_projection, nn.Linear):
398
+ if source_model.text_projection.weight.shape == self.text_projection.weight.shape:
399
+ self.text_projection.load_state_dict(source_model.text_projection.state_dict())
400
+
401
+ if hasattr(source_model, "logit_scale") and source_model.logit_scale.shape == self.logit_scale.shape:
402
+ self.logit_scale.data.copy_(source_model.logit_scale.data)
403
+ if hasattr(source_model, "logit_bias") and source_model.logit_bias.shape == self.logit_bias.shape:
404
+ self.logit_bias.data.copy_(source_model.logit_bias.data)
405
+
406
+ return missing, unexpected
407
+ finally:
408
+ if source_model is not None:
409
+ del source_model
410
+
411
+ def freeze_text_tower(self, trainable_layers: int = 0) -> None:
412
+ for parameter in self.text_model.parameters():
413
+ parameter.requires_grad = False
414
+
415
+ trainable_layers = int(trainable_layers)
416
+ if trainable_layers > 0 and hasattr(self.text_model, "text_model") and hasattr(
417
+ self.text_model.text_model, "encoder"
418
+ ):
419
+ layers = self.text_model.text_model.encoder.layers
420
+ for layer in layers[-trainable_layers:]:
421
+ for parameter in layer.parameters():
422
+ parameter.requires_grad = True
423
+
424
+ for module_name in ("final_layer_norm", "head"):
425
+ if hasattr(self.text_model.text_model, module_name):
426
+ for parameter in getattr(self.text_model.text_model, module_name).parameters():
427
+ parameter.requires_grad = True
428
+
429
+ for parameter in self.text_projection.parameters():
430
+ parameter.requires_grad = True
431
+
432
+ def freeze_vision_tower(self, trainable_layers: int = 0, train_embeddings: bool = False) -> None:
433
+ for parameter in self.vision_model.parameters():
434
+ parameter.requires_grad = False
435
+
436
+ if train_embeddings:
437
+ for parameter in self.vision_model.embeddings.parameters():
438
+ parameter.requires_grad = True
439
+
440
+ trainable_layers = int(trainable_layers)
441
+ if trainable_layers > 0:
442
+ layers = self.vision_model.encoder.layers
443
+ for layer in layers[-trainable_layers:]:
444
+ for parameter in layer.parameters():
445
+ parameter.requires_grad = True
446
+ for parameter in self.vision_model.post_layernorm.parameters():
447
+ parameter.requires_grad = True
448
+ for parameter in self.vision_model.head.parameters():
449
+ parameter.requires_grad = True
450
+
451
+ for parameter in self.visual_projection.parameters():
452
+ parameter.requires_grad = True
453
+
454
+ def get_text_features(
455
+ self,
456
+ input_ids: torch.LongTensor,
457
+ attention_mask: Optional[torch.Tensor] = None,
458
+ position_ids: Optional[torch.LongTensor] = None,
459
+ text_kwargs: Optional[Mapping[str, Any]] = None,
460
+ **kwargs: Any,
461
+ ) -> torch.FloatTensor:
462
+ kwargs = dict(kwargs)
463
+ nested_text_kwargs = kwargs.pop("text_kwargs", None)
464
+ if kwargs:
465
+ raise TypeError(f"Unexpected keyword arguments for text tower: {sorted(kwargs.keys())}")
466
+ merged_text_kwargs: dict[str, Any] = {}
467
+ if nested_text_kwargs:
468
+ merged_text_kwargs.update(dict(nested_text_kwargs))
469
+ if text_kwargs:
470
+ merged_text_kwargs.update(dict(text_kwargs))
471
+
472
+ text_outputs = self.text_model(
473
+ input_ids=input_ids,
474
+ attention_mask=attention_mask,
475
+ position_ids=position_ids,
476
+ **merged_text_kwargs,
477
+ )
478
+ text_features = self.text_projection(text_outputs.pooler_output)
479
+ return F.normalize(text_features, dim=-1)
480
+
481
+ def get_image_features(
482
+ self,
483
+ pixel_values: torch.FloatTensor,
484
+ interpolate_pos_encoding: bool = True,
485
+ vision_kwargs: Optional[Mapping[str, Any]] = None,
486
+ **kwargs: Any,
487
+ ) -> torch.FloatTensor:
488
+ kwargs = dict(kwargs)
489
+ nested_vision_kwargs = kwargs.pop("vision_kwargs", None)
490
+ legacy_interpolate_pos_encoding = kwargs.pop("interpolate_pos_encoding", None)
491
+ if kwargs:
492
+ raise TypeError(f"Unexpected keyword arguments for vision tower: {sorted(kwargs.keys())}")
493
+
494
+ merged_vision_kwargs: dict[str, Any] = {}
495
+ if nested_vision_kwargs:
496
+ merged_vision_kwargs.update(dict(nested_vision_kwargs))
497
+ if vision_kwargs:
498
+ merged_vision_kwargs.update(dict(vision_kwargs))
499
+ if legacy_interpolate_pos_encoding is not None:
500
+ interpolate_pos_encoding = bool(legacy_interpolate_pos_encoding)
501
+
502
+ vision_outputs = self.vision_model(
503
+ pixel_values=pixel_values,
504
+ interpolate_pos_encoding=interpolate_pos_encoding,
505
+ **merged_vision_kwargs,
506
+ )
507
+ image_features = self.visual_projection(vision_outputs.pooler_output)
508
+ return F.normalize(image_features, dim=-1)
509
+
510
+ def forward(
511
+ self,
512
+ input_ids: Optional[torch.LongTensor] = None,
513
+ pixel_values: Optional[torch.FloatTensor] = None,
514
+ attention_mask: Optional[torch.Tensor] = None,
515
+ position_ids: Optional[torch.LongTensor] = None,
516
+ return_loss: Optional[bool] = None,
517
+ gather_loss: bool = False,
518
+ interpolate_pos_encoding: bool = True,
519
+ vision_kwargs: Optional[Mapping[str, Any]] = None,
520
+ text_kwargs: Optional[Mapping[str, Any]] = None,
521
+ return_dict: Optional[bool] = None,
522
+ **kwargs: Any,
523
+ ) -> SiglipOutput:
524
+ if pixel_values is None:
525
+ raise ValueError("`pixel_values` must be provided.")
526
+ if input_ids is None:
527
+ raise ValueError("`input_ids` must be provided.")
528
+
529
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
530
+ return_loss = bool(return_loss) if return_loss is not None else False
531
+ kwargs = dict(kwargs)
532
+ nested_vision_kwargs = kwargs.pop("vision_kwargs", None)
533
+ nested_text_kwargs = kwargs.pop("text_kwargs", None)
534
+ legacy_interpolate_pos_encoding = kwargs.pop("interpolate_pos_encoding", None)
535
+ if kwargs:
536
+ raise TypeError(f"Unexpected keyword arguments in model.forward: {sorted(kwargs.keys())}")
537
+
538
+ merged_vision_kwargs: dict[str, Any] = {}
539
+ merged_text_kwargs: dict[str, Any] = {}
540
+ if nested_vision_kwargs:
541
+ merged_vision_kwargs.update(dict(nested_vision_kwargs))
542
+ if vision_kwargs:
543
+ merged_vision_kwargs.update(dict(vision_kwargs))
544
+ if nested_text_kwargs:
545
+ merged_text_kwargs.update(dict(nested_text_kwargs))
546
+ if text_kwargs:
547
+ merged_text_kwargs.update(dict(text_kwargs))
548
+ if legacy_interpolate_pos_encoding is not None:
549
+ interpolate_pos_encoding = bool(legacy_interpolate_pos_encoding)
550
+
551
+ vision_outputs = self.vision_model(
552
+ pixel_values=pixel_values,
553
+ interpolate_pos_encoding=interpolate_pos_encoding,
554
+ **merged_vision_kwargs,
555
+ )
556
+ text_outputs = self.text_model(
557
+ input_ids=input_ids,
558
+ attention_mask=attention_mask,
559
+ position_ids=position_ids,
560
+ **merged_text_kwargs,
561
+ )
562
+
563
+ image_embeds = self.visual_projection(vision_outputs.pooler_output)
564
+ text_embeds = self.text_projection(text_outputs.pooler_output)
565
+
566
+ image_embeds = F.normalize(image_embeds, p=2, dim=-1)
567
+ text_embeds = F.normalize(text_embeds, p=2, dim=-1)
568
+
569
+ image_embeds_for_loss = image_embeds
570
+ text_embeds_for_loss = text_embeds
571
+ if gather_loss and return_loss:
572
+ image_embeds_for_loss = _distributed_concat_with_grad(image_embeds)
573
+ text_embeds_for_loss = _distributed_concat_with_grad(text_embeds)
574
+
575
+ logit_scale = self.logit_scale.exp().clamp(
576
+ min=float(self.config.logit_scale_min),
577
+ max=float(self.config.logit_scale_max),
578
+ )
579
+
580
+ local_logits_per_text = torch.matmul(
581
+ text_embeds,
582
+ image_embeds.t().to(text_embeds.device),
583
+ )
584
+ local_logits_per_text = local_logits_per_text * logit_scale + self.logit_bias
585
+ local_logits_per_image = local_logits_per_text.t()
586
+
587
+ loss = None
588
+ if return_loss:
589
+ loss_logits_per_text = torch.matmul(
590
+ text_embeds_for_loss,
591
+ image_embeds_for_loss.t().to(text_embeds_for_loss.device),
592
+ )
593
+ loss_logits_per_text = loss_logits_per_text * logit_scale + self.logit_bias
594
+ loss = _siglip_sigmoid_loss(loss_logits_per_text)
595
+
596
+ if not return_dict:
597
+ output = (
598
+ local_logits_per_image,
599
+ local_logits_per_text,
600
+ text_embeds,
601
+ image_embeds,
602
+ text_outputs,
603
+ vision_outputs,
604
+ )
605
+ return ((loss,) + output) if loss is not None else output
606
+
607
+ return SiglipOutput(
608
+ loss=loss,
609
+ logits_per_image=local_logits_per_image,
610
+ logits_per_text=local_logits_per_text,
611
+ text_embeds=text_embeds,
612
+ image_embeds=image_embeds,
613
+ text_model_output=text_outputs,
614
+ vision_model_output=vision_outputs,
615
+ )
offline_aligned_preprocessing.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared offline-aligned preprocessing helpers for 3D brain MRI volumes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ from pathlib import Path
7
+ from typing import Any, Mapping
8
+
9
+ import nibabel as nib
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ from scipy import ndimage as scipy_ndimage
16
+ except Exception: # pragma: no cover - optional import surface
17
+ scipy_ndimage = None
18
+
19
+
20
+ TARGET_SHAPE = (128, 192, 192)
21
+ TARGET_SPACING = (1.25, 1.0, 1.0)
22
+ CROP_MARGIN_MM = 5.0
23
+ FOREGROUND_THRESHOLD = 1e-3
24
+ BACKGROUND_VALUE = -1.0
25
+ FOREGROUND_STRATEGY = "largest_component_nonzero"
26
+ GENERIC_RECIPE_ID = "generic_foreground_128x192x192_fp16_v1"
27
+ GENERIC_CACHE_VERSION = 1
28
+
29
+
30
+ def load_canonical_nifti(path: str | Path):
31
+ return nib.as_closest_canonical(nib.load(str(path)))
32
+
33
+
34
+ def load_image_spacing(image) -> tuple[float, float, float]:
35
+ zooms = image.header.get_zooms()[:3]
36
+ if len(zooms) != 3:
37
+ raise ValueError(f"Expected a 3D image spacing tuple, got {zooms}.")
38
+ return tuple(float(value) for value in zooms)
39
+
40
+
41
+ def coerce_volume_to_3d(volume: np.ndarray) -> np.ndarray:
42
+ if volume.ndim == 3:
43
+ return volume.astype(np.float32, copy=False)
44
+ if volume.ndim != 4:
45
+ raise ValueError(f"Expected a 3D or 4D volume, got shape {volume.shape}.")
46
+
47
+ if volume.shape[0] <= 4 and volume.shape[-1] > 4:
48
+ selected = volume[0]
49
+ else:
50
+ selected = volume[..., 0]
51
+ return np.asarray(selected, dtype=np.float32)
52
+
53
+
54
+ def largest_connected_component(mask: np.ndarray) -> np.ndarray:
55
+ if not mask.any() or scipy_ndimage is None:
56
+ return mask
57
+ structure = scipy_ndimage.generate_binary_structure(mask.ndim, 1)
58
+ labels, num_labels = scipy_ndimage.label(mask, structure=structure)
59
+ if num_labels <= 1:
60
+ return mask
61
+ counts = np.bincount(labels.reshape(-1))
62
+ if counts.size <= 1:
63
+ return mask
64
+ counts[0] = 0
65
+ winning_label = int(counts.argmax())
66
+ if winning_label <= 0 or counts[winning_label] <= 0:
67
+ return mask
68
+ return labels == winning_label
69
+
70
+
71
+ def build_foreground_mask(volume: np.ndarray, threshold: float = FOREGROUND_THRESHOLD) -> np.ndarray:
72
+ sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0)
73
+ raw_mask = np.abs(sanitized) > float(threshold)
74
+ if not raw_mask.any():
75
+ return np.ones_like(sanitized, dtype=bool)
76
+
77
+ component_mask = largest_connected_component(raw_mask)
78
+ component_count = int(component_mask.sum())
79
+ raw_count = int(raw_mask.sum())
80
+ if component_count <= 0:
81
+ return raw_mask
82
+ if component_count < 512 and raw_count > component_count:
83
+ return raw_mask
84
+ return component_mask
85
+
86
+
87
+ def compute_crop_bbox(
88
+ mask: np.ndarray,
89
+ spacing: tuple[float, float, float],
90
+ margin_mm: float = CROP_MARGIN_MM,
91
+ ) -> tuple[tuple[int, int], ...]:
92
+ coords = np.where(mask)
93
+ if coords[0].size == 0:
94
+ raise ValueError("Foreground mask contains no positive voxels after selection.")
95
+
96
+ bbox = []
97
+ for axis, values in enumerate(coords):
98
+ margin_voxels = int(math.ceil(float(margin_mm) / float(spacing[axis])))
99
+ start = max(0, int(values.min()) - margin_voxels)
100
+ stop = min(mask.shape[axis], int(values.max()) + margin_voxels + 1)
101
+ bbox.append((start, stop))
102
+ return tuple(bbox)
103
+
104
+
105
+ def crop_volume_and_mask(
106
+ volume: np.ndarray,
107
+ mask: np.ndarray,
108
+ spacing: tuple[float, float, float],
109
+ margin_mm: float = CROP_MARGIN_MM,
110
+ ) -> tuple[np.ndarray, np.ndarray, tuple[tuple[int, int], ...]]:
111
+ bbox = compute_crop_bbox(mask, spacing, margin_mm=margin_mm)
112
+ slices = tuple(slice(start, stop) for start, stop in bbox)
113
+ return volume[slices], mask[slices], bbox
114
+
115
+
116
+ def normalize_foreground_only(volume: np.ndarray, mask: np.ndarray) -> np.ndarray:
117
+ sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
118
+ foreground_values = sanitized[mask]
119
+ if foreground_values.size == 0:
120
+ raise ValueError("Cannot normalize volume because the foreground mask is empty.")
121
+
122
+ if foreground_values.size > 1_000_000:
123
+ step = max(1, foreground_values.size // 1_000_000)
124
+ foreground_values = foreground_values[::step]
125
+
126
+ low, high = np.percentile(foreground_values, [0.5, 99.5])
127
+ if not np.isfinite(low) or not np.isfinite(high) or high <= low:
128
+ normalized = np.zeros_like(sanitized, dtype=np.float32)
129
+ else:
130
+ normalized = np.clip(sanitized, float(low), float(high))
131
+ normalized = np.clip((normalized - float(low)) / float(high - low), 0.0, 1.0)
132
+ normalized = normalized * 2.0 - 1.0
133
+ return normalized.astype(np.float32, copy=False)
134
+
135
+
136
+ def resize_volume(volume: np.ndarray, size: tuple[int, int, int], mode: str) -> np.ndarray:
137
+ tensor = torch.from_numpy(volume).unsqueeze(0).unsqueeze(0)
138
+ kwargs = {}
139
+ if mode in {"linear", "bilinear", "bicubic", "trilinear"}:
140
+ kwargs["align_corners"] = False
141
+ tensor = F.interpolate(tensor, size=size, mode=mode, **kwargs)
142
+ return tensor.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32, copy=False)
143
+
144
+
145
+ def resize_mask(mask: np.ndarray, size: tuple[int, int, int]) -> np.ndarray:
146
+ tensor = torch.from_numpy(mask.astype(np.float32, copy=False)).unsqueeze(0).unsqueeze(0)
147
+ tensor = F.interpolate(tensor, size=size, mode="nearest")
148
+ return tensor.squeeze(0).squeeze(0).cpu().numpy() > 0.5
149
+
150
+
151
+ def resample_to_target_spacing(
152
+ volume: np.ndarray,
153
+ mask: np.ndarray,
154
+ source_spacing: tuple[float, float, float],
155
+ target_spacing: tuple[float, float, float] = TARGET_SPACING,
156
+ ) -> tuple[np.ndarray, np.ndarray]:
157
+ target_shape = []
158
+ for current_size, src, dst in zip(volume.shape, source_spacing, target_spacing):
159
+ target_shape.append(max(1, int(round(float(current_size) * float(src) / float(dst)))))
160
+ target_shape_tuple = tuple(target_shape)
161
+ if target_shape_tuple == tuple(int(v) for v in volume.shape):
162
+ return volume.astype(np.float32, copy=False), mask
163
+ return (
164
+ resize_volume(volume, target_shape_tuple, mode="trilinear"),
165
+ resize_mask(mask, target_shape_tuple),
166
+ )
167
+
168
+
169
+ def downscale_to_fit(
170
+ volume: np.ndarray,
171
+ mask: np.ndarray,
172
+ target_shape: tuple[int, int, int] = TARGET_SHAPE,
173
+ ) -> tuple[np.ndarray, np.ndarray]:
174
+ current_shape = tuple(int(v) for v in volume.shape)
175
+ if all(current <= target for current, target in zip(current_shape, target_shape)):
176
+ return volume, mask
177
+
178
+ scale = min(float(target) / float(current) for current, target in zip(current_shape, target_shape))
179
+ if scale >= 1.0:
180
+ return volume, mask
181
+
182
+ new_shape = tuple(
183
+ min(target, max(1, int(math.floor(float(current) * scale))))
184
+ for current, target in zip(current_shape, target_shape)
185
+ )
186
+ return (
187
+ resize_volume(volume, new_shape, mode="trilinear"),
188
+ resize_mask(mask, new_shape),
189
+ )
190
+
191
+
192
+ def center_pad(
193
+ array: np.ndarray,
194
+ target_shape: tuple[int, int, int] = TARGET_SHAPE,
195
+ fill_value: float = BACKGROUND_VALUE,
196
+ ) -> np.ndarray:
197
+ if any(current > target for current, target in zip(array.shape, target_shape)):
198
+ raise ValueError(f"Cannot center-pad shape {array.shape} into smaller target {target_shape}.")
199
+ pad_width = []
200
+ for current, target in zip(array.shape, target_shape):
201
+ delta = target - current
202
+ before = delta // 2
203
+ after = delta - before
204
+ pad_width.append((before, after))
205
+ return np.pad(array, pad_width=tuple(pad_width), mode="constant", constant_values=fill_value)
206
+
207
+
208
+ def preprocess_image_with_foreground_mask(
209
+ image_path: str | Path,
210
+ *,
211
+ target_shape: tuple[int, int, int] = TARGET_SHAPE,
212
+ target_spacing: tuple[float, float, float] = TARGET_SPACING,
213
+ crop_margin_mm: float = CROP_MARGIN_MM,
214
+ foreground_threshold: float = FOREGROUND_THRESHOLD,
215
+ background_value: float = BACKGROUND_VALUE,
216
+ foreground_strategy: str = FOREGROUND_STRATEGY,
217
+ recipe_id: str = GENERIC_RECIPE_ID,
218
+ cache_version: int = GENERIC_CACHE_VERSION,
219
+ ) -> dict[str, object]:
220
+ image_path = Path(image_path)
221
+ image = load_canonical_nifti(image_path)
222
+ source_shape = tuple(int(value) for value in image.shape)
223
+ source_spacing = load_image_spacing(image)
224
+ volume = np.asarray(image.get_fdata(dtype=np.float32), dtype=np.float32)
225
+ volume = coerce_volume_to_3d(volume)
226
+
227
+ foreground_mask = build_foreground_mask(volume, threshold=foreground_threshold)
228
+ cropped_volume, cropped_mask, crop_bbox = crop_volume_and_mask(
229
+ volume,
230
+ foreground_mask,
231
+ source_spacing,
232
+ margin_mm=crop_margin_mm,
233
+ )
234
+ normalized_volume = normalize_foreground_only(cropped_volume, cropped_mask)
235
+ resampled_volume, resampled_mask = resample_to_target_spacing(
236
+ normalized_volume,
237
+ cropped_mask,
238
+ source_spacing=source_spacing,
239
+ target_spacing=target_spacing,
240
+ )
241
+ fitted_volume, fitted_mask = downscale_to_fit(
242
+ resampled_volume,
243
+ resampled_mask,
244
+ target_shape=target_shape,
245
+ )
246
+ fitted_volume = np.clip(fitted_volume, -1.0, 1.0).astype(np.float32, copy=False)
247
+ fitted_volume[~fitted_mask] = float(background_value)
248
+
249
+ padded_volume = center_pad(
250
+ fitted_volume,
251
+ target_shape=target_shape,
252
+ fill_value=float(background_value),
253
+ ).astype(np.float32, copy=False)
254
+ pixel_values = torch.from_numpy(padded_volume).unsqueeze(0).to(dtype=torch.float16).contiguous()
255
+
256
+ return {
257
+ "pixel_values": pixel_values,
258
+ "source_image": str(image_path),
259
+ "source_shape": list(source_shape),
260
+ "source_spacing": list(source_spacing),
261
+ "crop_bbox": [[int(start), int(stop)] for start, stop in crop_bbox],
262
+ "foreground_strategy": foreground_strategy,
263
+ "recipe_id": recipe_id,
264
+ "cache_version": int(cache_version),
265
+ }
266
+
267
+
268
+ def validate_fixed_payload(
269
+ payload: Mapping[str, Any],
270
+ *,
271
+ target_shape: tuple[int, int, int] = TARGET_SHAPE,
272
+ ) -> None:
273
+ pixel_values = payload.get("pixel_values")
274
+ if not isinstance(pixel_values, torch.Tensor):
275
+ raise TypeError("`pixel_values` must be a torch.Tensor.")
276
+ expected_shape = (1,) + tuple(target_shape)
277
+ if tuple(pixel_values.shape) != expected_shape:
278
+ raise ValueError(f"Expected tensor shape {expected_shape}, got {tuple(pixel_values.shape)}.")
279
+ if pixel_values.dtype != torch.float16:
280
+ raise ValueError(f"Expected tensor dtype torch.float16, got {pixel_values.dtype}.")
281
+ if not torch.isfinite(pixel_values).all():
282
+ raise ValueError("Tensor contains non-finite values.")
283
+ min_value = float(pixel_values.min().item())
284
+ max_value = float(pixel_values.max().item())
285
+ if min_value < -1.01 or max_value > 1.01:
286
+ raise ValueError(f"Expected tensor values in [-1, 1]. Got min={min_value}, max={max_value}.")
preprocessor_config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "canonicalize_orientation": true,
3
+ "clip_percentiles": [
4
+ 0.5,
5
+ 99.5
6
+ ],
7
+ "crop_margin": 4,
8
+ "do_clip": true,
9
+ "do_crop_foreground": true,
10
+ "do_normalize": true,
11
+ "effective_pad_value": -1.0,
12
+ "foreground_threshold": 0.001,
13
+ "image_processor_type": "BrainMRISiglipVolumeProcessor",
14
+ "interpolation_mode": "trilinear",
15
+ "max_channel_dim": 4,
16
+ "output_range": [
17
+ -1.0,
18
+ 1.0
19
+ ],
20
+ "pad_value": null,
21
+ "path_background_value": -1.0,
22
+ "path_crop_margin_mm": 5.0,
23
+ "path_foreground_strategy": "largest_component_nonzero",
24
+ "path_foreground_threshold": 0.001,
25
+ "path_generic_cache_version": 1,
26
+ "path_generic_recipe_id": "generic_foreground_128x192x192_fp16_v1",
27
+ "path_recipe_mode": "auto",
28
+ "path_target_shape": [
29
+ 128,
30
+ 192,
31
+ 192
32
+ ],
33
+ "path_target_spacing": [
34
+ 1.25,
35
+ 1.0,
36
+ 1.0
37
+ ],
38
+ "prefer_nibabel_resample": false,
39
+ "resize_strategy": "pad_or_crop",
40
+ "spacing": [
41
+ 1.25,
42
+ 1.0,
43
+ 1.0
44
+ ],
45
+ "spacing_tolerance": 0.001,
46
+ "use_foreground_intensity_stats": true,
47
+ "volume_size": [
48
+ 128,
49
+ 192,
50
+ 192
51
+ ]
52
+ }
processing_brain_mri_siglip.py ADDED
@@ -0,0 +1,680 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Processor code for Brain MRI SigLIP."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from transformers import AutoTokenizer
14
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
15
+ from transformers.processing_utils import ProcessorMixin
16
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
17
+
18
+ from .common import copy_remote_code_files, to_3tuple
19
+ from .offline_aligned_preprocessing import (
20
+ BACKGROUND_VALUE as DEFAULT_PATH_BACKGROUND_VALUE,
21
+ CROP_MARGIN_MM as DEFAULT_PATH_CROP_MARGIN_MM,
22
+ FOREGROUND_STRATEGY as DEFAULT_PATH_FOREGROUND_STRATEGY,
23
+ FOREGROUND_THRESHOLD as DEFAULT_PATH_FOREGROUND_THRESHOLD,
24
+ GENERIC_CACHE_VERSION,
25
+ GENERIC_RECIPE_ID,
26
+ TARGET_SHAPE as DEFAULT_PATH_TARGET_SHAPE,
27
+ TARGET_SPACING as DEFAULT_PATH_TARGET_SPACING,
28
+ preprocess_image_with_foreground_mask,
29
+ )
30
+
31
+ try:
32
+ from scripts.fomo_300k_offline_pt.common import (
33
+ is_fomo_300k_path,
34
+ preprocess_fomo_300k_image,
35
+ )
36
+ except Exception: # pragma: no cover - optional import surface
37
+ is_fomo_300k_path = None
38
+ preprocess_fomo_300k_image = None
39
+
40
+ try:
41
+ from scripts.mr_rate_offline_pt.common import (
42
+ is_mr_rate_path,
43
+ preprocess_mr_rate_image,
44
+ )
45
+ except Exception: # pragma: no cover - optional import surface
46
+ is_mr_rate_path = None
47
+ preprocess_mr_rate_image = None
48
+
49
+ try:
50
+ import nibabel as nib
51
+ try:
52
+ from nibabel import processing as nib_processing
53
+ except Exception: # pragma: no cover - optional import
54
+ nib_processing = None
55
+ except Exception: # pragma: no cover - optional import
56
+ nib = None
57
+ nib_processing = None
58
+
59
+
60
+ VolumeInput = Union[str, Path, np.ndarray, torch.Tensor]
61
+ SpacingInput = Optional[Union[Sequence[float], Sequence[Sequence[float]]]]
62
+ LOGGER = logging.getLogger(__name__)
63
+
64
+
65
+ def _ensure_list(values: Union[VolumeInput, Sequence[VolumeInput]]) -> List[VolumeInput]:
66
+ if isinstance(values, (str, Path, np.ndarray, torch.Tensor)):
67
+ return [values]
68
+ return list(values)
69
+
70
+
71
+ def _normalize_spacing_value(value: Optional[Sequence[float]], field_name: str) -> Optional[Tuple[float, float, float]]:
72
+ if value is None:
73
+ return None
74
+ if len(value) != 3:
75
+ raise ValueError(f"`{field_name}` must be a length-3 sequence. Got: {value}")
76
+ return (float(value[0]), float(value[1]), float(value[2]))
77
+
78
+
79
+ def _ensure_spacing_list(
80
+ source_spacings: SpacingInput,
81
+ batch_size: int,
82
+ ) -> List[Optional[Tuple[float, float, float]]]:
83
+ if source_spacings is None:
84
+ return [None] * batch_size
85
+ if batch_size == 1 and isinstance(source_spacings, Sequence) and len(source_spacings) == 3 and not isinstance(
86
+ source_spacings[0], (list, tuple)
87
+ ):
88
+ return [_normalize_spacing_value(source_spacings, "source_spacing")]
89
+ values = list(source_spacings)
90
+ if len(values) != batch_size:
91
+ raise ValueError(
92
+ f"`source_spacings` must have length {batch_size} to match the input batch. Got {len(values)}."
93
+ )
94
+ return [_normalize_spacing_value(value, "source_spacing") for value in values]
95
+
96
+
97
+ def _normalize_shape_value(
98
+ value: Sequence[int],
99
+ field_name: str,
100
+ ) -> Tuple[int, int, int]:
101
+ normalized = to_3tuple(value, field_name)
102
+ return (int(normalized[0]), int(normalized[1]), int(normalized[2]))
103
+
104
+
105
+ class BrainMRISiglipVolumeProcessor(BaseImageProcessor):
106
+ """Image processor for 3D brain MRI volumes."""
107
+
108
+ model_input_names = ["pixel_values"]
109
+
110
+ def __init__(
111
+ self,
112
+ volume_size: Union[int, Sequence[int]] = (128, 192, 192),
113
+ clip_percentiles: Tuple[float, float] = (0.5, 99.5),
114
+ output_range: Tuple[float, float] = (-1.0, 1.0),
115
+ do_clip: bool = True,
116
+ do_normalize: bool = True,
117
+ interpolation_mode: str = "trilinear",
118
+ max_channel_dim: int = 4,
119
+ canonicalize_orientation: bool = True,
120
+ spacing: Optional[Sequence[float]] = None,
121
+ spacing_tolerance: float = 1e-3,
122
+ prefer_nibabel_resample: bool = True,
123
+ use_foreground_intensity_stats: bool = True,
124
+ do_crop_foreground: bool = True,
125
+ foreground_threshold: float = 1e-3,
126
+ crop_margin: int = 4,
127
+ resize_strategy: str = "pad_or_crop",
128
+ pad_value: Optional[float] = None,
129
+ path_recipe_mode: str = "auto",
130
+ path_target_shape: Union[int, Sequence[int]] = DEFAULT_PATH_TARGET_SHAPE,
131
+ path_target_spacing: Optional[Sequence[float]] = DEFAULT_PATH_TARGET_SPACING,
132
+ path_crop_margin_mm: float = DEFAULT_PATH_CROP_MARGIN_MM,
133
+ path_foreground_threshold: float = DEFAULT_PATH_FOREGROUND_THRESHOLD,
134
+ path_background_value: float = DEFAULT_PATH_BACKGROUND_VALUE,
135
+ path_foreground_strategy: str = DEFAULT_PATH_FOREGROUND_STRATEGY,
136
+ path_generic_recipe_id: str = GENERIC_RECIPE_ID,
137
+ path_generic_cache_version: int = GENERIC_CACHE_VERSION,
138
+ **kwargs: Any,
139
+ ) -> None:
140
+ super().__init__(**kwargs)
141
+ self.volume_size = list(to_3tuple(volume_size, "volume_size"))
142
+ self.clip_percentiles = (float(clip_percentiles[0]), float(clip_percentiles[1]))
143
+ self.output_range = (float(output_range[0]), float(output_range[1]))
144
+ self.do_clip = bool(do_clip)
145
+ self.do_normalize = bool(do_normalize)
146
+ self.interpolation_mode = str(interpolation_mode)
147
+ self.max_channel_dim = int(max_channel_dim)
148
+ self.canonicalize_orientation = bool(canonicalize_orientation)
149
+ self.spacing = list(_normalize_spacing_value(spacing, "spacing")) if spacing is not None else None
150
+ self.spacing_tolerance = float(spacing_tolerance)
151
+ self.prefer_nibabel_resample = bool(prefer_nibabel_resample)
152
+ self.use_foreground_intensity_stats = bool(use_foreground_intensity_stats)
153
+ self.do_crop_foreground = bool(do_crop_foreground)
154
+ self.foreground_threshold = float(foreground_threshold)
155
+ self.crop_margin = int(crop_margin)
156
+ self.resize_strategy = str(resize_strategy)
157
+ self.pad_value = None if pad_value is None else float(pad_value)
158
+ self.path_recipe_mode = str(path_recipe_mode)
159
+ self.path_target_shape = list(_normalize_shape_value(path_target_shape, "path_target_shape"))
160
+ self.path_target_spacing = (
161
+ list(_normalize_spacing_value(path_target_spacing, "path_target_spacing"))
162
+ if path_target_spacing is not None
163
+ else None
164
+ )
165
+ self.path_crop_margin_mm = float(path_crop_margin_mm)
166
+ self.path_foreground_threshold = float(path_foreground_threshold)
167
+ self.path_background_value = float(path_background_value)
168
+ self.path_foreground_strategy = str(path_foreground_strategy)
169
+ self.path_generic_recipe_id = str(path_generic_recipe_id)
170
+ self.path_generic_cache_version = int(path_generic_cache_version)
171
+ self.effective_pad_value = self._resolve_pad_value()
172
+ if self.max_channel_dim <= 0:
173
+ raise ValueError(f"`max_channel_dim` must be > 0. Got {self.max_channel_dim}.")
174
+ if not (0.0 <= self.clip_percentiles[0] < self.clip_percentiles[1] <= 100.0):
175
+ raise ValueError(
176
+ "`clip_percentiles` must satisfy 0 <= low < high <= 100. "
177
+ f"Got {self.clip_percentiles}."
178
+ )
179
+ if self.resize_strategy not in {"pad_or_crop", "interpolate"}:
180
+ raise ValueError(
181
+ "`resize_strategy` must be one of: pad_or_crop, interpolate. "
182
+ f"Got {self.resize_strategy!r}."
183
+ )
184
+ if self.path_recipe_mode not in {"auto", "legacy"}:
185
+ raise ValueError(
186
+ "`path_recipe_mode` must be one of: auto, legacy. "
187
+ f"Got {self.path_recipe_mode!r}."
188
+ )
189
+ if self.path_crop_margin_mm < 0:
190
+ raise ValueError(f"`path_crop_margin_mm` must be >= 0. Got {self.path_crop_margin_mm}.")
191
+ if self.path_foreground_threshold < 0:
192
+ raise ValueError(
193
+ f"`path_foreground_threshold` must be >= 0. Got {self.path_foreground_threshold}."
194
+ )
195
+ if self.spacing_tolerance < 0:
196
+ raise ValueError(f"`spacing_tolerance` must be >= 0. Got {self.spacing_tolerance}.")
197
+
198
+ def get_path_recipe_config(self) -> Dict[str, Any]:
199
+ return {
200
+ "path_recipe_mode": self.path_recipe_mode,
201
+ "path_target_shape": list(self.path_target_shape),
202
+ "path_target_spacing": None if self.path_target_spacing is None else list(self.path_target_spacing),
203
+ "path_crop_margin_mm": self.path_crop_margin_mm,
204
+ "path_foreground_threshold": self.path_foreground_threshold,
205
+ "path_background_value": self.path_background_value,
206
+ "path_foreground_strategy": self.path_foreground_strategy,
207
+ "path_generic_recipe_id": self.path_generic_recipe_id,
208
+ "path_generic_cache_version": self.path_generic_cache_version,
209
+ }
210
+
211
+ def _target_spacing(self) -> Optional[Tuple[float, float, float]]:
212
+ if self.spacing is None:
213
+ return None
214
+ return tuple(float(item) for item in self.spacing)
215
+
216
+ def _resolve_pad_value(self) -> float:
217
+ if self.pad_value is not None:
218
+ return float(self.pad_value)
219
+ if self.do_normalize:
220
+ return float(self.output_range[0])
221
+ return 0.0
222
+
223
+ def _spacing_matches(
224
+ self,
225
+ source_spacing: Optional[Tuple[float, float, float]],
226
+ target_spacing: Optional[Tuple[float, float, float]],
227
+ ) -> bool:
228
+ if source_spacing is None or target_spacing is None:
229
+ return False
230
+ return all(abs(src - dst) <= self.spacing_tolerance for src, dst in zip(source_spacing, target_spacing))
231
+
232
+ def _nibabel_resample_order(self) -> int:
233
+ if self.interpolation_mode == "nearest":
234
+ return 0
235
+ return 1
236
+
237
+ def _resample_nifti_image(
238
+ self,
239
+ image,
240
+ source_spacing: Optional[Tuple[float, float, float]],
241
+ ) -> tuple[Any, Optional[Tuple[float, float, float]], bool]:
242
+ if not self.prefer_nibabel_resample or nib_processing is None:
243
+ return image, source_spacing, False
244
+
245
+ target_spacing = self._target_spacing()
246
+ if target_spacing is None or self._spacing_matches(source_spacing, target_spacing):
247
+ return image, source_spacing, False
248
+
249
+ resampled = nib_processing.resample_to_output(
250
+ image,
251
+ voxel_sizes=target_spacing,
252
+ order=self._nibabel_resample_order(),
253
+ )
254
+ return resampled, target_spacing, True
255
+
256
+ def _load_volume(
257
+ self,
258
+ value: VolumeInput,
259
+ source_spacing: Optional[Tuple[float, float, float]] = None,
260
+ ) -> tuple[np.ndarray, Optional[Tuple[float, float, float]], bool]:
261
+ if isinstance(value, (str, Path)):
262
+ if nib is None:
263
+ raise ImportError("`nibabel` is required to load NIfTI paths.")
264
+ image = nib.load(str(value))
265
+ if self.canonicalize_orientation:
266
+ image = nib.as_closest_canonical(image)
267
+ image_spacing = image.header.get_zooms()[:3]
268
+ resolved_spacing = None
269
+ if len(image_spacing) == 3:
270
+ resolved_spacing = tuple(float(item) for item in image_spacing)
271
+ image, resolved_spacing, used_nibabel_resample = self._resample_nifti_image(image, resolved_spacing)
272
+ return (
273
+ np.asarray(image.get_fdata(dtype=np.float32), dtype=np.float32),
274
+ resolved_spacing,
275
+ used_nibabel_resample,
276
+ )
277
+
278
+ if isinstance(value, torch.Tensor):
279
+ return value.detach().cpu().numpy().astype(np.float32, copy=False), source_spacing, False
280
+
281
+ if isinstance(value, np.ndarray):
282
+ return value.astype(np.float32, copy=False), source_spacing, False
283
+
284
+ raise TypeError(f"Unsupported volume input type: {type(value).__name__}")
285
+
286
+ def _preprocess_with_offline_recipe(self, value: VolumeInput) -> Optional[np.ndarray]:
287
+ if self.path_recipe_mode != "auto" or not isinstance(value, (str, Path)):
288
+ return None
289
+
290
+ image_path = str(value)
291
+ try:
292
+ if is_mr_rate_path is not None and preprocess_mr_rate_image is not None and is_mr_rate_path(image_path):
293
+ payload = preprocess_mr_rate_image(image_path)
294
+ return payload["pixel_values"].detach().cpu().numpy().astype(np.float32, copy=False)
295
+ if is_fomo_300k_path is not None and preprocess_fomo_300k_image is not None and is_fomo_300k_path(image_path):
296
+ payload = preprocess_fomo_300k_image(image_path)
297
+ return payload["pixel_values"].detach().cpu().numpy().astype(np.float32, copy=False)
298
+ payload = preprocess_image_with_foreground_mask(
299
+ image_path,
300
+ target_shape=tuple(int(value) for value in self.path_target_shape),
301
+ target_spacing=None
302
+ if self.path_target_spacing is None
303
+ else tuple(float(value) for value in self.path_target_spacing),
304
+ crop_margin_mm=self.path_crop_margin_mm,
305
+ foreground_threshold=self.path_foreground_threshold,
306
+ background_value=self.path_background_value,
307
+ foreground_strategy=self.path_foreground_strategy,
308
+ recipe_id=self.path_generic_recipe_id,
309
+ cache_version=self.path_generic_cache_version,
310
+ )
311
+ return payload["pixel_values"].detach().cpu().numpy().astype(np.float32, copy=False)
312
+ except Exception as exc:
313
+ LOGGER.warning(
314
+ "Falling back to legacy online preprocessing for %s after offline-recipe path failed: %s",
315
+ image_path,
316
+ exc,
317
+ )
318
+ return None
319
+
320
+ def _ensure_channel_first(self, volume: np.ndarray) -> np.ndarray:
321
+ if volume.ndim == 3:
322
+ return volume[None, ...]
323
+ if volume.ndim != 4:
324
+ raise ValueError(
325
+ "Volume must be 3D or 4D. For 4D volume, expected channel-first `[C, D, H, W]` "
326
+ "or channel-last `[D, H, W, C]`."
327
+ )
328
+ if volume.shape[0] <= self.max_channel_dim:
329
+ return volume
330
+ if volume.shape[-1] <= self.max_channel_dim:
331
+ return np.moveaxis(volume, -1, 0)
332
+ raise ValueError(
333
+ f"Cannot infer channel dimension for shape {volume.shape}. Expected channel dim <= {self.max_channel_dim}. "
334
+ "Please provide volume in [C, D, H, W] or [D, H, W, C] format."
335
+ )
336
+
337
+ def _foreground_mask(self, volume: np.ndarray) -> np.ndarray:
338
+ threshold = abs(self.foreground_threshold)
339
+ if volume.ndim == 4:
340
+ return np.any(np.abs(volume) > threshold, axis=0)
341
+ return np.abs(volume) > threshold
342
+
343
+ def _intensity_stats_values(self, volume: np.ndarray) -> np.ndarray:
344
+ if not self.use_foreground_intensity_stats:
345
+ return volume.reshape(-1)
346
+
347
+ mask = self._foreground_mask(volume)
348
+ if not mask.any():
349
+ return volume.reshape(-1)
350
+ if volume.ndim == 4:
351
+ return volume[:, mask].reshape(-1)
352
+ return volume[mask].reshape(-1)
353
+
354
+ def _clip_and_normalize(self, volume: np.ndarray) -> np.ndarray:
355
+ output = volume
356
+ if self.do_clip or self.do_normalize:
357
+ # Sanitize before percentile so NaN/inf don't corrupt the result.
358
+ output = np.nan_to_num(output, nan=0.0, posinf=0.0, neginf=0.0)
359
+ stats_values = self._intensity_stats_values(output)
360
+ if self.do_clip:
361
+ flat = stats_values
362
+ if flat.size > 1_000_000:
363
+ # Deterministic stride-based subsample for speed.
364
+ step = max(1, flat.size // 1_000_000)
365
+ flat = flat[::step]
366
+ low, high = np.percentile(flat, self.clip_percentiles)
367
+ else:
368
+ low, high = float(stats_values.min()), float(stats_values.max())
369
+
370
+ if np.isfinite(low) and np.isfinite(high) and high > low:
371
+ if self.do_clip:
372
+ output = np.clip(output, low, high)
373
+ if self.do_normalize:
374
+ out_low, out_high = self.output_range
375
+ output = np.clip((output - low) / (high - low), 0.0, 1.0)
376
+ output = output * (out_high - out_low) + out_low
377
+ elif self.do_normalize:
378
+ output = np.zeros_like(output, dtype=np.float32)
379
+ return output.astype(np.float32, copy=False)
380
+
381
+ def _resample_spacing(
382
+ self,
383
+ volume: np.ndarray,
384
+ source_spacing: Optional[Tuple[float, float, float]],
385
+ affine: Optional[np.ndarray] = None,
386
+ ) -> np.ndarray:
387
+ if self.spacing is None or source_spacing is None:
388
+ return volume
389
+
390
+ target_spacing = self._target_spacing()
391
+ if self._spacing_matches(source_spacing, target_spacing):
392
+ return volume
393
+
394
+ target_shape = []
395
+ for current_size, src, dst in zip(volume.shape[1:], source_spacing, target_spacing):
396
+ target_shape.append(max(1, int(round(float(current_size) * float(src) / float(dst)))))
397
+
398
+ if tuple(target_shape) == tuple(int(dim) for dim in volume.shape[1:]):
399
+ return volume
400
+
401
+ tensor = torch.from_numpy(volume).unsqueeze(0)
402
+ tensor = F.interpolate(
403
+ tensor,
404
+ size=tuple(target_shape),
405
+ mode=self.interpolation_mode,
406
+ align_corners=False if self.interpolation_mode in {"linear", "bilinear", "bicubic", "trilinear"} else None,
407
+ )
408
+ return tensor.squeeze(0).numpy().astype(np.float32, copy=False)
409
+
410
+ # def _crop_foreground(self, volume: np.ndarray) -> np.ndarray:
411
+ # if not self.do_crop_foreground:
412
+ # return volume
413
+
414
+ # # Per-axis projection avoids the massive temporary arrays from np.where.
415
+ # src = volume[0] if volume.ndim == 4 else volume
416
+ # mask = src > self.foreground_threshold
417
+ # if not mask.any():
418
+ # return volume
419
+
420
+ # slices = []
421
+ # for dim in range(mask.ndim):
422
+ # proj = mask.any(axis=tuple(d for d in range(mask.ndim) if d != dim))
423
+ # lo = int(np.argmax(proj))
424
+ # hi = len(proj) - 1 - int(np.argmax(proj[::-1]))
425
+ # slices.append(slice(lo, hi + 1))
426
+
427
+ # return volume[(slice(None),) + tuple(slices)].astype(np.float32, copy=False)
428
+ def _crop_foreground(self, volume: np.ndarray) -> np.ndarray:
429
+ if not self.do_crop_foreground:
430
+ return volume
431
+
432
+ margin = self.crop_margin
433
+ src = self._foreground_mask(volume)
434
+
435
+ if not src.any():
436
+ return volume
437
+
438
+ slices = []
439
+ for dim in range(src.ndim):
440
+ proj = src.any(axis=tuple(d for d in range(src.ndim) if d != dim))
441
+ lo = int(np.argmax(proj))
442
+ hi = len(proj) - 1 - int(np.argmax(proj[::-1]))
443
+
444
+ lo = max(0, lo - margin)
445
+ hi = min(src.shape[dim] - 1, hi + margin)
446
+
447
+ slices.append(slice(lo, hi + 1))
448
+
449
+ return volume[(slice(None),) + tuple(slices)].astype(np.float32, copy=False)
450
+ def _pad_or_crop_volume(self, volume: np.ndarray) -> np.ndarray:
451
+ target_size = tuple(int(v) for v in self.volume_size)
452
+ if volume.shape[1:] == target_size:
453
+ return volume
454
+
455
+ slices = [slice(None)]
456
+ for current, target in zip(volume.shape[1:], target_size):
457
+ if current > target:
458
+ start = max(0, (current - target) // 2)
459
+ slices.append(slice(start, start + target))
460
+ else:
461
+ slices.append(slice(0, current))
462
+ cropped = volume[tuple(slices)]
463
+
464
+ pad_width = [(0, 0)]
465
+ for current, target in zip(cropped.shape[1:], target_size):
466
+ if current < target:
467
+ delta = target - current
468
+ before = delta // 2
469
+ after = delta - before
470
+ pad_width.append((before, after))
471
+ else:
472
+ pad_width.append((0, 0))
473
+ if any(before != 0 or after != 0 for before, after in pad_width[1:]):
474
+ cropped = np.pad(
475
+ cropped,
476
+ pad_width=pad_width,
477
+ mode="constant",
478
+ constant_values=self.effective_pad_value,
479
+ )
480
+ return cropped.astype(np.float32, copy=False)
481
+
482
+ def _resize_volume(self, volume: np.ndarray) -> np.ndarray:
483
+ target_size = tuple(int(v) for v in self.volume_size)
484
+ if volume.shape[1:] == target_size:
485
+ return volume
486
+ if self.resize_strategy == "pad_or_crop":
487
+ return self._pad_or_crop_volume(volume)
488
+
489
+ tensor = torch.from_numpy(volume).unsqueeze(0)
490
+ tensor = F.interpolate(
491
+ tensor,
492
+ size=target_size,
493
+ mode=self.interpolation_mode,
494
+ align_corners=False if self.interpolation_mode in {"linear", "bilinear", "bicubic", "trilinear"} else None,
495
+ )
496
+ return tensor.squeeze(0).numpy().astype(np.float32, copy=False)
497
+
498
+ def preprocess(
499
+ self,
500
+ volumes: Union[VolumeInput, Sequence[VolumeInput]],
501
+ return_tensors: Optional[Union[str, bool]] = "pt",
502
+ source_spacings: SpacingInput = None,
503
+ **kwargs: Any,
504
+ ) -> BatchFeature:
505
+ del kwargs
506
+ items = _ensure_list(volumes)
507
+ spacing_values = _ensure_spacing_list(source_spacings, len(items))
508
+ batch = []
509
+ for item, source_spacing in zip(items, spacing_values):
510
+ recipe_aligned = self._preprocess_with_offline_recipe(item)
511
+ if recipe_aligned is not None:
512
+ batch.append(torch.from_numpy(recipe_aligned))
513
+ continue
514
+ volume, loaded_spacing, used_nibabel_resample = self._load_volume(item, source_spacing=source_spacing)
515
+ volume = self._ensure_channel_first(volume)
516
+ if not used_nibabel_resample:
517
+ volume = self._resample_spacing(volume, source_spacing=loaded_spacing)
518
+ volume = self._crop_foreground(volume)
519
+ volume = self._clip_and_normalize(volume)
520
+ volume = self._resize_volume(volume)
521
+ batch.append(torch.from_numpy(volume))
522
+
523
+ pixel_values = torch.stack(batch, dim=0).to(dtype=torch.float32)
524
+ return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors)
525
+
526
+ def __call__(
527
+ self,
528
+ volumes: Union[VolumeInput, Sequence[VolumeInput]],
529
+ return_tensors: Optional[Union[str, bool]] = "pt",
530
+ **kwargs: Any,
531
+ ) -> BatchFeature:
532
+ return self.preprocess(volumes=volumes, return_tensors=return_tensors, **kwargs)
533
+
534
+
535
+ class BrainMRISiglipProcessor(ProcessorMixin):
536
+ """Processor wrapping MRI volume processor + tokenizer."""
537
+
538
+ attributes = ["image_processor", "tokenizer"]
539
+ image_processor_class = "BaseImageProcessor"
540
+ tokenizer_class = "AutoTokenizer"
541
+
542
+ def __init__(self, image_processor: BrainMRISiglipVolumeProcessor, tokenizer) -> None:
543
+ super().__init__(image_processor=image_processor, tokenizer=tokenizer)
544
+
545
+ @classmethod
546
+ def from_text_pretrained(
547
+ cls,
548
+ text_model_name_or_path: str = "google/medsiglip-448",
549
+ volume_size: Union[int, Sequence[int]] = (128, 192, 192),
550
+ local_files_only: bool = False,
551
+ trust_remote_code: bool = True,
552
+ **kwargs: Any,
553
+ ) -> "BrainMRISiglipProcessor":
554
+ tokenizer = AutoTokenizer.from_pretrained(
555
+ text_model_name_or_path,
556
+ local_files_only=local_files_only,
557
+ trust_remote_code=trust_remote_code,
558
+ )
559
+ image_processor = BrainMRISiglipVolumeProcessor(volume_size=volume_size, **kwargs)
560
+ return cls(image_processor=image_processor, tokenizer=tokenizer)
561
+
562
+ @classmethod
563
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs: Any):
564
+ image_processor_kwargs = dict(kwargs.pop("image_processor_kwargs", {}) or {})
565
+ tokenizer_kwargs = dict(kwargs.pop("tokenizer_kwargs", {}) or {})
566
+
567
+ # Backward-compatible convenience: treat image-specific keys as image processor kwargs.
568
+ image_only_keys = {
569
+ "volume_size",
570
+ "clip_percentiles",
571
+ "output_range",
572
+ "do_clip",
573
+ "do_normalize",
574
+ "interpolation_mode",
575
+ "max_channel_dim",
576
+ "canonicalize_orientation",
577
+ "spacing",
578
+ "spacing_tolerance",
579
+ "prefer_nibabel_resample",
580
+ "use_foreground_intensity_stats",
581
+ "do_crop_foreground",
582
+ "foreground_threshold",
583
+ "crop_margin",
584
+ "resize_strategy",
585
+ "pad_value",
586
+ "path_recipe_mode",
587
+ "path_target_shape",
588
+ "path_target_spacing",
589
+ "path_crop_margin_mm",
590
+ "path_foreground_threshold",
591
+ "path_background_value",
592
+ "path_foreground_strategy",
593
+ "path_generic_recipe_id",
594
+ "path_generic_cache_version",
595
+ }
596
+ shared_kwargs = dict(kwargs)
597
+ for key in list(shared_kwargs.keys()):
598
+ if key in image_only_keys and key not in image_processor_kwargs:
599
+ image_processor_kwargs[key] = shared_kwargs.pop(key)
600
+
601
+ image_processor = BrainMRISiglipVolumeProcessor.from_pretrained(
602
+ pretrained_model_name_or_path,
603
+ **shared_kwargs,
604
+ **image_processor_kwargs,
605
+ )
606
+ tokenizer = AutoTokenizer.from_pretrained(
607
+ pretrained_model_name_or_path,
608
+ **shared_kwargs,
609
+ **tokenizer_kwargs,
610
+ )
611
+ return cls(image_processor=image_processor, tokenizer=tokenizer)
612
+
613
+ def save_pretrained(self, save_directory: Union[str, Path], **kwargs: Any) -> tuple[str]:
614
+ save_path = Path(save_directory)
615
+ save_path.mkdir(parents=True, exist_ok=True)
616
+ self.image_processor.save_pretrained(str(save_path), **kwargs)
617
+ self.tokenizer.save_pretrained(str(save_path), **kwargs)
618
+ processor_config = {
619
+ "processor_class": self.__class__.__name__,
620
+ "auto_map": {"AutoProcessor": "processing_brain_mri_siglip.BrainMRISiglipProcessor"},
621
+ "offline_aligned_preprocessing": self.image_processor.get_path_recipe_config(),
622
+ }
623
+ (save_path / "processor_config.json").write_text(json.dumps(processor_config, indent=2), encoding="utf-8")
624
+ copy_remote_code_files(save_path)
625
+ return (str(save_path),)
626
+
627
+ @property
628
+ def model_input_names(self) -> List[str]:
629
+ names = list(self.tokenizer.model_input_names)
630
+ for item in self.image_processor.model_input_names:
631
+ if item not in names:
632
+ names.append(item)
633
+ return names
634
+
635
+ def __call__(
636
+ self,
637
+ text: Optional[Union[TextInput, PreTokenizedInput, Sequence[TextInput], Sequence[PreTokenizedInput]]] = None,
638
+ volumes: Optional[Union[VolumeInput, Sequence[VolumeInput]]] = None,
639
+ padding: Union[bool, str, PaddingStrategy] = "max_length",
640
+ truncation: Union[bool, str, TruncationStrategy] = True,
641
+ max_length: Optional[int] = None,
642
+ return_tensors: Optional[Union[str, bool]] = "pt",
643
+ **kwargs: Any,
644
+ ) -> BatchFeature:
645
+ if text is None and volumes is None:
646
+ raise ValueError("At least one of `text` or `volumes` must be provided.")
647
+
648
+ image_processor_kwargs = dict(kwargs.pop("image_processor_kwargs", {}) or {})
649
+ image_only_keys = {"source_spacings"}
650
+ for key in list(kwargs.keys()):
651
+ if key in image_only_keys and key not in image_processor_kwargs:
652
+ image_processor_kwargs[key] = kwargs.pop(key)
653
+
654
+ data: Dict[str, Any] = {}
655
+ if text is not None:
656
+ text_inputs = self.tokenizer(
657
+ text,
658
+ padding=padding,
659
+ truncation=truncation,
660
+ max_length=max_length,
661
+ return_tensors=return_tensors,
662
+ **kwargs,
663
+ )
664
+ data.update(dict(text_inputs))
665
+
666
+ if volumes is not None:
667
+ image_inputs = self.image_processor(
668
+ volumes=volumes,
669
+ return_tensors=return_tensors,
670
+ **image_processor_kwargs,
671
+ )
672
+ data.update(dict(image_inputs))
673
+
674
+ return BatchFeature(data=data, tensor_type=return_tensors)
675
+
676
+ def batch_decode(self, *args: Any, **kwargs: Any):
677
+ return self.tokenizer.batch_decode(*args, **kwargs)
678
+
679
+ def decode(self, *args: Any, **kwargs: Any):
680
+ return self.tokenizer.decode(*args, **kwargs)
processor_config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "BrainMRISiglipProcessor",
3
+ "auto_map": {
4
+ "AutoProcessor": "processing_brain_mri_siglip.BrainMRISiglipProcessor"
5
+ },
6
+ "offline_aligned_preprocessing": {
7
+ "path_recipe_mode": "auto",
8
+ "path_target_shape": [
9
+ 128,
10
+ 192,
11
+ 192
12
+ ],
13
+ "path_target_spacing": [
14
+ 1.25,
15
+ 1.0,
16
+ 1.0
17
+ ],
18
+ "path_crop_margin_mm": 5.0,
19
+ "path_foreground_threshold": 0.001,
20
+ "path_background_value": -1.0,
21
+ "path_foreground_strategy": "largest_component_nonzero",
22
+ "path_generic_recipe_id": "generic_foreground_128x192x192_fp16_v1",
23
+ "path_generic_cache_version": 1
24
+ }
25
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "eos_token": {
3
+ "content": "</s>",
4
+ "lstrip": true,
5
+ "normalized": false,
6
+ "rstrip": true,
7
+ "single_word": false
8
+ },
9
+ "pad_token": {
10
+ "content": "</s>",
11
+ "lstrip": true,
12
+ "normalized": false,
13
+ "rstrip": true,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": true,
19
+ "normalized": false,
20
+ "rstrip": true,
21
+ "single_word": false
22
+ }
23
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e5036bed065526c3c212dfbe288752391797c4bb1a284aa18c9a0b23fcaf8ec
3
+ size 798330
tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "1": {
4
+ "content": "</s>",
5
+ "lstrip": true,
6
+ "normalized": false,
7
+ "rstrip": true,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "2": {
12
+ "content": "<unk>",
13
+ "lstrip": true,
14
+ "normalized": false,
15
+ "rstrip": true,
16
+ "single_word": false,
17
+ "special": true
18
+ }
19
+ },
20
+ "additional_special_tokens": [],
21
+ "clean_up_tokenization_spaces": true,
22
+ "do_lower_case": true,
23
+ "eos_token": "</s>",
24
+ "extra_special_tokens": {},
25
+ "model_input_names": [
26
+ "input_ids"
27
+ ],
28
+ "model_max_length": 64,
29
+ "pad_token": "</s>",
30
+ "processor_class": "SiglipProcessor",
31
+ "sp_model_kwargs": {},
32
+ "tokenizer_class": "SiglipTokenizer",
33
+ "unk_token": "<unk>"
34
+ }