| --- |
| license: mit |
| tags: |
| - medical-imaging |
| - chest-x-ray |
| - temporal-analysis |
| - interval-change |
| - radiology |
| - vision-language |
| language: |
| - en |
| library_name: pytorch |
| pipeline_tag: image-feature-extraction |
| --- |
| |
| # TILA β Temporal Inversion-aware Learning and Alignment |
|
|
| **[Temporal Inversion for Learning Interval Change in Chest X-Rays](http://arxiv.org/abs/2604.04563)** |
| *Accepted at CVPR 2026* |
|
|
| TILA is a vision-language framework that uses *temporal inversion* β reversing image pairs β as a supervisory signal to enhance the sensitivity of temporal vision-language models to directional change in chest X-rays. Given a current and a prior radiograph, TILA can: |
|
|
| 1. **Extract temporal-aware image embeddings** (128-dim) that capture both the static anatomy and the interval change between the two images. |
| 2. **Encode radiology text** into the same 128-dim space for zero-shot classification via image-text similarity. |
| 3. **Predict interval change** (binary: change vs. no change) using a lightweight classification head. |
|
|
| The image encoder is based on the [BioViL-T](https://huggingface.co/microsoft/BiomedVLP-BioViL-T) architecture (ResNet-50 + Vision Transformer temporal pooler), and the text encoder is CXR-BERT, both fine-tuned with temporal inversion-aware alignment. |
|
|
| ## Quick Start |
|
|
| ### Installation |
|
|
| ```bash |
| pip install torch>=2.0 torchvision>=0.15 timm>=0.9 transformers>=4.30 safetensors>=0.4 pillow opencv-python numpy |
| ``` |
|
|
| ### Load Model and Processor |
|
|
| ```python |
| import torch |
| from transformers import AutoModel |
| |
| # Load from HuggingFace Hub |
| model = AutoModel.from_pretrained("lukeingawesome/TILA", trust_remote_code=True) |
| model = model.to("cuda", dtype=torch.bfloat16) |
| ``` |
|
|
| Or load locally: |
|
|
| ```python |
| from model import TILAModel |
| model = TILAModel.from_pretrained("model.safetensors") |
| model = model.to("cuda", dtype=torch.bfloat16) |
| ``` |
|
|
| ```python |
| from processor import TILAProcessor |
| |
| # Processor handles everything: raw image β model-ready tensor |
| processor = TILAProcessor(dtype=torch.bfloat16, device="cuda") |
| ``` |
|
|
| ### Extract Embeddings |
|
|
| ```python |
| current = processor("current_cxr.png") # accepts file paths, numpy arrays, or PIL images |
| previous = processor("previous_cxr.png") |
| |
| # 128-dim L2-normalized embeddings |
| embeddings = model.get_embeddings(current, previous) |
| ``` |
|
|
| The processor automatically applies medical image preprocessing (windowing, black padding removal, resize) followed by model transforms (center crop to 448x448, expand to 3 channels). If your images are already preprocessed, skip the medical preprocessing: |
|
|
| ```python |
| processor = TILAProcessor(raw_preprocess=False, dtype=torch.bfloat16, device="cuda") |
| ``` |
|
|
| The embeddings encode both the current image state and the temporal difference from the prior. |
| They can be used for retrieval, similarity search, or as features for downstream tasks. |
|
|
| ### Encode Text |
|
|
| ```python |
| text_emb = model.encode_text([ |
| "Improved pulmonary edema.", |
| "Stable pulmonary edema.", |
| "Worsening pulmonary edema.", |
| ]) |
| |
| # Zero-shot classification via image-text similarity |
| similarities = embeddings @ text_emb.T # [1, 3] |
| prediction = similarities.argmax(dim=1) # 0=improving, 1=stable, 2=worsening |
| ``` |
|
|
| ### Predict Interval Change |
|
|
| ```python |
| result = model.get_interval_change_prediction(current, previous, mode="bestf1") |
| |
| print(result["probabilities"]) # Raw change probability |
| print(result["predictions"]) # Binary: 0 = no change, 1 = change |
| print(result["threshold"]) # Threshold used |
| ``` |
|
|
| Three threshold modes are available: |
|
|
| | Mode | Threshold | Description | |
| |------|-----------|-------------| |
| | `"bestf1"` | 0.29 | Maximizes F1 score (balanced sensitivity/specificity) | |
| | `"default"` | 0.50 | Standard sigmoid cutoff | |
| | `"spec95"` | 0.64 | Targets 95% specificity (conservative, fewer false positives) | |
|
|
| ### CLI Example |
|
|
| ```bash |
| python inference.py \ |
| --checkpoint model.safetensors \ |
| --current_image /path/to/current.png \ |
| --previous_image /path/to/previous.png |
| ``` |
|
|
|
|
| ## Preprocessing Raw Images |
|
|
| > **Note:** This preprocessing is **not** applied automatically. Run it as a separate step before model inference. |
|
|
| If your chest X-rays are raw (e.g., DICOM-derived PNGs with varying bit depths, black borders, or 16-bit depth), preprocess them first: |
|
|
| ```python |
| import cv2 |
| from preprocess import preprocess_image |
| |
| img = preprocess_image("raw_cxr.png") |
| cv2.imwrite("preprocessed.png", img) |
| ``` |
|
|
| The pipeline applies: |
| 1. **Read as-is** β preserves original bit depth (supports 8-bit and 16-bit PNGs) |
| 2. **Windowing** β clips to `mean +/- 2*std`, normalizes to [0, 1] |
| 3. **Black padding removal** β contour-based crop |
| 4. **Aspect-ratio-preserving resize** β longest side to 512px (configurable) |
|
|
| ```bash |
| # CLI usage |
| python preprocess.py --input raw.png --output preprocessed.png |
| ``` |
|
|
| If your images are already preprocessed (contrast-normalized, cropped, resized grayscale PNGs), you can skip this step and feed them directly to the model. |
|
|
| ## Input Format |
|
|
| - **Image format**: Grayscale chest X-ray (PNG, JPEG) |
| - **Model input**: Resize to 512px (shorter side), center crop to 448x448, repeat to 3 channels (handled by the transform in `inference.py`) |
| - **Pair**: Current (follow-up) image + Previous (baseline) image of the same patient |
| - **Dtype**: `torch.bfloat16` recommended on GPU, `torch.float32` on CPU |
|
|
| ## Files |
|
|
| | File | Description | |
| |------|-------------| |
| | `model.safetensors` | Model weights (613 MB, image + text + classifier) | |
| | `config.json` | Model configuration (for AutoModel support) | |
| | `configuration_tila.py` | TILAConfig class | |
| | `model.py` | Self-contained model architecture | |
| | `processor.py` | Image processor (raw image β model-ready tensor) | |
| | `preprocess.py` | Medical image preprocessing utilities | |
| | `inference.py` | Example inference script | |
|
|
| ## Citation |
|
|
| If you use this model, please cite: |
|
|
| ```bibtex |
| @article{ko2026temporal, |
| title={Temporal Inversion for Learning Interval Change in Chest X-Rays}, |
| author={Ko, Hanbin and Jeon, Kyeongmin and Choi, Doowoong and Park, Chang Min}, |
| journal={arXiv preprint arXiv:2604.04563}, |
| year={2026} |
| } |
| ``` |
|
|
| ## Acknowledgements |
|
|
| This model builds upon [BioViL-T](https://huggingface.co/microsoft/BiomedVLP-BioViL-T) by Microsoft Research: |
|
|
| ```bibtex |
| @inproceedings{bannur2023biovilt, |
| title={Learning to Exploit Temporal Structure for Biomedical Vision-Language Processing}, |
| author={Bannur, Shruthi and Hyland, Stephanie and Liu, Qianchu and Perez-Garcia, Fernando and Oktay, Ozan and Naumann, Tristan and Nori, Aditya and Alvarez-Valle, Javier}, |
| booktitle={CVPR}, |
| year={2023} |
| } |
| ``` |
|
|
| ## License |
|
|
| This model is released under the [MIT License](LICENSE) following BioViL-T. |
|
|