| """ |
| TILA (Temporal Inversion-aware Learning and Alignment) β Model Architecture |
| |
| Paper: "Temporal Inversion for Learning Interval Change in Chest X-Rays" (CVPR 2026) |
| http://arxiv.org/abs/2604.04563 |
| |
| This module contains the full model architecture for TILA, built on top of the |
| BioViL-T (ResNet-50 + Vision Transformer pooler) backbone and CXR-BERT text encoder. |
| |
| Dependencies: |
| pip install torch torchvision timm transformers safetensors |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from functools import partial |
| from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from timm.layers import DropPath, Mlp, trunc_normal_ |
| from torchvision.models.resnet import Bottleneck, conv1x1 |
|
|
|
|
| |
| |
| |
|
|
|
|
| @dataclass |
| class ImageModelOutput: |
| img_embedding: torch.Tensor |
| patch_embeddings: torch.Tensor |
| projected_global_embedding: torch.Tensor |
| class_logits: Optional[torch.Tensor] |
| projected_patch_embeddings: torch.Tensor |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ResNet(nn.Module): |
| """Standard ResNet-50 (torchvision-compatible) without the final FC layer in forward.""" |
|
|
| def __init__( |
| self, |
| layers: Sequence[int] = (3, 4, 6, 3), |
| num_classes: int = 1000, |
| zero_init_residual: bool = False, |
| replace_stride_with_dilation: Optional[Sequence[bool]] = None, |
| ): |
| super().__init__() |
| block = Bottleneck |
| self.inplanes = 64 |
| self.dilation = 1 |
| if replace_stride_with_dilation is None: |
| replace_stride_with_dilation = [False, False, False] |
|
|
| self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) |
| self.bn1 = nn.BatchNorm2d(64) |
| self.relu = nn.ReLU(inplace=True) |
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
| self.layer1 = self._make_layer(block, 64, layers[0]) |
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) |
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) |
| self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) |
|
|
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| self.fc = nn.Linear(512 * block.expansion, num_classes) |
|
|
| |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
|
|
| if zero_init_residual: |
| for m in self.modules(): |
| if isinstance(m, Bottleneck) and m.bn3.weight is not None: |
| nn.init.constant_(m.bn3.weight, 0) |
|
|
| def _make_layer(self, block, planes, blocks, stride=1, dilate=False): |
| downsample = None |
| previous_dilation = self.dilation |
| if dilate: |
| self.dilation *= stride |
| stride = 1 |
| if stride != 1 or self.inplanes != planes * block.expansion: |
| downsample = nn.Sequential( |
| conv1x1(self.inplanes, planes * block.expansion, stride), |
| nn.BatchNorm2d(planes * block.expansion), |
| ) |
| layers = [block(self.inplanes, planes, stride, downsample, 1, 64, previous_dilation, nn.BatchNorm2d)] |
| self.inplanes = planes * block.expansion |
| for _ in range(1, blocks): |
| layers.append(block(self.inplanes, planes, dilation=self.dilation, norm_layer=nn.BatchNorm2d)) |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.maxpool(x) |
| x = self.layer1(x) |
| x = self.layer2(x) |
| x = self.layer3(x) |
| x = self.layer4(x) |
| return x |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SinePositionEmbedding: |
| def __init__(self, embedding_dim: int = 64, temperature: int = 10000, |
| normalize: bool = False, scale: Optional[float] = None): |
| self.embedding_dim = embedding_dim |
| self.temperature = temperature |
| self.normalize = normalize |
| self.scale = scale if scale is not None else 2 * math.pi |
|
|
| def __call__(self, mask: torch.Tensor) -> torch.Tensor: |
| B, H, W = mask.shape |
| y_embed = mask.cumsum(1, dtype=torch.float32) |
| x_embed = mask.cumsum(2, dtype=torch.float32) |
| if self.normalize: |
| y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale |
| x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale |
| dim_t = torch.arange(self.embedding_dim, dtype=torch.float32) |
| dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) |
| pos_x = x_embed[:, :, :, None] / dim_t |
| pos_y = y_embed[:, :, :, None] / dim_t |
| pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
| pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
| return torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2) |
|
|
|
|
| class MultiHeadAttentionLayer(nn.Module): |
| def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, |
| attn_drop: float = 0.0, proj_drop: float = 0.0): |
| super().__init__() |
| self.num_heads = num_heads |
| self.scale = (dim // num_heads) ** -0.5 |
| self.proj_q = nn.Linear(dim, dim, bias=qkv_bias) |
| self.proj_k = nn.Linear(dim, dim, bias=qkv_bias) |
| self.proj_v = nn.Linear(dim, dim, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, k, q, v): |
| B, N, C = v.shape |
| h = self.num_heads |
| wq = self.proj_q(q).reshape(B, N, h, C // h).permute(0, 2, 1, 3) |
| wk = self.proj_k(k).reshape(B, N, h, C // h).permute(0, 2, 1, 3) |
| wv = self.proj_v(v).reshape(B, N, h, C // h).permute(0, 2, 1, 3) |
| attn = (wq @ wk.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| o = (attn @ wv).transpose(1, 2).reshape(B, N, C) |
| return self.proj_drop(self.proj(o)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, dim, num_heads, mlp_ratio=1.0, qkv_bias=False, |
| drop=0.0, attn_drop=0.0, drop_path=0.0, |
| act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
| super().__init__() |
| self.norm1 = norm_layer(dim) |
| self.attn = MultiHeadAttentionLayer(dim, num_heads, qkv_bias, attn_drop, drop) |
| self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| self.norm2 = norm_layer(dim) |
| self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) |
|
|
| def forward(self, x, pos_and_type_embed=None): |
| x_norm = self.norm1(x) |
| if pos_and_type_embed is not None: |
| x_norm = x_norm + pos_and_type_embed |
| x = x + self.drop_path(self.attn(x_norm, x_norm, x_norm)) |
| x = x + self.drop_path(self.mlp(self.norm2(x))) |
| return x |
|
|
|
|
| class VisionTransformerPooler(nn.Module): |
| def __init__(self, input_dim: int, grid_shape: Tuple[int, int], |
| num_heads: int = 8, num_blocks: int = 3, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6)): |
| super().__init__() |
| block_kwargs = dict(dim=input_dim, num_heads=num_heads, mlp_ratio=1.0, |
| drop=0.10, attn_drop=0.10, drop_path=0.25, |
| act_layer=nn.GELU, norm_layer=norm_layer) |
| self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_blocks)]) |
| self.norm_post = norm_layer(input_dim) |
| self.grid_shape = grid_shape |
| self.num_patches = grid_shape[0] * grid_shape[1] |
|
|
| self.type_embed = nn.Parameter(torch.zeros(2, 1, input_dim)) |
| trunc_normal_(self.type_embed, std=0.02) |
|
|
| self.pos_drop = nn.Dropout(p=0.10) |
| pos_embed = SinePositionEmbedding(input_dim // 2, normalize=True)( |
| torch.ones([1, grid_shape[0], grid_shape[1]])) |
| self.register_buffer("pos_embed", pos_embed, persistent=False) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| trunc_normal_(m.weight, std=0.02) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| def forward(self, current_image, previous_image=None): |
| B, C, H, W = current_image.shape |
| if previous_image is not None: |
| prev = previous_image.view(B, C, H * W).transpose(1, 2) |
| else: |
| prev = None |
| cur = current_image.view(B, C, H * W).transpose(1, 2) |
| pos = self.pos_embed.repeat(B, 1, 1) |
|
|
| L = cur.shape[1] |
| type_emb = self.type_embed[0].expand(B, L, -1) |
| if prev is not None: |
| x = torch.cat((cur, prev), dim=1) |
| pos = torch.cat((pos, pos), dim=1) |
| type_emb = torch.cat((type_emb, self.type_embed[1].expand(B, L, -1)), dim=1) |
| else: |
| x = cur |
|
|
| pos_type = pos + type_emb |
| x = self.pos_drop(x) |
| for blk in self.blocks: |
| x = blk(x, pos_type) |
| x = self.norm_post(x) |
|
|
| return x[:, :self.num_patches].transpose(1, 2).view(B, C, H, W) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class MLP(nn.Module): |
| """Projection MLP (1x1 conv based).""" |
| def __init__(self, input_dim, output_dim, hidden_dim=None, use_1x1_convs=False): |
| super().__init__() |
| if use_1x1_convs and hidden_dim is not None: |
| self.model = nn.Sequential( |
| nn.Conv2d(input_dim, hidden_dim, 1, bias=False), |
| nn.BatchNorm2d(hidden_dim), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(hidden_dim, output_dim, 1, bias=True), |
| ) |
| elif hidden_dim is not None: |
| self.model = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim, bias=False), |
| nn.BatchNorm1d(hidden_dim), |
| nn.ReLU(inplace=True), |
| nn.Linear(hidden_dim, output_dim, bias=True), |
| ) |
| else: |
| self.model = nn.Linear(input_dim, output_dim) |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|
| class MultiImageEncoder(nn.Module): |
| """BioViL-T style multi-image encoder: ResNet-50 backbone + ViT temporal pooler.""" |
|
|
| def __init__(self): |
| super().__init__() |
| self.encoder = ResNet() |
| backbone_out_dim = 2048 |
| output_dim = 256 |
|
|
| self.backbone_to_vit = nn.Conv2d(backbone_out_dim, output_dim, 1, bias=False) |
| self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=(14, 14)) |
| self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1)) |
| trunc_normal_(self.missing_previous_emb, std=0.02) |
|
|
| def forward(self, current_image, previous_image=None, return_patch_embeddings=False): |
| B = current_image.shape[0] |
| if previous_image is not None: |
| x = torch.cat([current_image, previous_image], dim=0) |
| x = self.encoder(x) |
| x = self.backbone_to_vit(x) |
| patch_x, patch_prev = x[:B], x[B:] |
| diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_prev) |
| else: |
| x = self.encoder(current_image) |
| patch_x = self.backbone_to_vit(x) |
| _, _, W, H = patch_x.shape |
| diff_x = self.missing_previous_emb.repeat(B, 1, W, H) |
|
|
| patch_fused = torch.cat([patch_x, diff_x], dim=1) |
| avg_pooled = torch.flatten(F.adaptive_avg_pool2d(patch_fused, (1, 1)), 1) |
|
|
| if return_patch_embeddings: |
| return patch_fused, avg_pooled |
| return avg_pooled |
|
|
|
|
| class TILAImageEncoder(nn.Module): |
| """Full TILA image encoder: MultiImageEncoder + projection head. |
| |
| Outputs 128-dim normalized embeddings suitable for CLIP-style retrieval. |
| """ |
|
|
| JOINT_FEATURE_SIZE = 128 |
|
|
| def __init__(self): |
| super().__init__() |
| self.encoder = MultiImageEncoder() |
| self.projector = MLP( |
| input_dim=512, |
| output_dim=self.JOINT_FEATURE_SIZE, |
| hidden_dim=self.JOINT_FEATURE_SIZE, |
| use_1x1_convs=True, |
| ) |
|
|
| def forward(self, current_image, previous_image=None): |
| patch_fused, pooled = self.encoder(current_image, previous_image, return_patch_embeddings=True) |
| projected_patch = self.projector(patch_fused) |
| projected_global = torch.mean(projected_patch, dim=(2, 3)) |
| return ImageModelOutput( |
| img_embedding=pooled, |
| patch_embeddings=patch_fused, |
| class_logits=None, |
| projected_patch_embeddings=projected_patch, |
| projected_global_embedding=projected_global, |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| TEXT_MODEL_NAME = "microsoft/BiomedVLP-BioViL-T" |
|
|
|
|
| class TextEncoder(nn.Module): |
| """CXR-BERT text encoder with a projection head to 128-dim. |
| |
| Loads the pretrained BioViL-T text model and adds a LayerNorm + Linear |
| projection from 768-dim CLS embeddings to 128-dim joint space. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| from transformers import AutoConfig, AutoModel |
|
|
| config = AutoConfig.from_pretrained(TEXT_MODEL_NAME, trust_remote_code=True) |
| self.model = AutoModel.from_pretrained( |
| TEXT_MODEL_NAME, config=config, trust_remote_code=True, |
| ) |
| self.projection = nn.Sequential( |
| nn.LayerNorm(config.hidden_size), |
| nn.Linear(config.hidden_size, 128), |
| ) |
|
|
| def forward(self, text_inputs: dict) -> torch.Tensor: |
| """Encode tokenized text to 128-dim embeddings. |
| |
| Args: |
| text_inputs: Dict from tokenizer (input_ids, attention_mask, etc.) |
| |
| Returns: |
| Projected CLS embeddings [B, 128] |
| """ |
| outputs = self.model(**text_inputs) |
| cls_emb = outputs.last_hidden_state[:, 0, :] |
| if cls_emb.dtype != next(self.projection.parameters()).dtype: |
| cls_emb = cls_emb.to(next(self.projection.parameters()).dtype) |
| return self.projection(cls_emb) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class IntervalChangeClassifier(nn.Module): |
| """Binary classifier head for interval change detection. |
| |
| Takes 128-dim projected embeddings and outputs a change probability. |
| """ |
|
|
| def __init__(self): |
| super().__init__() |
| self.head = nn.Sequential( |
| nn.Linear(128, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| ) |
|
|
| def forward(self, embedding: torch.Tensor) -> torch.Tensor: |
| """Returns logit (pre-sigmoid). Apply torch.sigmoid() to get probability.""" |
| return self.head(embedding).squeeze(-1) |
|
|
|
|
| |
| |
| |
|
|
|
|
| try: |
| from transformers import PreTrainedModel |
| from configuration_tila import TILAConfig |
| _BASE_CLASS = PreTrainedModel |
| _HAS_TRANSFORMERS = True |
| except ImportError: |
| _BASE_CLASS = nn.Module |
| _HAS_TRANSFORMERS = False |
|
|
|
|
| class TILAModel(_BASE_CLASS): |
| """TILA model with image encoder, text encoder, and interval change classifier. |
| |
| Usage: |
| # Load from local safetensors |
| model = TILAModel.from_pretrained("model.safetensors") |
| |
| # Load via AutoModel (requires config.json + trust_remote_code) |
| from transformers import AutoModel |
| model = AutoModel.from_pretrained("lukeingawesome/TILA", trust_remote_code=True) |
| |
| # Get 128-dim image embeddings |
| emb = model.get_embeddings(current_img, previous_img) |
| |
| # Get 128-dim text embeddings |
| text_emb = model.encode_text(["Improved pulmonary edema."]) |
| |
| # Predict interval change |
| result = model.get_interval_change_prediction(current_img, previous_img) |
| """ |
|
|
| if _HAS_TRANSFORMERS: |
| config_class = TILAConfig |
|
|
| def __init__(self, config=None): |
| if _HAS_TRANSFORMERS and config is None: |
| config = TILAConfig() |
| if _HAS_TRANSFORMERS: |
| super().__init__(config) |
| else: |
| super().__init__() |
| self.image_encoder = TILAImageEncoder() |
| self.text_encoder = TextEncoder() |
| self.change_classifier = IntervalChangeClassifier() |
|
|
| @torch.no_grad() |
| def encode_text(self, texts: list) -> torch.Tensor: |
| """Encode text prompts to 128-dim normalized embeddings. |
| |
| Args: |
| texts: List of text strings |
| |
| Returns: |
| Normalized text embeddings [N, 128] |
| """ |
| from transformers import AutoTokenizer |
| if not hasattr(self, '_tokenizer'): |
| self._tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME, padding_side="right") |
| device = next(self.parameters()).device |
| tokens = self._tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256) |
| tokens = {k: v.to(device) for k, v in tokens.items()} |
| self.eval() |
| emb = self.text_encoder(tokens) |
| return F.normalize(emb.float(), p=2, dim=1) |
|
|
| @torch.no_grad() |
| def get_embeddings( |
| self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """Extract 128-dim projected global embeddings from a pair of chest X-rays. |
| |
| Args: |
| current_image: Current CXR tensor [B, 3, 448, 448] |
| previous_image: Previous CXR tensor [B, 3, 448, 448] (optional) |
| |
| Returns: |
| Normalized 128-dim embeddings [B, 128] |
| """ |
| self.eval() |
| out = self.image_encoder(current_image, previous_image) |
| return F.normalize(out.projected_global_embedding.float(), p=2, dim=1) |
|
|
| |
| THRESHOLDS = { |
| "default": 0.5000, |
| "bestf1": 0.2886, |
| "spec95": 0.6370, |
| } |
|
|
| @torch.no_grad() |
| def get_interval_change_prediction( |
| self, |
| current_image: torch.Tensor, |
| previous_image: torch.Tensor, |
| mode: str = "bestf1", |
| ) -> torch.Tensor: |
| """Predict interval change between two chest X-rays. |
| |
| Args: |
| current_image: Current CXR tensor [B, 3, 448, 448] |
| previous_image: Previous CXR tensor [B, 3, 448, 448] |
| mode: Threshold mode β one of: |
| "default" : threshold=0.50 (standard sigmoid cutoff) |
| "bestf1" : threshold=0.29 (maximizes F1, balanced sens/spec) |
| "spec95" : threshold=0.64 (targets 95% specificity, conservative) |
| |
| Returns: |
| Dict with keys: |
| "probabilities": raw change probabilities [B] |
| "predictions": binary predictions [B] (0=no change, 1=change) |
| "threshold": threshold used (float) |
| """ |
| if mode not in self.THRESHOLDS: |
| raise ValueError(f"mode must be one of {list(self.THRESHOLDS.keys())}, got '{mode}'") |
|
|
| self.eval() |
| out = self.image_encoder(current_image, previous_image) |
| logits = self.change_classifier(out.projected_global_embedding) |
| probs = torch.sigmoid(logits.float()) |
|
|
| threshold = self.THRESHOLDS[mode] |
| preds = (probs >= threshold).long() |
|
|
| return {"probabilities": probs, "predictions": preds, "threshold": threshold} |
|
|
| @classmethod |
| def from_pretrained(cls, path_or_repo: str, device: str = "cpu", **kwargs) -> "TILAModel": |
| """Load model from a local file or HuggingFace Hub. |
| |
| Args: |
| path_or_repo: Local path to model.safetensors, or HF repo ID (e.g. "lukeingawesome/TILA") |
| device: Device to load onto |
| |
| Examples: |
| model = TILAModel.from_pretrained("model.safetensors") |
| model = TILAModel.from_pretrained("lukeingawesome/TILA") |
| """ |
| import os |
|
|
| |
| config = kwargs.pop("config", None) |
|
|
| |
| if os.path.isfile(path_or_repo): |
| safetensors_path = path_or_repo |
| elif os.path.isdir(path_or_repo): |
| safetensors_path = os.path.join(path_or_repo, "model.safetensors") |
| else: |
| from huggingface_hub import hf_hub_download |
| safetensors_path = hf_hub_download( |
| repo_id=path_or_repo, |
| filename="model.safetensors", |
| ) |
|
|
| model = cls(config=config) |
|
|
| if safetensors_path.endswith(".safetensors"): |
| from safetensors.torch import load_file |
| state_dict = load_file(safetensors_path, device=device) |
| for k, v in state_dict.items(): |
| if v.dim() == 1 and v.shape[0] == 1 and "num_batches_tracked" in k: |
| state_dict[k] = v.squeeze(0) |
| else: |
| state_dict = torch.load(safetensors_path, map_location=device, weights_only=True) |
|
|
| model.load_state_dict(state_dict, strict=False) |
| model.eval() |
| return model |
|
|