You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

SMB-Vision-v0

SMB-Vision is a pure-vision backbone designed for medical imaging foundation models, with a strong focus on radiology modalities such as CT, MRI, and X‑ray. It implements efficient 3D patch embedding, rotary position encodings, scalable Transformer blocks, multi‑scale deep feature extraction, and two self‑supervised objectives tailored for medical imagery: masked image modeling (MIM) and joint embedding predictive architecture (JEPA).

Architecture Overview

The implementation lives in modeling_smb_vision.py and exposes three main classes:

  • SMBVisionEncoder: Vision encoder with 3D patch embedding and stacked Transformer blocks
  • SMBVisionPredictor: Lightweight Transformer for JEPA next‑embedding prediction
  • SMBVisionModel: Wrapper that combines encoder + predictor and computes MIM and JEPA losses

Key components and how they map to the code:

  • 3D Patch Embedding (SMBVisionPatchEmbed)

    • A Conv3d with kernel=stride=[temporal_patch_size, patch_size, patch_size] over per‑patch tensors
    • Supports in_channels = 1 (grayscale), 3 (RGB), or 4; radiology typically uses 1
    • Produces per‑patch embeddings of size hidden_size
  • Learned 2D Positional Embedding + Fast Interpolation

    • pos_embed: nn.Embedding(num_position_embeddings, hidden_size) with bilinear‑style interpolation (fast_pos_embed_interpolate) to target grid sizes (height×width) per frame
  • Rotary Position Embedding (RoPE) in Space (and Time)

    • SMBVisionRotaryEmbedding generates frequencies; applied in attention via apply_rotary_pos_emb_vision
    • Encodes spatial (and slice/temporal) structure for robust geometric reasoning
  • Transformer Blocks (SMBVisionBlock)

    • Pre‑norm residual blocks with SMBVisionAttention and SMBVisionMLP
    • Attention backends: eager, SDPA, FlashAttention‑2 (config‑selectable)
  • DeepStack Multi‑scale Features

    • deepstack_visual_indexes selects block indices whose outputs are merged by SMBVisionPatchMerger
    • Produces multi‑level visual descriptors for downstream tasks (e.g., detection, retrieval)
  • Masked Image Modeling (MIM)

    • Randomly masks a ratio of patch tokens and reconstructs pixels via to_pixels: Linear(hidden_size -> patch_volume)
    • Reconstruction loss: L1 (MAE) on masked patches
    • Note: For medical grayscale data, set in_channels=1 so reconstruction target matches output shape
  • JEPA Next‑Embedding Prediction

    • Context/target partitions at the study level expand to patch tokens internally
    • SMBVisionPredictor predicts target encoder embeddings; loss is MSE on target tokens

Radiology‑centric Design Notes

  • Modalities: CT/MRI volumes (slice stacks) and X‑ray images are supported via patch tokenization
  • Through‑plane handling: temporal_patch_size acts as slice depth for 3D patching over the Z/through‑plane axis
  • Grayscale emphasis: Use in_channels=1 for CT/MRI/X‑ray to align MIM reconstruction shapes
  • Scalability: Attention backends support SDPA and FlashAttention‑2 for large studies and high‑res inputs
  • Multi‑scale features: deepstack_visual_indexes provide hooks for detection/segmentation heads

Installation

pip install torch torchvision
pip install transformers nibabel monai smb_biopan_utils

Quick Start (CT volumes)

The encoder expects a list of patch tokens and a per‑sample grid descriptor grid_thw = [T, H, W], where:

  • T = num_slices / temporal_patch_size
  • H = image_height / patch_size
  • W = image_width / patch_size

You must first patchify the volume into non‑overlapping 3D patches of shape [in_channels, temporal_patch_size, patch_size, patch_size], flatten each patch to a token, and concatenate all tokens for the batch.

Example helper for NIfTI volumes:

from smb_biopan_utils import process_mm_info
from transformers import AutoModel


# Prepare message spec for your volume(s). Each "image" can be a path to NIfTI/DICOM.
messages = [
    {
        "content": [
            {"type": "image", "image": "dummy.nii.gz"}, # Volume size is [1, 64, 160, 160]
            {"type": "image", "image": "dummy.nii.gz"},
        ]
    }
]

# Convert to patch tokens and grid descriptor expected by SMB‑Vision
# Default patch size is 16 for all dimensions
images, grid_thw = process_mm_info(messages) # images size is [800(400*2), 4096]

# Optional - Dummy images and grid_thw
images, grid_thw = torch.randn(800, 4096), torch.tensor([[4, 10, 10], [4, 10, 10]])

# Load backbone from HF Hub (uses this repo's modeling with trust_remote_code)
model = AutoModel.from_pretrained(
    "standardmodelbio/smb-vision-v0",
    trust_remote_code=True,
    dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)
model.to("cuda")

# Encode features
encoded_patches, deepstack_features = model.forward_features(
    images.to("cuda"), grid_thw=grid_thw.to("cuda")
)
print(encoded_patches.shape)
# (800, 1152)

API Summary

  • SMBVisionEncoder.forward(hidden_states, grid_thw)(encoded_patches, deepstack_features)

    • hidden_states: Float tensor of shape (num_patches, in_channels * temporal_patch_size * patch_size^2)
    • grid_thw: Int tensor of shape (num_studies, 3) with [T, H, W] per study
  • SMBVisionModel.forward(hidden_states, grid_thw, context_mask, target_mask)SMBVisionModelOutput

    • Computes MIM (always) and JEPA (if masks provided)
    • Output contains losses and (optionally) encoder/predicted hidden states
  • SMBVisionModel.forward_features(hidden_states, grid_thw)(encoded_patches, deepstack_features)

    • Convenience wrapper that calls the encoder directly for feature extraction

Recommended Radiology Settings

  • CT chest/abdomen: patch_size=16, temporal_patch_size=16, in_channels=1, masking_ratio=0.4–0.6
  • MRI brain: patch_size=16, temporal_patch_size=16 (or per‑sequence 2D with temporal_patch_size=1), masking_ratio=0.4
  • X‑ray: patch_size=16, temporal_patch_size=1, in_channels=1, masking_ratio=0.4-0.6

Notes

  • FlashAttention‑2 can be enabled via the attention implementation setting in the vision config
  • Ensure volume dimensions are divisible by patch_size and temporal_patch_size (or center‑crop/pad before patchify)
  • For multi‑sequence MRI or 4‑channel inputs, set in_channels=4 and adapt reconstruction paths accordingly

Citation

If you use SMB‑Vision in your research, please cite this repository.

Downloads last month
19
Safetensors
Model size
0.6B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including standardmodelbio/smb-vision-v0-mim