| --- |
| license: cc-by-nc-sa-4.0 |
| datasets: |
| - ibrahimhamamci/CT-RATE |
| pipeline_tag: feature-extraction |
| tags: |
| - university-of-kentucky |
| - medical |
| - radiology |
| - chest-ct |
| - vision |
| - lejepa |
| language: |
| - en |
| --- |
| # Model Card for DALE-CT-0 |
|
|
| This repository hosts the backbone weights for DALE-CT-0 (Depth-Aware Latent-Euclidean Computed Tomography), a foundational Vision Transformer (ViT-Large) trained on Chest CT scans using the Latent-Euclidean Joint-Embedding Predictive Architecture ([LeJEPA](https://github.com/galilai-group/lejepa)) framework. |
|
|
| Unlike its 1S and 2S counterparts, this model was trained purely using self-supervised LeJEPA objectives without any auxiliary supervision, allowing it to learn general-purpose representations of CT volumes in a completely unsupervised manner. |
|
|
| This model was developed by the Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI) at the University of Kentucky to serve as a robust feature extractor for downstream medical imaging tasks, including segmentation, multi-instance learning (MIL), and anomaly detection. |
|
|
| ## Model Details |
| * **Model Type:** Vision Transformer (ViT-Large) for Chest CT analysis. |
| * **Developed by:** Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI) |
| * **Model Date:** 04/2026 |
| * **Base Model Architecture:** `vit_large_patch14_dinov2` (via `timm`). **Note: This model was randomly initialized and trained entirely from scratch with modified architecture arguments.** |
| * **Input:** 1-channel Grayscale CT Image. |
| * **Output:** Class token and patch tokens. These can be used for various downstream tasks (e.g., classification, anomaly detection, multi-instance learning). |
| * **Embedding Dimension:** 1024 |
| * **Patch Size:** 16 |
| * **Image Size Compatibility:** Native `512x512` resolution. The architecture supports variable input sizes dynamically, provided the height and width are divisible by the patch size (16). |
| * **License:** CC BY-NC-SA 4.0 (Inherited from the CT-RATE dataset terms). |
|
|
| ## Intended Uses |
| This model is intended for research purposes in the field of medical imaging and radiology. |
| * **Primary Intended Uses:** |
| * Feature extraction for quantitative analysis of Chest CT scans. |
| * Foundational backbone for downstream models predicting organ anomalies, segmentation, or volume-level analysis via MIL. |
|
|
| ## Training Data |
| * **Dataset(s):** The model was trained exclusively on the train split of the [CT-RATE](https://huggingface.co/datasets/ibrahimhamamci/CT-RATE) dataset. |
| * **Preprocessing:** Hounsfield Units (HU) were strictly clipped between `[-997.0, 888.0]`. These values correspond to the 0.5% and 99.5% pixel intensities of the foreground voxels calculated on a subset of the CT-RATE dataset. The clipped values were mapped to a `[0, 1]` range, followed by Z-score normalization utilizing a dataset mean of `-142.39` and standard deviation of `360.97`. |
|
|
| ## Training Procedure |
| * **Training System/Framework:** Distributed Data Parallel (DDP) utilizing `bf16` mixed precision. |
| * **Hardware & Scale:** The model was trained for a total of 50,000 iterations. The configuration utilized a batch size of 64 per GPU, a peak learning rate of `3.0e-04` (decaying to `3.0e-05`), and a 5,000-step warmup. |
| * **Training Strategy:** Global and local crops were sampled from within a 12mm physical slab rather than a single 2D plane to ensure anatomical awareness during the self-supervised matching task. |
|
|
| ### Self-Supervised Objective: LeJEPA |
| The model was trained using solely the self-supervised LeJEPA objective, which is a strictly non-predictive architecture that combines a spatial invariance loss with Sketched Isotropic Gaussian Regularization (SIGReg): |
| $$\mathcal{L}_{\text{LeJEPA}}=(1-\lambda)\mathcal{L}_{\text{invariance}}+\lambda\mathcal{L}_{\text{SIGReg}}$$ |
| where \\(\lambda=0.02\\). The SIGReg formulation projects embeddings onto a set of random 1D directions to enforce normality via empirical characteristic functions, completely avoiding traditional architectural heuristics like predictor networks. |
| |
| ### Data Augmentation Pipeline |
| A specialized, GPU-accelerated augmentation pipeline generated the multi-crop views required for the LeJEPA architecture. |
| |
| **1. Spatial Cropping** |
| * **Global Crops:** 2 global crops generated per volume, sized at 256x256 pixels, with a random scale between 60% and 100% of the original image dimensions. |
| * **Local Crops:** 8 local crops generated per volume, sized at 144x144 pixels, with a random scale between 30% and 60%. |
| * **Flip & Resize:** Resized using nearest-neighbor interpolation to prevent artificial averaging of HU values, followed by a random horizontal flip (50% probability). |
| |
| **2. Intensity & Noise Augmentations** |
| * **Random Gamma Correction:** A random gamma shift (range: 0.9 to 1.1) was applied independently to all crops with an 80% probability. |
| |
| ## How to Get Started with the Model |
| |
| Because CT scans require strict Hounsfield Unit (HU) windowing and normalization to match the training distribution, you **must** apply the specific preprocessing logic below. *Note: With the native 512px architecture, standard CT slices (512x512) do not require dynamic padding.* |
| |
| ```python |
| import torch |
| import numpy as np |
| import timm |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
| |
| class CTInferenceTransform: |
| """ |
| Applies the exact HU windowing and Z-score normalization used during LeJEPA training. |
| Assumes standard 512x512 CT slice inputs. |
| """ |
| def __init__(self): |
| self.clip_min = -997.0 |
| self.clip_max = 888.0 |
| self.mean_hu = -142.39 |
| self.std_hu = 360.97 |
| |
| # Calculate 0-1 scaled mean and std |
| range_val = self.clip_max - self.clip_min |
| self.norm_mean = (self.mean_hu - self.clip_min) / range_val |
| self.norm_std = self.std_hu / range_val |
| |
| def __call__(self, volume): |
| # Expects a 2D numpy array or torch tensor (H, W) in Hounsfield Units |
| if isinstance(volume, np.ndarray): |
| volume = torch.from_numpy(volume).float() |
| if volume.ndim == 2: |
| volume = volume.unsqueeze(0) # Add channel dim: (1, H, W) |
| |
| # 1. Clamp HU values and map strictly to [0, 1] |
| volume = torch.clamp(volume, self.clip_min, self.clip_max) |
| range_val = self.clip_max - self.clip_min |
| volume = (volume - self.clip_min) / range_val |
| |
| # 2. Z-score standardization |
| volume = (volume - self.norm_mean) / self.norm_std |
| |
| # Returns (1, 1, H, W). For batched inference, stack these along dim=0. |
| return volume.unsqueeze(0) |
| |
| def load_ct_model(repo_id="IBI-CAAI/DALE-CT-0"): |
| """ |
| Downloads and initializes the ViT-Large backbone using timm and safetensors. |
| """ |
| # 1. Initialize the base architecture with overrides (patch_size=16, img_size=512) |
| model = timm.create_model( |
| "vit_large_patch14_dinov2", |
| pretrained=False, |
| num_classes=0, |
| in_chans=1, |
| patch_size=16, |
| img_size=512, |
| dynamic_img_size=True |
| ) |
| |
| # 2. Download and load the custom safetensors weights |
| model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors") |
| state_dict = load_file(model_path) |
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
| |
| return model |
| |
| if __name__ == "__main__": |
| # Initialize the transform and the model |
| transform = CTInferenceTransform() |
| model = load_ct_model() |
| |
| # Simulate a raw CT slice (Replace this with an actual NIfTI/DICOM load in Hounsfield Units) |
| raw_ct_slice = np.random.uniform(-1000, 1000, size=(512, 512)) |
| |
| # Process the image to ensure correct normalization |
| input_tensor = transform(raw_ct_slice) |
| |
| # Extract embeddings |
| with torch.no_grad(): |
| # Option A: Get the single pooled global feature for the entire slice |
| global_feature = model(input_tensor) |
| |
| # Option B: Get the unpooled, dense spatial patch tokens (for fine-grained tasks like Segmentation) |
| patch_tokens = model.forward_features(input_tensor) |
| |
| print(f"Input tensor shape: {input_tensor.shape}") |
| print(f"Extracted features shape: {global_feature.shape}") |
| print(f"Dense patch tokens shape: {patch_tokens.shape}") |