File size: 2,765 Bytes
41a6ec2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """
Example: load a shipped StradaViT checkpoint and extract embeddings.
This mirrors the embedding policy:
- Use the ViT encoder's `last_hidden_state`
- Mean-pool patch tokens (drop CLS): `hs[:, 1:, :].mean(dim=1)`
Expected checkpoint layout (from our training scripts):
<RUN_ROOT>/checkpoints/
- config.json
- pytorch_model.bin (or model.safetensors)
- preprocessor_config.json
- (optional) tokenizer/feature extractor extras
Usage:
python3 examples/use_shipped_stradavit_model.py \\
--checkpoint /path/to/run/checkpoints \\
--image /path/to/image.png
"""
from __future__ import annotations
import argparse
import os
from typing import Any
import torch
import StradaViTModel
def load_model_and_processor(checkpoint_dir: str):
"""
Loads a StradaViT checkpoint and the matching HF image processor.
"""
from transformers import ViTImageProcessor, ViTMAEConfig
config = ViTMAEConfig.from_pretrained(checkpoint_dir)
processor = ViTImageProcessor.from_pretrained(checkpoint_dir)
model = StradaViTModel.from_pretrained(checkpoint_dir)
model.eval()
return model, processor, config
def load_image(path: str):
from PIL import Image
img = Image.open(path).convert("RGB")
return img
def main(argv: list[str] | None = None) -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", required=True, help="Path to <run_root>/checkpoints (contains config + weights)")
ap.add_argument("--image", required=True, help="Path to an image file (png/jpg/...)")
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = ap.parse_args(argv)
ckpt = os.path.abspath(args.checkpoint)
if not os.path.isdir(ckpt):
raise FileNotFoundError(f"--checkpoint must be a directory: {ckpt}")
device = torch.device(args.device)
model, processor, config = load_model_and_processor(ckpt)
model.to(device)
img = load_image(args.image)
inputs: dict[str, Any] = processor(images=img, return_tensors="pt")
pixel_values = inputs["pixel_values"].to(device)
with torch.inference_mode():
out = model(pixel_values=pixel_values)
emb = out.embedding
print(
f"Loaded checkpoint: {ckpt}\n"
f" model_type={getattr(config, 'model_type', None)} use_dino_encoder={bool(getattr(config, 'use_dino_encoder', False))} "
f"n_registers={int(getattr(config, 'n_registers', 0) or 0)}\n"
f" image_size={int(getattr(config, 'image_size', 0) or 0)} patch_size={int(getattr(config, 'patch_size', 0) or 0)}\n"
f"Embedding shape: {tuple(emb.shape)} dtype={emb.dtype} device={emb.device}"
)
return 0
if __name__ == "__main__":
raise SystemExit(main())
|