File size: 2,765 Bytes
41a6ec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Example: load a shipped StradaViT checkpoint and extract embeddings.

This mirrors the embedding policy:
  - Use the ViT encoder's `last_hidden_state`
  - Mean-pool patch tokens (drop CLS): `hs[:, 1:, :].mean(dim=1)`

Expected checkpoint layout (from our training scripts):
  <RUN_ROOT>/checkpoints/
    - config.json
    - pytorch_model.bin (or model.safetensors)
    - preprocessor_config.json
    - (optional) tokenizer/feature extractor extras

Usage:
  python3 examples/use_shipped_stradavit_model.py \\
    --checkpoint /path/to/run/checkpoints \\
    --image /path/to/image.png
"""

from __future__ import annotations

import argparse
import os
from typing import Any

import torch
import StradaViTModel


def load_model_and_processor(checkpoint_dir: str):
    """
    Loads a StradaViT checkpoint and the matching HF image processor.
    """
    from transformers import ViTImageProcessor, ViTMAEConfig

    config = ViTMAEConfig.from_pretrained(checkpoint_dir)
    processor = ViTImageProcessor.from_pretrained(checkpoint_dir)
    model = StradaViTModel.from_pretrained(checkpoint_dir)
    model.eval()
    return model, processor, config


def load_image(path: str):
    from PIL import Image

    img = Image.open(path).convert("RGB")
    return img


def main(argv: list[str] | None = None) -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--checkpoint", required=True, help="Path to <run_root>/checkpoints (contains config + weights)")
    ap.add_argument("--image", required=True, help="Path to an image file (png/jpg/...)")
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    args = ap.parse_args(argv)

    ckpt = os.path.abspath(args.checkpoint)
    if not os.path.isdir(ckpt):
        raise FileNotFoundError(f"--checkpoint must be a directory: {ckpt}")

    device = torch.device(args.device)

    model, processor, config = load_model_and_processor(ckpt)
    model.to(device)

    img = load_image(args.image)
    inputs: dict[str, Any] = processor(images=img, return_tensors="pt")
    pixel_values = inputs["pixel_values"].to(device)

    with torch.inference_mode():
        out = model(pixel_values=pixel_values)
        emb = out.embedding

    print(
        f"Loaded checkpoint: {ckpt}\n"
        f"  model_type={getattr(config, 'model_type', None)} use_dino_encoder={bool(getattr(config, 'use_dino_encoder', False))} "
        f"n_registers={int(getattr(config, 'n_registers', 0) or 0)}\n"
        f"  image_size={int(getattr(config, 'image_size', 0) or 0)} patch_size={int(getattr(config, 'patch_size', 0) or 0)}\n"
        f"Embedding shape: {tuple(emb.shape)} dtype={emb.dtype} device={emb.device}"
    )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())