File size: 1,610 Bytes
447fe2d
85617a2
 
2a061bd
 
 
 
 
 
 
 
 
 
 
 
c887cca
4718e54
447fe2d
 
 
 
 
 
 
85617a2
447fe2d
 
 
 
2a061bd
 
85617a2
447fe2d
85617a2
 
 
447fe2d
85617a2
 
4718e54
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import argparse
import mithridatium.evaluator as evaluator
import mithridatium.loader as loader
from mithridatium.data import build_dataloader
from mithridatium.io import load_preprocess_config

def test_build_dataloader_one_batch():
    # expects models/resnet18_bd.json from Issue 1
    pp = load_preprocess_config("models/resnet18_bd.pth")
    loader = build_dataloader("cifar10", "test", pp, batch_size=8)
    x, y = next(iter(loader))
    assert x.ndim == 4 and x.shape[1] == 3   # NCHW RGB
    assert y.ndim == 1
    # optional: verify spatial dims match config
    assert x.shape[-2:] == pp.input_size
        
def main():
    parser = argparse.ArgumentParser()
    '''
    .venv/bin/python -m scripts.check_evaluator --model models/resnet18_poison.pth
    '''
    parser.add_argument("--model", type=str, default="models/resnet18_bd.pth", help="Path to model checkpoint")
    parser.add_argument("--batch_size", type=int, default=256, help="Batch size for evaluation")
    args = parser.parse_args()

    # Load model from checkpoint
    model, feature_module = loader.load_resnet18(args.model)

    # Prepare CIFAR-10 test set
    pp = load_preprocess_config(args.model)
    test_loader = build_dataloader("cifar10", "test", pp, batch_size=args.batch_size)

    # Extract embeddings
    embs, labels = evaluator.extract_embeddings(model, test_loader, feature_module)
    print(f"Embeddings shape: {embs.shape}")

    # Evaluate accuracy
    loss, accy = evaluator.evaluate(model, test_loader)
    print(f"Test accuracy: {accy*100:.2f}% | Test loss: {loss:.4f}")

if __name__ == "__main__":
    main()