TILA β€” Temporal Inversion-aware Learning and Alignment

Temporal Inversion for Learning Interval Change in Chest X-Rays Accepted at CVPR 2026

TILA is a vision-language framework that uses temporal inversion β€” reversing image pairs β€” as a supervisory signal to enhance the sensitivity of temporal vision-language models to directional change in chest X-rays. Given a current and a prior radiograph, TILA can:

  1. Extract temporal-aware image embeddings (128-dim) that capture both the static anatomy and the interval change between the two images.
  2. Encode radiology text into the same 128-dim space for zero-shot classification via image-text similarity.
  3. Predict interval change (binary: change vs. no change) using a lightweight classification head.

The image encoder is based on the BioViL-T architecture (ResNet-50 + Vision Transformer temporal pooler), and the text encoder is CXR-BERT, both fine-tuned with temporal inversion-aware alignment.

Quick Start

Installation

pip install torch>=2.0 torchvision>=0.15 timm>=0.9 transformers>=4.30 safetensors>=0.4 pillow opencv-python numpy

Load Model and Processor

import torch
from transformers import AutoModel

# Load from HuggingFace Hub
model = AutoModel.from_pretrained("lukeingawesome/TILA", trust_remote_code=True)
model = model.to("cuda", dtype=torch.bfloat16)

Or load locally:

from model import TILAModel
model = TILAModel.from_pretrained("model.safetensors")
model = model.to("cuda", dtype=torch.bfloat16)
from processor import TILAProcessor

# Processor handles everything: raw image β†’ model-ready tensor
processor = TILAProcessor(dtype=torch.bfloat16, device="cuda")

Extract Embeddings

current = processor("current_cxr.png")    # accepts file paths, numpy arrays, or PIL images
previous = processor("previous_cxr.png")

# 128-dim L2-normalized embeddings
embeddings = model.get_embeddings(current, previous)

The processor automatically applies medical image preprocessing (windowing, black padding removal, resize) followed by model transforms (center crop to 448x448, expand to 3 channels). If your images are already preprocessed, skip the medical preprocessing:

processor = TILAProcessor(raw_preprocess=False, dtype=torch.bfloat16, device="cuda")

The embeddings encode both the current image state and the temporal difference from the prior. They can be used for retrieval, similarity search, or as features for downstream tasks.

Encode Text

text_emb = model.encode_text([
    "Improved pulmonary edema.",
    "Stable pulmonary edema.",
    "Worsening pulmonary edema.",
])

# Zero-shot classification via image-text similarity
similarities = embeddings @ text_emb.T  # [1, 3]
prediction = similarities.argmax(dim=1)  # 0=improving, 1=stable, 2=worsening

Predict Interval Change

result = model.get_interval_change_prediction(current, previous, mode="bestf1")

print(result["probabilities"])  # Raw change probability
print(result["predictions"])    # Binary: 0 = no change, 1 = change
print(result["threshold"])      # Threshold used

Three threshold modes are available:

Mode Threshold Description
"bestf1" 0.29 Maximizes F1 score (balanced sensitivity/specificity)
"default" 0.50 Standard sigmoid cutoff
"spec95" 0.64 Targets 95% specificity (conservative, fewer false positives)

CLI Example

python inference.py \
    --checkpoint model.safetensors \
    --current_image /path/to/current.png \
    --previous_image /path/to/previous.png

Preprocessing Raw Images

Note: This preprocessing is not applied automatically. Run it as a separate step before model inference.

If your chest X-rays are raw (e.g., DICOM-derived PNGs with varying bit depths, black borders, or 16-bit depth), preprocess them first:

import cv2
from preprocess import preprocess_image

img = preprocess_image("raw_cxr.png")
cv2.imwrite("preprocessed.png", img)

The pipeline applies:

  1. Read as-is β€” preserves original bit depth (supports 8-bit and 16-bit PNGs)
  2. Windowing β€” clips to mean +/- 2*std, normalizes to [0, 1]
  3. Black padding removal β€” contour-based crop
  4. Aspect-ratio-preserving resize β€” longest side to 512px (configurable)
# CLI usage
python preprocess.py --input raw.png --output preprocessed.png

If your images are already preprocessed (contrast-normalized, cropped, resized grayscale PNGs), you can skip this step and feed them directly to the model.

Input Format

  • Image format: Grayscale chest X-ray (PNG, JPEG)
  • Model input: Resize to 512px (shorter side), center crop to 448x448, repeat to 3 channels (handled by the transform in inference.py)
  • Pair: Current (follow-up) image + Previous (baseline) image of the same patient
  • Dtype: torch.bfloat16 recommended on GPU, torch.float32 on CPU

Files

File Description
model.safetensors Model weights (613 MB, image + text + classifier)
config.json Model configuration (for AutoModel support)
configuration_tila.py TILAConfig class
model.py Self-contained model architecture
processor.py Image processor (raw image β†’ model-ready tensor)
preprocess.py Medical image preprocessing utilities
inference.py Example inference script

Citation

If you use this model, please cite:

@inproceedings{ko2026tila,
  title={Temporal Inversion for Learning Interval Change in Chest X-Rays},
  author={Ko, Hanbin and Jeon, Kyeongmin and Choi, Doowoong and Park, Chang Min},
  booktitle={CVPR},
  year={2026},
  url={http://arxiv.org/abs/2604.04563}
}

Acknowledgements

This model builds upon BioViL-T by Microsoft Research:

@inproceedings{bannur2023biovilt,
  title={Learning to Exploit Temporal Structure for Biomedical Vision-Language Processing},
  author={Bannur, Shruthi and Hyland, Stephanie and Liu, Qianchu and Perez-Garcia, Fernando and Oktay, Ozan and Naumann, Tristan and Nori, Aditya and Alvarez-Valle, Javier},
  booktitle={CVPR},
  year={2023}
}

License

This model is released under the MIT License.

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Paper for lukeingawesome/TILA