--- license: cc-by-nc-sa-4.0 datasets: - ibrahimhamamci/CT-RATE pipeline_tag: feature-extraction tags: - university-of-kentucky - medical - radiology - chest-ct - vision - lejepa language: - en --- # Model Card for DALE-CT-0 This repository hosts the backbone weights for DALE-CT-0 (Depth-Aware Latent-Euclidean Computed Tomography), a foundational Vision Transformer (ViT-Large) trained on Chest CT scans using the Latent-Euclidean Joint-Embedding Predictive Architecture ([LeJEPA](https://github.com/galilai-group/lejepa)) framework. Unlike its 1S and 2S counterparts, this model was trained purely using self-supervised LeJEPA objectives without any auxiliary supervision, allowing it to learn general-purpose representations of CT volumes in a completely unsupervised manner. This model was developed by the Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI) at the University of Kentucky to serve as a robust feature extractor for downstream medical imaging tasks, including segmentation, multi-instance learning (MIL), and anomaly detection. ## Model Details * **Model Type:** Vision Transformer (ViT-Large) for Chest CT analysis. * **Developed by:** Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI) * **Model Date:** 04/2026 * **Base Model Architecture:** `vit_large_patch14_dinov2` (via `timm`). **Note: This model was randomly initialized and trained entirely from scratch with modified architecture arguments.** * **Input:** 1-channel Grayscale CT Image. * **Output:** Class token and patch tokens. These can be used for various downstream tasks (e.g., classification, anomaly detection, multi-instance learning). * **Embedding Dimension:** 1024 * **Patch Size:** 16 * **Image Size Compatibility:** Native `512x512` resolution. The architecture supports variable input sizes dynamically, provided the height and width are divisible by the patch size (16). * **License:** CC BY-NC-SA 4.0 (Inherited from the CT-RATE dataset terms). ## Intended Uses This model is intended for research purposes in the field of medical imaging and radiology. * **Primary Intended Uses:** * Feature extraction for quantitative analysis of Chest CT scans. * Foundational backbone for downstream models predicting organ anomalies, segmentation, or volume-level analysis via MIL. ## Training Data * **Dataset(s):** The model was trained exclusively on the train split of the [CT-RATE](https://huggingface.co/datasets/ibrahimhamamci/CT-RATE) dataset. * **Preprocessing:** Hounsfield Units (HU) were strictly clipped between `[-997.0, 888.0]`. These values correspond to the 0.5% and 99.5% pixel intensities of the foreground voxels calculated on a subset of the CT-RATE dataset. The clipped values were mapped to a `[0, 1]` range, followed by Z-score normalization utilizing a dataset mean of `-142.39` and standard deviation of `360.97`. ## Training Procedure * **Training System/Framework:** Distributed Data Parallel (DDP) utilizing `bf16` mixed precision. * **Hardware & Scale:** The model was trained for a total of 50,000 iterations. The configuration utilized a batch size of 64 per GPU, a peak learning rate of `3.0e-04` (decaying to `3.0e-05`), and a 5,000-step warmup. * **Training Strategy:** Global and local crops were sampled from within a 12mm physical slab rather than a single 2D plane to ensure anatomical awareness during the self-supervised matching task. ### Self-Supervised Objective: LeJEPA The model was trained using solely the self-supervised LeJEPA objective, which is a strictly non-predictive architecture that combines a spatial invariance loss with Sketched Isotropic Gaussian Regularization (SIGReg): $$\mathcal{L}_{\text{LeJEPA}}=(1-\lambda)\mathcal{L}_{\text{invariance}}+\lambda\mathcal{L}_{\text{SIGReg}}$$ where \\(\lambda=0.02\\). The SIGReg formulation projects embeddings onto a set of random 1D directions to enforce normality via empirical characteristic functions, completely avoiding traditional architectural heuristics like predictor networks. ### Data Augmentation Pipeline A specialized, GPU-accelerated augmentation pipeline generated the multi-crop views required for the LeJEPA architecture. **1. Spatial Cropping** * **Global Crops:** 2 global crops generated per volume, sized at 256x256 pixels, with a random scale between 60% and 100% of the original image dimensions. * **Local Crops:** 8 local crops generated per volume, sized at 144x144 pixels, with a random scale between 30% and 60%. * **Flip & Resize:** Resized using nearest-neighbor interpolation to prevent artificial averaging of HU values, followed by a random horizontal flip (50% probability). **2. Intensity & Noise Augmentations** * **Random Gamma Correction:** A random gamma shift (range: 0.9 to 1.1) was applied independently to all crops with an 80% probability. ## How to Get Started with the Model Because CT scans require strict Hounsfield Unit (HU) windowing and normalization to match the training distribution, you **must** apply the specific preprocessing logic below. *Note: With the native 512px architecture, standard CT slices (512x512) do not require dynamic padding.* ```python import torch import numpy as np import timm from huggingface_hub import hf_hub_download from safetensors.torch import load_file class CTInferenceTransform: """ Applies the exact HU windowing and Z-score normalization used during LeJEPA training. Assumes standard 512x512 CT slice inputs. """ def __init__(self): self.clip_min = -997.0 self.clip_max = 888.0 self.mean_hu = -142.39 self.std_hu = 360.97 # Calculate 0-1 scaled mean and std range_val = self.clip_max - self.clip_min self.norm_mean = (self.mean_hu - self.clip_min) / range_val self.norm_std = self.std_hu / range_val def __call__(self, volume): # Expects a 2D numpy array or torch tensor (H, W) in Hounsfield Units if isinstance(volume, np.ndarray): volume = torch.from_numpy(volume).float() if volume.ndim == 2: volume = volume.unsqueeze(0) # Add channel dim: (1, H, W) # 1. Clamp HU values and map strictly to [0, 1] volume = torch.clamp(volume, self.clip_min, self.clip_max) range_val = self.clip_max - self.clip_min volume = (volume - self.clip_min) / range_val # 2. Z-score standardization volume = (volume - self.norm_mean) / self.norm_std # Returns (1, 1, H, W). For batched inference, stack these along dim=0. return volume.unsqueeze(0) def load_ct_model(repo_id="IBI-CAAI/DALE-CT-0"): """ Downloads and initializes the ViT-Large backbone using timm and safetensors. """ # 1. Initialize the base architecture with overrides (patch_size=16, img_size=512) model = timm.create_model( "vit_large_patch14_dinov2", pretrained=False, num_classes=0, in_chans=1, patch_size=16, img_size=512, dynamic_img_size=True ) # 2. Download and load the custom safetensors weights model_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors") state_dict = load_file(model_path) model.load_state_dict(state_dict, strict=False) model.eval() return model if __name__ == "__main__": # Initialize the transform and the model transform = CTInferenceTransform() model = load_ct_model() # Simulate a raw CT slice (Replace this with an actual NIfTI/DICOM load in Hounsfield Units) raw_ct_slice = np.random.uniform(-1000, 1000, size=(512, 512)) # Process the image to ensure correct normalization input_tensor = transform(raw_ct_slice) # Extract embeddings with torch.no_grad(): # Option A: Get the single pooled global feature for the entire slice global_feature = model(input_tensor) # Option B: Get the unpooled, dense spatial patch tokens (for fine-grained tasks like Segmentation) patch_tokens = model.forward_features(input_tensor) print(f"Input tensor shape: {input_tensor.shape}") print(f"Extracted features shape: {global_feature.shape}") print(f"Dense patch tokens shape: {patch_tokens.shape}")