Ker-VLJEPA-3B
A multi-modal vision-language model for automated Chest CT radiology report generation. Ker-VLJEPA-3B takes pre-computed 1024-d slice embeddings from Guided-Chest-CT-LeJEPA as input (not raw CT images) and generates free-text narrative findings reports in the style of a radiologist.
Dependency: This model requires IBI-CAAI/Guided-Chest-CT-LeJEPA to first extract per-slice embeddings from raw CT volumes. Ker-VLJEPA-3B does not process CT images directly.
New SOTA on CT-RATE: Macro F1 = 0.429, surpassing the previous state-of-the-art U-VLM (F1 = 0.414) by +3.6%.
Developed by the Institute for Biomedical Informatics Center for Applied AI (IBI-CAAI) at the University of Kentucky.
Model Details
| Property | Value |
|---|---|
| Model Type | Multi-modal Vision-Language Model (VLM) for Chest CT report generation |
| Language Model | Llama 3.2 3B with LoRA adapters (r=64, alpha=128) |
| Input Embeddings | Guided-Chest-CT-LeJEPA (ViT-Large) β required dependency, provides 1024-d per-slice embeddings from raw CT |
| Visual Encoder | Z-Zoned Perceiver β compresses variable-length slice embeddings into 32 fixed visual tokens (3072-d) |
| Visual Grounding | Flamingo-style gated cross-attention adapters at LLM layers 7, 14, 21 |
| Input | Pre-computed LeJEPA slice embeddings (num_slices x 1024) β not raw CT images |
| Output | Free-text narrative radiology findings report |
| Model Date | 03/2026 |
| License | CC BY-NC-SA 4.0 (inherited from CT-RATE dataset terms) |
| Parameters | ~3.2B (Llama 3B base) + 1.7 GB LoRA + 320 MB bridge components |
Architecture
Ker-VLJEPA-3B bridges vision and language through a 3-stage architecture:
Raw CT Volume
β
βΌ
βββββββββββββββββββββββββββββββ
β Guided-Chest-CT-LeJEPA β ViT-Large backbone (separate model)
β (per-slice feature extract) β Produces 1024-d embedding per slice
βββββββββββββββ¬ββββββββββββββββ
β (num_slices, 1024)
βΌ
βββββββββββββββββββββββββββββββ
β Z-Zoned Perceiver β 32 anatomical zones, each with learned queries
β + Global Self-Attention β Compresses variable-length slices β 32 tokens
β + JEPA Predictor (β 3072d) β Projects into LLM hidden space
βββββββββββββββ¬ββββββββββββββββ
β (32, 3072)
βΌ
βββββββββββββββββββββββββββββββ
β Llama 3.2 3B + LoRA β Embedding grafting: visual tokens replace
β + Cross-Attention Adapters β <|visual_region|> placeholders in prompt
β @ layers 7, 14, 21 β Cross-attention: text attends to visual tokens
β + Additive Layer Projectors β at 3 intermediate LLM layers
βββββββββββββββ¬ββββββββββββββββ
β
βΌ
Narrative Report
Key Architectural Innovations
- Z-Zoned Cross-Attention: Divides the body axis into 32 spatial zones. Each zone's queries attend only to slices within their anatomical region, enforcing spatial specialization by construction.
- Flamingo-style Cross-Attention Bridge: Rather than simple additive injection (which the LLM can trivially ignore), text hidden states explicitly attend to visual tokens via cross-attention at layers 7, 14, and 21.
- Warm Bridge Technique: Bridge components (cross-attention adapters, layer projectors, LoRA) are initialized from a prior converged checkpoint rather than random, yielding +4.4% F1 improvement.
- Frozen Cross-Attention + EWC: During narrative fine-tuning (Phase 4), cross-attention adapters are frozen to preserve visual grounding, while Elastic Weight Consolidation constrains LoRA drift.
Intended Uses
This model is intended for research purposes in medical imaging and radiology.
- Automated generation of narrative chest CT findings reports
- Research into vision-language architectures for medical imaging
- Benchmarking on the CT-RATE evaluation protocol
Not intended for: Clinical diagnosis, treatment decisions, or any patient-facing application without proper clinical validation.
Training
Data
Trained exclusively on the train split of CT-RATE (~46,400 volumes), with guided cropping annotations from ReXGroundingCT.
Hardware
8x NVIDIA H200 GPUs with DDP via HuggingFace Accelerate, bf16 mixed precision.
4-Phase Training Pipeline
| Phase | Objective | Trainable Components | Key Hyperparameters | Result |
|---|---|---|---|---|
| Phase 1 | Visual encoder alignment | Visual encoder + JEPA head | LR=5e-5, BS=32, 20 epochs | F1=0.460, AUC=0.811 |
| Phase 2 | Contrastive bridge (InfoNCE + MMD) | Visual encoder + LoRA + JEPA head | LR=3e-5, BS=64, 30 epochs | F1=0.465, AUC=0.816 |
| Phase 3 | Generative fine-tuning (positive-findings-only text) | JEPA predictor + LoRA + cross-attn + layer projectors | LR=2e-5, BS=8, warm bridge init | Gen F1=0.422 |
| Phase 4 | Raw narrative fine-tuning | LoRA only (cross-attn frozen, EWC=100.0) | LR=5e-7, BS=8, 14 epochs | Gen F1=0.429 |
Phase 3 β Phase 4 innovation: In Phase 3, the model trains on positive-findings-only text to learn pathology detection without gradient domination from normal-text tokens (~90% of raw reports). Phase 4 then adapts to raw radiologist prose with frozen cross-attention + ultra-conservative LoRA updates.
Generation Configuration
| Parameter | Value | Rationale |
|---|---|---|
| Temperature | 0.6 | Optimal via sweep β balances diversity and accuracy |
| Top-p | 0.9 | Nucleus sampling |
| Repetition penalty | 1.1 | Prevents degenerate loops |
| No-repeat n-gram | 4 | Medical reports need some anatomical repetition |
| Max new tokens | 384 | Covers 95%+ of reports |
Results
CT-RATE Benchmark (Cross-Method Comparison)
All methods evaluated on the CT-RATE validation set (2,984 volumes) using the official RadBERT classifier for 18-class binary label extraction.
| Method | Macro F1 | Macro Prec | Macro Rec | Notes |
|---|---|---|---|---|
| CT-CLIP (Hamamci et al.) | 0.194 | β | β | Zero-shot |
| CT-CHAT (Hamamci et al.) | 0.287 | β | β | Fine-tuned |
| BTB3D (Song et al.) | 0.354 | β | β | |
| U-VLM (Lee et al.) | 0.414 | 0.491 | 0.429 | Previous SOTA |
| Ker-VLJEPA-3B | 0.429 | 0.389 | 0.524 |
Per-Class Results (2,984 validation samples)
| Class | Prec | Rec | F1 | Support |
|---|---|---|---|---|
| Pleural effusion | 0.574 | 0.789 | 0.664 | 370 |
| Arterial wall calcification | 0.642 | 0.660 | 0.651 | 849 |
| Coronary artery wall calcification | 0.580 | 0.585 | 0.582 | 752 |
| Cardiomegaly | 0.437 | 0.743 | 0.550 | 315 |
| Lung nodule | 0.491 | 0.560 | 0.523 | 1,344 |
| Lung opacity | 0.545 | 0.453 | 0.494 | 1,173 |
| Emphysema | 0.346 | 0.667 | 0.456 | 588 |
| Lymphadenopathy | 0.393 | 0.532 | 0.452 | 769 |
| Pulmonary fibrotic sequela | 0.384 | 0.505 | 0.436 | 819 |
| Atelectasis | 0.369 | 0.496 | 0.423 | 698 |
| Consolidation | 0.430 | 0.389 | 0.408 | 576 |
| Pericardial effusion | 0.238 | 0.614 | 0.343 | 215 |
| Mosaic attenuation pattern | 0.256 | 0.494 | 0.337 | 245 |
| Hiatal hernia | 0.257 | 0.397 | 0.312 | 413 |
| Peribronchial thickening | 0.271 | 0.354 | 0.307 | 347 |
| Medical material | 0.194 | 0.722 | 0.305 | 306 |
| Bronchiectasis | 0.238 | 0.328 | 0.276 | 326 |
| Interlobular septal thickening | 0.357 | 0.142 | 0.203 | 246 |
Macro (default 0.5 threshold): F1=0.429, Prec=0.389, Rec=0.524
Evaluation Protocol
Clinical accuracy follows the CT-RATE evaluation protocol, the same methodology used by CT-CLIP, CT-CHAT, BTB3D, and U-VLM:
- Generate free-text narrative reports from all validation CT volumes (temp=0.6, top_p=0.9)
- Clean degenerate text (truncate at onset of repeated chars/unicode)
- Strip negation suffixes before RadBERT (prevents false positive extraction)
- Extract 18 binary abnormality labels using the official CT-RATE RadBERT classifier with CT-RATE weights
- Compute macro-averaged F1, precision, recall against ground-truth labels
The RadBERT classifier achieves F1=0.982 on ground-truth reports, confirming it is a reliable extraction tool.
How to Use
Prerequisites
pip install torch transformers peft safetensors accelerate
pip install flash-attn --no-build-isolation # recommended for fast generation
You will also need:
- A local copy of meta-llama/Llama-3.2-3B
- Pre-computed LeJEPA slice embeddings from the Guided-Chest-CT-LeJEPA backbone
Repository Contents
Ker-VLJEPA-3B/
βββ README.md # This file
βββ model.py # Inference-only model class
βββ generate_report.py # CLI example
βββ weights/
βββ visual_encoder.safetensors # Z-Zoned Perceiver + JEPA predictor (43 MB)
βββ bridge_components.safetensors # Cross-attn adapters, layer projectors, norms (277 MB)
βββ lora_adapters/ # LoRA adapter weights for Llama 3.2 3B (1.7 GB)
β βββ adapter_config.json
β βββ adapter_model.safetensors
βββ tokenizer/ # Tokenizer with <|visual_region|> special token
βββ chat_template.jinja
βββ tokenizer.json
βββ tokenizer_config.json
Python API
import numpy as np
import torch
from model import load_model
# 1. Load the model (requires local Llama 3.2 3B)
model = load_model(
llm_path="/path/to/Llama-3.2-3B",
weights_dir="weights",
device="cuda",
)
# 2. Load pre-computed LeJEPA slice embeddings for a CT volume
# Shape: (1, num_slices, 1024) β one 1024-d embedding per CT slice
# See IBI-CAAI/Guided-Chest-CT-LeJEPA for how to extract these
embeddings = torch.from_numpy(np.load("volume_embeddings.npy")).unsqueeze(0).float()
mask = torch.ones(1, embeddings.shape[1]) # 1=valid, 0=padding
# 3. Generate a narrative report
report = model.generate(
slice_embeddings=embeddings,
mask=mask,
temperature=0.6,
max_new_tokens=384,
)
print(report)
CLI
python generate_report.py \
--llm_path /path/to/Llama-3.2-3B \
--embeddings volume_embeddings.npy \
--temperature 0.6
End-to-End: Raw CT Volume β Report
This example shows the full pipeline, from a raw NIfTI CT volume through LeJEPA embedding extraction to report generation.
import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as F
import timm
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from model import load_model
# --- Step 1: Load the LeJEPA visual backbone ---
backbone = timm.create_model(
"vit_large_patch14_dinov2", pretrained=False,
num_classes=0, in_chans=1, img_size=518, dynamic_img_size=True,
)
ckpt = load_file(hf_hub_download("IBI-CAAI/Guided-Chest-CT-LeJEPA", "model.safetensors"))
backbone.load_state_dict(ckpt, strict=False)
backbone.eval().cuda()
# --- Step 2: Preprocess and extract slice embeddings ---
# HU clipping and normalization must match LeJEPA training exactly
CLIP_MIN, CLIP_MAX = -997.0, 888.0
MEAN_HU, STD_HU = -142.39, 360.97
RANGE = CLIP_MAX - CLIP_MIN
NORM_MEAN = (MEAN_HU - CLIP_MIN) / RANGE
NORM_STD = STD_HU / RANGE
vol = nib.load("chest_ct.nii.gz").get_fdata() # (H, W, D) in Hounsfield Units
slice_embeddings = []
with torch.no_grad():
for i in range(vol.shape[2]):
s = torch.from_numpy(vol[:, :, i]).float().unsqueeze(0) # (1, H, W)
s = torch.clamp(s, CLIP_MIN, CLIP_MAX)
s = (s - CLIP_MIN) / RANGE
s = (s - NORM_MEAN) / NORM_STD
# Align to patch size 14
_, H, W = s.shape
tH, tW = (H // 14) * 14, (W // 14) * 14
if tH != H or tW != W:
s = F.interpolate(s.unsqueeze(0), size=(tH, tW), mode='nearest').squeeze(0)
emb = backbone(s.unsqueeze(0).cuda()) # (1, 1024)
slice_embeddings.append(emb.cpu())
embeddings = torch.cat(slice_embeddings, dim=0).unsqueeze(0) # (1, num_slices, 1024)
mask = torch.ones(1, embeddings.shape[1])
# --- Step 3: Generate the report ---
model = load_model("/path/to/Llama-3.2-3B", "weights", "cuda")
report = model.generate(slice_embeddings=embeddings, mask=mask)
print(report)
Prompt Format
The model uses the Llama 3.2 chat template with 32 <|visual_region|> tokens injected into the user message:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a radiology reporting assistant. Describe thoracic findings based on
the provided CT scan visual features. Report only what you observe.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Based on the visual features from this CT scan, describe the thoracic findings.
<|visual_region|> x 32<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Each <|visual_region|> placeholder is replaced at the embedding level with the corresponding visual token from the perceiver encoder. Cross-attention adapters at layers 7, 14, and 21 provide additional grounding throughout generation.
Limitations
- Domain-specific: Trained exclusively on Chest CT from CT-RATE. Performance on other body regions, modalities, or datasets is unknown.
- Not clinically validated: This is a research model. Generated reports should not be used for clinical decision-making.
- Hallucination: Like all generative models, Ker-VLJEPA-3B can produce findings not present in the scan. The 0.389 precision indicates ~61% of generated positive findings are true positives.
- Degeneration: At high token counts (>384), generation quality degrades. Reports should be truncated.
- English only: All training text is in English.
Citation
If you use this model, please cite the CT-RATE dataset:
@misc{bumgardner2026curriculumdriven3dctreport,
title={Curriculum-Driven 3D CT Report Generation via Language-Free Visual Grafting and Zone-Constrained Compression},
author={V. K. Cody Bumgardner and Mitchell A. Klusty and Mahmut S. Gokmen and Evan W. Damron},
year={2026},
eprint={2603.23308},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2603.23308},
}
Model tree for IBI-CAAI/Ker-VLJEPA-3B
Base model
meta-llama/Llama-3.2-3B