File size: 2,554 Bytes
f46fb4d
 
 
 
 
 
 
 
 
 
 
 
 
a9c6da4
f46fb4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9c6da4
 
 
 
 
 
 
 
 
f46fb4d
 
 
a9c6da4
f46fb4d
a9c6da4
 
 
f46fb4d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
TILA — Inference Example

Usage:
    # From raw images (full preprocessing applied automatically):
    python inference.py --current_image raw_current.png --previous_image raw_previous.png

    # From already-preprocessed images:
    python inference.py --current_image prep_current.png --previous_image prep_previous.png --no-preprocess
"""

import argparse
import torch
import torch.nn.functional
from model import TILAModel
from processor import TILAProcessor


def main():
    parser = argparse.ArgumentParser(description="TILA Inference")
    parser.add_argument("--checkpoint", type=str, default="model.safetensors")
    parser.add_argument("--current_image", type=str, required=True)
    parser.add_argument("--previous_image", type=str, required=True)
    parser.add_argument("--no-preprocess", action="store_true",
                        help="Skip medical preprocessing (use if images are already preprocessed)")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
    args = parser.parse_args()

    device = args.device
    dtype = torch.bfloat16 if "cuda" in device else torch.float32

    # Load model
    model = TILAModel.from_pretrained(args.checkpoint, device=device)
    model = model.to(dtype=dtype)

    # Load and process images (preprocessing is built into the processor)
    processor = TILAProcessor(
        raw_preprocess=not args.no_preprocess,
        dtype=dtype,
        device=device,
    )
    current = processor(args.current_image)
    previous = processor(args.previous_image)

    # Run image encoder once, reuse output for embeddings + all thresholds
    with torch.no_grad():
        model.eval()
        out = model.image_encoder(current, previous)
        embeddings = torch.nn.functional.normalize(out.projected_global_embedding.float(), p=2, dim=1)
        logits = model.change_classifier(out.projected_global_embedding)
        probs = torch.sigmoid(logits.float())

    # 1. Embeddings (128-dim, L2-normalized)
    print(f"Embedding shape: {embeddings.shape}")
    print(f"Embedding (first 8 dims): {embeddings[0, :8].float().tolist()}")

    # 2. Interval change prediction (3 threshold modes, no re-computation)
    for mode in ["default", "bestf1", "spec95"]:
        thresh = model.THRESHOLDS[mode]
        pred = (probs >= thresh).long().item()
        prob = probs.item()
        label = "CHANGE" if pred == 1 else "NO CHANGE"
        print(f"[{mode}] threshold={thresh:.4f}, prob={prob:.4f} -> {label}")


if __name__ == "__main__":
    main()