TILA β Temporal Inversion-aware Learning and Alignment
Temporal Inversion for Learning Interval Change in Chest X-Rays Accepted at CVPR 2026
TILA is a vision-language framework that uses temporal inversion β reversing image pairs β as a supervisory signal to enhance the sensitivity of temporal vision-language models to directional change in chest X-rays. Given a current and a prior radiograph, TILA can:
- Extract temporal-aware image embeddings (128-dim) that capture both the static anatomy and the interval change between the two images.
- Encode radiology text into the same 128-dim space for zero-shot classification via image-text similarity.
- Predict interval change (binary: change vs. no change) using a lightweight classification head.
The image encoder is based on the BioViL-T architecture (ResNet-50 + Vision Transformer temporal pooler), and the text encoder is CXR-BERT, both fine-tuned with temporal inversion-aware alignment.
Quick Start
Installation
pip install torch>=2.0 torchvision>=0.15 timm>=0.9 transformers>=4.30 safetensors>=0.4 pillow opencv-python numpy
Load Model and Processor
import torch
from transformers import AutoModel
# Load from HuggingFace Hub
model = AutoModel.from_pretrained("lukeingawesome/TILA", trust_remote_code=True)
model = model.to("cuda", dtype=torch.bfloat16)
Or load locally:
from model import TILAModel
model = TILAModel.from_pretrained("model.safetensors")
model = model.to("cuda", dtype=torch.bfloat16)
from processor import TILAProcessor
# Processor handles everything: raw image β model-ready tensor
processor = TILAProcessor(dtype=torch.bfloat16, device="cuda")
Extract Embeddings
current = processor("current_cxr.png") # accepts file paths, numpy arrays, or PIL images
previous = processor("previous_cxr.png")
# 128-dim L2-normalized embeddings
embeddings = model.get_embeddings(current, previous)
The processor automatically applies medical image preprocessing (windowing, black padding removal, resize) followed by model transforms (center crop to 448x448, expand to 3 channels). If your images are already preprocessed, skip the medical preprocessing:
processor = TILAProcessor(raw_preprocess=False, dtype=torch.bfloat16, device="cuda")
The embeddings encode both the current image state and the temporal difference from the prior. They can be used for retrieval, similarity search, or as features for downstream tasks.
Encode Text
text_emb = model.encode_text([
"Improved pulmonary edema.",
"Stable pulmonary edema.",
"Worsening pulmonary edema.",
])
# Zero-shot classification via image-text similarity
similarities = embeddings @ text_emb.T # [1, 3]
prediction = similarities.argmax(dim=1) # 0=improving, 1=stable, 2=worsening
Predict Interval Change
result = model.get_interval_change_prediction(current, previous, mode="bestf1")
print(result["probabilities"]) # Raw change probability
print(result["predictions"]) # Binary: 0 = no change, 1 = change
print(result["threshold"]) # Threshold used
Three threshold modes are available:
| Mode | Threshold | Description |
|---|---|---|
"bestf1" |
0.29 | Maximizes F1 score (balanced sensitivity/specificity) |
"default" |
0.50 | Standard sigmoid cutoff |
"spec95" |
0.64 | Targets 95% specificity (conservative, fewer false positives) |
CLI Example
python inference.py \
--checkpoint model.safetensors \
--current_image /path/to/current.png \
--previous_image /path/to/previous.png
Preprocessing Raw Images
Note: This preprocessing is not applied automatically. Run it as a separate step before model inference.
If your chest X-rays are raw (e.g., DICOM-derived PNGs with varying bit depths, black borders, or 16-bit depth), preprocess them first:
import cv2
from preprocess import preprocess_image
img = preprocess_image("raw_cxr.png")
cv2.imwrite("preprocessed.png", img)
The pipeline applies:
- Read as-is β preserves original bit depth (supports 8-bit and 16-bit PNGs)
- Windowing β clips to
mean +/- 2*std, normalizes to [0, 1] - Black padding removal β contour-based crop
- Aspect-ratio-preserving resize β longest side to 512px (configurable)
# CLI usage
python preprocess.py --input raw.png --output preprocessed.png
If your images are already preprocessed (contrast-normalized, cropped, resized grayscale PNGs), you can skip this step and feed them directly to the model.
Input Format
- Image format: Grayscale chest X-ray (PNG, JPEG)
- Model input: Resize to 512px (shorter side), center crop to 448x448, repeat to 3 channels (handled by the transform in
inference.py) - Pair: Current (follow-up) image + Previous (baseline) image of the same patient
- Dtype:
torch.bfloat16recommended on GPU,torch.float32on CPU
Files
| File | Description |
|---|---|
model.safetensors |
Model weights (613 MB, image + text + classifier) |
config.json |
Model configuration (for AutoModel support) |
configuration_tila.py |
TILAConfig class |
model.py |
Self-contained model architecture |
processor.py |
Image processor (raw image β model-ready tensor) |
preprocess.py |
Medical image preprocessing utilities |
inference.py |
Example inference script |
Citation
If you use this model, please cite:
@inproceedings{ko2026tila,
title={Temporal Inversion for Learning Interval Change in Chest X-Rays},
author={Ko, Hanbin and Jeon, Kyeongmin and Choi, Doowoong and Park, Chang Min},
booktitle={CVPR},
year={2026},
url={http://arxiv.org/abs/2604.04563}
}
Acknowledgements
This model builds upon BioViL-T by Microsoft Research:
@inproceedings{bannur2023biovilt,
title={Learning to Exploit Temporal Structure for Biomedical Vision-Language Processing},
author={Bannur, Shruthi and Hyland, Stephanie and Liu, Qianchu and Perez-Garcia, Fernando and Oktay, Ozan and Naumann, Tristan and Nori, Aditya and Alvarez-Valle, Javier},
booktitle={CVPR},
year={2023}
}
License
This model is released under the MIT License.
- Downloads last month
- -