TILA / inference.py
lukeingawesome's picture
Upload folder using huggingface_hub
f46fb4d verified
"""
TILA — Inference Example
Usage:
# From raw images (full preprocessing applied automatically):
python inference.py --current_image raw_current.png --previous_image raw_previous.png
# From already-preprocessed images:
python inference.py --current_image prep_current.png --previous_image prep_previous.png --no-preprocess
"""
import argparse
import torch
from model import TILAModel
from processor import TILAProcessor
def main():
parser = argparse.ArgumentParser(description="TILA Inference")
parser.add_argument("--checkpoint", type=str, default="model.safetensors")
parser.add_argument("--current_image", type=str, required=True)
parser.add_argument("--previous_image", type=str, required=True)
parser.add_argument("--no-preprocess", action="store_true",
help="Skip medical preprocessing (use if images are already preprocessed)")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
device = args.device
dtype = torch.bfloat16 if "cuda" in device else torch.float32
# Load model
model = TILAModel.from_pretrained(args.checkpoint, device=device)
model = model.to(dtype=dtype)
# Load and process images (preprocessing is built into the processor)
processor = TILAProcessor(
raw_preprocess=not args.no_preprocess,
dtype=dtype,
device=device,
)
current = processor(args.current_image)
previous = processor(args.previous_image)
# 1. Get embeddings (128-dim, L2-normalized)
embeddings = model.get_embeddings(current, previous)
print(f"Embedding shape: {embeddings.shape}")
print(f"Embedding (first 8 dims): {embeddings[0, :8].float().tolist()}")
# 2. Get interval change prediction (3 modes available)
for mode in ["default", "bestf1", "spec95"]:
result = model.get_interval_change_prediction(current, previous, mode=mode)
prob = result["probabilities"].item()
pred = result["predictions"].item()
thresh = result["threshold"]
label = "CHANGE" if pred == 1 else "NO CHANGE"
print(f"[{mode}] threshold={thresh:.4f}, prob={prob:.4f} -> {label}")
if __name__ == "__main__":
main()