""" TILA — Inference Example Usage: # From raw images (full preprocessing applied automatically): python inference.py --current_image raw_current.png --previous_image raw_previous.png # From already-preprocessed images: python inference.py --current_image prep_current.png --previous_image prep_previous.png --no-preprocess """ import argparse import torch from model import TILAModel from processor import TILAProcessor def main(): parser = argparse.ArgumentParser(description="TILA Inference") parser.add_argument("--checkpoint", type=str, default="model.safetensors") parser.add_argument("--current_image", type=str, required=True) parser.add_argument("--previous_image", type=str, required=True) parser.add_argument("--no-preprocess", action="store_true", help="Skip medical preprocessing (use if images are already preprocessed)") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_args() device = args.device dtype = torch.bfloat16 if "cuda" in device else torch.float32 # Load model model = TILAModel.from_pretrained(args.checkpoint, device=device) model = model.to(dtype=dtype) # Load and process images (preprocessing is built into the processor) processor = TILAProcessor( raw_preprocess=not args.no_preprocess, dtype=dtype, device=device, ) current = processor(args.current_image) previous = processor(args.previous_image) # 1. Get embeddings (128-dim, L2-normalized) embeddings = model.get_embeddings(current, previous) print(f"Embedding shape: {embeddings.shape}") print(f"Embedding (first 8 dims): {embeddings[0, :8].float().tolist()}") # 2. Get interval change prediction (3 modes available) for mode in ["default", "bestf1", "spec95"]: result = model.get_interval_change_prediction(current, previous, mode=mode) prob = result["probabilities"].item() pred = result["predictions"].item() thresh = result["threshold"] label = "CHANGE" if pred == 1 else "NO CHANGE" print(f"[{mode}] threshold={thresh:.4f}, prob={prob:.4f} -> {label}") if __name__ == "__main__": main()