Feature Extraction
Transformers
Safetensors
PyTorch
brain-mri-siglip
medical-imaging
mri
brain-mri
siglip
vision-language
contrastive-learning
custom-code
custom_code
Instructions to use shenxiaochen/brain-mri-siglip with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use shenxiaochen/brain-mri-siglip with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="shenxiaochen/brain-mri-siglip", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("shenxiaochen/brain-mri-siglip", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Add files using upload-large-folder tool
Browse files- README.md +156 -0
- __init__.py +19 -0
- common.py +36 -0
- config.json +181 -0
- configuration_brain_mri_siglip.py +112 -0
- model.safetensors +3 -0
- modeling_brain_mri_siglip.py +615 -0
- offline_aligned_preprocessing.py +286 -0
- preprocessor_config.json +52 -0
- processing_brain_mri_siglip.py +680 -0
- processor_config.json +25 -0
- special_tokens_map.json +23 -0
- spiece.model +3 -0
- tokenizer_config.json +34 -0
README.md
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
pipeline_tag: feature-extraction
|
| 4 |
+
base_model: google/medsiglip-448
|
| 5 |
+
tags:
|
| 6 |
+
- medical-imaging
|
| 7 |
+
- mri
|
| 8 |
+
- brain-mri
|
| 9 |
+
- siglip
|
| 10 |
+
- vision-language
|
| 11 |
+
- contrastive-learning
|
| 12 |
+
- feature-extraction
|
| 13 |
+
- custom-code
|
| 14 |
+
- pytorch
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# Brain MRI SigLIP
|
| 18 |
+
|
| 19 |
+
Brain MRI SigLIP is a 3D MRI vision-language representation model trained with a SigLIP-style image-text contrastive objective. This repository publishes the final saved `stage2_joint_finetune` checkpoint from the `brain_mri_siglip_run_0509` experiment.
|
| 20 |
+
|
| 21 |
+
This checkpoint is intended as a research visual encoder for brain MRI downstream tasks and as a warm-start encoder for building a medical VLM. It is not a clinical diagnostic device.
|
| 22 |
+
|
| 23 |
+
## Model Summary
|
| 24 |
+
|
| 25 |
+
- Base text tower: `google/medsiglip-448`
|
| 26 |
+
- Model class: `BrainMRISiglipModel`
|
| 27 |
+
- Vision input: single-channel 3D MRI volumes
|
| 28 |
+
- Expected volume shape: `[1, 128, 192, 192]`
|
| 29 |
+
- Projection dimension: `1152`
|
| 30 |
+
- Patch size: `[8, 16, 16]`
|
| 31 |
+
- Training precision: `bf16`
|
| 32 |
+
- Training input format: preprocessed `.pt` tensors, `float16`, value range `[-1, 1]`
|
| 33 |
+
|
| 34 |
+
## Training Context
|
| 35 |
+
|
| 36 |
+
This model was initialized from the `brain_mri_siglip_run_0509/stage1_freeze_text` checkpoint and then jointly fine-tuned with both vision and text towers trainable.
|
| 37 |
+
|
| 38 |
+
Training summary:
|
| 39 |
+
|
| 40 |
+
- Training samples: `950,720`
|
| 41 |
+
- Validation samples: `67,450`
|
| 42 |
+
- Validation samples with `metadata_text`: `32,278`
|
| 43 |
+
- Stage 1: frozen text tower, vision-heavy training
|
| 44 |
+
- Stage 2: joint vision-text fine-tuning
|
| 45 |
+
- Stage 2 epochs configured: `8`
|
| 46 |
+
- World size: `5`
|
| 47 |
+
- Stage 2 per-device batch size: `160`
|
| 48 |
+
- Stage 2 contrastive forward batch: `800`
|
| 49 |
+
- Gradient checkpointing: text and vision enabled
|
| 50 |
+
|
| 51 |
+
Training-time retrieval evaluation used capped validation subsets and should be treated as monitoring rather than a final benchmark.
|
| 52 |
+
|
| 53 |
+
## Loading
|
| 54 |
+
|
| 55 |
+
This model uses custom Transformers code. Load it with `trust_remote_code=True`.
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
import torch
|
| 59 |
+
from transformers import AutoModel, AutoProcessor
|
| 60 |
+
|
| 61 |
+
repo_id = "shenxiaochen/brain-mri-siglip"
|
| 62 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 63 |
+
|
| 64 |
+
model = AutoModel.from_pretrained(
|
| 65 |
+
repo_id,
|
| 66 |
+
trust_remote_code=True,
|
| 67 |
+
dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 68 |
+
).to(device).eval()
|
| 69 |
+
|
| 70 |
+
processor = AutoProcessor.from_pretrained(
|
| 71 |
+
repo_id,
|
| 72 |
+
trust_remote_code=True,
|
| 73 |
+
)
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## NIfTI Preprocessing
|
| 77 |
+
|
| 78 |
+
For reproducible inference from NIfTI files, pass paths directly to the saved processor. This repository includes the offline-aligned preprocessing implementation used to match the training tensor distribution.
|
| 79 |
+
|
| 80 |
+
```python
|
| 81 |
+
nifti_path = "/path/to/brain_mri.nii.gz"
|
| 82 |
+
|
| 83 |
+
inputs = processor(
|
| 84 |
+
volumes=nifti_path,
|
| 85 |
+
return_tensors="pt",
|
| 86 |
+
)
|
| 87 |
+
pixel_values = inputs["pixel_values"].to(device)
|
| 88 |
+
|
| 89 |
+
if torch.cuda.is_available():
|
| 90 |
+
pixel_values = pixel_values.to(dtype=torch.bfloat16)
|
| 91 |
+
|
| 92 |
+
with torch.inference_mode():
|
| 93 |
+
image_embeds = model.get_image_features(pixel_values=pixel_values)
|
| 94 |
+
|
| 95 |
+
print(pixel_values.shape) # [1, 1, 128, 192, 192]
|
| 96 |
+
print(image_embeds.shape) # [1, 1152]
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
The saved path-based preprocessing recipe is:
|
| 100 |
+
|
| 101 |
+
- canonicalize image orientation to closest RAS
|
| 102 |
+
- build foreground mask with threshold `1e-3`
|
| 103 |
+
- keep the largest connected foreground component
|
| 104 |
+
- crop foreground with `5mm` margin
|
| 105 |
+
- normalize foreground intensities with `0.5/99.5` percentiles
|
| 106 |
+
- map intensities to `[-1, 1]`
|
| 107 |
+
- resample to spacing `(1.25, 1.0, 1.0)`
|
| 108 |
+
- downscale to fit `[128, 192, 192]`
|
| 109 |
+
- center-pad with background value `-1.0`
|
| 110 |
+
|
| 111 |
+
The exact settings are saved in `preprocessor_config.json` and `processor_config.json`.
|
| 112 |
+
|
| 113 |
+
## Using Preprocessed `.pt` Inputs
|
| 114 |
+
|
| 115 |
+
If your data is already stored as the same offline preprocessed tensors used during training, you can load it directly:
|
| 116 |
+
|
| 117 |
+
```python
|
| 118 |
+
payload = torch.load("/path/to/sample.pt", map_location="cpu")
|
| 119 |
+
pixel_values = payload["pixel_values"] if isinstance(payload, dict) else payload
|
| 120 |
+
|
| 121 |
+
if pixel_values.ndim == 4:
|
| 122 |
+
pixel_values = pixel_values.unsqueeze(0)
|
| 123 |
+
|
| 124 |
+
pixel_values = pixel_values.to(device=device, dtype=torch.bfloat16)
|
| 125 |
+
|
| 126 |
+
with torch.inference_mode():
|
| 127 |
+
image_embeds = model.get_image_features(pixel_values=pixel_values)
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
Expected tensor format:
|
| 131 |
+
|
| 132 |
+
- shape `[1, 128, 192, 192]` for one volume, or `[B, 1, 128, 192, 192]` for a batch
|
| 133 |
+
- values in `[-1, 1]`
|
| 134 |
+
- padded background voxels near `-1.0`
|
| 135 |
+
|
| 136 |
+
## VLM Integration Notes
|
| 137 |
+
|
| 138 |
+
For VLM construction, use the 3D vision tower as a visual backbone and add a projector, Q-Former, Perceiver resampler, or other token compressor before connecting to an LLM.
|
| 139 |
+
|
| 140 |
+
A practical downstream recipe is:
|
| 141 |
+
|
| 142 |
+
1. Freeze this MRI encoder and train only the multimodal projector/resampler.
|
| 143 |
+
2. Evaluate downstream classification, retrieval, report alignment, or instruction-following behavior.
|
| 144 |
+
3. Optionally unfreeze the top vision layers with a much smaller learning rate.
|
| 145 |
+
|
| 146 |
+
## Limitations
|
| 147 |
+
|
| 148 |
+
- This checkpoint was trained for representation learning, not diagnosis.
|
| 149 |
+
- Performance should be validated on task-specific subject-level or study-level splits.
|
| 150 |
+
- Scanner, protocol, site, and preprocessing differences can affect embeddings.
|
| 151 |
+
- External users should preserve the saved preprocessing pipeline for NIfTI inference.
|
| 152 |
+
- Retrieval monitoring during training is not a substitute for downstream clinical validation.
|
| 153 |
+
|
| 154 |
+
## Citation
|
| 155 |
+
|
| 156 |
+
If you use this checkpoint, please cite this model repository and the upstream MedSigLIP model where appropriate.
|
__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration_brain_mri_siglip import BrainMRISiglipConfig
|
| 2 |
+
from .modeling_brain_mri_siglip import BrainMRISiglipModel
|
| 3 |
+
from .processing_brain_mri_siglip import BrainMRISiglipProcessor, BrainMRISiglipVolumeProcessor
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"BrainMRISiglipConfig",
|
| 7 |
+
"BrainMRISiglipModel",
|
| 8 |
+
"BrainMRISiglipProcessor",
|
| 9 |
+
"BrainMRISiglipVolumeProcessor",
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
BrainMRISiglipConfig.register_for_auto_class("AutoConfig")
|
| 14 |
+
BrainMRISiglipModel.register_for_auto_class("AutoModel")
|
| 15 |
+
BrainMRISiglipProcessor.register_for_auto_class("AutoProcessor")
|
| 16 |
+
except Exception:
|
| 17 |
+
# Registration is best-effort and not required for local imports.
|
| 18 |
+
pass
|
| 19 |
+
|
common.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Common utility helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import shutil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Iterable
|
| 8 |
+
from typing import Sequence, Tuple, Union
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def to_3tuple(value: Union[int, Sequence[int]], name: str) -> Tuple[int, int, int]:
|
| 12 |
+
if isinstance(value, int):
|
| 13 |
+
return (value, value, value)
|
| 14 |
+
if len(value) != 3:
|
| 15 |
+
raise ValueError(f"`{name}` must be an int or length-3 sequence. Got: {value}")
|
| 16 |
+
return (int(value[0]), int(value[1]), int(value[2]))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
REMOTE_CODE_FILES = (
|
| 20 |
+
"__init__.py",
|
| 21 |
+
"common.py",
|
| 22 |
+
"configuration_brain_mri_siglip.py",
|
| 23 |
+
"modeling_brain_mri_siglip.py",
|
| 24 |
+
"offline_aligned_preprocessing.py",
|
| 25 |
+
"processing_brain_mri_siglip.py",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def copy_remote_code_files(destination: Union[str, Path], file_names: Iterable[str] = REMOTE_CODE_FILES) -> None:
|
| 30 |
+
src_dir = Path(__file__).resolve().parent
|
| 31 |
+
dst_dir = Path(destination)
|
| 32 |
+
dst_dir.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
for name in file_names:
|
| 34 |
+
src_file = src_dir / name
|
| 35 |
+
if src_file.exists():
|
| 36 |
+
shutil.copy2(src_file, dst_dir / name)
|
config.json
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BrainMRISiglipModel"
|
| 4 |
+
],
|
| 5 |
+
"attn_implementation": null,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_brain_mri_siglip.BrainMRISiglipConfig",
|
| 8 |
+
"AutoModel": "modeling_brain_mri_siglip.BrainMRISiglipModel",
|
| 9 |
+
"AutoProcessor": "processing_brain_mri_siglip.BrainMRISiglipProcessor"
|
| 10 |
+
},
|
| 11 |
+
"dtype": "float32",
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"logit_bias_init_value": -10.0,
|
| 14 |
+
"logit_scale_init_value": 2.6592,
|
| 15 |
+
"logit_scale_max": 100.0,
|
| 16 |
+
"logit_scale_min": 0.001,
|
| 17 |
+
"max_text_length": 64,
|
| 18 |
+
"model_type": "brain-mri-siglip",
|
| 19 |
+
"num_channels": 1,
|
| 20 |
+
"patch_size": [
|
| 21 |
+
8,
|
| 22 |
+
16,
|
| 23 |
+
16
|
| 24 |
+
],
|
| 25 |
+
"projection_dim": 1152,
|
| 26 |
+
"text_config": {
|
| 27 |
+
"_name_or_path": "",
|
| 28 |
+
"add_cross_attention": false,
|
| 29 |
+
"architectures": null,
|
| 30 |
+
"attention_dropout": 0.0,
|
| 31 |
+
"bad_words_ids": null,
|
| 32 |
+
"begin_suppress_tokens": null,
|
| 33 |
+
"bos_token_id": 49406,
|
| 34 |
+
"chunk_size_feed_forward": 0,
|
| 35 |
+
"cross_attention_hidden_size": null,
|
| 36 |
+
"decoder_start_token_id": null,
|
| 37 |
+
"diversity_penalty": 0.0,
|
| 38 |
+
"do_sample": false,
|
| 39 |
+
"dtype": null,
|
| 40 |
+
"early_stopping": false,
|
| 41 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 42 |
+
"eos_token_id": 49407,
|
| 43 |
+
"exponential_decay_length_penalty": null,
|
| 44 |
+
"finetuning_task": null,
|
| 45 |
+
"forced_bos_token_id": null,
|
| 46 |
+
"forced_eos_token_id": null,
|
| 47 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 48 |
+
"hidden_size": 1152,
|
| 49 |
+
"id2label": {
|
| 50 |
+
"0": "LABEL_0",
|
| 51 |
+
"1": "LABEL_1"
|
| 52 |
+
},
|
| 53 |
+
"intermediate_size": 4304,
|
| 54 |
+
"is_decoder": false,
|
| 55 |
+
"is_encoder_decoder": false,
|
| 56 |
+
"label2id": {
|
| 57 |
+
"LABEL_0": 0,
|
| 58 |
+
"LABEL_1": 1
|
| 59 |
+
},
|
| 60 |
+
"layer_norm_eps": 1e-06,
|
| 61 |
+
"length_penalty": 1.0,
|
| 62 |
+
"max_length": 20,
|
| 63 |
+
"max_position_embeddings": 64,
|
| 64 |
+
"min_length": 0,
|
| 65 |
+
"model_type": "siglip_text_model",
|
| 66 |
+
"no_repeat_ngram_size": 0,
|
| 67 |
+
"num_attention_heads": 16,
|
| 68 |
+
"num_beam_groups": 1,
|
| 69 |
+
"num_beams": 1,
|
| 70 |
+
"num_hidden_layers": 27,
|
| 71 |
+
"num_return_sequences": 1,
|
| 72 |
+
"output_attentions": false,
|
| 73 |
+
"output_hidden_states": false,
|
| 74 |
+
"output_scores": false,
|
| 75 |
+
"pad_token_id": 1,
|
| 76 |
+
"prefix": null,
|
| 77 |
+
"problem_type": null,
|
| 78 |
+
"projection_size": 1152,
|
| 79 |
+
"pruned_heads": {},
|
| 80 |
+
"remove_invalid_values": false,
|
| 81 |
+
"repetition_penalty": 1.0,
|
| 82 |
+
"return_dict": true,
|
| 83 |
+
"return_dict_in_generate": false,
|
| 84 |
+
"sep_token_id": null,
|
| 85 |
+
"suppress_tokens": null,
|
| 86 |
+
"task_specific_params": null,
|
| 87 |
+
"temperature": 1.0,
|
| 88 |
+
"tf_legacy_loss": false,
|
| 89 |
+
"tie_encoder_decoder": false,
|
| 90 |
+
"tie_word_embeddings": true,
|
| 91 |
+
"tokenizer_class": null,
|
| 92 |
+
"top_k": 50,
|
| 93 |
+
"top_p": 1.0,
|
| 94 |
+
"torchscript": false,
|
| 95 |
+
"transformers_version": "4.57.6",
|
| 96 |
+
"typical_p": 1.0,
|
| 97 |
+
"use_bfloat16": false,
|
| 98 |
+
"vocab_size": 32000
|
| 99 |
+
},
|
| 100 |
+
"text_model_name_or_path": "google/medsiglip-448",
|
| 101 |
+
"transformers_version": "4.57.6",
|
| 102 |
+
"vision_config": {
|
| 103 |
+
"_name_or_path": "",
|
| 104 |
+
"add_cross_attention": false,
|
| 105 |
+
"architectures": null,
|
| 106 |
+
"attention_dropout": 0.0,
|
| 107 |
+
"bad_words_ids": null,
|
| 108 |
+
"begin_suppress_tokens": null,
|
| 109 |
+
"bos_token_id": null,
|
| 110 |
+
"chunk_size_feed_forward": 0,
|
| 111 |
+
"cross_attention_hidden_size": null,
|
| 112 |
+
"decoder_start_token_id": null,
|
| 113 |
+
"diversity_penalty": 0.0,
|
| 114 |
+
"do_sample": false,
|
| 115 |
+
"dtype": null,
|
| 116 |
+
"early_stopping": false,
|
| 117 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 118 |
+
"eos_token_id": null,
|
| 119 |
+
"exponential_decay_length_penalty": null,
|
| 120 |
+
"finetuning_task": null,
|
| 121 |
+
"forced_bos_token_id": null,
|
| 122 |
+
"forced_eos_token_id": null,
|
| 123 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 124 |
+
"hidden_size": 1152,
|
| 125 |
+
"id2label": {
|
| 126 |
+
"0": "LABEL_0",
|
| 127 |
+
"1": "LABEL_1"
|
| 128 |
+
},
|
| 129 |
+
"image_size": 448,
|
| 130 |
+
"intermediate_size": 4304,
|
| 131 |
+
"is_decoder": false,
|
| 132 |
+
"is_encoder_decoder": false,
|
| 133 |
+
"label2id": {
|
| 134 |
+
"LABEL_0": 0,
|
| 135 |
+
"LABEL_1": 1
|
| 136 |
+
},
|
| 137 |
+
"layer_norm_eps": 1e-06,
|
| 138 |
+
"length_penalty": 1.0,
|
| 139 |
+
"max_length": 20,
|
| 140 |
+
"min_length": 0,
|
| 141 |
+
"model_type": "siglip_vision_model",
|
| 142 |
+
"no_repeat_ngram_size": 0,
|
| 143 |
+
"num_attention_heads": 16,
|
| 144 |
+
"num_beam_groups": 1,
|
| 145 |
+
"num_beams": 1,
|
| 146 |
+
"num_channels": 1,
|
| 147 |
+
"num_hidden_layers": 27,
|
| 148 |
+
"num_return_sequences": 1,
|
| 149 |
+
"output_attentions": false,
|
| 150 |
+
"output_hidden_states": false,
|
| 151 |
+
"output_scores": false,
|
| 152 |
+
"pad_token_id": null,
|
| 153 |
+
"patch_size": 14,
|
| 154 |
+
"prefix": null,
|
| 155 |
+
"problem_type": null,
|
| 156 |
+
"pruned_heads": {},
|
| 157 |
+
"remove_invalid_values": false,
|
| 158 |
+
"repetition_penalty": 1.0,
|
| 159 |
+
"return_dict": true,
|
| 160 |
+
"return_dict_in_generate": false,
|
| 161 |
+
"sep_token_id": null,
|
| 162 |
+
"suppress_tokens": null,
|
| 163 |
+
"task_specific_params": null,
|
| 164 |
+
"temperature": 1.0,
|
| 165 |
+
"tf_legacy_loss": false,
|
| 166 |
+
"tie_encoder_decoder": false,
|
| 167 |
+
"tie_word_embeddings": true,
|
| 168 |
+
"tokenizer_class": null,
|
| 169 |
+
"top_k": 50,
|
| 170 |
+
"top_p": 1.0,
|
| 171 |
+
"torchscript": false,
|
| 172 |
+
"transformers_version": "4.57.6",
|
| 173 |
+
"typical_p": 1.0,
|
| 174 |
+
"use_bfloat16": false
|
| 175 |
+
},
|
| 176 |
+
"volume_size": [
|
| 177 |
+
128,
|
| 178 |
+
192,
|
| 179 |
+
192
|
| 180 |
+
]
|
| 181 |
+
}
|
configuration_brain_mri_siglip.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration for Brain MRI SigLIP."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, Mapping, Optional, Sequence, Union
|
| 6 |
+
|
| 7 |
+
from transformers import PretrainedConfig, SiglipTextConfig, SiglipVisionConfig
|
| 8 |
+
|
| 9 |
+
from .common import to_3tuple
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class BrainMRISiglipConfig(PretrainedConfig):
|
| 13 |
+
r"""Configuration class for :class:`BrainMRISiglipModel`."""
|
| 14 |
+
|
| 15 |
+
model_type = "brain-mri-siglip"
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
text_config: Optional[Mapping[str, Any]] = None,
|
| 20 |
+
vision_config: Optional[Mapping[str, Any]] = None,
|
| 21 |
+
text_model_name_or_path: str = "google/medsiglip-448",
|
| 22 |
+
volume_size: Union[int, Sequence[int]] = (128, 192, 192),
|
| 23 |
+
patch_size: Union[int, Sequence[int]] = (8, 16, 16),
|
| 24 |
+
num_channels: int = 1,
|
| 25 |
+
projection_dim: Optional[int] = None,
|
| 26 |
+
logit_scale_init_value: float = 2.6592,
|
| 27 |
+
logit_scale_min: float = 1e-3,
|
| 28 |
+
logit_bias_init_value: float = -10.0,
|
| 29 |
+
logit_scale_max: float = 100.0,
|
| 30 |
+
attn_implementation: Optional[str] = None,
|
| 31 |
+
max_text_length: int = 64,
|
| 32 |
+
initializer_range: float = 0.02,
|
| 33 |
+
auto_map: Optional[Mapping[str, str]] = None,
|
| 34 |
+
**kwargs: Any,
|
| 35 |
+
) -> None:
|
| 36 |
+
if text_config is None:
|
| 37 |
+
text_config_dict = SiglipTextConfig().to_dict()
|
| 38 |
+
else:
|
| 39 |
+
text_config_dict = dict(text_config)
|
| 40 |
+
|
| 41 |
+
if vision_config is None:
|
| 42 |
+
vision_config_dict = SiglipVisionConfig().to_dict()
|
| 43 |
+
else:
|
| 44 |
+
vision_config_dict = dict(vision_config)
|
| 45 |
+
|
| 46 |
+
resolved_volume_size = to_3tuple(volume_size, "volume_size")
|
| 47 |
+
resolved_patch_size = to_3tuple(patch_size, "patch_size")
|
| 48 |
+
if any(v <= 0 for v in resolved_volume_size):
|
| 49 |
+
raise ValueError(f"`volume_size` must contain positive integers. Got {resolved_volume_size}.")
|
| 50 |
+
if any(p <= 0 for p in resolved_patch_size):
|
| 51 |
+
raise ValueError(f"`patch_size` must contain positive integers. Got {resolved_patch_size}.")
|
| 52 |
+
if any(v % p != 0 for v, p in zip(resolved_volume_size, resolved_patch_size)):
|
| 53 |
+
raise ValueError(
|
| 54 |
+
f"`volume_size` must be divisible by `patch_size`. "
|
| 55 |
+
f"Got volume_size={resolved_volume_size}, patch_size={resolved_patch_size}."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
vision_config_dict["num_channels"] = int(num_channels)
|
| 59 |
+
|
| 60 |
+
if projection_dim is None:
|
| 61 |
+
projection_dim = int(
|
| 62 |
+
text_config_dict.get(
|
| 63 |
+
"projection_size",
|
| 64 |
+
text_config_dict.get("hidden_size", vision_config_dict.get("hidden_size", 768)),
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if auto_map is None:
|
| 69 |
+
# Keep module paths as `<module>.<Class>` for compatibility with HF dynamic loader.
|
| 70 |
+
auto_map = {
|
| 71 |
+
"AutoConfig": "configuration_brain_mri_siglip.BrainMRISiglipConfig",
|
| 72 |
+
"AutoModel": "modeling_brain_mri_siglip.BrainMRISiglipModel",
|
| 73 |
+
"AutoProcessor": "processing_brain_mri_siglip.BrainMRISiglipProcessor",
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
self.text_config = text_config_dict
|
| 77 |
+
self.vision_config = vision_config_dict
|
| 78 |
+
self.text_model_name_or_path = text_model_name_or_path
|
| 79 |
+
self.volume_size = list(resolved_volume_size)
|
| 80 |
+
self.patch_size = list(resolved_patch_size)
|
| 81 |
+
self.num_channels = int(num_channels)
|
| 82 |
+
self.projection_dim = int(projection_dim)
|
| 83 |
+
self.logit_scale_init_value = float(logit_scale_init_value)
|
| 84 |
+
self.logit_scale_min = float(logit_scale_min)
|
| 85 |
+
self.logit_bias_init_value = float(logit_bias_init_value)
|
| 86 |
+
self.logit_scale_max = float(logit_scale_max)
|
| 87 |
+
self.attn_implementation = attn_implementation
|
| 88 |
+
self.max_text_length = int(max_text_length)
|
| 89 |
+
self.initializer_range = float(initializer_range)
|
| 90 |
+
self.auto_map = dict(auto_map)
|
| 91 |
+
|
| 92 |
+
super().__init__(**kwargs)
|
| 93 |
+
|
| 94 |
+
def get_text_config(self, *args: Any, **kwargs: Any) -> SiglipTextConfig:
|
| 95 |
+
del args, kwargs
|
| 96 |
+
config = SiglipTextConfig(**self.text_config)
|
| 97 |
+
if self.attn_implementation:
|
| 98 |
+
config._attn_implementation = self.attn_implementation
|
| 99 |
+
elif getattr(config, "_attn_implementation", None) is None:
|
| 100 |
+
config._attn_implementation = "sdpa"
|
| 101 |
+
return config
|
| 102 |
+
|
| 103 |
+
def get_vision_config(self, *args: Any, **kwargs: Any) -> SiglipVisionConfig:
|
| 104 |
+
del args, kwargs
|
| 105 |
+
cfg_dict = dict(self.vision_config)
|
| 106 |
+
cfg_dict["num_channels"] = int(self.num_channels)
|
| 107 |
+
config = SiglipVisionConfig(**cfg_dict)
|
| 108 |
+
if self.attn_implementation:
|
| 109 |
+
config._attn_implementation = self.attn_implementation
|
| 110 |
+
elif getattr(config, "_attn_implementation", None) is None:
|
| 111 |
+
config._attn_implementation = "sdpa"
|
| 112 |
+
return config
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:656edd47b6a98dfa950593e10cb0d8214b30e7eca4a4f725c69129f22aabf055
|
| 3 |
+
size 3536557760
|
modeling_brain_mri_siglip.py
ADDED
|
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modeling code for Brain MRI SigLIP."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from typing import Any, Mapping, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from torch import nn
|
| 12 |
+
from torch.distributed.nn.functional import all_gather as all_gather_with_grad
|
| 13 |
+
from transformers import AutoConfig, AutoModel, PreTrainedModel
|
| 14 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
| 15 |
+
from transformers.models.siglip import SiglipTextConfig, SiglipVisionConfig
|
| 16 |
+
from transformers.models.siglip.modeling_siglip import (
|
| 17 |
+
SiglipAttention,
|
| 18 |
+
SiglipEncoder,
|
| 19 |
+
SiglipMLP,
|
| 20 |
+
SiglipMultiheadAttentionPoolingHead,
|
| 21 |
+
SiglipOutput,
|
| 22 |
+
SiglipTextModel,
|
| 23 |
+
default_flax_embed_init,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from .configuration_brain_mri_siglip import BrainMRISiglipConfig
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _siglip_sigmoid_loss(logits_per_text: torch.Tensor) -> torch.Tensor:
|
| 30 |
+
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device, dtype=logits_per_text.dtype)
|
| 31 |
+
labels = -torch.ones_like(logits_per_text) + 2 * eye
|
| 32 |
+
loglik = F.logsigmoid(labels * logits_per_text)
|
| 33 |
+
nll = -torch.sum(loglik, dim=-1)
|
| 34 |
+
return nll.mean()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _lecun_normal_(tensor: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(tensor)
|
| 39 |
+
if fan_in <= 0:
|
| 40 |
+
return nn.init.normal_(tensor, mean=0.0, std=1.0)
|
| 41 |
+
return nn.init.normal_(tensor, mean=0.0, std=1.0 / math.sqrt(fan_in))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _siglip_embedding_init_(tensor: torch.Tensor) -> torch.Tensor:
|
| 45 |
+
default_flax_embed_init(tensor)
|
| 46 |
+
return tensor
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _distributed_concat_with_grad(embeddings: torch.Tensor) -> torch.Tensor:
|
| 50 |
+
if not dist.is_available() or not dist.is_initialized():
|
| 51 |
+
return embeddings
|
| 52 |
+
world_size = dist.get_world_size()
|
| 53 |
+
local_batch = embeddings.shape[0]
|
| 54 |
+
local_batch_tensor = torch.tensor([local_batch], dtype=torch.long, device=embeddings.device)
|
| 55 |
+
batch_sizes = [torch.zeros_like(local_batch_tensor) for _ in range(world_size)]
|
| 56 |
+
dist.all_gather(batch_sizes, local_batch_tensor)
|
| 57 |
+
batch_sizes_int = [int(size.item()) for size in batch_sizes]
|
| 58 |
+
max_batch = max(batch_sizes_int)
|
| 59 |
+
|
| 60 |
+
if local_batch < max_batch:
|
| 61 |
+
pad_shape = (max_batch - local_batch, embeddings.shape[1])
|
| 62 |
+
padding = embeddings.new_zeros(pad_shape)
|
| 63 |
+
padded_embeddings = torch.cat([embeddings, padding], dim=0)
|
| 64 |
+
else:
|
| 65 |
+
padded_embeddings = embeddings
|
| 66 |
+
|
| 67 |
+
gathered = all_gather_with_grad(padded_embeddings)
|
| 68 |
+
if isinstance(gathered, torch.Tensor):
|
| 69 |
+
if gathered.ndim == 3 and gathered.shape[0] == world_size:
|
| 70 |
+
chunks = [gathered[rank] for rank in range(world_size)]
|
| 71 |
+
else:
|
| 72 |
+
chunks = list(torch.split(gathered, max_batch, dim=0))
|
| 73 |
+
else:
|
| 74 |
+
chunks = list(gathered)
|
| 75 |
+
|
| 76 |
+
trimmed = [chunk[: batch_sizes_int[rank]] for rank, chunk in enumerate(chunks) if batch_sizes_int[rank] > 0]
|
| 77 |
+
if not trimmed:
|
| 78 |
+
return embeddings.new_zeros((0, embeddings.shape[1]))
|
| 79 |
+
return torch.cat(trimmed, dim=0)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _load_state_dict_with_flexible_prefix(
|
| 83 |
+
module: nn.Module,
|
| 84 |
+
source_state_dict: Mapping[str, torch.Tensor],
|
| 85 |
+
strict: bool = True,
|
| 86 |
+
) -> Tuple[Any, Any]:
|
| 87 |
+
target_keys = list(module.state_dict().keys())
|
| 88 |
+
source_keys = list(source_state_dict.keys())
|
| 89 |
+
|
| 90 |
+
if not target_keys or not source_keys:
|
| 91 |
+
return module.load_state_dict(source_state_dict, strict=strict)
|
| 92 |
+
|
| 93 |
+
target_has_text_model_prefix = all(key.startswith("text_model.") for key in target_keys)
|
| 94 |
+
source_has_text_model_prefix = all(key.startswith("text_model.") for key in source_keys)
|
| 95 |
+
|
| 96 |
+
aligned_state_dict = dict(source_state_dict)
|
| 97 |
+
if target_has_text_model_prefix and not source_has_text_model_prefix:
|
| 98 |
+
aligned_state_dict = {f"text_model.{key}": value for key, value in source_state_dict.items()}
|
| 99 |
+
elif source_has_text_model_prefix and not target_has_text_model_prefix:
|
| 100 |
+
aligned_state_dict = {
|
| 101 |
+
key[len("text_model.") :]: value for key, value in source_state_dict.items() if key.startswith("text_model.")
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
return module.load_state_dict(aligned_state_dict, strict=strict)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class SiglipVisionEmbeddings3D(nn.Module):
|
| 108 |
+
"""3D patch embeddings for MRI volumes."""
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
vision_config: SiglipVisionConfig,
|
| 113 |
+
volume_size: Tuple[int, int, int],
|
| 114 |
+
patch_size: Tuple[int, int, int],
|
| 115 |
+
num_channels: int,
|
| 116 |
+
) -> None:
|
| 117 |
+
super().__init__()
|
| 118 |
+
self.embed_dim = int(vision_config.hidden_size)
|
| 119 |
+
self.volume_size = tuple(int(v) for v in volume_size)
|
| 120 |
+
self.patch_size = tuple(int(v) for v in patch_size)
|
| 121 |
+
|
| 122 |
+
if any(v % p != 0 for v, p in zip(self.volume_size, self.patch_size)):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
"Volume size must be divisible by patch size for all dimensions. "
|
| 125 |
+
f"Got volume_size={self.volume_size}, patch_size={self.patch_size}."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
self.patch_embedding = nn.Conv3d(
|
| 129 |
+
in_channels=int(num_channels),
|
| 130 |
+
out_channels=self.embed_dim,
|
| 131 |
+
kernel_size=self.patch_size,
|
| 132 |
+
stride=self.patch_size,
|
| 133 |
+
padding=0,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
patches_per_dim = tuple(v // p for v, p in zip(self.volume_size, self.patch_size))
|
| 137 |
+
self.grid_size = patches_per_dim
|
| 138 |
+
self.num_patches = int(patches_per_dim[0] * patches_per_dim[1] * patches_per_dim[2])
|
| 139 |
+
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
|
| 140 |
+
self.register_buffer("position_ids", torch.arange(self.num_patches).expand((1, -1)), persistent=False)
|
| 141 |
+
|
| 142 |
+
def _interpolate_position_embeddings(
|
| 143 |
+
self,
|
| 144 |
+
grid_size: Tuple[int, int, int],
|
| 145 |
+
target_dtype: torch.dtype,
|
| 146 |
+
target_device: torch.device,
|
| 147 |
+
) -> torch.Tensor:
|
| 148 |
+
base_grid_depth, base_grid_height, base_grid_width = self.grid_size
|
| 149 |
+
position_embeddings = self.position_embedding.weight.reshape(
|
| 150 |
+
base_grid_depth,
|
| 151 |
+
base_grid_height,
|
| 152 |
+
base_grid_width,
|
| 153 |
+
self.embed_dim,
|
| 154 |
+
)
|
| 155 |
+
position_embeddings = position_embeddings.permute(3, 0, 1, 2).unsqueeze(0)
|
| 156 |
+
position_embeddings = F.interpolate(
|
| 157 |
+
position_embeddings,
|
| 158 |
+
size=grid_size,
|
| 159 |
+
mode="trilinear",
|
| 160 |
+
align_corners=False,
|
| 161 |
+
)
|
| 162 |
+
position_embeddings = position_embeddings.squeeze(0).permute(1, 2, 3, 0).reshape(1, -1, self.embed_dim)
|
| 163 |
+
return position_embeddings.to(dtype=target_dtype, device=target_device)
|
| 164 |
+
|
| 165 |
+
def _get_position_embeddings(
|
| 166 |
+
self,
|
| 167 |
+
grid_size: Tuple[int, int, int],
|
| 168 |
+
target_dtype: torch.dtype,
|
| 169 |
+
target_device: torch.device,
|
| 170 |
+
interpolate_pos_encoding: bool,
|
| 171 |
+
) -> torch.Tensor:
|
| 172 |
+
num_patches = int(grid_size[0] * grid_size[1] * grid_size[2])
|
| 173 |
+
if num_patches == self.num_patches:
|
| 174 |
+
return self.position_embedding(self.position_ids).to(dtype=target_dtype, device=target_device)
|
| 175 |
+
if not interpolate_pos_encoding:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"Unexpected number of patches: {num_patches} vs expected {self.num_patches}. "
|
| 178 |
+
"Enable `interpolate_pos_encoding=True` for variable volume sizes."
|
| 179 |
+
)
|
| 180 |
+
return self._interpolate_position_embeddings(grid_size, target_dtype=target_dtype, target_device=target_device)
|
| 181 |
+
|
| 182 |
+
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = True) -> torch.Tensor:
|
| 183 |
+
if pixel_values.ndim != 5:
|
| 184 |
+
raise ValueError(
|
| 185 |
+
"`pixel_values` must have shape [batch, channels, depth, height, width]. "
|
| 186 |
+
f"Got shape {tuple(pixel_values.shape)}"
|
| 187 |
+
)
|
| 188 |
+
spatial_shape = tuple(int(v) for v in pixel_values.shape[-3:])
|
| 189 |
+
if any(dim % patch != 0 for dim, patch in zip(spatial_shape, self.patch_size)):
|
| 190 |
+
raise ValueError(
|
| 191 |
+
f"Input spatial size {spatial_shape} must be divisible by patch_size {self.patch_size}."
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 195 |
+
embeddings = self.patch_embedding(pixel_values.to(dtype=target_dtype))
|
| 196 |
+
grid_size = tuple(int(v) for v in embeddings.shape[-3:])
|
| 197 |
+
embeddings = embeddings.flatten(2).transpose(1, 2)
|
| 198 |
+
position_embeddings = self._get_position_embeddings(
|
| 199 |
+
grid_size=grid_size,
|
| 200 |
+
target_dtype=embeddings.dtype,
|
| 201 |
+
target_device=embeddings.device,
|
| 202 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 203 |
+
)
|
| 204 |
+
return embeddings + position_embeddings
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
class BrainMRISiglipVisionTransformer(nn.Module):
|
| 208 |
+
"""SigLIP vision tower with 3D embeddings."""
|
| 209 |
+
|
| 210 |
+
def __init__(self, config: BrainMRISiglipConfig) -> None:
|
| 211 |
+
super().__init__()
|
| 212 |
+
vision_config = config.get_vision_config()
|
| 213 |
+
volume_size = tuple(int(v) for v in config.volume_size)
|
| 214 |
+
patch_size = tuple(int(v) for v in config.patch_size)
|
| 215 |
+
|
| 216 |
+
self.embeddings = SiglipVisionEmbeddings3D(
|
| 217 |
+
vision_config=vision_config,
|
| 218 |
+
volume_size=volume_size,
|
| 219 |
+
patch_size=patch_size,
|
| 220 |
+
num_channels=int(config.num_channels),
|
| 221 |
+
)
|
| 222 |
+
self.encoder = SiglipEncoder(vision_config)
|
| 223 |
+
self.post_layernorm = nn.LayerNorm(vision_config.hidden_size, eps=vision_config.layer_norm_eps)
|
| 224 |
+
self.head = SiglipMultiheadAttentionPoolingHead(vision_config)
|
| 225 |
+
|
| 226 |
+
def forward(
|
| 227 |
+
self,
|
| 228 |
+
pixel_values: torch.Tensor,
|
| 229 |
+
interpolate_pos_encoding: bool = True,
|
| 230 |
+
**kwargs: Any,
|
| 231 |
+
) -> BaseModelOutputWithPooling:
|
| 232 |
+
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
| 233 |
+
encoder_outputs = self.encoder(inputs_embeds=hidden_states, **kwargs)
|
| 234 |
+
last_hidden_state = self.post_layernorm(encoder_outputs.last_hidden_state)
|
| 235 |
+
pooler_output = self.head(last_hidden_state)
|
| 236 |
+
return BaseModelOutputWithPooling(last_hidden_state=last_hidden_state, pooler_output=pooler_output)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class BrainMRISiglipPreTrainedModel(PreTrainedModel):
|
| 240 |
+
config_class = BrainMRISiglipConfig
|
| 241 |
+
base_model_prefix = "brain_mri_siglip"
|
| 242 |
+
supports_gradient_checkpointing = True
|
| 243 |
+
|
| 244 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 245 |
+
if isinstance(module, SiglipVisionEmbeddings3D):
|
| 246 |
+
width = int(self.config.get_vision_config().hidden_size)
|
| 247 |
+
nn.init.normal_(module.position_embedding.weight, std=1.0 / math.sqrt(width))
|
| 248 |
+
_lecun_normal_(module.patch_embedding.weight)
|
| 249 |
+
if module.patch_embedding.bias is not None:
|
| 250 |
+
nn.init.zeros_(module.patch_embedding.bias)
|
| 251 |
+
return
|
| 252 |
+
|
| 253 |
+
if isinstance(module, nn.Embedding):
|
| 254 |
+
_siglip_embedding_init_(module.weight)
|
| 255 |
+
return
|
| 256 |
+
|
| 257 |
+
if isinstance(module, SiglipAttention):
|
| 258 |
+
nn.init.xavier_uniform_(module.q_proj.weight)
|
| 259 |
+
nn.init.xavier_uniform_(module.k_proj.weight)
|
| 260 |
+
nn.init.xavier_uniform_(module.v_proj.weight)
|
| 261 |
+
nn.init.xavier_uniform_(module.out_proj.weight)
|
| 262 |
+
if module.q_proj.bias is not None:
|
| 263 |
+
nn.init.zeros_(module.q_proj.bias)
|
| 264 |
+
if module.k_proj.bias is not None:
|
| 265 |
+
nn.init.zeros_(module.k_proj.bias)
|
| 266 |
+
if module.v_proj.bias is not None:
|
| 267 |
+
nn.init.zeros_(module.v_proj.bias)
|
| 268 |
+
if module.out_proj.bias is not None:
|
| 269 |
+
nn.init.zeros_(module.out_proj.bias)
|
| 270 |
+
return
|
| 271 |
+
|
| 272 |
+
if isinstance(module, SiglipMLP):
|
| 273 |
+
nn.init.xavier_uniform_(module.fc1.weight)
|
| 274 |
+
nn.init.xavier_uniform_(module.fc2.weight)
|
| 275 |
+
if module.fc1.bias is not None:
|
| 276 |
+
nn.init.normal_(module.fc1.bias, std=1e-6)
|
| 277 |
+
if module.fc2.bias is not None:
|
| 278 |
+
nn.init.normal_(module.fc2.bias, std=1e-6)
|
| 279 |
+
return
|
| 280 |
+
|
| 281 |
+
if isinstance(module, SiglipMultiheadAttentionPoolingHead):
|
| 282 |
+
nn.init.xavier_uniform_(module.probe)
|
| 283 |
+
nn.init.xavier_uniform_(module.attention.in_proj_weight)
|
| 284 |
+
if module.attention.in_proj_bias is not None:
|
| 285 |
+
nn.init.zeros_(module.attention.in_proj_bias)
|
| 286 |
+
return
|
| 287 |
+
|
| 288 |
+
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
| 289 |
+
_lecun_normal_(module.weight)
|
| 290 |
+
if module.bias is not None:
|
| 291 |
+
nn.init.zeros_(module.bias)
|
| 292 |
+
return
|
| 293 |
+
|
| 294 |
+
if isinstance(module, nn.LayerNorm):
|
| 295 |
+
module.bias.data.zero_()
|
| 296 |
+
module.weight.data.fill_(1.0)
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
class BrainMRISiglipModel(BrainMRISiglipPreTrainedModel):
|
| 300 |
+
"""3D MRI + text dual-encoder model with SigLIP contrastive loss."""
|
| 301 |
+
|
| 302 |
+
def __init__(self, config: BrainMRISiglipConfig) -> None:
|
| 303 |
+
super().__init__(config)
|
| 304 |
+
self.text_config = config.get_text_config()
|
| 305 |
+
self.vision_config = config.get_vision_config()
|
| 306 |
+
|
| 307 |
+
self.text_model = SiglipTextModel(self.text_config)
|
| 308 |
+
self.vision_model = BrainMRISiglipVisionTransformer(config)
|
| 309 |
+
|
| 310 |
+
projection_dim = int(config.projection_dim)
|
| 311 |
+
self.visual_projection = nn.Linear(self.vision_config.hidden_size, projection_dim, bias=False)
|
| 312 |
+
self.text_projection = nn.Linear(self.text_config.hidden_size, projection_dim, bias=False)
|
| 313 |
+
|
| 314 |
+
self.logit_scale = nn.Parameter(torch.tensor(float(config.logit_scale_init_value)))
|
| 315 |
+
self.logit_bias = nn.Parameter(torch.tensor(float(config.logit_bias_init_value)))
|
| 316 |
+
|
| 317 |
+
self.post_init()
|
| 318 |
+
|
| 319 |
+
@classmethod
|
| 320 |
+
def from_medsiglip_pretrained(
|
| 321 |
+
cls,
|
| 322 |
+
text_model_name_or_path: str = "google/medsiglip-448",
|
| 323 |
+
trust_remote_code: bool = True,
|
| 324 |
+
local_files_only: bool = False,
|
| 325 |
+
**kwargs: Any,
|
| 326 |
+
) -> "BrainMRISiglipModel":
|
| 327 |
+
base_config = AutoConfig.from_pretrained(
|
| 328 |
+
text_model_name_or_path,
|
| 329 |
+
trust_remote_code=trust_remote_code,
|
| 330 |
+
local_files_only=local_files_only,
|
| 331 |
+
)
|
| 332 |
+
if hasattr(base_config, "text_config"):
|
| 333 |
+
raw_text_config = base_config.text_config
|
| 334 |
+
text_config = raw_text_config.to_dict() if hasattr(raw_text_config, "to_dict") else dict(raw_text_config)
|
| 335 |
+
else:
|
| 336 |
+
text_config = SiglipTextConfig().to_dict()
|
| 337 |
+
|
| 338 |
+
if hasattr(base_config, "vision_config"):
|
| 339 |
+
raw_vision_config = base_config.vision_config
|
| 340 |
+
vision_config = (
|
| 341 |
+
raw_vision_config.to_dict()
|
| 342 |
+
if hasattr(raw_vision_config, "to_dict")
|
| 343 |
+
else dict(raw_vision_config)
|
| 344 |
+
)
|
| 345 |
+
else:
|
| 346 |
+
vision_config = SiglipVisionConfig().to_dict()
|
| 347 |
+
projection_dim = kwargs.pop(
|
| 348 |
+
"projection_dim",
|
| 349 |
+
int(getattr(base_config, "projection_dim", text_config.get("projection_size", text_config["hidden_size"]))),
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
config = BrainMRISiglipConfig(
|
| 353 |
+
text_config=text_config,
|
| 354 |
+
vision_config=vision_config,
|
| 355 |
+
projection_dim=projection_dim,
|
| 356 |
+
text_model_name_or_path=text_model_name_or_path,
|
| 357 |
+
**kwargs,
|
| 358 |
+
)
|
| 359 |
+
model = cls(config)
|
| 360 |
+
model.load_text_tower_from_pretrained(
|
| 361 |
+
text_model_name_or_path,
|
| 362 |
+
trust_remote_code=trust_remote_code,
|
| 363 |
+
local_files_only=local_files_only,
|
| 364 |
+
)
|
| 365 |
+
return model
|
| 366 |
+
|
| 367 |
+
def load_text_tower_from_pretrained(
|
| 368 |
+
self,
|
| 369 |
+
text_model_name_or_path: str,
|
| 370 |
+
trust_remote_code: bool = True,
|
| 371 |
+
local_files_only: bool = False,
|
| 372 |
+
strict: bool = True,
|
| 373 |
+
) -> Tuple[Any, Any]:
|
| 374 |
+
source_model = None
|
| 375 |
+
try:
|
| 376 |
+
source_model = AutoModel.from_pretrained(
|
| 377 |
+
text_model_name_or_path,
|
| 378 |
+
trust_remote_code=trust_remote_code,
|
| 379 |
+
local_files_only=local_files_only,
|
| 380 |
+
)
|
| 381 |
+
if hasattr(source_model, "text_model"):
|
| 382 |
+
source_text_model = source_model.text_model
|
| 383 |
+
elif isinstance(source_model, SiglipTextModel):
|
| 384 |
+
source_text_model = source_model
|
| 385 |
+
else:
|
| 386 |
+
raise ValueError(
|
| 387 |
+
f"Could not find a SigLIP text tower in `{text_model_name_or_path}` "
|
| 388 |
+
f"({type(source_model).__name__})."
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
missing, unexpected = _load_state_dict_with_flexible_prefix(
|
| 392 |
+
self.text_model,
|
| 393 |
+
source_text_model.state_dict(),
|
| 394 |
+
strict=strict,
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if hasattr(source_model, "text_projection") and isinstance(source_model.text_projection, nn.Linear):
|
| 398 |
+
if source_model.text_projection.weight.shape == self.text_projection.weight.shape:
|
| 399 |
+
self.text_projection.load_state_dict(source_model.text_projection.state_dict())
|
| 400 |
+
|
| 401 |
+
if hasattr(source_model, "logit_scale") and source_model.logit_scale.shape == self.logit_scale.shape:
|
| 402 |
+
self.logit_scale.data.copy_(source_model.logit_scale.data)
|
| 403 |
+
if hasattr(source_model, "logit_bias") and source_model.logit_bias.shape == self.logit_bias.shape:
|
| 404 |
+
self.logit_bias.data.copy_(source_model.logit_bias.data)
|
| 405 |
+
|
| 406 |
+
return missing, unexpected
|
| 407 |
+
finally:
|
| 408 |
+
if source_model is not None:
|
| 409 |
+
del source_model
|
| 410 |
+
|
| 411 |
+
def freeze_text_tower(self, trainable_layers: int = 0) -> None:
|
| 412 |
+
for parameter in self.text_model.parameters():
|
| 413 |
+
parameter.requires_grad = False
|
| 414 |
+
|
| 415 |
+
trainable_layers = int(trainable_layers)
|
| 416 |
+
if trainable_layers > 0 and hasattr(self.text_model, "text_model") and hasattr(
|
| 417 |
+
self.text_model.text_model, "encoder"
|
| 418 |
+
):
|
| 419 |
+
layers = self.text_model.text_model.encoder.layers
|
| 420 |
+
for layer in layers[-trainable_layers:]:
|
| 421 |
+
for parameter in layer.parameters():
|
| 422 |
+
parameter.requires_grad = True
|
| 423 |
+
|
| 424 |
+
for module_name in ("final_layer_norm", "head"):
|
| 425 |
+
if hasattr(self.text_model.text_model, module_name):
|
| 426 |
+
for parameter in getattr(self.text_model.text_model, module_name).parameters():
|
| 427 |
+
parameter.requires_grad = True
|
| 428 |
+
|
| 429 |
+
for parameter in self.text_projection.parameters():
|
| 430 |
+
parameter.requires_grad = True
|
| 431 |
+
|
| 432 |
+
def freeze_vision_tower(self, trainable_layers: int = 0, train_embeddings: bool = False) -> None:
|
| 433 |
+
for parameter in self.vision_model.parameters():
|
| 434 |
+
parameter.requires_grad = False
|
| 435 |
+
|
| 436 |
+
if train_embeddings:
|
| 437 |
+
for parameter in self.vision_model.embeddings.parameters():
|
| 438 |
+
parameter.requires_grad = True
|
| 439 |
+
|
| 440 |
+
trainable_layers = int(trainable_layers)
|
| 441 |
+
if trainable_layers > 0:
|
| 442 |
+
layers = self.vision_model.encoder.layers
|
| 443 |
+
for layer in layers[-trainable_layers:]:
|
| 444 |
+
for parameter in layer.parameters():
|
| 445 |
+
parameter.requires_grad = True
|
| 446 |
+
for parameter in self.vision_model.post_layernorm.parameters():
|
| 447 |
+
parameter.requires_grad = True
|
| 448 |
+
for parameter in self.vision_model.head.parameters():
|
| 449 |
+
parameter.requires_grad = True
|
| 450 |
+
|
| 451 |
+
for parameter in self.visual_projection.parameters():
|
| 452 |
+
parameter.requires_grad = True
|
| 453 |
+
|
| 454 |
+
def get_text_features(
|
| 455 |
+
self,
|
| 456 |
+
input_ids: torch.LongTensor,
|
| 457 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 458 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 459 |
+
text_kwargs: Optional[Mapping[str, Any]] = None,
|
| 460 |
+
**kwargs: Any,
|
| 461 |
+
) -> torch.FloatTensor:
|
| 462 |
+
kwargs = dict(kwargs)
|
| 463 |
+
nested_text_kwargs = kwargs.pop("text_kwargs", None)
|
| 464 |
+
if kwargs:
|
| 465 |
+
raise TypeError(f"Unexpected keyword arguments for text tower: {sorted(kwargs.keys())}")
|
| 466 |
+
merged_text_kwargs: dict[str, Any] = {}
|
| 467 |
+
if nested_text_kwargs:
|
| 468 |
+
merged_text_kwargs.update(dict(nested_text_kwargs))
|
| 469 |
+
if text_kwargs:
|
| 470 |
+
merged_text_kwargs.update(dict(text_kwargs))
|
| 471 |
+
|
| 472 |
+
text_outputs = self.text_model(
|
| 473 |
+
input_ids=input_ids,
|
| 474 |
+
attention_mask=attention_mask,
|
| 475 |
+
position_ids=position_ids,
|
| 476 |
+
**merged_text_kwargs,
|
| 477 |
+
)
|
| 478 |
+
text_features = self.text_projection(text_outputs.pooler_output)
|
| 479 |
+
return F.normalize(text_features, dim=-1)
|
| 480 |
+
|
| 481 |
+
def get_image_features(
|
| 482 |
+
self,
|
| 483 |
+
pixel_values: torch.FloatTensor,
|
| 484 |
+
interpolate_pos_encoding: bool = True,
|
| 485 |
+
vision_kwargs: Optional[Mapping[str, Any]] = None,
|
| 486 |
+
**kwargs: Any,
|
| 487 |
+
) -> torch.FloatTensor:
|
| 488 |
+
kwargs = dict(kwargs)
|
| 489 |
+
nested_vision_kwargs = kwargs.pop("vision_kwargs", None)
|
| 490 |
+
legacy_interpolate_pos_encoding = kwargs.pop("interpolate_pos_encoding", None)
|
| 491 |
+
if kwargs:
|
| 492 |
+
raise TypeError(f"Unexpected keyword arguments for vision tower: {sorted(kwargs.keys())}")
|
| 493 |
+
|
| 494 |
+
merged_vision_kwargs: dict[str, Any] = {}
|
| 495 |
+
if nested_vision_kwargs:
|
| 496 |
+
merged_vision_kwargs.update(dict(nested_vision_kwargs))
|
| 497 |
+
if vision_kwargs:
|
| 498 |
+
merged_vision_kwargs.update(dict(vision_kwargs))
|
| 499 |
+
if legacy_interpolate_pos_encoding is not None:
|
| 500 |
+
interpolate_pos_encoding = bool(legacy_interpolate_pos_encoding)
|
| 501 |
+
|
| 502 |
+
vision_outputs = self.vision_model(
|
| 503 |
+
pixel_values=pixel_values,
|
| 504 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 505 |
+
**merged_vision_kwargs,
|
| 506 |
+
)
|
| 507 |
+
image_features = self.visual_projection(vision_outputs.pooler_output)
|
| 508 |
+
return F.normalize(image_features, dim=-1)
|
| 509 |
+
|
| 510 |
+
def forward(
|
| 511 |
+
self,
|
| 512 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 513 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 514 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 515 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 516 |
+
return_loss: Optional[bool] = None,
|
| 517 |
+
gather_loss: bool = False,
|
| 518 |
+
interpolate_pos_encoding: bool = True,
|
| 519 |
+
vision_kwargs: Optional[Mapping[str, Any]] = None,
|
| 520 |
+
text_kwargs: Optional[Mapping[str, Any]] = None,
|
| 521 |
+
return_dict: Optional[bool] = None,
|
| 522 |
+
**kwargs: Any,
|
| 523 |
+
) -> SiglipOutput:
|
| 524 |
+
if pixel_values is None:
|
| 525 |
+
raise ValueError("`pixel_values` must be provided.")
|
| 526 |
+
if input_ids is None:
|
| 527 |
+
raise ValueError("`input_ids` must be provided.")
|
| 528 |
+
|
| 529 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 530 |
+
return_loss = bool(return_loss) if return_loss is not None else False
|
| 531 |
+
kwargs = dict(kwargs)
|
| 532 |
+
nested_vision_kwargs = kwargs.pop("vision_kwargs", None)
|
| 533 |
+
nested_text_kwargs = kwargs.pop("text_kwargs", None)
|
| 534 |
+
legacy_interpolate_pos_encoding = kwargs.pop("interpolate_pos_encoding", None)
|
| 535 |
+
if kwargs:
|
| 536 |
+
raise TypeError(f"Unexpected keyword arguments in model.forward: {sorted(kwargs.keys())}")
|
| 537 |
+
|
| 538 |
+
merged_vision_kwargs: dict[str, Any] = {}
|
| 539 |
+
merged_text_kwargs: dict[str, Any] = {}
|
| 540 |
+
if nested_vision_kwargs:
|
| 541 |
+
merged_vision_kwargs.update(dict(nested_vision_kwargs))
|
| 542 |
+
if vision_kwargs:
|
| 543 |
+
merged_vision_kwargs.update(dict(vision_kwargs))
|
| 544 |
+
if nested_text_kwargs:
|
| 545 |
+
merged_text_kwargs.update(dict(nested_text_kwargs))
|
| 546 |
+
if text_kwargs:
|
| 547 |
+
merged_text_kwargs.update(dict(text_kwargs))
|
| 548 |
+
if legacy_interpolate_pos_encoding is not None:
|
| 549 |
+
interpolate_pos_encoding = bool(legacy_interpolate_pos_encoding)
|
| 550 |
+
|
| 551 |
+
vision_outputs = self.vision_model(
|
| 552 |
+
pixel_values=pixel_values,
|
| 553 |
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
| 554 |
+
**merged_vision_kwargs,
|
| 555 |
+
)
|
| 556 |
+
text_outputs = self.text_model(
|
| 557 |
+
input_ids=input_ids,
|
| 558 |
+
attention_mask=attention_mask,
|
| 559 |
+
position_ids=position_ids,
|
| 560 |
+
**merged_text_kwargs,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
image_embeds = self.visual_projection(vision_outputs.pooler_output)
|
| 564 |
+
text_embeds = self.text_projection(text_outputs.pooler_output)
|
| 565 |
+
|
| 566 |
+
image_embeds = F.normalize(image_embeds, p=2, dim=-1)
|
| 567 |
+
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
|
| 568 |
+
|
| 569 |
+
image_embeds_for_loss = image_embeds
|
| 570 |
+
text_embeds_for_loss = text_embeds
|
| 571 |
+
if gather_loss and return_loss:
|
| 572 |
+
image_embeds_for_loss = _distributed_concat_with_grad(image_embeds)
|
| 573 |
+
text_embeds_for_loss = _distributed_concat_with_grad(text_embeds)
|
| 574 |
+
|
| 575 |
+
logit_scale = self.logit_scale.exp().clamp(
|
| 576 |
+
min=float(self.config.logit_scale_min),
|
| 577 |
+
max=float(self.config.logit_scale_max),
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
local_logits_per_text = torch.matmul(
|
| 581 |
+
text_embeds,
|
| 582 |
+
image_embeds.t().to(text_embeds.device),
|
| 583 |
+
)
|
| 584 |
+
local_logits_per_text = local_logits_per_text * logit_scale + self.logit_bias
|
| 585 |
+
local_logits_per_image = local_logits_per_text.t()
|
| 586 |
+
|
| 587 |
+
loss = None
|
| 588 |
+
if return_loss:
|
| 589 |
+
loss_logits_per_text = torch.matmul(
|
| 590 |
+
text_embeds_for_loss,
|
| 591 |
+
image_embeds_for_loss.t().to(text_embeds_for_loss.device),
|
| 592 |
+
)
|
| 593 |
+
loss_logits_per_text = loss_logits_per_text * logit_scale + self.logit_bias
|
| 594 |
+
loss = _siglip_sigmoid_loss(loss_logits_per_text)
|
| 595 |
+
|
| 596 |
+
if not return_dict:
|
| 597 |
+
output = (
|
| 598 |
+
local_logits_per_image,
|
| 599 |
+
local_logits_per_text,
|
| 600 |
+
text_embeds,
|
| 601 |
+
image_embeds,
|
| 602 |
+
text_outputs,
|
| 603 |
+
vision_outputs,
|
| 604 |
+
)
|
| 605 |
+
return ((loss,) + output) if loss is not None else output
|
| 606 |
+
|
| 607 |
+
return SiglipOutput(
|
| 608 |
+
loss=loss,
|
| 609 |
+
logits_per_image=local_logits_per_image,
|
| 610 |
+
logits_per_text=local_logits_per_text,
|
| 611 |
+
text_embeds=text_embeds,
|
| 612 |
+
image_embeds=image_embeds,
|
| 613 |
+
text_model_output=text_outputs,
|
| 614 |
+
vision_model_output=vision_outputs,
|
| 615 |
+
)
|
offline_aligned_preprocessing.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared offline-aligned preprocessing helpers for 3D brain MRI volumes."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Mapping
|
| 8 |
+
|
| 9 |
+
import nibabel as nib
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from scipy import ndimage as scipy_ndimage
|
| 16 |
+
except Exception: # pragma: no cover - optional import surface
|
| 17 |
+
scipy_ndimage = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
TARGET_SHAPE = (128, 192, 192)
|
| 21 |
+
TARGET_SPACING = (1.25, 1.0, 1.0)
|
| 22 |
+
CROP_MARGIN_MM = 5.0
|
| 23 |
+
FOREGROUND_THRESHOLD = 1e-3
|
| 24 |
+
BACKGROUND_VALUE = -1.0
|
| 25 |
+
FOREGROUND_STRATEGY = "largest_component_nonzero"
|
| 26 |
+
GENERIC_RECIPE_ID = "generic_foreground_128x192x192_fp16_v1"
|
| 27 |
+
GENERIC_CACHE_VERSION = 1
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_canonical_nifti(path: str | Path):
|
| 31 |
+
return nib.as_closest_canonical(nib.load(str(path)))
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_image_spacing(image) -> tuple[float, float, float]:
|
| 35 |
+
zooms = image.header.get_zooms()[:3]
|
| 36 |
+
if len(zooms) != 3:
|
| 37 |
+
raise ValueError(f"Expected a 3D image spacing tuple, got {zooms}.")
|
| 38 |
+
return tuple(float(value) for value in zooms)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def coerce_volume_to_3d(volume: np.ndarray) -> np.ndarray:
|
| 42 |
+
if volume.ndim == 3:
|
| 43 |
+
return volume.astype(np.float32, copy=False)
|
| 44 |
+
if volume.ndim != 4:
|
| 45 |
+
raise ValueError(f"Expected a 3D or 4D volume, got shape {volume.shape}.")
|
| 46 |
+
|
| 47 |
+
if volume.shape[0] <= 4 and volume.shape[-1] > 4:
|
| 48 |
+
selected = volume[0]
|
| 49 |
+
else:
|
| 50 |
+
selected = volume[..., 0]
|
| 51 |
+
return np.asarray(selected, dtype=np.float32)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def largest_connected_component(mask: np.ndarray) -> np.ndarray:
|
| 55 |
+
if not mask.any() or scipy_ndimage is None:
|
| 56 |
+
return mask
|
| 57 |
+
structure = scipy_ndimage.generate_binary_structure(mask.ndim, 1)
|
| 58 |
+
labels, num_labels = scipy_ndimage.label(mask, structure=structure)
|
| 59 |
+
if num_labels <= 1:
|
| 60 |
+
return mask
|
| 61 |
+
counts = np.bincount(labels.reshape(-1))
|
| 62 |
+
if counts.size <= 1:
|
| 63 |
+
return mask
|
| 64 |
+
counts[0] = 0
|
| 65 |
+
winning_label = int(counts.argmax())
|
| 66 |
+
if winning_label <= 0 or counts[winning_label] <= 0:
|
| 67 |
+
return mask
|
| 68 |
+
return labels == winning_label
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_foreground_mask(volume: np.ndarray, threshold: float = FOREGROUND_THRESHOLD) -> np.ndarray:
|
| 72 |
+
sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0)
|
| 73 |
+
raw_mask = np.abs(sanitized) > float(threshold)
|
| 74 |
+
if not raw_mask.any():
|
| 75 |
+
return np.ones_like(sanitized, dtype=bool)
|
| 76 |
+
|
| 77 |
+
component_mask = largest_connected_component(raw_mask)
|
| 78 |
+
component_count = int(component_mask.sum())
|
| 79 |
+
raw_count = int(raw_mask.sum())
|
| 80 |
+
if component_count <= 0:
|
| 81 |
+
return raw_mask
|
| 82 |
+
if component_count < 512 and raw_count > component_count:
|
| 83 |
+
return raw_mask
|
| 84 |
+
return component_mask
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def compute_crop_bbox(
|
| 88 |
+
mask: np.ndarray,
|
| 89 |
+
spacing: tuple[float, float, float],
|
| 90 |
+
margin_mm: float = CROP_MARGIN_MM,
|
| 91 |
+
) -> tuple[tuple[int, int], ...]:
|
| 92 |
+
coords = np.where(mask)
|
| 93 |
+
if coords[0].size == 0:
|
| 94 |
+
raise ValueError("Foreground mask contains no positive voxels after selection.")
|
| 95 |
+
|
| 96 |
+
bbox = []
|
| 97 |
+
for axis, values in enumerate(coords):
|
| 98 |
+
margin_voxels = int(math.ceil(float(margin_mm) / float(spacing[axis])))
|
| 99 |
+
start = max(0, int(values.min()) - margin_voxels)
|
| 100 |
+
stop = min(mask.shape[axis], int(values.max()) + margin_voxels + 1)
|
| 101 |
+
bbox.append((start, stop))
|
| 102 |
+
return tuple(bbox)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def crop_volume_and_mask(
|
| 106 |
+
volume: np.ndarray,
|
| 107 |
+
mask: np.ndarray,
|
| 108 |
+
spacing: tuple[float, float, float],
|
| 109 |
+
margin_mm: float = CROP_MARGIN_MM,
|
| 110 |
+
) -> tuple[np.ndarray, np.ndarray, tuple[tuple[int, int], ...]]:
|
| 111 |
+
bbox = compute_crop_bbox(mask, spacing, margin_mm=margin_mm)
|
| 112 |
+
slices = tuple(slice(start, stop) for start, stop in bbox)
|
| 113 |
+
return volume[slices], mask[slices], bbox
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def normalize_foreground_only(volume: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 117 |
+
sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
|
| 118 |
+
foreground_values = sanitized[mask]
|
| 119 |
+
if foreground_values.size == 0:
|
| 120 |
+
raise ValueError("Cannot normalize volume because the foreground mask is empty.")
|
| 121 |
+
|
| 122 |
+
if foreground_values.size > 1_000_000:
|
| 123 |
+
step = max(1, foreground_values.size // 1_000_000)
|
| 124 |
+
foreground_values = foreground_values[::step]
|
| 125 |
+
|
| 126 |
+
low, high = np.percentile(foreground_values, [0.5, 99.5])
|
| 127 |
+
if not np.isfinite(low) or not np.isfinite(high) or high <= low:
|
| 128 |
+
normalized = np.zeros_like(sanitized, dtype=np.float32)
|
| 129 |
+
else:
|
| 130 |
+
normalized = np.clip(sanitized, float(low), float(high))
|
| 131 |
+
normalized = np.clip((normalized - float(low)) / float(high - low), 0.0, 1.0)
|
| 132 |
+
normalized = normalized * 2.0 - 1.0
|
| 133 |
+
return normalized.astype(np.float32, copy=False)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def resize_volume(volume: np.ndarray, size: tuple[int, int, int], mode: str) -> np.ndarray:
|
| 137 |
+
tensor = torch.from_numpy(volume).unsqueeze(0).unsqueeze(0)
|
| 138 |
+
kwargs = {}
|
| 139 |
+
if mode in {"linear", "bilinear", "bicubic", "trilinear"}:
|
| 140 |
+
kwargs["align_corners"] = False
|
| 141 |
+
tensor = F.interpolate(tensor, size=size, mode=mode, **kwargs)
|
| 142 |
+
return tensor.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32, copy=False)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def resize_mask(mask: np.ndarray, size: tuple[int, int, int]) -> np.ndarray:
|
| 146 |
+
tensor = torch.from_numpy(mask.astype(np.float32, copy=False)).unsqueeze(0).unsqueeze(0)
|
| 147 |
+
tensor = F.interpolate(tensor, size=size, mode="nearest")
|
| 148 |
+
return tensor.squeeze(0).squeeze(0).cpu().numpy() > 0.5
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def resample_to_target_spacing(
|
| 152 |
+
volume: np.ndarray,
|
| 153 |
+
mask: np.ndarray,
|
| 154 |
+
source_spacing: tuple[float, float, float],
|
| 155 |
+
target_spacing: tuple[float, float, float] = TARGET_SPACING,
|
| 156 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 157 |
+
target_shape = []
|
| 158 |
+
for current_size, src, dst in zip(volume.shape, source_spacing, target_spacing):
|
| 159 |
+
target_shape.append(max(1, int(round(float(current_size) * float(src) / float(dst)))))
|
| 160 |
+
target_shape_tuple = tuple(target_shape)
|
| 161 |
+
if target_shape_tuple == tuple(int(v) for v in volume.shape):
|
| 162 |
+
return volume.astype(np.float32, copy=False), mask
|
| 163 |
+
return (
|
| 164 |
+
resize_volume(volume, target_shape_tuple, mode="trilinear"),
|
| 165 |
+
resize_mask(mask, target_shape_tuple),
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def downscale_to_fit(
|
| 170 |
+
volume: np.ndarray,
|
| 171 |
+
mask: np.ndarray,
|
| 172 |
+
target_shape: tuple[int, int, int] = TARGET_SHAPE,
|
| 173 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
| 174 |
+
current_shape = tuple(int(v) for v in volume.shape)
|
| 175 |
+
if all(current <= target for current, target in zip(current_shape, target_shape)):
|
| 176 |
+
return volume, mask
|
| 177 |
+
|
| 178 |
+
scale = min(float(target) / float(current) for current, target in zip(current_shape, target_shape))
|
| 179 |
+
if scale >= 1.0:
|
| 180 |
+
return volume, mask
|
| 181 |
+
|
| 182 |
+
new_shape = tuple(
|
| 183 |
+
min(target, max(1, int(math.floor(float(current) * scale))))
|
| 184 |
+
for current, target in zip(current_shape, target_shape)
|
| 185 |
+
)
|
| 186 |
+
return (
|
| 187 |
+
resize_volume(volume, new_shape, mode="trilinear"),
|
| 188 |
+
resize_mask(mask, new_shape),
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def center_pad(
|
| 193 |
+
array: np.ndarray,
|
| 194 |
+
target_shape: tuple[int, int, int] = TARGET_SHAPE,
|
| 195 |
+
fill_value: float = BACKGROUND_VALUE,
|
| 196 |
+
) -> np.ndarray:
|
| 197 |
+
if any(current > target for current, target in zip(array.shape, target_shape)):
|
| 198 |
+
raise ValueError(f"Cannot center-pad shape {array.shape} into smaller target {target_shape}.")
|
| 199 |
+
pad_width = []
|
| 200 |
+
for current, target in zip(array.shape, target_shape):
|
| 201 |
+
delta = target - current
|
| 202 |
+
before = delta // 2
|
| 203 |
+
after = delta - before
|
| 204 |
+
pad_width.append((before, after))
|
| 205 |
+
return np.pad(array, pad_width=tuple(pad_width), mode="constant", constant_values=fill_value)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def preprocess_image_with_foreground_mask(
|
| 209 |
+
image_path: str | Path,
|
| 210 |
+
*,
|
| 211 |
+
target_shape: tuple[int, int, int] = TARGET_SHAPE,
|
| 212 |
+
target_spacing: tuple[float, float, float] = TARGET_SPACING,
|
| 213 |
+
crop_margin_mm: float = CROP_MARGIN_MM,
|
| 214 |
+
foreground_threshold: float = FOREGROUND_THRESHOLD,
|
| 215 |
+
background_value: float = BACKGROUND_VALUE,
|
| 216 |
+
foreground_strategy: str = FOREGROUND_STRATEGY,
|
| 217 |
+
recipe_id: str = GENERIC_RECIPE_ID,
|
| 218 |
+
cache_version: int = GENERIC_CACHE_VERSION,
|
| 219 |
+
) -> dict[str, object]:
|
| 220 |
+
image_path = Path(image_path)
|
| 221 |
+
image = load_canonical_nifti(image_path)
|
| 222 |
+
source_shape = tuple(int(value) for value in image.shape)
|
| 223 |
+
source_spacing = load_image_spacing(image)
|
| 224 |
+
volume = np.asarray(image.get_fdata(dtype=np.float32), dtype=np.float32)
|
| 225 |
+
volume = coerce_volume_to_3d(volume)
|
| 226 |
+
|
| 227 |
+
foreground_mask = build_foreground_mask(volume, threshold=foreground_threshold)
|
| 228 |
+
cropped_volume, cropped_mask, crop_bbox = crop_volume_and_mask(
|
| 229 |
+
volume,
|
| 230 |
+
foreground_mask,
|
| 231 |
+
source_spacing,
|
| 232 |
+
margin_mm=crop_margin_mm,
|
| 233 |
+
)
|
| 234 |
+
normalized_volume = normalize_foreground_only(cropped_volume, cropped_mask)
|
| 235 |
+
resampled_volume, resampled_mask = resample_to_target_spacing(
|
| 236 |
+
normalized_volume,
|
| 237 |
+
cropped_mask,
|
| 238 |
+
source_spacing=source_spacing,
|
| 239 |
+
target_spacing=target_spacing,
|
| 240 |
+
)
|
| 241 |
+
fitted_volume, fitted_mask = downscale_to_fit(
|
| 242 |
+
resampled_volume,
|
| 243 |
+
resampled_mask,
|
| 244 |
+
target_shape=target_shape,
|
| 245 |
+
)
|
| 246 |
+
fitted_volume = np.clip(fitted_volume, -1.0, 1.0).astype(np.float32, copy=False)
|
| 247 |
+
fitted_volume[~fitted_mask] = float(background_value)
|
| 248 |
+
|
| 249 |
+
padded_volume = center_pad(
|
| 250 |
+
fitted_volume,
|
| 251 |
+
target_shape=target_shape,
|
| 252 |
+
fill_value=float(background_value),
|
| 253 |
+
).astype(np.float32, copy=False)
|
| 254 |
+
pixel_values = torch.from_numpy(padded_volume).unsqueeze(0).to(dtype=torch.float16).contiguous()
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
"pixel_values": pixel_values,
|
| 258 |
+
"source_image": str(image_path),
|
| 259 |
+
"source_shape": list(source_shape),
|
| 260 |
+
"source_spacing": list(source_spacing),
|
| 261 |
+
"crop_bbox": [[int(start), int(stop)] for start, stop in crop_bbox],
|
| 262 |
+
"foreground_strategy": foreground_strategy,
|
| 263 |
+
"recipe_id": recipe_id,
|
| 264 |
+
"cache_version": int(cache_version),
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def validate_fixed_payload(
|
| 269 |
+
payload: Mapping[str, Any],
|
| 270 |
+
*,
|
| 271 |
+
target_shape: tuple[int, int, int] = TARGET_SHAPE,
|
| 272 |
+
) -> None:
|
| 273 |
+
pixel_values = payload.get("pixel_values")
|
| 274 |
+
if not isinstance(pixel_values, torch.Tensor):
|
| 275 |
+
raise TypeError("`pixel_values` must be a torch.Tensor.")
|
| 276 |
+
expected_shape = (1,) + tuple(target_shape)
|
| 277 |
+
if tuple(pixel_values.shape) != expected_shape:
|
| 278 |
+
raise ValueError(f"Expected tensor shape {expected_shape}, got {tuple(pixel_values.shape)}.")
|
| 279 |
+
if pixel_values.dtype != torch.float16:
|
| 280 |
+
raise ValueError(f"Expected tensor dtype torch.float16, got {pixel_values.dtype}.")
|
| 281 |
+
if not torch.isfinite(pixel_values).all():
|
| 282 |
+
raise ValueError("Tensor contains non-finite values.")
|
| 283 |
+
min_value = float(pixel_values.min().item())
|
| 284 |
+
max_value = float(pixel_values.max().item())
|
| 285 |
+
if min_value < -1.01 or max_value > 1.01:
|
| 286 |
+
raise ValueError(f"Expected tensor values in [-1, 1]. Got min={min_value}, max={max_value}.")
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"canonicalize_orientation": true,
|
| 3 |
+
"clip_percentiles": [
|
| 4 |
+
0.5,
|
| 5 |
+
99.5
|
| 6 |
+
],
|
| 7 |
+
"crop_margin": 4,
|
| 8 |
+
"do_clip": true,
|
| 9 |
+
"do_crop_foreground": true,
|
| 10 |
+
"do_normalize": true,
|
| 11 |
+
"effective_pad_value": -1.0,
|
| 12 |
+
"foreground_threshold": 0.001,
|
| 13 |
+
"image_processor_type": "BrainMRISiglipVolumeProcessor",
|
| 14 |
+
"interpolation_mode": "trilinear",
|
| 15 |
+
"max_channel_dim": 4,
|
| 16 |
+
"output_range": [
|
| 17 |
+
-1.0,
|
| 18 |
+
1.0
|
| 19 |
+
],
|
| 20 |
+
"pad_value": null,
|
| 21 |
+
"path_background_value": -1.0,
|
| 22 |
+
"path_crop_margin_mm": 5.0,
|
| 23 |
+
"path_foreground_strategy": "largest_component_nonzero",
|
| 24 |
+
"path_foreground_threshold": 0.001,
|
| 25 |
+
"path_generic_cache_version": 1,
|
| 26 |
+
"path_generic_recipe_id": "generic_foreground_128x192x192_fp16_v1",
|
| 27 |
+
"path_recipe_mode": "auto",
|
| 28 |
+
"path_target_shape": [
|
| 29 |
+
128,
|
| 30 |
+
192,
|
| 31 |
+
192
|
| 32 |
+
],
|
| 33 |
+
"path_target_spacing": [
|
| 34 |
+
1.25,
|
| 35 |
+
1.0,
|
| 36 |
+
1.0
|
| 37 |
+
],
|
| 38 |
+
"prefer_nibabel_resample": false,
|
| 39 |
+
"resize_strategy": "pad_or_crop",
|
| 40 |
+
"spacing": [
|
| 41 |
+
1.25,
|
| 42 |
+
1.0,
|
| 43 |
+
1.0
|
| 44 |
+
],
|
| 45 |
+
"spacing_tolerance": 0.001,
|
| 46 |
+
"use_foreground_intensity_stats": true,
|
| 47 |
+
"volume_size": [
|
| 48 |
+
128,
|
| 49 |
+
192,
|
| 50 |
+
192
|
| 51 |
+
]
|
| 52 |
+
}
|
processing_brain_mri_siglip.py
ADDED
|
@@ -0,0 +1,680 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Processor code for Brain MRI SigLIP."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from transformers import AutoTokenizer
|
| 14 |
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| 15 |
+
from transformers.processing_utils import ProcessorMixin
|
| 16 |
+
from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
|
| 17 |
+
|
| 18 |
+
from .common import copy_remote_code_files, to_3tuple
|
| 19 |
+
from .offline_aligned_preprocessing import (
|
| 20 |
+
BACKGROUND_VALUE as DEFAULT_PATH_BACKGROUND_VALUE,
|
| 21 |
+
CROP_MARGIN_MM as DEFAULT_PATH_CROP_MARGIN_MM,
|
| 22 |
+
FOREGROUND_STRATEGY as DEFAULT_PATH_FOREGROUND_STRATEGY,
|
| 23 |
+
FOREGROUND_THRESHOLD as DEFAULT_PATH_FOREGROUND_THRESHOLD,
|
| 24 |
+
GENERIC_CACHE_VERSION,
|
| 25 |
+
GENERIC_RECIPE_ID,
|
| 26 |
+
TARGET_SHAPE as DEFAULT_PATH_TARGET_SHAPE,
|
| 27 |
+
TARGET_SPACING as DEFAULT_PATH_TARGET_SPACING,
|
| 28 |
+
preprocess_image_with_foreground_mask,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
from scripts.fomo_300k_offline_pt.common import (
|
| 33 |
+
is_fomo_300k_path,
|
| 34 |
+
preprocess_fomo_300k_image,
|
| 35 |
+
)
|
| 36 |
+
except Exception: # pragma: no cover - optional import surface
|
| 37 |
+
is_fomo_300k_path = None
|
| 38 |
+
preprocess_fomo_300k_image = None
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
from scripts.mr_rate_offline_pt.common import (
|
| 42 |
+
is_mr_rate_path,
|
| 43 |
+
preprocess_mr_rate_image,
|
| 44 |
+
)
|
| 45 |
+
except Exception: # pragma: no cover - optional import surface
|
| 46 |
+
is_mr_rate_path = None
|
| 47 |
+
preprocess_mr_rate_image = None
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
import nibabel as nib
|
| 51 |
+
try:
|
| 52 |
+
from nibabel import processing as nib_processing
|
| 53 |
+
except Exception: # pragma: no cover - optional import
|
| 54 |
+
nib_processing = None
|
| 55 |
+
except Exception: # pragma: no cover - optional import
|
| 56 |
+
nib = None
|
| 57 |
+
nib_processing = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
VolumeInput = Union[str, Path, np.ndarray, torch.Tensor]
|
| 61 |
+
SpacingInput = Optional[Union[Sequence[float], Sequence[Sequence[float]]]]
|
| 62 |
+
LOGGER = logging.getLogger(__name__)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _ensure_list(values: Union[VolumeInput, Sequence[VolumeInput]]) -> List[VolumeInput]:
|
| 66 |
+
if isinstance(values, (str, Path, np.ndarray, torch.Tensor)):
|
| 67 |
+
return [values]
|
| 68 |
+
return list(values)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _normalize_spacing_value(value: Optional[Sequence[float]], field_name: str) -> Optional[Tuple[float, float, float]]:
|
| 72 |
+
if value is None:
|
| 73 |
+
return None
|
| 74 |
+
if len(value) != 3:
|
| 75 |
+
raise ValueError(f"`{field_name}` must be a length-3 sequence. Got: {value}")
|
| 76 |
+
return (float(value[0]), float(value[1]), float(value[2]))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _ensure_spacing_list(
|
| 80 |
+
source_spacings: SpacingInput,
|
| 81 |
+
batch_size: int,
|
| 82 |
+
) -> List[Optional[Tuple[float, float, float]]]:
|
| 83 |
+
if source_spacings is None:
|
| 84 |
+
return [None] * batch_size
|
| 85 |
+
if batch_size == 1 and isinstance(source_spacings, Sequence) and len(source_spacings) == 3 and not isinstance(
|
| 86 |
+
source_spacings[0], (list, tuple)
|
| 87 |
+
):
|
| 88 |
+
return [_normalize_spacing_value(source_spacings, "source_spacing")]
|
| 89 |
+
values = list(source_spacings)
|
| 90 |
+
if len(values) != batch_size:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f"`source_spacings` must have length {batch_size} to match the input batch. Got {len(values)}."
|
| 93 |
+
)
|
| 94 |
+
return [_normalize_spacing_value(value, "source_spacing") for value in values]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _normalize_shape_value(
|
| 98 |
+
value: Sequence[int],
|
| 99 |
+
field_name: str,
|
| 100 |
+
) -> Tuple[int, int, int]:
|
| 101 |
+
normalized = to_3tuple(value, field_name)
|
| 102 |
+
return (int(normalized[0]), int(normalized[1]), int(normalized[2]))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class BrainMRISiglipVolumeProcessor(BaseImageProcessor):
|
| 106 |
+
"""Image processor for 3D brain MRI volumes."""
|
| 107 |
+
|
| 108 |
+
model_input_names = ["pixel_values"]
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
volume_size: Union[int, Sequence[int]] = (128, 192, 192),
|
| 113 |
+
clip_percentiles: Tuple[float, float] = (0.5, 99.5),
|
| 114 |
+
output_range: Tuple[float, float] = (-1.0, 1.0),
|
| 115 |
+
do_clip: bool = True,
|
| 116 |
+
do_normalize: bool = True,
|
| 117 |
+
interpolation_mode: str = "trilinear",
|
| 118 |
+
max_channel_dim: int = 4,
|
| 119 |
+
canonicalize_orientation: bool = True,
|
| 120 |
+
spacing: Optional[Sequence[float]] = None,
|
| 121 |
+
spacing_tolerance: float = 1e-3,
|
| 122 |
+
prefer_nibabel_resample: bool = True,
|
| 123 |
+
use_foreground_intensity_stats: bool = True,
|
| 124 |
+
do_crop_foreground: bool = True,
|
| 125 |
+
foreground_threshold: float = 1e-3,
|
| 126 |
+
crop_margin: int = 4,
|
| 127 |
+
resize_strategy: str = "pad_or_crop",
|
| 128 |
+
pad_value: Optional[float] = None,
|
| 129 |
+
path_recipe_mode: str = "auto",
|
| 130 |
+
path_target_shape: Union[int, Sequence[int]] = DEFAULT_PATH_TARGET_SHAPE,
|
| 131 |
+
path_target_spacing: Optional[Sequence[float]] = DEFAULT_PATH_TARGET_SPACING,
|
| 132 |
+
path_crop_margin_mm: float = DEFAULT_PATH_CROP_MARGIN_MM,
|
| 133 |
+
path_foreground_threshold: float = DEFAULT_PATH_FOREGROUND_THRESHOLD,
|
| 134 |
+
path_background_value: float = DEFAULT_PATH_BACKGROUND_VALUE,
|
| 135 |
+
path_foreground_strategy: str = DEFAULT_PATH_FOREGROUND_STRATEGY,
|
| 136 |
+
path_generic_recipe_id: str = GENERIC_RECIPE_ID,
|
| 137 |
+
path_generic_cache_version: int = GENERIC_CACHE_VERSION,
|
| 138 |
+
**kwargs: Any,
|
| 139 |
+
) -> None:
|
| 140 |
+
super().__init__(**kwargs)
|
| 141 |
+
self.volume_size = list(to_3tuple(volume_size, "volume_size"))
|
| 142 |
+
self.clip_percentiles = (float(clip_percentiles[0]), float(clip_percentiles[1]))
|
| 143 |
+
self.output_range = (float(output_range[0]), float(output_range[1]))
|
| 144 |
+
self.do_clip = bool(do_clip)
|
| 145 |
+
self.do_normalize = bool(do_normalize)
|
| 146 |
+
self.interpolation_mode = str(interpolation_mode)
|
| 147 |
+
self.max_channel_dim = int(max_channel_dim)
|
| 148 |
+
self.canonicalize_orientation = bool(canonicalize_orientation)
|
| 149 |
+
self.spacing = list(_normalize_spacing_value(spacing, "spacing")) if spacing is not None else None
|
| 150 |
+
self.spacing_tolerance = float(spacing_tolerance)
|
| 151 |
+
self.prefer_nibabel_resample = bool(prefer_nibabel_resample)
|
| 152 |
+
self.use_foreground_intensity_stats = bool(use_foreground_intensity_stats)
|
| 153 |
+
self.do_crop_foreground = bool(do_crop_foreground)
|
| 154 |
+
self.foreground_threshold = float(foreground_threshold)
|
| 155 |
+
self.crop_margin = int(crop_margin)
|
| 156 |
+
self.resize_strategy = str(resize_strategy)
|
| 157 |
+
self.pad_value = None if pad_value is None else float(pad_value)
|
| 158 |
+
self.path_recipe_mode = str(path_recipe_mode)
|
| 159 |
+
self.path_target_shape = list(_normalize_shape_value(path_target_shape, "path_target_shape"))
|
| 160 |
+
self.path_target_spacing = (
|
| 161 |
+
list(_normalize_spacing_value(path_target_spacing, "path_target_spacing"))
|
| 162 |
+
if path_target_spacing is not None
|
| 163 |
+
else None
|
| 164 |
+
)
|
| 165 |
+
self.path_crop_margin_mm = float(path_crop_margin_mm)
|
| 166 |
+
self.path_foreground_threshold = float(path_foreground_threshold)
|
| 167 |
+
self.path_background_value = float(path_background_value)
|
| 168 |
+
self.path_foreground_strategy = str(path_foreground_strategy)
|
| 169 |
+
self.path_generic_recipe_id = str(path_generic_recipe_id)
|
| 170 |
+
self.path_generic_cache_version = int(path_generic_cache_version)
|
| 171 |
+
self.effective_pad_value = self._resolve_pad_value()
|
| 172 |
+
if self.max_channel_dim <= 0:
|
| 173 |
+
raise ValueError(f"`max_channel_dim` must be > 0. Got {self.max_channel_dim}.")
|
| 174 |
+
if not (0.0 <= self.clip_percentiles[0] < self.clip_percentiles[1] <= 100.0):
|
| 175 |
+
raise ValueError(
|
| 176 |
+
"`clip_percentiles` must satisfy 0 <= low < high <= 100. "
|
| 177 |
+
f"Got {self.clip_percentiles}."
|
| 178 |
+
)
|
| 179 |
+
if self.resize_strategy not in {"pad_or_crop", "interpolate"}:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
"`resize_strategy` must be one of: pad_or_crop, interpolate. "
|
| 182 |
+
f"Got {self.resize_strategy!r}."
|
| 183 |
+
)
|
| 184 |
+
if self.path_recipe_mode not in {"auto", "legacy"}:
|
| 185 |
+
raise ValueError(
|
| 186 |
+
"`path_recipe_mode` must be one of: auto, legacy. "
|
| 187 |
+
f"Got {self.path_recipe_mode!r}."
|
| 188 |
+
)
|
| 189 |
+
if self.path_crop_margin_mm < 0:
|
| 190 |
+
raise ValueError(f"`path_crop_margin_mm` must be >= 0. Got {self.path_crop_margin_mm}.")
|
| 191 |
+
if self.path_foreground_threshold < 0:
|
| 192 |
+
raise ValueError(
|
| 193 |
+
f"`path_foreground_threshold` must be >= 0. Got {self.path_foreground_threshold}."
|
| 194 |
+
)
|
| 195 |
+
if self.spacing_tolerance < 0:
|
| 196 |
+
raise ValueError(f"`spacing_tolerance` must be >= 0. Got {self.spacing_tolerance}.")
|
| 197 |
+
|
| 198 |
+
def get_path_recipe_config(self) -> Dict[str, Any]:
|
| 199 |
+
return {
|
| 200 |
+
"path_recipe_mode": self.path_recipe_mode,
|
| 201 |
+
"path_target_shape": list(self.path_target_shape),
|
| 202 |
+
"path_target_spacing": None if self.path_target_spacing is None else list(self.path_target_spacing),
|
| 203 |
+
"path_crop_margin_mm": self.path_crop_margin_mm,
|
| 204 |
+
"path_foreground_threshold": self.path_foreground_threshold,
|
| 205 |
+
"path_background_value": self.path_background_value,
|
| 206 |
+
"path_foreground_strategy": self.path_foreground_strategy,
|
| 207 |
+
"path_generic_recipe_id": self.path_generic_recipe_id,
|
| 208 |
+
"path_generic_cache_version": self.path_generic_cache_version,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
def _target_spacing(self) -> Optional[Tuple[float, float, float]]:
|
| 212 |
+
if self.spacing is None:
|
| 213 |
+
return None
|
| 214 |
+
return tuple(float(item) for item in self.spacing)
|
| 215 |
+
|
| 216 |
+
def _resolve_pad_value(self) -> float:
|
| 217 |
+
if self.pad_value is not None:
|
| 218 |
+
return float(self.pad_value)
|
| 219 |
+
if self.do_normalize:
|
| 220 |
+
return float(self.output_range[0])
|
| 221 |
+
return 0.0
|
| 222 |
+
|
| 223 |
+
def _spacing_matches(
|
| 224 |
+
self,
|
| 225 |
+
source_spacing: Optional[Tuple[float, float, float]],
|
| 226 |
+
target_spacing: Optional[Tuple[float, float, float]],
|
| 227 |
+
) -> bool:
|
| 228 |
+
if source_spacing is None or target_spacing is None:
|
| 229 |
+
return False
|
| 230 |
+
return all(abs(src - dst) <= self.spacing_tolerance for src, dst in zip(source_spacing, target_spacing))
|
| 231 |
+
|
| 232 |
+
def _nibabel_resample_order(self) -> int:
|
| 233 |
+
if self.interpolation_mode == "nearest":
|
| 234 |
+
return 0
|
| 235 |
+
return 1
|
| 236 |
+
|
| 237 |
+
def _resample_nifti_image(
|
| 238 |
+
self,
|
| 239 |
+
image,
|
| 240 |
+
source_spacing: Optional[Tuple[float, float, float]],
|
| 241 |
+
) -> tuple[Any, Optional[Tuple[float, float, float]], bool]:
|
| 242 |
+
if not self.prefer_nibabel_resample or nib_processing is None:
|
| 243 |
+
return image, source_spacing, False
|
| 244 |
+
|
| 245 |
+
target_spacing = self._target_spacing()
|
| 246 |
+
if target_spacing is None or self._spacing_matches(source_spacing, target_spacing):
|
| 247 |
+
return image, source_spacing, False
|
| 248 |
+
|
| 249 |
+
resampled = nib_processing.resample_to_output(
|
| 250 |
+
image,
|
| 251 |
+
voxel_sizes=target_spacing,
|
| 252 |
+
order=self._nibabel_resample_order(),
|
| 253 |
+
)
|
| 254 |
+
return resampled, target_spacing, True
|
| 255 |
+
|
| 256 |
+
def _load_volume(
|
| 257 |
+
self,
|
| 258 |
+
value: VolumeInput,
|
| 259 |
+
source_spacing: Optional[Tuple[float, float, float]] = None,
|
| 260 |
+
) -> tuple[np.ndarray, Optional[Tuple[float, float, float]], bool]:
|
| 261 |
+
if isinstance(value, (str, Path)):
|
| 262 |
+
if nib is None:
|
| 263 |
+
raise ImportError("`nibabel` is required to load NIfTI paths.")
|
| 264 |
+
image = nib.load(str(value))
|
| 265 |
+
if self.canonicalize_orientation:
|
| 266 |
+
image = nib.as_closest_canonical(image)
|
| 267 |
+
image_spacing = image.header.get_zooms()[:3]
|
| 268 |
+
resolved_spacing = None
|
| 269 |
+
if len(image_spacing) == 3:
|
| 270 |
+
resolved_spacing = tuple(float(item) for item in image_spacing)
|
| 271 |
+
image, resolved_spacing, used_nibabel_resample = self._resample_nifti_image(image, resolved_spacing)
|
| 272 |
+
return (
|
| 273 |
+
np.asarray(image.get_fdata(dtype=np.float32), dtype=np.float32),
|
| 274 |
+
resolved_spacing,
|
| 275 |
+
used_nibabel_resample,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
if isinstance(value, torch.Tensor):
|
| 279 |
+
return value.detach().cpu().numpy().astype(np.float32, copy=False), source_spacing, False
|
| 280 |
+
|
| 281 |
+
if isinstance(value, np.ndarray):
|
| 282 |
+
return value.astype(np.float32, copy=False), source_spacing, False
|
| 283 |
+
|
| 284 |
+
raise TypeError(f"Unsupported volume input type: {type(value).__name__}")
|
| 285 |
+
|
| 286 |
+
def _preprocess_with_offline_recipe(self, value: VolumeInput) -> Optional[np.ndarray]:
|
| 287 |
+
if self.path_recipe_mode != "auto" or not isinstance(value, (str, Path)):
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
image_path = str(value)
|
| 291 |
+
try:
|
| 292 |
+
if is_mr_rate_path is not None and preprocess_mr_rate_image is not None and is_mr_rate_path(image_path):
|
| 293 |
+
payload = preprocess_mr_rate_image(image_path)
|
| 294 |
+
return payload["pixel_values"].detach().cpu().numpy().astype(np.float32, copy=False)
|
| 295 |
+
if is_fomo_300k_path is not None and preprocess_fomo_300k_image is not None and is_fomo_300k_path(image_path):
|
| 296 |
+
payload = preprocess_fomo_300k_image(image_path)
|
| 297 |
+
return payload["pixel_values"].detach().cpu().numpy().astype(np.float32, copy=False)
|
| 298 |
+
payload = preprocess_image_with_foreground_mask(
|
| 299 |
+
image_path,
|
| 300 |
+
target_shape=tuple(int(value) for value in self.path_target_shape),
|
| 301 |
+
target_spacing=None
|
| 302 |
+
if self.path_target_spacing is None
|
| 303 |
+
else tuple(float(value) for value in self.path_target_spacing),
|
| 304 |
+
crop_margin_mm=self.path_crop_margin_mm,
|
| 305 |
+
foreground_threshold=self.path_foreground_threshold,
|
| 306 |
+
background_value=self.path_background_value,
|
| 307 |
+
foreground_strategy=self.path_foreground_strategy,
|
| 308 |
+
recipe_id=self.path_generic_recipe_id,
|
| 309 |
+
cache_version=self.path_generic_cache_version,
|
| 310 |
+
)
|
| 311 |
+
return payload["pixel_values"].detach().cpu().numpy().astype(np.float32, copy=False)
|
| 312 |
+
except Exception as exc:
|
| 313 |
+
LOGGER.warning(
|
| 314 |
+
"Falling back to legacy online preprocessing for %s after offline-recipe path failed: %s",
|
| 315 |
+
image_path,
|
| 316 |
+
exc,
|
| 317 |
+
)
|
| 318 |
+
return None
|
| 319 |
+
|
| 320 |
+
def _ensure_channel_first(self, volume: np.ndarray) -> np.ndarray:
|
| 321 |
+
if volume.ndim == 3:
|
| 322 |
+
return volume[None, ...]
|
| 323 |
+
if volume.ndim != 4:
|
| 324 |
+
raise ValueError(
|
| 325 |
+
"Volume must be 3D or 4D. For 4D volume, expected channel-first `[C, D, H, W]` "
|
| 326 |
+
"or channel-last `[D, H, W, C]`."
|
| 327 |
+
)
|
| 328 |
+
if volume.shape[0] <= self.max_channel_dim:
|
| 329 |
+
return volume
|
| 330 |
+
if volume.shape[-1] <= self.max_channel_dim:
|
| 331 |
+
return np.moveaxis(volume, -1, 0)
|
| 332 |
+
raise ValueError(
|
| 333 |
+
f"Cannot infer channel dimension for shape {volume.shape}. Expected channel dim <= {self.max_channel_dim}. "
|
| 334 |
+
"Please provide volume in [C, D, H, W] or [D, H, W, C] format."
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
def _foreground_mask(self, volume: np.ndarray) -> np.ndarray:
|
| 338 |
+
threshold = abs(self.foreground_threshold)
|
| 339 |
+
if volume.ndim == 4:
|
| 340 |
+
return np.any(np.abs(volume) > threshold, axis=0)
|
| 341 |
+
return np.abs(volume) > threshold
|
| 342 |
+
|
| 343 |
+
def _intensity_stats_values(self, volume: np.ndarray) -> np.ndarray:
|
| 344 |
+
if not self.use_foreground_intensity_stats:
|
| 345 |
+
return volume.reshape(-1)
|
| 346 |
+
|
| 347 |
+
mask = self._foreground_mask(volume)
|
| 348 |
+
if not mask.any():
|
| 349 |
+
return volume.reshape(-1)
|
| 350 |
+
if volume.ndim == 4:
|
| 351 |
+
return volume[:, mask].reshape(-1)
|
| 352 |
+
return volume[mask].reshape(-1)
|
| 353 |
+
|
| 354 |
+
def _clip_and_normalize(self, volume: np.ndarray) -> np.ndarray:
|
| 355 |
+
output = volume
|
| 356 |
+
if self.do_clip or self.do_normalize:
|
| 357 |
+
# Sanitize before percentile so NaN/inf don't corrupt the result.
|
| 358 |
+
output = np.nan_to_num(output, nan=0.0, posinf=0.0, neginf=0.0)
|
| 359 |
+
stats_values = self._intensity_stats_values(output)
|
| 360 |
+
if self.do_clip:
|
| 361 |
+
flat = stats_values
|
| 362 |
+
if flat.size > 1_000_000:
|
| 363 |
+
# Deterministic stride-based subsample for speed.
|
| 364 |
+
step = max(1, flat.size // 1_000_000)
|
| 365 |
+
flat = flat[::step]
|
| 366 |
+
low, high = np.percentile(flat, self.clip_percentiles)
|
| 367 |
+
else:
|
| 368 |
+
low, high = float(stats_values.min()), float(stats_values.max())
|
| 369 |
+
|
| 370 |
+
if np.isfinite(low) and np.isfinite(high) and high > low:
|
| 371 |
+
if self.do_clip:
|
| 372 |
+
output = np.clip(output, low, high)
|
| 373 |
+
if self.do_normalize:
|
| 374 |
+
out_low, out_high = self.output_range
|
| 375 |
+
output = np.clip((output - low) / (high - low), 0.0, 1.0)
|
| 376 |
+
output = output * (out_high - out_low) + out_low
|
| 377 |
+
elif self.do_normalize:
|
| 378 |
+
output = np.zeros_like(output, dtype=np.float32)
|
| 379 |
+
return output.astype(np.float32, copy=False)
|
| 380 |
+
|
| 381 |
+
def _resample_spacing(
|
| 382 |
+
self,
|
| 383 |
+
volume: np.ndarray,
|
| 384 |
+
source_spacing: Optional[Tuple[float, float, float]],
|
| 385 |
+
affine: Optional[np.ndarray] = None,
|
| 386 |
+
) -> np.ndarray:
|
| 387 |
+
if self.spacing is None or source_spacing is None:
|
| 388 |
+
return volume
|
| 389 |
+
|
| 390 |
+
target_spacing = self._target_spacing()
|
| 391 |
+
if self._spacing_matches(source_spacing, target_spacing):
|
| 392 |
+
return volume
|
| 393 |
+
|
| 394 |
+
target_shape = []
|
| 395 |
+
for current_size, src, dst in zip(volume.shape[1:], source_spacing, target_spacing):
|
| 396 |
+
target_shape.append(max(1, int(round(float(current_size) * float(src) / float(dst)))))
|
| 397 |
+
|
| 398 |
+
if tuple(target_shape) == tuple(int(dim) for dim in volume.shape[1:]):
|
| 399 |
+
return volume
|
| 400 |
+
|
| 401 |
+
tensor = torch.from_numpy(volume).unsqueeze(0)
|
| 402 |
+
tensor = F.interpolate(
|
| 403 |
+
tensor,
|
| 404 |
+
size=tuple(target_shape),
|
| 405 |
+
mode=self.interpolation_mode,
|
| 406 |
+
align_corners=False if self.interpolation_mode in {"linear", "bilinear", "bicubic", "trilinear"} else None,
|
| 407 |
+
)
|
| 408 |
+
return tensor.squeeze(0).numpy().astype(np.float32, copy=False)
|
| 409 |
+
|
| 410 |
+
# def _crop_foreground(self, volume: np.ndarray) -> np.ndarray:
|
| 411 |
+
# if not self.do_crop_foreground:
|
| 412 |
+
# return volume
|
| 413 |
+
|
| 414 |
+
# # Per-axis projection avoids the massive temporary arrays from np.where.
|
| 415 |
+
# src = volume[0] if volume.ndim == 4 else volume
|
| 416 |
+
# mask = src > self.foreground_threshold
|
| 417 |
+
# if not mask.any():
|
| 418 |
+
# return volume
|
| 419 |
+
|
| 420 |
+
# slices = []
|
| 421 |
+
# for dim in range(mask.ndim):
|
| 422 |
+
# proj = mask.any(axis=tuple(d for d in range(mask.ndim) if d != dim))
|
| 423 |
+
# lo = int(np.argmax(proj))
|
| 424 |
+
# hi = len(proj) - 1 - int(np.argmax(proj[::-1]))
|
| 425 |
+
# slices.append(slice(lo, hi + 1))
|
| 426 |
+
|
| 427 |
+
# return volume[(slice(None),) + tuple(slices)].astype(np.float32, copy=False)
|
| 428 |
+
def _crop_foreground(self, volume: np.ndarray) -> np.ndarray:
|
| 429 |
+
if not self.do_crop_foreground:
|
| 430 |
+
return volume
|
| 431 |
+
|
| 432 |
+
margin = self.crop_margin
|
| 433 |
+
src = self._foreground_mask(volume)
|
| 434 |
+
|
| 435 |
+
if not src.any():
|
| 436 |
+
return volume
|
| 437 |
+
|
| 438 |
+
slices = []
|
| 439 |
+
for dim in range(src.ndim):
|
| 440 |
+
proj = src.any(axis=tuple(d for d in range(src.ndim) if d != dim))
|
| 441 |
+
lo = int(np.argmax(proj))
|
| 442 |
+
hi = len(proj) - 1 - int(np.argmax(proj[::-1]))
|
| 443 |
+
|
| 444 |
+
lo = max(0, lo - margin)
|
| 445 |
+
hi = min(src.shape[dim] - 1, hi + margin)
|
| 446 |
+
|
| 447 |
+
slices.append(slice(lo, hi + 1))
|
| 448 |
+
|
| 449 |
+
return volume[(slice(None),) + tuple(slices)].astype(np.float32, copy=False)
|
| 450 |
+
def _pad_or_crop_volume(self, volume: np.ndarray) -> np.ndarray:
|
| 451 |
+
target_size = tuple(int(v) for v in self.volume_size)
|
| 452 |
+
if volume.shape[1:] == target_size:
|
| 453 |
+
return volume
|
| 454 |
+
|
| 455 |
+
slices = [slice(None)]
|
| 456 |
+
for current, target in zip(volume.shape[1:], target_size):
|
| 457 |
+
if current > target:
|
| 458 |
+
start = max(0, (current - target) // 2)
|
| 459 |
+
slices.append(slice(start, start + target))
|
| 460 |
+
else:
|
| 461 |
+
slices.append(slice(0, current))
|
| 462 |
+
cropped = volume[tuple(slices)]
|
| 463 |
+
|
| 464 |
+
pad_width = [(0, 0)]
|
| 465 |
+
for current, target in zip(cropped.shape[1:], target_size):
|
| 466 |
+
if current < target:
|
| 467 |
+
delta = target - current
|
| 468 |
+
before = delta // 2
|
| 469 |
+
after = delta - before
|
| 470 |
+
pad_width.append((before, after))
|
| 471 |
+
else:
|
| 472 |
+
pad_width.append((0, 0))
|
| 473 |
+
if any(before != 0 or after != 0 for before, after in pad_width[1:]):
|
| 474 |
+
cropped = np.pad(
|
| 475 |
+
cropped,
|
| 476 |
+
pad_width=pad_width,
|
| 477 |
+
mode="constant",
|
| 478 |
+
constant_values=self.effective_pad_value,
|
| 479 |
+
)
|
| 480 |
+
return cropped.astype(np.float32, copy=False)
|
| 481 |
+
|
| 482 |
+
def _resize_volume(self, volume: np.ndarray) -> np.ndarray:
|
| 483 |
+
target_size = tuple(int(v) for v in self.volume_size)
|
| 484 |
+
if volume.shape[1:] == target_size:
|
| 485 |
+
return volume
|
| 486 |
+
if self.resize_strategy == "pad_or_crop":
|
| 487 |
+
return self._pad_or_crop_volume(volume)
|
| 488 |
+
|
| 489 |
+
tensor = torch.from_numpy(volume).unsqueeze(0)
|
| 490 |
+
tensor = F.interpolate(
|
| 491 |
+
tensor,
|
| 492 |
+
size=target_size,
|
| 493 |
+
mode=self.interpolation_mode,
|
| 494 |
+
align_corners=False if self.interpolation_mode in {"linear", "bilinear", "bicubic", "trilinear"} else None,
|
| 495 |
+
)
|
| 496 |
+
return tensor.squeeze(0).numpy().astype(np.float32, copy=False)
|
| 497 |
+
|
| 498 |
+
def preprocess(
|
| 499 |
+
self,
|
| 500 |
+
volumes: Union[VolumeInput, Sequence[VolumeInput]],
|
| 501 |
+
return_tensors: Optional[Union[str, bool]] = "pt",
|
| 502 |
+
source_spacings: SpacingInput = None,
|
| 503 |
+
**kwargs: Any,
|
| 504 |
+
) -> BatchFeature:
|
| 505 |
+
del kwargs
|
| 506 |
+
items = _ensure_list(volumes)
|
| 507 |
+
spacing_values = _ensure_spacing_list(source_spacings, len(items))
|
| 508 |
+
batch = []
|
| 509 |
+
for item, source_spacing in zip(items, spacing_values):
|
| 510 |
+
recipe_aligned = self._preprocess_with_offline_recipe(item)
|
| 511 |
+
if recipe_aligned is not None:
|
| 512 |
+
batch.append(torch.from_numpy(recipe_aligned))
|
| 513 |
+
continue
|
| 514 |
+
volume, loaded_spacing, used_nibabel_resample = self._load_volume(item, source_spacing=source_spacing)
|
| 515 |
+
volume = self._ensure_channel_first(volume)
|
| 516 |
+
if not used_nibabel_resample:
|
| 517 |
+
volume = self._resample_spacing(volume, source_spacing=loaded_spacing)
|
| 518 |
+
volume = self._crop_foreground(volume)
|
| 519 |
+
volume = self._clip_and_normalize(volume)
|
| 520 |
+
volume = self._resize_volume(volume)
|
| 521 |
+
batch.append(torch.from_numpy(volume))
|
| 522 |
+
|
| 523 |
+
pixel_values = torch.stack(batch, dim=0).to(dtype=torch.float32)
|
| 524 |
+
return BatchFeature(data={"pixel_values": pixel_values}, tensor_type=return_tensors)
|
| 525 |
+
|
| 526 |
+
def __call__(
|
| 527 |
+
self,
|
| 528 |
+
volumes: Union[VolumeInput, Sequence[VolumeInput]],
|
| 529 |
+
return_tensors: Optional[Union[str, bool]] = "pt",
|
| 530 |
+
**kwargs: Any,
|
| 531 |
+
) -> BatchFeature:
|
| 532 |
+
return self.preprocess(volumes=volumes, return_tensors=return_tensors, **kwargs)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
class BrainMRISiglipProcessor(ProcessorMixin):
|
| 536 |
+
"""Processor wrapping MRI volume processor + tokenizer."""
|
| 537 |
+
|
| 538 |
+
attributes = ["image_processor", "tokenizer"]
|
| 539 |
+
image_processor_class = "BaseImageProcessor"
|
| 540 |
+
tokenizer_class = "AutoTokenizer"
|
| 541 |
+
|
| 542 |
+
def __init__(self, image_processor: BrainMRISiglipVolumeProcessor, tokenizer) -> None:
|
| 543 |
+
super().__init__(image_processor=image_processor, tokenizer=tokenizer)
|
| 544 |
+
|
| 545 |
+
@classmethod
|
| 546 |
+
def from_text_pretrained(
|
| 547 |
+
cls,
|
| 548 |
+
text_model_name_or_path: str = "google/medsiglip-448",
|
| 549 |
+
volume_size: Union[int, Sequence[int]] = (128, 192, 192),
|
| 550 |
+
local_files_only: bool = False,
|
| 551 |
+
trust_remote_code: bool = True,
|
| 552 |
+
**kwargs: Any,
|
| 553 |
+
) -> "BrainMRISiglipProcessor":
|
| 554 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 555 |
+
text_model_name_or_path,
|
| 556 |
+
local_files_only=local_files_only,
|
| 557 |
+
trust_remote_code=trust_remote_code,
|
| 558 |
+
)
|
| 559 |
+
image_processor = BrainMRISiglipVolumeProcessor(volume_size=volume_size, **kwargs)
|
| 560 |
+
return cls(image_processor=image_processor, tokenizer=tokenizer)
|
| 561 |
+
|
| 562 |
+
@classmethod
|
| 563 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs: Any):
|
| 564 |
+
image_processor_kwargs = dict(kwargs.pop("image_processor_kwargs", {}) or {})
|
| 565 |
+
tokenizer_kwargs = dict(kwargs.pop("tokenizer_kwargs", {}) or {})
|
| 566 |
+
|
| 567 |
+
# Backward-compatible convenience: treat image-specific keys as image processor kwargs.
|
| 568 |
+
image_only_keys = {
|
| 569 |
+
"volume_size",
|
| 570 |
+
"clip_percentiles",
|
| 571 |
+
"output_range",
|
| 572 |
+
"do_clip",
|
| 573 |
+
"do_normalize",
|
| 574 |
+
"interpolation_mode",
|
| 575 |
+
"max_channel_dim",
|
| 576 |
+
"canonicalize_orientation",
|
| 577 |
+
"spacing",
|
| 578 |
+
"spacing_tolerance",
|
| 579 |
+
"prefer_nibabel_resample",
|
| 580 |
+
"use_foreground_intensity_stats",
|
| 581 |
+
"do_crop_foreground",
|
| 582 |
+
"foreground_threshold",
|
| 583 |
+
"crop_margin",
|
| 584 |
+
"resize_strategy",
|
| 585 |
+
"pad_value",
|
| 586 |
+
"path_recipe_mode",
|
| 587 |
+
"path_target_shape",
|
| 588 |
+
"path_target_spacing",
|
| 589 |
+
"path_crop_margin_mm",
|
| 590 |
+
"path_foreground_threshold",
|
| 591 |
+
"path_background_value",
|
| 592 |
+
"path_foreground_strategy",
|
| 593 |
+
"path_generic_recipe_id",
|
| 594 |
+
"path_generic_cache_version",
|
| 595 |
+
}
|
| 596 |
+
shared_kwargs = dict(kwargs)
|
| 597 |
+
for key in list(shared_kwargs.keys()):
|
| 598 |
+
if key in image_only_keys and key not in image_processor_kwargs:
|
| 599 |
+
image_processor_kwargs[key] = shared_kwargs.pop(key)
|
| 600 |
+
|
| 601 |
+
image_processor = BrainMRISiglipVolumeProcessor.from_pretrained(
|
| 602 |
+
pretrained_model_name_or_path,
|
| 603 |
+
**shared_kwargs,
|
| 604 |
+
**image_processor_kwargs,
|
| 605 |
+
)
|
| 606 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 607 |
+
pretrained_model_name_or_path,
|
| 608 |
+
**shared_kwargs,
|
| 609 |
+
**tokenizer_kwargs,
|
| 610 |
+
)
|
| 611 |
+
return cls(image_processor=image_processor, tokenizer=tokenizer)
|
| 612 |
+
|
| 613 |
+
def save_pretrained(self, save_directory: Union[str, Path], **kwargs: Any) -> tuple[str]:
|
| 614 |
+
save_path = Path(save_directory)
|
| 615 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 616 |
+
self.image_processor.save_pretrained(str(save_path), **kwargs)
|
| 617 |
+
self.tokenizer.save_pretrained(str(save_path), **kwargs)
|
| 618 |
+
processor_config = {
|
| 619 |
+
"processor_class": self.__class__.__name__,
|
| 620 |
+
"auto_map": {"AutoProcessor": "processing_brain_mri_siglip.BrainMRISiglipProcessor"},
|
| 621 |
+
"offline_aligned_preprocessing": self.image_processor.get_path_recipe_config(),
|
| 622 |
+
}
|
| 623 |
+
(save_path / "processor_config.json").write_text(json.dumps(processor_config, indent=2), encoding="utf-8")
|
| 624 |
+
copy_remote_code_files(save_path)
|
| 625 |
+
return (str(save_path),)
|
| 626 |
+
|
| 627 |
+
@property
|
| 628 |
+
def model_input_names(self) -> List[str]:
|
| 629 |
+
names = list(self.tokenizer.model_input_names)
|
| 630 |
+
for item in self.image_processor.model_input_names:
|
| 631 |
+
if item not in names:
|
| 632 |
+
names.append(item)
|
| 633 |
+
return names
|
| 634 |
+
|
| 635 |
+
def __call__(
|
| 636 |
+
self,
|
| 637 |
+
text: Optional[Union[TextInput, PreTokenizedInput, Sequence[TextInput], Sequence[PreTokenizedInput]]] = None,
|
| 638 |
+
volumes: Optional[Union[VolumeInput, Sequence[VolumeInput]]] = None,
|
| 639 |
+
padding: Union[bool, str, PaddingStrategy] = "max_length",
|
| 640 |
+
truncation: Union[bool, str, TruncationStrategy] = True,
|
| 641 |
+
max_length: Optional[int] = None,
|
| 642 |
+
return_tensors: Optional[Union[str, bool]] = "pt",
|
| 643 |
+
**kwargs: Any,
|
| 644 |
+
) -> BatchFeature:
|
| 645 |
+
if text is None and volumes is None:
|
| 646 |
+
raise ValueError("At least one of `text` or `volumes` must be provided.")
|
| 647 |
+
|
| 648 |
+
image_processor_kwargs = dict(kwargs.pop("image_processor_kwargs", {}) or {})
|
| 649 |
+
image_only_keys = {"source_spacings"}
|
| 650 |
+
for key in list(kwargs.keys()):
|
| 651 |
+
if key in image_only_keys and key not in image_processor_kwargs:
|
| 652 |
+
image_processor_kwargs[key] = kwargs.pop(key)
|
| 653 |
+
|
| 654 |
+
data: Dict[str, Any] = {}
|
| 655 |
+
if text is not None:
|
| 656 |
+
text_inputs = self.tokenizer(
|
| 657 |
+
text,
|
| 658 |
+
padding=padding,
|
| 659 |
+
truncation=truncation,
|
| 660 |
+
max_length=max_length,
|
| 661 |
+
return_tensors=return_tensors,
|
| 662 |
+
**kwargs,
|
| 663 |
+
)
|
| 664 |
+
data.update(dict(text_inputs))
|
| 665 |
+
|
| 666 |
+
if volumes is not None:
|
| 667 |
+
image_inputs = self.image_processor(
|
| 668 |
+
volumes=volumes,
|
| 669 |
+
return_tensors=return_tensors,
|
| 670 |
+
**image_processor_kwargs,
|
| 671 |
+
)
|
| 672 |
+
data.update(dict(image_inputs))
|
| 673 |
+
|
| 674 |
+
return BatchFeature(data=data, tensor_type=return_tensors)
|
| 675 |
+
|
| 676 |
+
def batch_decode(self, *args: Any, **kwargs: Any):
|
| 677 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 678 |
+
|
| 679 |
+
def decode(self, *args: Any, **kwargs: Any):
|
| 680 |
+
return self.tokenizer.decode(*args, **kwargs)
|
processor_config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"processor_class": "BrainMRISiglipProcessor",
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoProcessor": "processing_brain_mri_siglip.BrainMRISiglipProcessor"
|
| 5 |
+
},
|
| 6 |
+
"offline_aligned_preprocessing": {
|
| 7 |
+
"path_recipe_mode": "auto",
|
| 8 |
+
"path_target_shape": [
|
| 9 |
+
128,
|
| 10 |
+
192,
|
| 11 |
+
192
|
| 12 |
+
],
|
| 13 |
+
"path_target_spacing": [
|
| 14 |
+
1.25,
|
| 15 |
+
1.0,
|
| 16 |
+
1.0
|
| 17 |
+
],
|
| 18 |
+
"path_crop_margin_mm": 5.0,
|
| 19 |
+
"path_foreground_threshold": 0.001,
|
| 20 |
+
"path_background_value": -1.0,
|
| 21 |
+
"path_foreground_strategy": "largest_component_nonzero",
|
| 22 |
+
"path_generic_recipe_id": "generic_foreground_128x192x192_fp16_v1",
|
| 23 |
+
"path_generic_cache_version": 1
|
| 24 |
+
}
|
| 25 |
+
}
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"eos_token": {
|
| 3 |
+
"content": "</s>",
|
| 4 |
+
"lstrip": true,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": true,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"pad_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": true,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": true,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"unk_token": {
|
| 17 |
+
"content": "<unk>",
|
| 18 |
+
"lstrip": true,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": true,
|
| 21 |
+
"single_word": false
|
| 22 |
+
}
|
| 23 |
+
}
|
spiece.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1e5036bed065526c3c212dfbe288752391797c4bb1a284aa18c9a0b23fcaf8ec
|
| 3 |
+
size 798330
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"1": {
|
| 4 |
+
"content": "</s>",
|
| 5 |
+
"lstrip": true,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": true,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"2": {
|
| 12 |
+
"content": "<unk>",
|
| 13 |
+
"lstrip": true,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": true,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
}
|
| 19 |
+
},
|
| 20 |
+
"additional_special_tokens": [],
|
| 21 |
+
"clean_up_tokenization_spaces": true,
|
| 22 |
+
"do_lower_case": true,
|
| 23 |
+
"eos_token": "</s>",
|
| 24 |
+
"extra_special_tokens": {},
|
| 25 |
+
"model_input_names": [
|
| 26 |
+
"input_ids"
|
| 27 |
+
],
|
| 28 |
+
"model_max_length": 64,
|
| 29 |
+
"pad_token": "</s>",
|
| 30 |
+
"processor_class": "SiglipProcessor",
|
| 31 |
+
"sp_model_kwargs": {},
|
| 32 |
+
"tokenizer_class": "SiglipTokenizer",
|
| 33 |
+
"unk_token": "<unk>"
|
| 34 |
+
}
|