| """ |
| TILA — Inference Example |
| |
| Usage: |
| # From raw images (full preprocessing applied automatically): |
| python inference.py --current_image raw_current.png --previous_image raw_previous.png |
| |
| # From already-preprocessed images: |
| python inference.py --current_image prep_current.png --previous_image prep_previous.png --no-preprocess |
| """ |
|
|
| import argparse |
| import torch |
| from model import TILAModel |
| from processor import TILAProcessor |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="TILA Inference") |
| parser.add_argument("--checkpoint", type=str, default="model.safetensors") |
| parser.add_argument("--current_image", type=str, required=True) |
| parser.add_argument("--previous_image", type=str, required=True) |
| parser.add_argument("--no-preprocess", action="store_true", |
| help="Skip medical preprocessing (use if images are already preprocessed)") |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") |
| args = parser.parse_args() |
|
|
| device = args.device |
| dtype = torch.bfloat16 if "cuda" in device else torch.float32 |
|
|
| |
| model = TILAModel.from_pretrained(args.checkpoint, device=device) |
| model = model.to(dtype=dtype) |
|
|
| |
| processor = TILAProcessor( |
| raw_preprocess=not args.no_preprocess, |
| dtype=dtype, |
| device=device, |
| ) |
| current = processor(args.current_image) |
| previous = processor(args.previous_image) |
|
|
| |
| embeddings = model.get_embeddings(current, previous) |
| print(f"Embedding shape: {embeddings.shape}") |
| print(f"Embedding (first 8 dims): {embeddings[0, :8].float().tolist()}") |
|
|
| |
| for mode in ["default", "bestf1", "spec95"]: |
| result = model.get_interval_change_prediction(current, previous, mode=mode) |
| prob = result["probabilities"].item() |
| pred = result["predictions"].item() |
| thresh = result["threshold"] |
| label = "CHANGE" if pred == 1 else "NO CHANGE" |
| print(f"[{mode}] threshold={thresh:.4f}, prob={prob:.4f} -> {label}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|