DALE-CT-0 / README.md
evn13's picture
Update README.md
8b81431 verified
---
license: cc-by-nc-sa-4.0
datasets:
- ibrahimhamamci/CT-RATE
pipeline_tag: feature-extraction
tags:
- university-of-kentucky
- medical
- radiology
- chest-ct
- vision
- lejepa
language:
- en
---
# Model Card for DALE-CT-0
This repository hosts the backbone weights for DALE-CT-0 (Depth-Aware Latent-Euclidean Computed Tomography), a foundational Vision Transformer (ViT-Large) trained on Chest CT scans using the Latent-Euclidean Joint-Embedding Predictive Architecture ([LeJEPA](https://github.com/galilai-group/lejepa)) framework.
Unlike its 1S and 2S counterparts, this model was trained purely using self-supervised LeJEPA objectives without any auxiliary supervision, allowing it to learn general-purpose representations of CT volumes in a completely unsupervised manner.
This model was developed by the Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI) at the University of Kentucky to serve as a robust feature extractor for downstream medical imaging tasks, including segmentation, multi-instance learning (MIL), and anomaly detection.
## Model Details
* **Model Type:** Vision Transformer (ViT-Large) for Chest CT analysis.
* **Developed by:** Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI)
* **Model Date:** 04/2026
* **Base Model Architecture:** `vit_large_patch14_dinov2` (via `timm`). **Note: This model was randomly initialized and trained entirely from scratch with modified architecture arguments.**
* **Input:** 1-channel Grayscale CT Image.
* **Output:** Class token and patch tokens. These can be used for various downstream tasks (e.g., classification, anomaly detection, multi-instance learning).
* **Embedding Dimension:** 1024
* **Patch Size:** 16
* **Image Size Compatibility:** Native `512x512` resolution. The architecture supports variable input sizes dynamically, provided the height and width are divisible by the patch size (16).
* **License:** CC BY-NC-SA 4.0 (Inherited from the CT-RATE dataset terms).
## Intended Uses
This model is intended for research purposes in the field of medical imaging and radiology.
* **Primary Intended Uses:**
* Feature extraction for quantitative analysis of Chest CT scans.
* Foundational backbone for downstream models predicting organ anomalies, segmentation, or volume-level analysis via MIL.
## Training Data
* **Dataset(s):** The model was trained exclusively on the train split of the [CT-RATE](https://huggingface.co/datasets/ibrahimhamamci/CT-RATE) dataset.
* **Preprocessing:** Hounsfield Units (HU) were strictly clipped between `[-997.0, 888.0]`. These values correspond to the 0.5% and 99.5% pixel intensities of the foreground voxels calculated on a subset of the CT-RATE dataset. The clipped values were mapped to a `[0, 1]` range, followed by Z-score normalization utilizing a dataset mean of `-142.39` and standard deviation of `360.97`.
## Training Procedure
* **Training System/Framework:** Distributed Data Parallel (DDP) utilizing `bf16` mixed precision.
* **Hardware & Scale:** The model was trained for a total of 50,000 iterations. The configuration utilized a batch size of 64 per GPU, a peak learning rate of `3.0e-04` (decaying to `3.0e-05`), and a 5,000-step warmup.
* **Training Strategy:** Global and local crops were sampled from within a 12mm physical slab rather than a single 2D plane to ensure anatomical awareness during the self-supervised matching task.
### Self-Supervised Objective: LeJEPA
The model was trained using solely the self-supervised LeJEPA objective, which is a strictly non-predictive architecture that combines a spatial invariance loss with Sketched Isotropic Gaussian Regularization (SIGReg):
$$\mathcal{L}_{\text{LeJEPA}}=(1-\lambda)\mathcal{L}_{\text{invariance}}+\lambda\mathcal{L}_{\text{SIGReg}}$$
where \\(\lambda=0.02\\). The SIGReg formulation projects embeddings onto a set of random 1D directions to enforce normality via empirical characteristic functions, completely avoiding traditional architectural heuristics like predictor networks.
### Data Augmentation Pipeline
A specialized, GPU-accelerated augmentation pipeline generated the multi-crop views required for the LeJEPA architecture.
**1. Spatial Cropping**
* **Global Crops:** 2 global crops generated per volume, sized at 256x256 pixels, with a random scale between 60% and 100% of the original image dimensions.
* **Local Crops:** 8 local crops generated per volume, sized at 144x144 pixels, with a random scale between 30% and 60%.
* **Flip & Resize:** Resized using nearest-neighbor interpolation to prevent artificial averaging of HU values, followed by a random horizontal flip (50% probability).
**2. Intensity & Noise Augmentations**
* **Random Gamma Correction:** A random gamma shift (range: 0.9 to 1.1) was applied independently to all crops with an 80% probability.
## How to Get Started with the Model
Because CT scans require strict Hounsfield Unit (HU) windowing and normalization to match the training distribution, you **must** apply the specific preprocessing logic below. *Note: With the native 512px architecture, standard CT slices (512x512) do not require dynamic padding.*
```python
import torch
import numpy as np
import timm
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
class CTInferenceTransform:
"""
Applies the exact HU windowing and Z-score normalization used during LeJEPA training.
Assumes standard 512x512 CT slice inputs.
"""
def __init__(self):
self.clip_min = -997.0
self.clip_max = 888.0
self.mean_hu = -142.39
self.std_hu = 360.97
# Calculate 0-1 scaled mean and std
range_val = self.clip_max - self.clip_min
self.norm_mean = (self.mean_hu - self.clip_min) / range_val
self.norm_std = self.std_hu / range_val
def __call__(self, volume):
# Expects a 2D numpy array or torch tensor (H, W) in Hounsfield Units
if isinstance(volume, np.ndarray):
volume = torch.from_numpy(volume).float()
if volume.ndim == 2:
volume = volume.unsqueeze(0) # Add channel dim: (1, H, W)
# 1. Clamp HU values and map strictly to [0, 1]
volume = torch.clamp(volume, self.clip_min, self.clip_max)
range_val = self.clip_max - self.clip_min
volume = (volume - self.clip_min) / range_val
# 2. Z-score standardization
volume = (volume - self.norm_mean) / self.norm_std
# Returns (1, 1, H, W). For batched inference, stack these along dim=0.
return volume.unsqueeze(0)
def load_ct_model(repo_id="IBI-CAAI/DALE-CT-0"):
"""
Downloads and initializes the ViT-Large backbone using timm and safetensors.
"""
# 1. Initialize the base architecture with overrides (patch_size=16, img_size=512)
model = timm.create_model(
"vit_large_patch14_dinov2",
pretrained=False,
num_classes=0,
in_chans=1,
patch_size=16,
img_size=512,
dynamic_img_size=True
)
# 2. Download and load the custom safetensors weights
model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors")
state_dict = load_file(model_path)
model.load_state_dict(state_dict, strict=False)
model.eval()
return model
if __name__ == "__main__":
# Initialize the transform and the model
transform = CTInferenceTransform()
model = load_ct_model()
# Simulate a raw CT slice (Replace this with an actual NIfTI/DICOM load in Hounsfield Units)
raw_ct_slice = np.random.uniform(-1000, 1000, size=(512, 512))
# Process the image to ensure correct normalization
input_tensor = transform(raw_ct_slice)
# Extract embeddings
with torch.no_grad():
# Option A: Get the single pooled global feature for the entire slice
global_feature = model(input_tensor)
# Option B: Get the unpooled, dense spatial patch tokens (for fine-grained tasks like Segmentation)
patch_tokens = model.forward_features(input_tensor)
print(f"Input tensor shape: {input_tensor.shape}")
print(f"Extracted features shape: {global_feature.shape}")
print(f"Dense patch tokens shape: {patch_tokens.shape}")