File size: 2,554 Bytes
f46fb4d a9c6da4 f46fb4d a9c6da4 f46fb4d a9c6da4 f46fb4d a9c6da4 f46fb4d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | """
TILA — Inference Example
Usage:
# From raw images (full preprocessing applied automatically):
python inference.py --current_image raw_current.png --previous_image raw_previous.png
# From already-preprocessed images:
python inference.py --current_image prep_current.png --previous_image prep_previous.png --no-preprocess
"""
import argparse
import torch
import torch.nn.functional
from model import TILAModel
from processor import TILAProcessor
def main():
parser = argparse.ArgumentParser(description="TILA Inference")
parser.add_argument("--checkpoint", type=str, default="model.safetensors")
parser.add_argument("--current_image", type=str, required=True)
parser.add_argument("--previous_image", type=str, required=True)
parser.add_argument("--no-preprocess", action="store_true",
help="Skip medical preprocessing (use if images are already preprocessed)")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
device = args.device
dtype = torch.bfloat16 if "cuda" in device else torch.float32
# Load model
model = TILAModel.from_pretrained(args.checkpoint, device=device)
model = model.to(dtype=dtype)
# Load and process images (preprocessing is built into the processor)
processor = TILAProcessor(
raw_preprocess=not args.no_preprocess,
dtype=dtype,
device=device,
)
current = processor(args.current_image)
previous = processor(args.previous_image)
# Run image encoder once, reuse output for embeddings + all thresholds
with torch.no_grad():
model.eval()
out = model.image_encoder(current, previous)
embeddings = torch.nn.functional.normalize(out.projected_global_embedding.float(), p=2, dim=1)
logits = model.change_classifier(out.projected_global_embedding)
probs = torch.sigmoid(logits.float())
# 1. Embeddings (128-dim, L2-normalized)
print(f"Embedding shape: {embeddings.shape}")
print(f"Embedding (first 8 dims): {embeddings[0, :8].float().tolist()}")
# 2. Interval change prediction (3 threshold modes, no re-computation)
for mode in ["default", "bestf1", "spec95"]:
thresh = model.THRESHOLDS[mode]
pred = (probs >= thresh).long().item()
prob = probs.item()
label = "CHANGE" if pred == 1 else "NO CHANGE"
print(f"[{mode}] threshold={thresh:.4f}, prob={prob:.4f} -> {label}")
if __name__ == "__main__":
main()
|