TILA / model.py
lukeingawesome's picture
Upload folder using huggingface_hub
82069f1 verified
"""
TILA (Temporal Inversion-aware Learning and Alignment) β€” Model Architecture
Paper: "Temporal Inversion for Learning Interval Change in Chest X-Rays" (CVPR 2026)
http://arxiv.org/abs/2604.04563
This module contains the full model architecture for TILA, built on top of the
BioViL-T (ResNet-50 + Vision Transformer pooler) backbone and CXR-BERT text encoder.
Dependencies:
pip install torch torchvision timm transformers safetensors
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath, Mlp, trunc_normal_
from torchvision.models.resnet import Bottleneck, conv1x1
# ──────────────────────────────────────────────────────────────────────────────
# Output types
# ──────────────────────────────────────────────────────────────────────────────
@dataclass
class ImageModelOutput:
img_embedding: torch.Tensor
patch_embeddings: torch.Tensor
projected_global_embedding: torch.Tensor
class_logits: Optional[torch.Tensor]
projected_patch_embeddings: torch.Tensor
# ──────────────────────────────────────────────────────────────────────────────
# ResNet-50 backbone
# ──────────────────────────────────────────────────────────────────────────────
class ResNet(nn.Module):
"""Standard ResNet-50 (torchvision-compatible) without the final FC layer in forward."""
def __init__(
self,
layers: Sequence[int] = (3, 4, 6, 3),
num_classes: int = 1000,
zero_init_residual: bool = False,
replace_stride_with_dilation: Optional[Sequence[bool]] = None,
):
super().__init__()
block = Bottleneck
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
# Weight init
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck) and m.bn3.weight is not None:
nn.init.constant_(m.bn3.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion),
)
layers = [block(self.inplanes, planes, stride, downsample, 1, 64, previous_dilation, nn.BatchNorm2d)]
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=self.dilation, norm_layer=nn.BatchNorm2d))
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x # patch features [B, 2048, H, W]
# ──────────────────────────────────────────────────────────────────────────────
# Vision Transformer Pooler (temporal attention)
# ──────────────────────────────────────────────────────────────────────────────
class SinePositionEmbedding:
def __init__(self, embedding_dim: int = 64, temperature: int = 10000,
normalize: bool = False, scale: Optional[float] = None):
self.embedding_dim = embedding_dim
self.temperature = temperature
self.normalize = normalize
self.scale = scale if scale is not None else 2 * math.pi
def __call__(self, mask: torch.Tensor) -> torch.Tensor:
B, H, W = mask.shape
y_embed = mask.cumsum(1, dtype=torch.float32)
x_embed = mask.cumsum(2, dtype=torch.float32)
if self.normalize:
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
return torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2)
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False,
attn_drop: float = 0.0, proj_drop: float = 0.0):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
self.proj_v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, k, q, v):
B, N, C = v.shape
h = self.num_heads
wq = self.proj_q(q).reshape(B, N, h, C // h).permute(0, 2, 1, 3)
wk = self.proj_k(k).reshape(B, N, h, C // h).permute(0, 2, 1, 3)
wv = self.proj_v(v).reshape(B, N, h, C // h).permute(0, 2, 1, 3)
attn = (wq @ wk.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
o = (attn @ wv).transpose(1, 2).reshape(B, N, C)
return self.proj_drop(self.proj(o))
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=1.0, qkv_bias=False,
drop=0.0, attn_drop=0.0, drop_path=0.0,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = MultiHeadAttentionLayer(dim, num_heads, qkv_bias, attn_drop, drop)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
def forward(self, x, pos_and_type_embed=None):
x_norm = self.norm1(x)
if pos_and_type_embed is not None:
x_norm = x_norm + pos_and_type_embed
x = x + self.drop_path(self.attn(x_norm, x_norm, x_norm))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformerPooler(nn.Module):
def __init__(self, input_dim: int, grid_shape: Tuple[int, int],
num_heads: int = 8, num_blocks: int = 3,
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
super().__init__()
block_kwargs = dict(dim=input_dim, num_heads=num_heads, mlp_ratio=1.0,
drop=0.10, attn_drop=0.10, drop_path=0.25,
act_layer=nn.GELU, norm_layer=norm_layer)
self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_blocks)])
self.norm_post = norm_layer(input_dim)
self.grid_shape = grid_shape
self.num_patches = grid_shape[0] * grid_shape[1]
self.type_embed = nn.Parameter(torch.zeros(2, 1, input_dim))
trunc_normal_(self.type_embed, std=0.02)
self.pos_drop = nn.Dropout(p=0.10)
pos_embed = SinePositionEmbedding(input_dim // 2, normalize=True)(
torch.ones([1, grid_shape[0], grid_shape[1]]))
self.register_buffer("pos_embed", pos_embed, persistent=False)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, current_image, previous_image=None):
B, C, H, W = current_image.shape
if previous_image is not None:
prev = previous_image.view(B, C, H * W).transpose(1, 2)
else:
prev = None
cur = current_image.view(B, C, H * W).transpose(1, 2)
pos = self.pos_embed.repeat(B, 1, 1)
L = cur.shape[1]
type_emb = self.type_embed[0].expand(B, L, -1)
if prev is not None:
x = torch.cat((cur, prev), dim=1)
pos = torch.cat((pos, pos), dim=1)
type_emb = torch.cat((type_emb, self.type_embed[1].expand(B, L, -1)), dim=1)
else:
x = cur
pos_type = pos + type_emb
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x, pos_type)
x = self.norm_post(x)
return x[:, :self.num_patches].transpose(1, 2).view(B, C, H, W)
# ──────────────────────────────────────────────────────────────────────────────
# Multi-image encoder (temporal)
# ──────────────────────────────────────────────────────────────────────────────
class MLP(nn.Module):
"""Projection MLP (1x1 conv based)."""
def __init__(self, input_dim, output_dim, hidden_dim=None, use_1x1_convs=False):
super().__init__()
if use_1x1_convs and hidden_dim is not None:
self.model = nn.Sequential(
nn.Conv2d(input_dim, hidden_dim, 1, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_dim, output_dim, 1, bias=True),
)
elif hidden_dim is not None:
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=False),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Linear(hidden_dim, output_dim, bias=True),
)
else:
self.model = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.model(x)
class MultiImageEncoder(nn.Module):
"""BioViL-T style multi-image encoder: ResNet-50 backbone + ViT temporal pooler."""
def __init__(self):
super().__init__()
self.encoder = ResNet()
backbone_out_dim = 2048 # ResNet-50 output channels
output_dim = 256
self.backbone_to_vit = nn.Conv2d(backbone_out_dim, output_dim, 1, bias=False)
self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=(14, 14))
self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1))
trunc_normal_(self.missing_previous_emb, std=0.02)
def forward(self, current_image, previous_image=None, return_patch_embeddings=False):
B = current_image.shape[0]
if previous_image is not None:
x = torch.cat([current_image, previous_image], dim=0)
x = self.encoder(x)
x = self.backbone_to_vit(x)
patch_x, patch_prev = x[:B], x[B:]
diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_prev)
else:
x = self.encoder(current_image)
patch_x = self.backbone_to_vit(x)
_, _, W, H = patch_x.shape
diff_x = self.missing_previous_emb.repeat(B, 1, W, H)
patch_fused = torch.cat([patch_x, diff_x], dim=1) # [B, 512, H, W]
avg_pooled = torch.flatten(F.adaptive_avg_pool2d(patch_fused, (1, 1)), 1)
if return_patch_embeddings:
return patch_fused, avg_pooled
return avg_pooled
class TILAImageEncoder(nn.Module):
"""Full TILA image encoder: MultiImageEncoder + projection head.
Outputs 128-dim normalized embeddings suitable for CLIP-style retrieval.
"""
JOINT_FEATURE_SIZE = 128
def __init__(self):
super().__init__()
self.encoder = MultiImageEncoder()
self.projector = MLP(
input_dim=512, # patch_x (256) + diff_x (256)
output_dim=self.JOINT_FEATURE_SIZE,
hidden_dim=self.JOINT_FEATURE_SIZE,
use_1x1_convs=True,
)
def forward(self, current_image, previous_image=None):
patch_fused, pooled = self.encoder(current_image, previous_image, return_patch_embeddings=True)
projected_patch = self.projector(patch_fused)
projected_global = torch.mean(projected_patch, dim=(2, 3))
return ImageModelOutput(
img_embedding=pooled,
patch_embeddings=patch_fused,
class_logits=None,
projected_patch_embeddings=projected_patch,
projected_global_embedding=projected_global,
)
# ──────────────────────────────────────────────────────────────────────────────
# Text encoder (BioViL-T CXR-BERT + projection)
# ──────────────────────────────────────────────────────────────────────────────
TEXT_MODEL_NAME = "microsoft/BiomedVLP-BioViL-T"
class TextEncoder(nn.Module):
"""CXR-BERT text encoder with a projection head to 128-dim.
Loads the pretrained BioViL-T text model and adds a LayerNorm + Linear
projection from 768-dim CLS embeddings to 128-dim joint space.
"""
def __init__(self):
super().__init__()
from transformers import AutoConfig, AutoModel
config = AutoConfig.from_pretrained(TEXT_MODEL_NAME, trust_remote_code=True)
self.model = AutoModel.from_pretrained(
TEXT_MODEL_NAME, config=config, trust_remote_code=True,
)
self.projection = nn.Sequential(
nn.LayerNorm(config.hidden_size),
nn.Linear(config.hidden_size, 128),
)
def forward(self, text_inputs: dict) -> torch.Tensor:
"""Encode tokenized text to 128-dim embeddings.
Args:
text_inputs: Dict from tokenizer (input_ids, attention_mask, etc.)
Returns:
Projected CLS embeddings [B, 128]
"""
outputs = self.model(**text_inputs)
cls_emb = outputs.last_hidden_state[:, 0, :]
if cls_emb.dtype != next(self.projection.parameters()).dtype:
cls_emb = cls_emb.to(next(self.projection.parameters()).dtype)
return self.projection(cls_emb)
# ──────────────────────────────────────────────────────────────────────────────
# Interval change classifier head
# ──────────────────────────────────────────────────────────────────────────────
class IntervalChangeClassifier(nn.Module):
"""Binary classifier head for interval change detection.
Takes 128-dim projected embeddings and outputs a change probability.
"""
def __init__(self):
super().__init__()
self.head = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1),
)
def forward(self, embedding: torch.Tensor) -> torch.Tensor:
"""Returns logit (pre-sigmoid). Apply torch.sigmoid() to get probability."""
return self.head(embedding).squeeze(-1)
# ──────────────────────────────────────────────────────────────────────────────
# Full model wrapper
# ──────────────────────────────────────────────────────────────────────────────
try:
from transformers import PreTrainedModel
from configuration_tila import TILAConfig
_BASE_CLASS = PreTrainedModel
_HAS_TRANSFORMERS = True
except ImportError:
_BASE_CLASS = nn.Module
_HAS_TRANSFORMERS = False
class TILAModel(_BASE_CLASS):
"""TILA model with image encoder, text encoder, and interval change classifier.
Usage:
# Load from local safetensors
model = TILAModel.from_pretrained("model.safetensors")
# Load via AutoModel (requires config.json + trust_remote_code)
from transformers import AutoModel
model = AutoModel.from_pretrained("lukeingawesome/TILA", trust_remote_code=True)
# Get 128-dim image embeddings
emb = model.get_embeddings(current_img, previous_img)
# Get 128-dim text embeddings
text_emb = model.encode_text(["Improved pulmonary edema."])
# Predict interval change
result = model.get_interval_change_prediction(current_img, previous_img)
"""
if _HAS_TRANSFORMERS:
config_class = TILAConfig
def __init__(self, config=None):
if _HAS_TRANSFORMERS and config is None:
config = TILAConfig()
if _HAS_TRANSFORMERS:
super().__init__(config)
else:
super().__init__()
self.image_encoder = TILAImageEncoder()
self.text_encoder = TextEncoder()
self.change_classifier = IntervalChangeClassifier()
@torch.no_grad()
def encode_text(self, texts: list) -> torch.Tensor:
"""Encode text prompts to 128-dim normalized embeddings.
Args:
texts: List of text strings
Returns:
Normalized text embeddings [N, 128]
"""
from transformers import AutoTokenizer
if not hasattr(self, '_tokenizer'):
self._tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME, padding_side="right")
device = next(self.parameters()).device
tokens = self._tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
tokens = {k: v.to(device) for k, v in tokens.items()}
self.eval()
emb = self.text_encoder(tokens)
return F.normalize(emb.float(), p=2, dim=1)
@torch.no_grad()
def get_embeddings(
self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Extract 128-dim projected global embeddings from a pair of chest X-rays.
Args:
current_image: Current CXR tensor [B, 3, 448, 448]
previous_image: Previous CXR tensor [B, 3, 448, 448] (optional)
Returns:
Normalized 128-dim embeddings [B, 128]
"""
self.eval()
out = self.image_encoder(current_image, previous_image)
return F.normalize(out.projected_global_embedding.float(), p=2, dim=1)
# Thresholds calibrated on validation set (AUC=0.7558)
THRESHOLDS = {
"default": 0.5000, # Standard sigmoid midpoint
"bestf1": 0.2886, # Youden's J β€” best F1=0.7210, sens=0.7798, spec=0.6166
"spec95": 0.6370, # Specificity ~0.95 β€” sens=0.1752, spec=0.9502
}
@torch.no_grad()
def get_interval_change_prediction(
self,
current_image: torch.Tensor,
previous_image: torch.Tensor,
mode: str = "bestf1",
) -> torch.Tensor:
"""Predict interval change between two chest X-rays.
Args:
current_image: Current CXR tensor [B, 3, 448, 448]
previous_image: Previous CXR tensor [B, 3, 448, 448]
mode: Threshold mode β€” one of:
"default" : threshold=0.50 (standard sigmoid cutoff)
"bestf1" : threshold=0.29 (maximizes F1, balanced sens/spec)
"spec95" : threshold=0.64 (targets 95% specificity, conservative)
Returns:
Dict with keys:
"probabilities": raw change probabilities [B]
"predictions": binary predictions [B] (0=no change, 1=change)
"threshold": threshold used (float)
"""
if mode not in self.THRESHOLDS:
raise ValueError(f"mode must be one of {list(self.THRESHOLDS.keys())}, got '{mode}'")
self.eval()
out = self.image_encoder(current_image, previous_image)
logits = self.change_classifier(out.projected_global_embedding)
probs = torch.sigmoid(logits.float())
threshold = self.THRESHOLDS[mode]
preds = (probs >= threshold).long()
return {"probabilities": probs, "predictions": preds, "threshold": threshold}
@classmethod
def from_pretrained(cls, path_or_repo: str, device: str = "cpu", **kwargs) -> "TILAModel":
"""Load model from a local file or HuggingFace Hub.
Args:
path_or_repo: Local path to model.safetensors, or HF repo ID (e.g. "lukeingawesome/TILA")
device: Device to load onto
Examples:
model = TILAModel.from_pretrained("model.safetensors")
model = TILAModel.from_pretrained("lukeingawesome/TILA")
"""
import os
# If called by HF's AutoModel, it passes config as first positional arg
config = kwargs.pop("config", None)
# Determine if this is a local file or a HF repo
if os.path.isfile(path_or_repo):
safetensors_path = path_or_repo
elif os.path.isdir(path_or_repo):
safetensors_path = os.path.join(path_or_repo, "model.safetensors")
else:
from huggingface_hub import hf_hub_download
safetensors_path = hf_hub_download(
repo_id=path_or_repo,
filename="model.safetensors",
)
model = cls(config=config)
if safetensors_path.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict = load_file(safetensors_path, device=device)
for k, v in state_dict.items():
if v.dim() == 1 and v.shape[0] == 1 and "num_batches_tracked" in k:
state_dict[k] = v.squeeze(0)
else:
state_dict = torch.load(safetensors_path, map_location=device, weights_only=True)
model.load_state_dict(state_dict, strict=False)
model.eval()
return model