SMB-Vision-v0
SMB-Vision is a pure-vision backbone designed for medical imaging foundation models, with a strong focus on radiology modalities such as CT, MRI, and X‑ray. It implements efficient 3D patch embedding, rotary position encodings, scalable Transformer blocks, multi‑scale deep feature extraction, and two self‑supervised objectives tailored for medical imagery: masked image modeling (MIM) and joint embedding predictive architecture (JEPA).
Architecture Overview
The implementation lives in modeling_smb_vision.py and exposes three main classes:
SMBVisionEncoder: Vision encoder with 3D patch embedding and stacked Transformer blocksSMBVisionPredictor: Lightweight Transformer for JEPA next‑embedding predictionSMBVisionModel: Wrapper that combines encoder + predictor and computes MIM and JEPA losses
Key components and how they map to the code:
3D Patch Embedding (
SMBVisionPatchEmbed)- A
Conv3dwith kernel=stride=[temporal_patch_size, patch_size, patch_size]over per‑patch tensors - Supports
in_channels= 1 (grayscale), 3 (RGB), or 4; radiology typically uses 1 - Produces per‑patch embeddings of size
hidden_size
- A
Learned 2D Positional Embedding + Fast Interpolation
pos_embed: nn.Embedding(num_position_embeddings, hidden_size)with bilinear‑style interpolation (fast_pos_embed_interpolate) to target grid sizes (height×width) per frame
Rotary Position Embedding (RoPE) in Space (and Time)
SMBVisionRotaryEmbeddinggenerates frequencies; applied in attention viaapply_rotary_pos_emb_vision- Encodes spatial (and slice/temporal) structure for robust geometric reasoning
Transformer Blocks (
SMBVisionBlock)- Pre‑norm residual blocks with
SMBVisionAttentionandSMBVisionMLP - Attention backends: eager, SDPA, FlashAttention‑2 (config‑selectable)
- Pre‑norm residual blocks with
DeepStack Multi‑scale Features
deepstack_visual_indexesselects block indices whose outputs are merged bySMBVisionPatchMerger- Produces multi‑level visual descriptors for downstream tasks (e.g., detection, retrieval)
Masked Image Modeling (MIM)
- Randomly masks a ratio of patch tokens and reconstructs pixels via
to_pixels: Linear(hidden_size -> patch_volume) - Reconstruction loss: L1 (MAE) on masked patches
- Note: For medical grayscale data, set
in_channels=1so reconstruction target matches output shape
- Randomly masks a ratio of patch tokens and reconstructs pixels via
JEPA Next‑Embedding Prediction
- Context/target partitions at the study level expand to patch tokens internally
SMBVisionPredictorpredicts target encoder embeddings; loss is MSE on target tokens
Radiology‑centric Design Notes
- Modalities: CT/MRI volumes (slice stacks) and X‑ray images are supported via patch tokenization
- Through‑plane handling:
temporal_patch_sizeacts as slice depth for 3D patching over the Z/through‑plane axis - Grayscale emphasis: Use
in_channels=1for CT/MRI/X‑ray to align MIM reconstruction shapes - Scalability: Attention backends support SDPA and FlashAttention‑2 for large studies and high‑res inputs
- Multi‑scale features:
deepstack_visual_indexesprovide hooks for detection/segmentation heads
Installation
pip install torch torchvision
pip install transformers nibabel monai smb_biopan_utils
Quick Start (CT volumes)
The encoder expects a list of patch tokens and a per‑sample grid descriptor grid_thw = [T, H, W], where:
T = num_slices / temporal_patch_sizeH = image_height / patch_sizeW = image_width / patch_size
You must first patchify the volume into non‑overlapping 3D patches of shape [in_channels, temporal_patch_size, patch_size, patch_size], flatten each patch to a token, and concatenate all tokens for the batch.
Example helper for NIfTI volumes:
from smb_biopan_utils import process_mm_info
from transformers import AutoModel
# Prepare message spec for your volume(s). Each "image" can be a path to NIfTI/DICOM.
messages = [
{
"content": [
{"type": "image", "image": "dummy.nii.gz"}, # Volume size is [1, 64, 160, 160]
{"type": "image", "image": "dummy.nii.gz"},
]
}
]
# Convert to patch tokens and grid descriptor expected by SMB‑Vision
# Default patch size is 16 for all dimensions
images, grid_thw = process_mm_info(messages) # images size is [800(400*2), 4096]
# Optional - Dummy images and grid_thw
images, grid_thw = torch.randn(800, 4096), torch.tensor([[4, 10, 10], [4, 10, 10]])
# Load backbone from HF Hub (uses this repo's modeling with trust_remote_code)
model = AutoModel.from_pretrained(
"standardmodelbio/smb-vision-v0",
trust_remote_code=True,
dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.to("cuda")
# Encode features
encoded_patches, deepstack_features = model.forward_features(
images.to("cuda"), grid_thw=grid_thw.to("cuda")
)
print(encoded_patches.shape)
# (800, 1152)
API Summary
SMBVisionEncoder.forward(hidden_states, grid_thw)→(encoded_patches, deepstack_features)hidden_states: Float tensor of shape(num_patches, in_channels * temporal_patch_size * patch_size^2)grid_thw: Int tensor of shape(num_studies, 3)with[T, H, W]per study
SMBVisionModel.forward(hidden_states, grid_thw, context_mask, target_mask)→SMBVisionModelOutput- Computes MIM (always) and JEPA (if masks provided)
- Output contains losses and (optionally) encoder/predicted hidden states
SMBVisionModel.forward_features(hidden_states, grid_thw)→(encoded_patches, deepstack_features)- Convenience wrapper that calls the encoder directly for feature extraction
Recommended Radiology Settings
- CT chest/abdomen:
patch_size=16,temporal_patch_size=16,in_channels=1,masking_ratio=0.4–0.6 - MRI brain:
patch_size=16,temporal_patch_size=16(or per‑sequence 2D withtemporal_patch_size=1),masking_ratio=0.4 - X‑ray:
patch_size=16,temporal_patch_size=1,in_channels=1,masking_ratio=0.4-0.6
Notes
- FlashAttention‑2 can be enabled via the attention implementation setting in the vision config
- Ensure volume dimensions are divisible by
patch_sizeandtemporal_patch_size(or center‑crop/pad before patchify) - For multi‑sequence MRI or 4‑channel inputs, set
in_channels=4and adapt reconstruction paths accordingly
Citation
If you use SMB‑Vision in your research, please cite this repository.
- Downloads last month
- 19