TILA / README.md
lukeingawesome's picture
Update model files
a9c6da4 verified
---
license: mit
tags:
- medical-imaging
- chest-x-ray
- temporal-analysis
- interval-change
- radiology
- vision-language
language:
- en
library_name: pytorch
pipeline_tag: image-feature-extraction
---
# TILA β€” Temporal Inversion-aware Learning and Alignment
**[Temporal Inversion for Learning Interval Change in Chest X-Rays](http://arxiv.org/abs/2604.04563)**
*Accepted at CVPR 2026*
TILA is a vision-language framework that uses *temporal inversion* β€” reversing image pairs β€” as a supervisory signal to enhance the sensitivity of temporal vision-language models to directional change in chest X-rays. Given a current and a prior radiograph, TILA can:
1. **Extract temporal-aware image embeddings** (128-dim) that capture both the static anatomy and the interval change between the two images.
2. **Encode radiology text** into the same 128-dim space for zero-shot classification via image-text similarity.
3. **Predict interval change** (binary: change vs. no change) using a lightweight classification head.
The image encoder is based on the [BioViL-T](https://huggingface.co/microsoft/BiomedVLP-BioViL-T) architecture (ResNet-50 + Vision Transformer temporal pooler), and the text encoder is CXR-BERT, both fine-tuned with temporal inversion-aware alignment.
## Quick Start
### Installation
```bash
pip install torch>=2.0 torchvision>=0.15 timm>=0.9 transformers>=4.30 safetensors>=0.4 pillow opencv-python numpy
```
### Load Model and Processor
```python
import torch
from transformers import AutoModel
# Load from HuggingFace Hub
model = AutoModel.from_pretrained("lukeingawesome/TILA", trust_remote_code=True)
model = model.to("cuda", dtype=torch.bfloat16)
```
Or load locally:
```python
from model import TILAModel
model = TILAModel.from_pretrained("model.safetensors")
model = model.to("cuda", dtype=torch.bfloat16)
```
```python
from processor import TILAProcessor
# Processor handles everything: raw image β†’ model-ready tensor
processor = TILAProcessor(dtype=torch.bfloat16, device="cuda")
```
### Extract Embeddings
```python
current = processor("current_cxr.png") # accepts file paths, numpy arrays, or PIL images
previous = processor("previous_cxr.png")
# 128-dim L2-normalized embeddings
embeddings = model.get_embeddings(current, previous)
```
The processor automatically applies medical image preprocessing (windowing, black padding removal, resize) followed by model transforms (center crop to 448x448, expand to 3 channels). If your images are already preprocessed, skip the medical preprocessing:
```python
processor = TILAProcessor(raw_preprocess=False, dtype=torch.bfloat16, device="cuda")
```
The embeddings encode both the current image state and the temporal difference from the prior.
They can be used for retrieval, similarity search, or as features for downstream tasks.
### Encode Text
```python
text_emb = model.encode_text([
"Improved pulmonary edema.",
"Stable pulmonary edema.",
"Worsening pulmonary edema.",
])
# Zero-shot classification via image-text similarity
similarities = embeddings @ text_emb.T # [1, 3]
prediction = similarities.argmax(dim=1) # 0=improving, 1=stable, 2=worsening
```
### Predict Interval Change
```python
result = model.get_interval_change_prediction(current, previous, mode="bestf1")
print(result["probabilities"]) # Raw change probability
print(result["predictions"]) # Binary: 0 = no change, 1 = change
print(result["threshold"]) # Threshold used
```
Three threshold modes are available:
| Mode | Threshold | Description |
|------|-----------|-------------|
| `"bestf1"` | 0.29 | Maximizes F1 score (balanced sensitivity/specificity) |
| `"default"` | 0.50 | Standard sigmoid cutoff |
| `"spec95"` | 0.64 | Targets 95% specificity (conservative, fewer false positives) |
### CLI Example
```bash
python inference.py \
--checkpoint model.safetensors \
--current_image /path/to/current.png \
--previous_image /path/to/previous.png
```
## Preprocessing Raw Images
> **Note:** This preprocessing is **not** applied automatically. Run it as a separate step before model inference.
If your chest X-rays are raw (e.g., DICOM-derived PNGs with varying bit depths, black borders, or 16-bit depth), preprocess them first:
```python
import cv2
from preprocess import preprocess_image
img = preprocess_image("raw_cxr.png")
cv2.imwrite("preprocessed.png", img)
```
The pipeline applies:
1. **Read as-is** β€” preserves original bit depth (supports 8-bit and 16-bit PNGs)
2. **Windowing** β€” clips to `mean +/- 2*std`, normalizes to [0, 1]
3. **Black padding removal** β€” contour-based crop
4. **Aspect-ratio-preserving resize** β€” longest side to 512px (configurable)
```bash
# CLI usage
python preprocess.py --input raw.png --output preprocessed.png
```
If your images are already preprocessed (contrast-normalized, cropped, resized grayscale PNGs), you can skip this step and feed them directly to the model.
## Input Format
- **Image format**: Grayscale chest X-ray (PNG, JPEG)
- **Model input**: Resize to 512px (shorter side), center crop to 448x448, repeat to 3 channels (handled by the transform in `inference.py`)
- **Pair**: Current (follow-up) image + Previous (baseline) image of the same patient
- **Dtype**: `torch.bfloat16` recommended on GPU, `torch.float32` on CPU
## Files
| File | Description |
|------|-------------|
| `model.safetensors` | Model weights (613 MB, image + text + classifier) |
| `config.json` | Model configuration (for AutoModel support) |
| `configuration_tila.py` | TILAConfig class |
| `model.py` | Self-contained model architecture |
| `processor.py` | Image processor (raw image β†’ model-ready tensor) |
| `preprocess.py` | Medical image preprocessing utilities |
| `inference.py` | Example inference script |
## Citation
If you use this model, please cite:
```bibtex
@article{ko2026temporal,
title={Temporal Inversion for Learning Interval Change in Chest X-Rays},
author={Ko, Hanbin and Jeon, Kyeongmin and Choi, Doowoong and Park, Chang Min},
journal={arXiv preprint arXiv:2604.04563},
year={2026}
}
```
## Acknowledgements
This model builds upon [BioViL-T](https://huggingface.co/microsoft/BiomedVLP-BioViL-T) by Microsoft Research:
```bibtex
@inproceedings{bannur2023biovilt,
title={Learning to Exploit Temporal Structure for Biomedical Vision-Language Processing},
author={Bannur, Shruthi and Hyland, Stephanie and Liu, Qianchu and Perez-Garcia, Fernando and Oktay, Ozan and Naumann, Tristan and Nori, Aditya and Alvarez-Valle, Javier},
booktitle={CVPR},
year={2023}
}
```
## License
This model is released under the [MIT License](LICENSE) following BioViL-T.