""" TILA (Temporal Inversion-aware Learning and Alignment) — Model Architecture Paper: "Temporal Inversion for Learning Interval Change in Chest X-Rays" (CVPR 2026) http://arxiv.org/abs/2604.04563 This module contains the full model architecture for TILA, built on top of the BioViL-T (ResNet-50 + Vision Transformer pooler) backbone and CXR-BERT text encoder. Dependencies: pip install torch torchvision timm transformers safetensors """ from __future__ import annotations import math from dataclasses import dataclass from functools import partial from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from timm.layers import DropPath, Mlp, trunc_normal_ from torchvision.models.resnet import Bottleneck, conv1x1 # ────────────────────────────────────────────────────────────────────────────── # Output types # ────────────────────────────────────────────────────────────────────────────── @dataclass class ImageModelOutput: img_embedding: torch.Tensor patch_embeddings: torch.Tensor projected_global_embedding: torch.Tensor class_logits: Optional[torch.Tensor] projected_patch_embeddings: torch.Tensor # ────────────────────────────────────────────────────────────────────────────── # ResNet-50 backbone # ────────────────────────────────────────────────────────────────────────────── class ResNet(nn.Module): """Standard ResNet-50 (torchvision-compatible) without the final FC layer in forward.""" def __init__( self, layers: Sequence[int] = (3, 4, 6, 3), num_classes: int = 1000, zero_init_residual: bool = False, replace_stride_with_dilation: Optional[Sequence[bool]] = None, ): super().__init__() block = Bottleneck self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: replace_stride_with_dilation = [False, False, False] self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes) # Weight init for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck) and m.bn3.weight is not None: nn.init.constant_(m.bn3.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), nn.BatchNorm2d(planes * block.expansion), ) layers = [block(self.inplanes, planes, stride, downsample, 1, 64, previous_dilation, nn.BatchNorm2d)] self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, dilation=self.dilation, norm_layer=nn.BatchNorm2d)) return nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x # patch features [B, 2048, H, W] # ────────────────────────────────────────────────────────────────────────────── # Vision Transformer Pooler (temporal attention) # ────────────────────────────────────────────────────────────────────────────── class SinePositionEmbedding: def __init__(self, embedding_dim: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None): self.embedding_dim = embedding_dim self.temperature = temperature self.normalize = normalize self.scale = scale if scale is not None else 2 * math.pi def __call__(self, mask: torch.Tensor) -> torch.Tensor: B, H, W = mask.shape y_embed = mask.cumsum(1, dtype=torch.float32) x_embed = mask.cumsum(2, dtype=torch.float32) if self.normalize: y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale dim_t = torch.arange(self.embedding_dim, dtype=torch.float32) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) return torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2) class MultiHeadAttentionLayer(nn.Module): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0): super().__init__() self.num_heads = num_heads self.scale = (dim // num_heads) ** -0.5 self.proj_q = nn.Linear(dim, dim, bias=qkv_bias) self.proj_k = nn.Linear(dim, dim, bias=qkv_bias) self.proj_v = nn.Linear(dim, dim, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, k, q, v): B, N, C = v.shape h = self.num_heads wq = self.proj_q(q).reshape(B, N, h, C // h).permute(0, 2, 1, 3) wk = self.proj_k(k).reshape(B, N, h, C // h).permute(0, 2, 1, 3) wv = self.proj_v(v).reshape(B, N, h, C // h).permute(0, 2, 1, 3) attn = (wq @ wk.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) o = (attn @ wv).transpose(1, 2).reshape(B, N, C) return self.proj_drop(self.proj(o)) class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=1.0, qkv_bias=False, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = MultiHeadAttentionLayer(dim, num_heads, qkv_bias, attn_drop, drop) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) def forward(self, x, pos_and_type_embed=None): x_norm = self.norm1(x) if pos_and_type_embed is not None: x_norm = x_norm + pos_and_type_embed x = x + self.drop_path(self.attn(x_norm, x_norm, x_norm)) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class VisionTransformerPooler(nn.Module): def __init__(self, input_dim: int, grid_shape: Tuple[int, int], num_heads: int = 8, num_blocks: int = 3, norm_layer=partial(nn.LayerNorm, eps=1e-6)): super().__init__() block_kwargs = dict(dim=input_dim, num_heads=num_heads, mlp_ratio=1.0, drop=0.10, attn_drop=0.10, drop_path=0.25, act_layer=nn.GELU, norm_layer=norm_layer) self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_blocks)]) self.norm_post = norm_layer(input_dim) self.grid_shape = grid_shape self.num_patches = grid_shape[0] * grid_shape[1] self.type_embed = nn.Parameter(torch.zeros(2, 1, input_dim)) trunc_normal_(self.type_embed, std=0.02) self.pos_drop = nn.Dropout(p=0.10) pos_embed = SinePositionEmbedding(input_dim // 2, normalize=True)( torch.ones([1, grid_shape[0], grid_shape[1]])) self.register_buffer("pos_embed", pos_embed, persistent=False) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, current_image, previous_image=None): B, C, H, W = current_image.shape if previous_image is not None: prev = previous_image.view(B, C, H * W).transpose(1, 2) else: prev = None cur = current_image.view(B, C, H * W).transpose(1, 2) pos = self.pos_embed.repeat(B, 1, 1) L = cur.shape[1] type_emb = self.type_embed[0].expand(B, L, -1) if prev is not None: x = torch.cat((cur, prev), dim=1) pos = torch.cat((pos, pos), dim=1) type_emb = torch.cat((type_emb, self.type_embed[1].expand(B, L, -1)), dim=1) else: x = cur pos_type = pos + type_emb x = self.pos_drop(x) for blk in self.blocks: x = blk(x, pos_type) x = self.norm_post(x) return x[:, :self.num_patches].transpose(1, 2).view(B, C, H, W) # ────────────────────────────────────────────────────────────────────────────── # Multi-image encoder (temporal) # ────────────────────────────────────────────────────────────────────────────── class MLP(nn.Module): """Projection MLP (1x1 conv based).""" def __init__(self, input_dim, output_dim, hidden_dim=None, use_1x1_convs=False): super().__init__() if use_1x1_convs and hidden_dim is not None: self.model = nn.Sequential( nn.Conv2d(input_dim, hidden_dim, 1, bias=False), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, output_dim, 1, bias=True), ) elif hidden_dim is not None: self.model = nn.Sequential( nn.Linear(input_dim, hidden_dim, bias=False), nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, output_dim, bias=True), ) else: self.model = nn.Linear(input_dim, output_dim) def forward(self, x): return self.model(x) class MultiImageEncoder(nn.Module): """BioViL-T style multi-image encoder: ResNet-50 backbone + ViT temporal pooler.""" def __init__(self): super().__init__() self.encoder = ResNet() backbone_out_dim = 2048 # ResNet-50 output channels output_dim = 256 self.backbone_to_vit = nn.Conv2d(backbone_out_dim, output_dim, 1, bias=False) self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=(14, 14)) self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1)) trunc_normal_(self.missing_previous_emb, std=0.02) def forward(self, current_image, previous_image=None, return_patch_embeddings=False): B = current_image.shape[0] if previous_image is not None: x = torch.cat([current_image, previous_image], dim=0) x = self.encoder(x) x = self.backbone_to_vit(x) patch_x, patch_prev = x[:B], x[B:] diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_prev) else: x = self.encoder(current_image) patch_x = self.backbone_to_vit(x) _, _, W, H = patch_x.shape diff_x = self.missing_previous_emb.repeat(B, 1, W, H) patch_fused = torch.cat([patch_x, diff_x], dim=1) # [B, 512, H, W] avg_pooled = torch.flatten(F.adaptive_avg_pool2d(patch_fused, (1, 1)), 1) if return_patch_embeddings: return patch_fused, avg_pooled return avg_pooled class TILAImageEncoder(nn.Module): """Full TILA image encoder: MultiImageEncoder + projection head. Outputs 128-dim normalized embeddings suitable for CLIP-style retrieval. """ JOINT_FEATURE_SIZE = 128 def __init__(self): super().__init__() self.encoder = MultiImageEncoder() self.projector = MLP( input_dim=512, # patch_x (256) + diff_x (256) output_dim=self.JOINT_FEATURE_SIZE, hidden_dim=self.JOINT_FEATURE_SIZE, use_1x1_convs=True, ) def forward(self, current_image, previous_image=None): patch_fused, pooled = self.encoder(current_image, previous_image, return_patch_embeddings=True) projected_patch = self.projector(patch_fused) projected_global = torch.mean(projected_patch, dim=(2, 3)) return ImageModelOutput( img_embedding=pooled, patch_embeddings=patch_fused, class_logits=None, projected_patch_embeddings=projected_patch, projected_global_embedding=projected_global, ) # ────────────────────────────────────────────────────────────────────────────── # Text encoder (BioViL-T CXR-BERT + projection) # ────────────────────────────────────────────────────────────────────────────── TEXT_MODEL_NAME = "microsoft/BiomedVLP-BioViL-T" class TextEncoder(nn.Module): """CXR-BERT text encoder with a projection head to 128-dim. Loads the pretrained BioViL-T text model and adds a LayerNorm + Linear projection from 768-dim CLS embeddings to 128-dim joint space. """ def __init__(self): super().__init__() from transformers import AutoConfig, AutoModel config = AutoConfig.from_pretrained(TEXT_MODEL_NAME, trust_remote_code=True) self.model = AutoModel.from_pretrained( TEXT_MODEL_NAME, config=config, trust_remote_code=True, ) self.projection = nn.Sequential( nn.LayerNorm(config.hidden_size), nn.Linear(config.hidden_size, 128), ) def forward(self, text_inputs: dict) -> torch.Tensor: """Encode tokenized text to 128-dim embeddings. Args: text_inputs: Dict from tokenizer (input_ids, attention_mask, etc.) Returns: Projected CLS embeddings [B, 128] """ outputs = self.model(**text_inputs) cls_emb = outputs.last_hidden_state[:, 0, :] if cls_emb.dtype != next(self.projection.parameters()).dtype: cls_emb = cls_emb.to(next(self.projection.parameters()).dtype) return self.projection(cls_emb) # ────────────────────────────────────────────────────────────────────────────── # Interval change classifier head # ────────────────────────────────────────────────────────────────────────────── class IntervalChangeClassifier(nn.Module): """Binary classifier head for interval change detection. Takes 128-dim projected embeddings and outputs a change probability. """ def __init__(self): super().__init__() self.head = nn.Sequential( nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1), ) def forward(self, embedding: torch.Tensor) -> torch.Tensor: """Returns logit (pre-sigmoid). Apply torch.sigmoid() to get probability.""" return self.head(embedding).squeeze(-1) # ────────────────────────────────────────────────────────────────────────────── # Full model wrapper # ────────────────────────────────────────────────────────────────────────────── try: from transformers import PreTrainedModel from configuration_tila import TILAConfig _BASE_CLASS = PreTrainedModel _HAS_TRANSFORMERS = True except ImportError: _BASE_CLASS = nn.Module _HAS_TRANSFORMERS = False class TILAModel(_BASE_CLASS): """TILA model with image encoder, text encoder, and interval change classifier. Usage: # Load from local safetensors model = TILAModel.from_pretrained("model.safetensors") # Load via AutoModel (requires config.json + trust_remote_code) from transformers import AutoModel model = AutoModel.from_pretrained("lukeingawesome/TILA", trust_remote_code=True) # Get 128-dim image embeddings emb = model.get_embeddings(current_img, previous_img) # Get 128-dim text embeddings text_emb = model.encode_text(["Improved pulmonary edema."]) # Predict interval change result = model.get_interval_change_prediction(current_img, previous_img) """ if _HAS_TRANSFORMERS: config_class = TILAConfig def __init__(self, config=None): if _HAS_TRANSFORMERS and config is None: config = TILAConfig() if _HAS_TRANSFORMERS: super().__init__(config) else: super().__init__() self.image_encoder = TILAImageEncoder() self.text_encoder = TextEncoder() self.change_classifier = IntervalChangeClassifier() @torch.no_grad() def encode_text(self, texts: list) -> torch.Tensor: """Encode text prompts to 128-dim normalized embeddings. Args: texts: List of text strings Returns: Normalized text embeddings [N, 128] """ from transformers import AutoTokenizer if not hasattr(self, '_tokenizer'): self._tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME, padding_side="right") device = next(self.parameters()).device tokens = self._tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256) tokens = {k: v.to(device) for k, v in tokens.items()} self.eval() emb = self.text_encoder(tokens) return F.normalize(emb.float(), p=2, dim=1) @torch.no_grad() def get_embeddings( self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None ) -> torch.Tensor: """Extract 128-dim projected global embeddings from a pair of chest X-rays. Args: current_image: Current CXR tensor [B, 3, 448, 448] previous_image: Previous CXR tensor [B, 3, 448, 448] (optional) Returns: Normalized 128-dim embeddings [B, 128] """ self.eval() out = self.image_encoder(current_image, previous_image) return F.normalize(out.projected_global_embedding.float(), p=2, dim=1) # Thresholds calibrated on validation set (AUC=0.7558) THRESHOLDS = { "default": 0.5000, # Standard sigmoid midpoint "bestf1": 0.2886, # Youden's J — best F1=0.7210, sens=0.7798, spec=0.6166 "spec95": 0.6370, # Specificity ~0.95 — sens=0.1752, spec=0.9502 } @torch.no_grad() def get_interval_change_prediction( self, current_image: torch.Tensor, previous_image: torch.Tensor, mode: str = "bestf1", ) -> torch.Tensor: """Predict interval change between two chest X-rays. Args: current_image: Current CXR tensor [B, 3, 448, 448] previous_image: Previous CXR tensor [B, 3, 448, 448] mode: Threshold mode — one of: "default" : threshold=0.50 (standard sigmoid cutoff) "bestf1" : threshold=0.29 (maximizes F1, balanced sens/spec) "spec95" : threshold=0.64 (targets 95% specificity, conservative) Returns: Dict with keys: "probabilities": raw change probabilities [B] "predictions": binary predictions [B] (0=no change, 1=change) "threshold": threshold used (float) """ if mode not in self.THRESHOLDS: raise ValueError(f"mode must be one of {list(self.THRESHOLDS.keys())}, got '{mode}'") self.eval() out = self.image_encoder(current_image, previous_image) logits = self.change_classifier(out.projected_global_embedding) probs = torch.sigmoid(logits.float()) threshold = self.THRESHOLDS[mode] preds = (probs >= threshold).long() return {"probabilities": probs, "predictions": preds, "threshold": threshold} @classmethod def from_pretrained(cls, path_or_repo: str, device: str = "cpu", **kwargs) -> "TILAModel": """Load model from a local file or HuggingFace Hub. Args: path_or_repo: Local path to model.safetensors, or HF repo ID (e.g. "lukeingawesome/TILA") device: Device to load onto Examples: model = TILAModel.from_pretrained("model.safetensors") model = TILAModel.from_pretrained("lukeingawesome/TILA") """ import os # If called by HF's AutoModel, it passes config as first positional arg config = kwargs.pop("config", None) # Determine if this is a local file or a HF repo if os.path.isfile(path_or_repo): safetensors_path = path_or_repo elif os.path.isdir(path_or_repo): safetensors_path = os.path.join(path_or_repo, "model.safetensors") else: from huggingface_hub import hf_hub_download safetensors_path = hf_hub_download( repo_id=path_or_repo, filename="model.safetensors", ) model = cls(config=config) if safetensors_path.endswith(".safetensors"): from safetensors.torch import load_file state_dict = load_file(safetensors_path, device=device) for k, v in state_dict.items(): if v.dim() == 1 and v.shape[0] == 1 and "num_batches_tracked" in k: state_dict[k] = v.squeeze(0) else: state_dict = torch.load(safetensors_path, map_location=device, weights_only=True) model.load_state_dict(state_dict, strict=False) model.eval() return model