lukeingawesome commited on
Commit
f46fb4d
·
verified ·
1 Parent(s): 5853c2d

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. README.md +213 -0
  2. inference.py +60 -0
  3. model.py +547 -0
  4. model.safetensors +3 -0
  5. preprocess.py +111 -0
  6. processor.py +117 -0
README.md ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - medical-imaging
5
+ - chest-x-ray
6
+ - temporal-analysis
7
+ - interval-change
8
+ - radiology
9
+ language:
10
+ - en
11
+ library_name: pytorch
12
+ pipeline_tag: image-feature-extraction
13
+ ---
14
+
15
+ # TILA — Temporal Image-Language Alignment for Chest X-rays
16
+
17
+ TILA is a vision-language model for analyzing **temporal changes** between pairs of chest X-rays. Given a current and a prior radiograph, TILA can:
18
+
19
+ 1. **Extract temporal-aware image embeddings** (128-dim) that capture both the static anatomy and the interval change between the two images.
20
+ 2. **Encode radiology text** into the same 128-dim space for zero-shot classification via image-text similarity.
21
+ 3. **Predict interval change** (binary: change vs. no change) using a lightweight classification head.
22
+
23
+ The image encoder is based on the [BioViL-T](https://huggingface.co/microsoft/BiomedVLP-BioViL-T) architecture (ResNet-50 + Vision Transformer temporal pooler), and the text encoder is CXR-BERT, both fine-tuned with temporal image-language alignment.
24
+
25
+ ## Quick Start
26
+
27
+ ### Installation
28
+
29
+ ```bash
30
+ pip install torch>=2.0 torchvision>=0.15 timm>=0.9 transformers>=4.30 safetensors>=0.4 pillow opencv-python numpy
31
+ ```
32
+
33
+ ### Load Model and Processor
34
+
35
+ ```python
36
+ import torch
37
+ from model import TILAModel
38
+ from processor import TILAProcessor
39
+
40
+ model = TILAModel.from_pretrained("model.safetensors")
41
+ model = model.to("cuda", dtype=torch.bfloat16)
42
+
43
+ # Processor handles everything: raw image → model-ready tensor
44
+ processor = TILAProcessor(dtype=torch.bfloat16, device="cuda")
45
+ ```
46
+
47
+ ### Extract Embeddings
48
+
49
+ ```python
50
+ current = processor("current_cxr.png") # accepts file paths, numpy arrays, or PIL images
51
+ previous = processor("previous_cxr.png")
52
+
53
+ # 128-dim L2-normalized embeddings
54
+ embeddings = model.get_embeddings(current, previous)
55
+ ```
56
+
57
+ The processor automatically applies medical image preprocessing (windowing, black padding removal, resize) followed by model transforms (center crop to 448x448, expand to 3 channels). If your images are already preprocessed, skip the medical preprocessing:
58
+
59
+ ```python
60
+ processor = TILAProcessor(raw_preprocess=False, dtype=torch.bfloat16, device="cuda")
61
+ ```
62
+
63
+ The embeddings encode both the current image state and the temporal difference from the prior.
64
+ They can be used for retrieval, similarity search, or as features for downstream tasks.
65
+
66
+ ### Encode Text
67
+
68
+ ```python
69
+ text_emb = model.encode_text([
70
+ "Improved pulmonary edema.",
71
+ "Stable pulmonary edema.",
72
+ "Worsening pulmonary edema.",
73
+ ])
74
+
75
+ # Zero-shot classification via image-text similarity
76
+ similarities = embeddings @ text_emb.T # [1, 3]
77
+ prediction = similarities.argmax(dim=1) # 0=improving, 1=stable, 2=worsening
78
+ ```
79
+
80
+ ### Predict Interval Change
81
+
82
+ ```python
83
+ result = model.get_interval_change_prediction(current, previous, mode="bestf1")
84
+
85
+ print(result["probabilities"]) # Raw change probability
86
+ print(result["predictions"]) # Binary: 0 = no change, 1 = change
87
+ print(result["threshold"]) # Threshold used
88
+ ```
89
+
90
+ Three threshold modes are available:
91
+
92
+ | Mode | Threshold | Description |
93
+ |------|-----------|-------------|
94
+ | `"bestf1"` | 0.29 | Maximizes F1 score (balanced sensitivity/specificity) |
95
+ | `"default"` | 0.50 | Standard sigmoid cutoff |
96
+ | `"spec95"` | 0.64 | Targets 95% specificity (conservative, fewer false positives) |
97
+
98
+ ### CLI Example
99
+
100
+ ```bash
101
+ python inference.py \
102
+ --checkpoint model.safetensors \
103
+ --current_image /path/to/current.png \
104
+ --previous_image /path/to/previous.png
105
+ ```
106
+
107
+ ## Model Architecture
108
+
109
+ ```
110
+ IMAGE ENCODER:
111
+ Input: current CXR [B, 3, 448, 448] + previous CXR [B, 3, 448, 448]
112
+ |
113
+ +-- ResNet-50 backbone (shared weights, processes both images)
114
+ | -> patch features [B, 2048, 14, 14]
115
+ |
116
+ +-- 1x1 Conv projection (2048 -> 256)
117
+ |
118
+ +-- Vision Transformer Pooler (3 blocks, 8 heads)
119
+ | -> temporal difference features [B, 256, 14, 14]
120
+ |
121
+ +-- Concatenate [static, temporal] -> [B, 512, 14, 14]
122
+ |
123
+ +-- MLP Projector (512 -> 128)
124
+ -> image embedding [B, 128] <-- get_embeddings()
125
+
126
+ TEXT ENCODER:
127
+ Input: tokenized text
128
+ |
129
+ +-- CXR-BERT (12 layers, 768-dim)
130
+ | -> CLS token [B, 768]
131
+ |
132
+ +-- LayerNorm + Linear (768 -> 128)
133
+ -> text embedding [B, 128] <-- encode_text()
134
+
135
+ CLASSIFIER:
136
+ image embedding [B, 128]
137
+ |
138
+ +-- Linear (128 -> 64) -> ReLU -> Linear (64 -> 1)
139
+ -> change probability [B] <-- get_interval_change_prediction()
140
+ ```
141
+
142
+ ## Preprocessing Raw Images
143
+
144
+ > **Note:** This preprocessing is **not** applied automatically. Run it as a separate step before model inference.
145
+
146
+ If your chest X-rays are raw (e.g., DICOM-derived PNGs with varying bit depths, black borders, or 16-bit depth), preprocess them first:
147
+
148
+ ```python
149
+ import cv2
150
+ from preprocess import preprocess_image
151
+
152
+ img = preprocess_image("raw_cxr.png")
153
+ cv2.imwrite("preprocessed.png", img)
154
+ ```
155
+
156
+ The pipeline applies:
157
+ 1. **Read as-is** — preserves original bit depth (supports 8-bit and 16-bit PNGs)
158
+ 2. **Windowing** — clips to `mean +/- 2*std`, normalizes to [0, 1]
159
+ 3. **Black padding removal** — contour-based crop
160
+ 4. **Aspect-ratio-preserving resize** — longest side to 512px (configurable)
161
+
162
+ ```bash
163
+ # CLI usage
164
+ python preprocess.py --input raw.png --output preprocessed.png
165
+ ```
166
+
167
+ If your images are already preprocessed (contrast-normalized, cropped, resized grayscale PNGs), you can skip this step and feed them directly to the model.
168
+
169
+ ## Input Format
170
+
171
+ - **Image format**: Grayscale chest X-ray (PNG, JPEG)
172
+ - **Model input**: Resize to 512px (shorter side), center crop to 448x448, repeat to 3 channels (handled by the transform in `inference.py`)
173
+ - **Pair**: Current (follow-up) image + Previous (baseline) image of the same patient
174
+ - **Dtype**: `torch.bfloat16` recommended on GPU, `torch.float32` on CPU
175
+
176
+ ## Files
177
+
178
+ | File | Description |
179
+ |------|-------------|
180
+ | `model.safetensors` | Model weights (613 MB, image + text + classifier) |
181
+ | `model.py` | Self-contained model architecture |
182
+ | `processor.py` | Image processor (raw image → model-ready tensor) |
183
+ | `preprocess.py` | Medical image preprocessing utilities |
184
+ | `inference.py` | Example inference script |
185
+
186
+ ## Citation
187
+
188
+ If you use this model, please cite:
189
+
190
+ ```bibtex
191
+ @article{tila2026,
192
+ title={TILA: Temporal Image-Language Alignment for Chest X-rays},
193
+ year={2026}
194
+ }
195
+ ```
196
+
197
+ ## Acknowledgements
198
+
199
+ This model builds upon [BioViL-T](https://huggingface.co/microsoft/BiomedVLP-BioViL-T) by Microsoft Research:
200
+
201
+ ```bibtex
202
+ @inproceedings{bannur2023biovilt,
203
+ title={Learning to Exploit Temporal Structure for Biomedical Vision-Language Processing},
204
+ author={Bannur, Shruthi and Hyland, Stephanie and Liu, Qianchu and Perez-Garcia, Fernando and Oktay, Ozan and Naumann, Tristan and Nori, Aditya and Alvarez-Valle, Javier},
205
+ booktitle={CVPR},
206
+ year={2023}
207
+ }
208
+ ```
209
+
210
+ ## License
211
+
212
+ This model is released under the [MIT License](LICENSE).
213
+ The BioViL-T architecture and CXR-BERT text encoder are by Microsoft Research, also released under MIT.
inference.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TILA — Inference Example
3
+
4
+ Usage:
5
+ # From raw images (full preprocessing applied automatically):
6
+ python inference.py --current_image raw_current.png --previous_image raw_previous.png
7
+
8
+ # From already-preprocessed images:
9
+ python inference.py --current_image prep_current.png --previous_image prep_previous.png --no-preprocess
10
+ """
11
+
12
+ import argparse
13
+ import torch
14
+ from model import TILAModel
15
+ from processor import TILAProcessor
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description="TILA Inference")
20
+ parser.add_argument("--checkpoint", type=str, default="model.safetensors")
21
+ parser.add_argument("--current_image", type=str, required=True)
22
+ parser.add_argument("--previous_image", type=str, required=True)
23
+ parser.add_argument("--no-preprocess", action="store_true",
24
+ help="Skip medical preprocessing (use if images are already preprocessed)")
25
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
26
+ args = parser.parse_args()
27
+
28
+ device = args.device
29
+ dtype = torch.bfloat16 if "cuda" in device else torch.float32
30
+
31
+ # Load model
32
+ model = TILAModel.from_pretrained(args.checkpoint, device=device)
33
+ model = model.to(dtype=dtype)
34
+
35
+ # Load and process images (preprocessing is built into the processor)
36
+ processor = TILAProcessor(
37
+ raw_preprocess=not args.no_preprocess,
38
+ dtype=dtype,
39
+ device=device,
40
+ )
41
+ current = processor(args.current_image)
42
+ previous = processor(args.previous_image)
43
+
44
+ # 1. Get embeddings (128-dim, L2-normalized)
45
+ embeddings = model.get_embeddings(current, previous)
46
+ print(f"Embedding shape: {embeddings.shape}")
47
+ print(f"Embedding (first 8 dims): {embeddings[0, :8].float().tolist()}")
48
+
49
+ # 2. Get interval change prediction (3 modes available)
50
+ for mode in ["default", "bestf1", "spec95"]:
51
+ result = model.get_interval_change_prediction(current, previous, mode=mode)
52
+ prob = result["probabilities"].item()
53
+ pred = result["predictions"].item()
54
+ thresh = result["threshold"]
55
+ label = "CHANGE" if pred == 1 else "NO CHANGE"
56
+ print(f"[{mode}] threshold={thresh:.4f}, prob={prob:.4f} -> {label}")
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
model.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TILA (Temporal Image-Language Alignment) — Model Architecture
3
+
4
+ This module contains the full model architecture for the TILA image encoder,
5
+ built on top of the BioViL-T (ResNet-50 + Vision Transformer pooler) backbone.
6
+
7
+ Dependencies:
8
+ pip install torch torchvision timm
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import math
14
+ from dataclasses import dataclass
15
+ from functools import partial
16
+ from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from timm.layers import DropPath, Mlp, trunc_normal_
22
+ from torchvision.models.resnet import Bottleneck, conv1x1
23
+
24
+
25
+ # ──────────────────────────────────────────────────────────────────────────────
26
+ # Output types
27
+ # ──────────────────────────────────────────────────────────────────────────────
28
+
29
+
30
+ @dataclass
31
+ class ImageModelOutput:
32
+ img_embedding: torch.Tensor
33
+ patch_embeddings: torch.Tensor
34
+ projected_global_embedding: torch.Tensor
35
+ class_logits: Optional[torch.Tensor]
36
+ projected_patch_embeddings: torch.Tensor
37
+
38
+
39
+ # ──────────────────────────────────────────────────────────────────────────────
40
+ # ResNet-50 backbone
41
+ # ──────────────────────────────────────────────────────────────────────────────
42
+
43
+
44
+ class ResNet(nn.Module):
45
+ """Standard ResNet-50 (torchvision-compatible) without the final FC layer in forward."""
46
+
47
+ def __init__(
48
+ self,
49
+ layers: Sequence[int] = (3, 4, 6, 3),
50
+ num_classes: int = 1000,
51
+ zero_init_residual: bool = False,
52
+ replace_stride_with_dilation: Optional[Sequence[bool]] = None,
53
+ ):
54
+ super().__init__()
55
+ block = Bottleneck
56
+ self.inplanes = 64
57
+ self.dilation = 1
58
+ if replace_stride_with_dilation is None:
59
+ replace_stride_with_dilation = [False, False, False]
60
+
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
62
+ self.bn1 = nn.BatchNorm2d(64)
63
+ self.relu = nn.ReLU(inplace=True)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+
66
+ self.layer1 = self._make_layer(block, 64, layers[0])
67
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
68
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
69
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
70
+
71
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
72
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
73
+
74
+ # Weight init
75
+ for m in self.modules():
76
+ if isinstance(m, nn.Conv2d):
77
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
78
+ elif isinstance(m, nn.BatchNorm2d):
79
+ nn.init.constant_(m.weight, 1)
80
+ nn.init.constant_(m.bias, 0)
81
+
82
+ if zero_init_residual:
83
+ for m in self.modules():
84
+ if isinstance(m, Bottleneck) and m.bn3.weight is not None:
85
+ nn.init.constant_(m.bn3.weight, 0)
86
+
87
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
88
+ downsample = None
89
+ previous_dilation = self.dilation
90
+ if dilate:
91
+ self.dilation *= stride
92
+ stride = 1
93
+ if stride != 1 or self.inplanes != planes * block.expansion:
94
+ downsample = nn.Sequential(
95
+ conv1x1(self.inplanes, planes * block.expansion, stride),
96
+ nn.BatchNorm2d(planes * block.expansion),
97
+ )
98
+ layers = [block(self.inplanes, planes, stride, downsample, 1, 64, previous_dilation, nn.BatchNorm2d)]
99
+ self.inplanes = planes * block.expansion
100
+ for _ in range(1, blocks):
101
+ layers.append(block(self.inplanes, planes, dilation=self.dilation, norm_layer=nn.BatchNorm2d))
102
+ return nn.Sequential(*layers)
103
+
104
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
105
+ x = self.conv1(x)
106
+ x = self.bn1(x)
107
+ x = self.relu(x)
108
+ x = self.maxpool(x)
109
+ x = self.layer1(x)
110
+ x = self.layer2(x)
111
+ x = self.layer3(x)
112
+ x = self.layer4(x)
113
+ return x # patch features [B, 2048, H, W]
114
+
115
+
116
+ # ──────────────────────────────────────────��───────────────────────────────────
117
+ # Vision Transformer Pooler (temporal attention)
118
+ # ──────────────────────────────────────────────────────────────────────────────
119
+
120
+
121
+ class SinePositionEmbedding:
122
+ def __init__(self, embedding_dim: int = 64, temperature: int = 10000,
123
+ normalize: bool = False, scale: Optional[float] = None):
124
+ self.embedding_dim = embedding_dim
125
+ self.temperature = temperature
126
+ self.normalize = normalize
127
+ self.scale = scale if scale is not None else 2 * math.pi
128
+
129
+ def __call__(self, mask: torch.Tensor) -> torch.Tensor:
130
+ B, H, W = mask.shape
131
+ y_embed = mask.cumsum(1, dtype=torch.float32)
132
+ x_embed = mask.cumsum(2, dtype=torch.float32)
133
+ if self.normalize:
134
+ y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
135
+ x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
136
+ dim_t = torch.arange(self.embedding_dim, dtype=torch.float32)
137
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
138
+ pos_x = x_embed[:, :, :, None] / dim_t
139
+ pos_y = y_embed[:, :, :, None] / dim_t
140
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
141
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
142
+ return torch.cat((pos_y, pos_x), dim=3).view(B, H * W, self.embedding_dim * 2)
143
+
144
+
145
+ class MultiHeadAttentionLayer(nn.Module):
146
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False,
147
+ attn_drop: float = 0.0, proj_drop: float = 0.0):
148
+ super().__init__()
149
+ self.num_heads = num_heads
150
+ self.scale = (dim // num_heads) ** -0.5
151
+ self.proj_q = nn.Linear(dim, dim, bias=qkv_bias)
152
+ self.proj_k = nn.Linear(dim, dim, bias=qkv_bias)
153
+ self.proj_v = nn.Linear(dim, dim, bias=qkv_bias)
154
+ self.attn_drop = nn.Dropout(attn_drop)
155
+ self.proj = nn.Linear(dim, dim)
156
+ self.proj_drop = nn.Dropout(proj_drop)
157
+
158
+ def forward(self, k, q, v):
159
+ B, N, C = v.shape
160
+ h = self.num_heads
161
+ wq = self.proj_q(q).reshape(B, N, h, C // h).permute(0, 2, 1, 3)
162
+ wk = self.proj_k(k).reshape(B, N, h, C // h).permute(0, 2, 1, 3)
163
+ wv = self.proj_v(v).reshape(B, N, h, C // h).permute(0, 2, 1, 3)
164
+ attn = (wq @ wk.transpose(-2, -1)) * self.scale
165
+ attn = attn.softmax(dim=-1)
166
+ attn = self.attn_drop(attn)
167
+ o = (attn @ wv).transpose(1, 2).reshape(B, N, C)
168
+ return self.proj_drop(self.proj(o))
169
+
170
+
171
+ class Block(nn.Module):
172
+ def __init__(self, dim, num_heads, mlp_ratio=1.0, qkv_bias=False,
173
+ drop=0.0, attn_drop=0.0, drop_path=0.0,
174
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
175
+ super().__init__()
176
+ self.norm1 = norm_layer(dim)
177
+ self.attn = MultiHeadAttentionLayer(dim, num_heads, qkv_bias, attn_drop, drop)
178
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
179
+ self.norm2 = norm_layer(dim)
180
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
181
+
182
+ def forward(self, x, pos_and_type_embed=None):
183
+ x_norm = self.norm1(x)
184
+ if pos_and_type_embed is not None:
185
+ x_norm = x_norm + pos_and_type_embed
186
+ x = x + self.drop_path(self.attn(x_norm, x_norm, x_norm))
187
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
188
+ return x
189
+
190
+
191
+ class VisionTransformerPooler(nn.Module):
192
+ def __init__(self, input_dim: int, grid_shape: Tuple[int, int],
193
+ num_heads: int = 8, num_blocks: int = 3,
194
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)):
195
+ super().__init__()
196
+ block_kwargs = dict(dim=input_dim, num_heads=num_heads, mlp_ratio=1.0,
197
+ drop=0.10, attn_drop=0.10, drop_path=0.25,
198
+ act_layer=nn.GELU, norm_layer=norm_layer)
199
+ self.blocks = nn.ModuleList([Block(**block_kwargs) for _ in range(num_blocks)])
200
+ self.norm_post = norm_layer(input_dim)
201
+ self.grid_shape = grid_shape
202
+ self.num_patches = grid_shape[0] * grid_shape[1]
203
+
204
+ self.type_embed = nn.Parameter(torch.zeros(2, 1, input_dim))
205
+ trunc_normal_(self.type_embed, std=0.02)
206
+
207
+ self.pos_drop = nn.Dropout(p=0.10)
208
+ pos_embed = SinePositionEmbedding(input_dim // 2, normalize=True)(
209
+ torch.ones([1, grid_shape[0], grid_shape[1]]))
210
+ self.register_buffer("pos_embed", pos_embed, persistent=False)
211
+ self.apply(self._init_weights)
212
+
213
+ def _init_weights(self, m):
214
+ if isinstance(m, nn.Linear):
215
+ trunc_normal_(m.weight, std=0.02)
216
+ if m.bias is not None:
217
+ nn.init.constant_(m.bias, 0)
218
+ elif isinstance(m, nn.LayerNorm):
219
+ nn.init.constant_(m.bias, 0)
220
+ nn.init.constant_(m.weight, 1.0)
221
+
222
+ def forward(self, current_image, previous_image=None):
223
+ B, C, H, W = current_image.shape
224
+ if previous_image is not None:
225
+ prev = previous_image.view(B, C, H * W).transpose(1, 2)
226
+ else:
227
+ prev = None
228
+ cur = current_image.view(B, C, H * W).transpose(1, 2)
229
+ pos = self.pos_embed.repeat(B, 1, 1)
230
+
231
+ L = cur.shape[1]
232
+ type_emb = self.type_embed[0].expand(B, L, -1)
233
+ if prev is not None:
234
+ x = torch.cat((cur, prev), dim=1)
235
+ pos = torch.cat((pos, pos), dim=1)
236
+ type_emb = torch.cat((type_emb, self.type_embed[1].expand(B, L, -1)), dim=1)
237
+ else:
238
+ x = cur
239
+
240
+ pos_type = pos + type_emb
241
+ x = self.pos_drop(x)
242
+ for blk in self.blocks:
243
+ x = blk(x, pos_type)
244
+ x = self.norm_post(x)
245
+
246
+ return x[:, :self.num_patches].transpose(1, 2).view(B, C, H, W)
247
+
248
+
249
+ # ──────────────────────────────────────────────────────────────────────────────
250
+ # Multi-image encoder (temporal)
251
+ # ──────────────────────────────────────────────────────────────────────────────
252
+
253
+
254
+ class MLP(nn.Module):
255
+ """Projection MLP (1x1 conv based)."""
256
+ def __init__(self, input_dim, output_dim, hidden_dim=None, use_1x1_convs=False):
257
+ super().__init__()
258
+ if use_1x1_convs and hidden_dim is not None:
259
+ self.model = nn.Sequential(
260
+ nn.Conv2d(input_dim, hidden_dim, 1, bias=False),
261
+ nn.BatchNorm2d(hidden_dim),
262
+ nn.ReLU(inplace=True),
263
+ nn.Conv2d(hidden_dim, output_dim, 1, bias=True),
264
+ )
265
+ elif hidden_dim is not None:
266
+ self.model = nn.Sequential(
267
+ nn.Linear(input_dim, hidden_dim, bias=False),
268
+ nn.BatchNorm1d(hidden_dim),
269
+ nn.ReLU(inplace=True),
270
+ nn.Linear(hidden_dim, output_dim, bias=True),
271
+ )
272
+ else:
273
+ self.model = nn.Linear(input_dim, output_dim)
274
+
275
+ def forward(self, x):
276
+ return self.model(x)
277
+
278
+
279
+ class MultiImageEncoder(nn.Module):
280
+ """BioViL-T style multi-image encoder: ResNet-50 backbone + ViT temporal pooler."""
281
+
282
+ def __init__(self):
283
+ super().__init__()
284
+ self.encoder = ResNet()
285
+ backbone_out_dim = 2048 # ResNet-50 output channels
286
+ output_dim = 256
287
+
288
+ self.backbone_to_vit = nn.Conv2d(backbone_out_dim, output_dim, 1, bias=False)
289
+ self.vit_pooler = VisionTransformerPooler(input_dim=output_dim, grid_shape=(14, 14))
290
+ self.missing_previous_emb = nn.Parameter(torch.zeros(1, output_dim, 1, 1))
291
+ trunc_normal_(self.missing_previous_emb, std=0.02)
292
+
293
+ def forward(self, current_image, previous_image=None, return_patch_embeddings=False):
294
+ B = current_image.shape[0]
295
+ if previous_image is not None:
296
+ x = torch.cat([current_image, previous_image], dim=0)
297
+ x = self.encoder(x)
298
+ x = self.backbone_to_vit(x)
299
+ patch_x, patch_prev = x[:B], x[B:]
300
+ diff_x = self.vit_pooler(current_image=patch_x, previous_image=patch_prev)
301
+ else:
302
+ x = self.encoder(current_image)
303
+ patch_x = self.backbone_to_vit(x)
304
+ _, _, W, H = patch_x.shape
305
+ diff_x = self.missing_previous_emb.repeat(B, 1, W, H)
306
+
307
+ patch_fused = torch.cat([patch_x, diff_x], dim=1) # [B, 512, H, W]
308
+ avg_pooled = torch.flatten(F.adaptive_avg_pool2d(patch_fused, (1, 1)), 1)
309
+
310
+ if return_patch_embeddings:
311
+ return patch_fused, avg_pooled
312
+ return avg_pooled
313
+
314
+
315
+ class TILAImageEncoder(nn.Module):
316
+ """Full TILA image encoder: MultiImageEncoder + projection head.
317
+
318
+ Outputs 128-dim normalized embeddings suitable for CLIP-style retrieval.
319
+ """
320
+
321
+ JOINT_FEATURE_SIZE = 128
322
+
323
+ def __init__(self):
324
+ super().__init__()
325
+ self.encoder = MultiImageEncoder()
326
+ self.projector = MLP(
327
+ input_dim=512, # patch_x (256) + diff_x (256)
328
+ output_dim=self.JOINT_FEATURE_SIZE,
329
+ hidden_dim=self.JOINT_FEATURE_SIZE,
330
+ use_1x1_convs=True,
331
+ )
332
+
333
+ def forward(self, current_image, previous_image=None):
334
+ patch_fused, pooled = self.encoder(current_image, previous_image, return_patch_embeddings=True)
335
+ projected_patch = self.projector(patch_fused)
336
+ projected_global = torch.mean(projected_patch, dim=(2, 3))
337
+ return ImageModelOutput(
338
+ img_embedding=pooled,
339
+ patch_embeddings=patch_fused,
340
+ class_logits=None,
341
+ projected_patch_embeddings=projected_patch,
342
+ projected_global_embedding=projected_global,
343
+ )
344
+
345
+
346
+ # ──────────────────────────────────────────────────────────────────────────────
347
+ # Text encoder (BioViL-T CXR-BERT + projection)
348
+ # ──────────────────────────────────────────────────────────────────────────────
349
+
350
+
351
+ TEXT_MODEL_NAME = "microsoft/BiomedVLP-BioViL-T"
352
+
353
+
354
+ class TextEncoder(nn.Module):
355
+ """CXR-BERT text encoder with a projection head to 128-dim.
356
+
357
+ Loads the pretrained BioViL-T text model and adds a LayerNorm + Linear
358
+ projection from 768-dim CLS embeddings to 128-dim joint space.
359
+ """
360
+
361
+ def __init__(self):
362
+ super().__init__()
363
+ from transformers import AutoConfig, AutoModel
364
+
365
+ config = AutoConfig.from_pretrained(TEXT_MODEL_NAME, trust_remote_code=True)
366
+ self.model = AutoModel.from_pretrained(
367
+ TEXT_MODEL_NAME, config=config, trust_remote_code=True,
368
+ )
369
+ self.projection = nn.Sequential(
370
+ nn.LayerNorm(config.hidden_size),
371
+ nn.Linear(config.hidden_size, 128),
372
+ )
373
+
374
+ def forward(self, text_inputs: dict) -> torch.Tensor:
375
+ """Encode tokenized text to 128-dim embeddings.
376
+
377
+ Args:
378
+ text_inputs: Dict from tokenizer (input_ids, attention_mask, etc.)
379
+
380
+ Returns:
381
+ Projected CLS embeddings [B, 128]
382
+ """
383
+ outputs = self.model(**text_inputs)
384
+ cls_emb = outputs.last_hidden_state[:, 0, :]
385
+ if cls_emb.dtype != next(self.projection.parameters()).dtype:
386
+ cls_emb = cls_emb.to(next(self.projection.parameters()).dtype)
387
+ return self.projection(cls_emb)
388
+
389
+
390
+ # ──────────────────────────────────────────────────────────────────────────────
391
+ # Interval change classifier head
392
+ # ──────────────────────────────────────────────────────────────────────────────
393
+
394
+
395
+ class IntervalChangeClassifier(nn.Module):
396
+ """Binary classifier head for interval change detection.
397
+
398
+ Takes 128-dim projected embeddings and outputs a change probability.
399
+ """
400
+
401
+ def __init__(self):
402
+ super().__init__()
403
+ self.head = nn.Sequential(
404
+ nn.Linear(128, 64),
405
+ nn.ReLU(),
406
+ nn.Linear(64, 1),
407
+ )
408
+
409
+ def forward(self, embedding: torch.Tensor) -> torch.Tensor:
410
+ """Returns logit (pre-sigmoid). Apply torch.sigmoid() to get probability."""
411
+ return self.head(embedding).squeeze(-1)
412
+
413
+
414
+ # ──────────────────────────────────────────────────────────────────────────────
415
+ # Full model wrapper
416
+ # ──────────────────────────────────────────────────────────────────────────────
417
+
418
+
419
+ class TILAModel(nn.Module):
420
+ """TILA model with image encoder, text encoder, and interval change classifier.
421
+
422
+ Usage:
423
+ model = TILAModel.from_pretrained("model.safetensors")
424
+
425
+ # Get 128-dim image embeddings
426
+ emb = model.get_embeddings(current_img, previous_img)
427
+
428
+ # Get 128-dim text embeddings
429
+ text_emb = model.encode_text(["Improved pulmonary edema."])
430
+
431
+ # Predict interval change
432
+ result = model.get_interval_change_prediction(current_img, previous_img)
433
+ """
434
+
435
+ def __init__(self):
436
+ super().__init__()
437
+ self.image_encoder = TILAImageEncoder()
438
+ self.text_encoder = TextEncoder()
439
+ self.change_classifier = IntervalChangeClassifier()
440
+
441
+ @torch.no_grad()
442
+ def encode_text(self, texts: list) -> torch.Tensor:
443
+ """Encode text prompts to 128-dim normalized embeddings.
444
+
445
+ Args:
446
+ texts: List of text strings
447
+
448
+ Returns:
449
+ Normalized text embeddings [N, 128]
450
+ """
451
+ from transformers import AutoTokenizer
452
+ if not hasattr(self, '_tokenizer'):
453
+ self._tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME, padding_side="right")
454
+ device = next(self.parameters()).device
455
+ tokens = self._tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
456
+ tokens = {k: v.to(device) for k, v in tokens.items()}
457
+ self.eval()
458
+ # Run text encoder in float32 for numerical stability
459
+ with torch.autocast(device_type=device.type if isinstance(device, torch.device) else "cuda", enabled=False):
460
+ self.text_encoder.float()
461
+ emb = self.text_encoder(tokens)
462
+ self.text_encoder.to(next(self.image_encoder.parameters()).dtype)
463
+ return F.normalize(emb.float(), p=2, dim=1)
464
+
465
+ @torch.no_grad()
466
+ def get_embeddings(
467
+ self, current_image: torch.Tensor, previous_image: Optional[torch.Tensor] = None
468
+ ) -> torch.Tensor:
469
+ """Extract 128-dim projected global embeddings from a pair of chest X-rays.
470
+
471
+ Args:
472
+ current_image: Current CXR tensor [B, 3, 448, 448]
473
+ previous_image: Previous CXR tensor [B, 3, 448, 448] (optional)
474
+
475
+ Returns:
476
+ Normalized 128-dim embeddings [B, 128]
477
+ """
478
+ self.eval()
479
+ out = self.image_encoder(current_image, previous_image)
480
+ return F.normalize(out.projected_global_embedding.float(), p=2, dim=1)
481
+
482
+ # Thresholds calibrated on validation set (AUC=0.7558)
483
+ THRESHOLDS = {
484
+ "default": 0.5000, # Standard sigmoid midpoint
485
+ "bestf1": 0.2886, # Youden's J — best F1=0.7210, sens=0.7798, spec=0.6166
486
+ "spec95": 0.6370, # Specificity ~0.95 — sens=0.1752, spec=0.9502
487
+ }
488
+
489
+ @torch.no_grad()
490
+ def get_interval_change_prediction(
491
+ self,
492
+ current_image: torch.Tensor,
493
+ previous_image: torch.Tensor,
494
+ mode: str = "bestf1",
495
+ ) -> torch.Tensor:
496
+ """Predict interval change between two chest X-rays.
497
+
498
+ Args:
499
+ current_image: Current CXR tensor [B, 3, 448, 448]
500
+ previous_image: Previous CXR tensor [B, 3, 448, 448]
501
+ mode: Threshold mode — one of:
502
+ "default" : threshold=0.50 (standard sigmoid cutoff)
503
+ "bestf1" : threshold=0.29 (maximizes F1, balanced sens/spec)
504
+ "spec95" : threshold=0.64 (targets 95% specificity, conservative)
505
+
506
+ Returns:
507
+ Dict with keys:
508
+ "probabilities": raw change probabilities [B]
509
+ "predictions": binary predictions [B] (0=no change, 1=change)
510
+ "threshold": threshold used (float)
511
+ """
512
+ if mode not in self.THRESHOLDS:
513
+ raise ValueError(f"mode must be one of {list(self.THRESHOLDS.keys())}, got '{mode}'")
514
+
515
+ self.eval()
516
+ out = self.image_encoder(current_image, previous_image)
517
+ logits = self.change_classifier(out.projected_global_embedding)
518
+ probs = torch.sigmoid(logits.float())
519
+
520
+ threshold = self.THRESHOLDS[mode]
521
+ preds = (probs >= threshold).long()
522
+
523
+ return {"probabilities": probs, "predictions": preds, "threshold": threshold}
524
+
525
+ @classmethod
526
+ def from_pretrained(cls, checkpoint_path: str, device: str = "cpu") -> "TILAModel":
527
+ """Load model from a checkpoint file.
528
+
529
+ Args:
530
+ checkpoint_path: Path to model.safetensors (or pytorch_model.bin)
531
+ device: Device to load onto
532
+ """
533
+ model = cls()
534
+
535
+ if checkpoint_path.endswith(".safetensors"):
536
+ from safetensors.torch import load_file
537
+ state_dict = load_file(checkpoint_path, device=device)
538
+ # safetensors stores scalar tensors as 1-d; squeeze them back
539
+ for k, v in state_dict.items():
540
+ if v.dim() == 1 and v.shape[0] == 1 and "num_batches_tracked" in k:
541
+ state_dict[k] = v.squeeze(0)
542
+ else:
543
+ state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
544
+
545
+ model.load_state_dict(state_dict)
546
+ model.eval()
547
+ return model
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b16b6bcf47ac6e4e79c4d9da2db88055b297adca22715935e4522184f87ce101
3
+ size 642508642
preprocess.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TILA — Image Preprocessing
3
+
4
+ Converts raw chest X-ray images (DICOM-derived PNGs or raw PNGs) into the
5
+ normalized format expected by the TILA model.
6
+
7
+ Pipeline:
8
+ 1. Read image as-is (preserving bit depth)
9
+ 2. Windowing: clip to mean +/- 2*std, normalize to [0, 1]
10
+ 3. Convert to uint8
11
+ 4. Remove black padding (contour-based crop)
12
+ 5. Resize preserving aspect ratio (longest side = 512)
13
+
14
+ Usage:
15
+ from preprocess import preprocess_image
16
+
17
+ img = preprocess_image("raw_cxr.png")
18
+ cv2.imwrite("preprocessed.png", img)
19
+ """
20
+
21
+ import cv2
22
+ import numpy as np
23
+ from pathlib import Path
24
+
25
+
26
+ def apply_windowing(image: np.ndarray, width_param: float = 4.0) -> np.ndarray:
27
+ """Apply intensity windowing based on image statistics.
28
+
29
+ Clips intensities to [mean - width_param/2 * std, mean + width_param/2 * std]
30
+ and normalizes to [0, 1].
31
+ """
32
+ image = image.astype(np.float64)
33
+ mean = np.mean(image)
34
+ std = np.std(image)
35
+ window_center = mean
36
+ window_width = width_param * std
37
+ img_min = window_center - window_width / 2
38
+ img_max = window_center + window_width / 2
39
+ image = np.clip(image, img_min, img_max)
40
+ image = (image - img_min) / (img_max - img_min + 1e-8)
41
+ return image
42
+
43
+
44
+ def remove_black_padding(image: np.ndarray) -> np.ndarray:
45
+ """Remove black padded borders by finding the largest contour."""
46
+ _, thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY)
47
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
48
+ if not contours:
49
+ return image
50
+ largest = max(contours, key=cv2.contourArea)
51
+ x, y, w, h = cv2.boundingRect(largest)
52
+ return image[y:y + h, x:x + w]
53
+
54
+
55
+ def resize_preserve_aspect_ratio(image: np.ndarray, max_size: int = 512) -> np.ndarray:
56
+ """Resize so the longest side equals max_size, preserving aspect ratio."""
57
+ if len(image.shape) == 3:
58
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
59
+ h, w = image.shape
60
+ if w < h:
61
+ new_w = max_size
62
+ new_h = int(new_w / (w / h))
63
+ else:
64
+ new_h = max_size
65
+ new_w = int(new_h * (w / h))
66
+ return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
67
+
68
+
69
+ def preprocess_image(
70
+ image_path: str,
71
+ width_param: float = 4.0,
72
+ max_size: int = 512,
73
+ ) -> np.ndarray:
74
+ """Full preprocessing pipeline for a chest X-ray image.
75
+
76
+ Args:
77
+ image_path: Path to raw image (PNG, JPEG, etc.)
78
+ width_param: Windowing width in multiples of std (default: 4.0)
79
+ max_size: Target size for longest dimension (default: 512)
80
+
81
+ Returns:
82
+ Preprocessed uint8 grayscale image
83
+ """
84
+ # IMREAD_UNCHANGED preserves bit depth (important for 16-bit DICOM-derived PNGs)
85
+ image = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
86
+ if image is None:
87
+ raise ValueError(f"Could not read image: {image_path}")
88
+ # Convert color to grayscale if needed
89
+ if len(image.shape) == 3:
90
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
91
+
92
+ image = apply_windowing(image, width_param)
93
+ image = (image * 255.0).astype(np.uint8)
94
+ image = remove_black_padding(image)
95
+ image = resize_preserve_aspect_ratio(image, max_size)
96
+ return image
97
+
98
+
99
+ if __name__ == "__main__":
100
+ import argparse
101
+
102
+ parser = argparse.ArgumentParser(description="Preprocess chest X-ray images for TILA")
103
+ parser.add_argument("--input", required=True, help="Input image path")
104
+ parser.add_argument("--output", required=True, help="Output image path")
105
+ parser.add_argument("--width-param", type=float, default=4.0)
106
+ parser.add_argument("--max-size", type=int, default=512)
107
+ args = parser.parse_args()
108
+
109
+ img = preprocess_image(args.input, args.width_param, args.max_size)
110
+ cv2.imwrite(args.output, img)
111
+ print(f"Saved preprocessed image to {args.output} ({img.shape[1]}x{img.shape[0]})")
processor.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TILA — Image Processor
3
+
4
+ Single processor that handles the full pipeline:
5
+ raw image (path, numpy, or PIL) → model-ready tensor [1, 3, 448, 448]
6
+
7
+ Combines:
8
+ 1. Medical image preprocessing (windowing, padding removal, resize)
9
+ 2. Model transforms (resize, center crop, to tensor, expand channels)
10
+
11
+ Usage:
12
+ from processor import TILAProcessor
13
+
14
+ processor = TILAProcessor()
15
+
16
+ # From file path (applies full preprocessing)
17
+ tensor = processor("raw_cxr.png")
18
+
19
+ # From PIL image (skips medical preprocessing, applies model transforms only)
20
+ tensor = processor(Image.open("preprocessed.png"))
21
+
22
+ # Pair of images for the model
23
+ current = processor("current.png")
24
+ previous = processor("previous.png")
25
+ result = model.get_interval_change_prediction(current, previous)
26
+ """
27
+
28
+ import cv2
29
+ import numpy as np
30
+ import torch
31
+ from PIL import Image
32
+ from torchvision import transforms
33
+ from typing import Union
34
+
35
+ from preprocess import preprocess_image
36
+
37
+
38
+ class TILAProcessor:
39
+ """End-to-end image processor for the TILA model.
40
+
41
+ Accepts file paths (str/Path), numpy arrays, or PIL Images.
42
+ - File paths: full pipeline (windowing → crop → resize → model transform)
43
+ - Numpy arrays: treated as raw, full pipeline applied
44
+ - PIL Images: assumed already preprocessed, only model transforms applied
45
+
46
+ Args:
47
+ raw_preprocess: Apply medical preprocessing (windowing, padding removal).
48
+ Set False if images are already preprocessed PNGs.
49
+ width_param: Windowing width parameter (default: 4.0)
50
+ max_size: Resize longest side to this before model transforms (default: 512)
51
+ crop_size: Center crop size for model input (default: 448)
52
+ dtype: Output tensor dtype (default: torch.bfloat16)
53
+ device: Output tensor device (default: "cpu")
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ raw_preprocess: bool = True,
59
+ width_param: float = 4.0,
60
+ max_size: int = 512,
61
+ crop_size: int = 448,
62
+ dtype: torch.dtype = torch.bfloat16,
63
+ device: str = "cpu",
64
+ ):
65
+ self.raw_preprocess = raw_preprocess
66
+ self.width_param = width_param
67
+ self.max_size = max_size
68
+ self.dtype = dtype
69
+ self.device = device
70
+
71
+ self.model_transform = transforms.Compose([
72
+ transforms.Resize(max_size),
73
+ transforms.CenterCrop(crop_size),
74
+ transforms.ToTensor(),
75
+ _ExpandChannels(),
76
+ ])
77
+
78
+ def __call__(self, image: Union[str, np.ndarray, Image.Image]) -> torch.Tensor:
79
+ """Process a single image into a model-ready tensor.
80
+
81
+ Args:
82
+ image: File path (str), numpy array, or PIL Image
83
+
84
+ Returns:
85
+ Tensor of shape [1, 3, 448, 448]
86
+ """
87
+ if isinstance(image, str):
88
+ if self.raw_preprocess:
89
+ img_np = preprocess_image(image, self.width_param, self.max_size)
90
+ pil_img = Image.fromarray(img_np)
91
+ else:
92
+ pil_img = Image.open(image).convert("L")
93
+ elif isinstance(image, np.ndarray):
94
+ if self.raw_preprocess:
95
+ from preprocess import apply_windowing, remove_black_padding, resize_preserve_aspect_ratio
96
+ if len(image.shape) == 3:
97
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
98
+ image = apply_windowing(image, self.width_param)
99
+ image = (image * 255.0).astype(np.uint8)
100
+ image = remove_black_padding(image)
101
+ image = resize_preserve_aspect_ratio(image, self.max_size)
102
+ pil_img = Image.fromarray(image)
103
+ elif isinstance(image, Image.Image):
104
+ pil_img = image.convert("L")
105
+ else:
106
+ raise TypeError(f"Expected str, np.ndarray, or PIL.Image, got {type(image)}")
107
+
108
+ tensor = self.model_transform(pil_img).unsqueeze(0)
109
+ return tensor.to(dtype=self.dtype, device=self.device)
110
+
111
+
112
+ class _ExpandChannels:
113
+ """Expand single-channel tensor to 3 channels."""
114
+ def __call__(self, x: torch.Tensor) -> torch.Tensor:
115
+ if x.shape[0] == 1:
116
+ return x.repeat(3, 1, 1)
117
+ return x