STELLAR / load_stellar.py
tedzh17's picture
initial commit
6794be9
Raw
History Blame Contribute Delete
4.73 kB
"""Convenience loader for the STELLAR models hosted on the Hugging Face Hub.
Requires the STELLAR code (https://github.com/microsoft/STELLAR) to be installed
and importable, plus ``huggingface_hub`` and ``safetensors``::
pip install huggingface_hub safetensors
# and install the STELLAR repo so that `src` is importable
The released checkpoints contain every trained STELLAR module, so a single file
serves three purposes via the ``purpose`` argument:
* ``"encode"`` – feature extraction (encoder only). No extra files needed.
* ``"reconstruct"`` – image reconstruction (encoder + decoder). Needs the
MaskGIT-VQGAN tokenizer (``vq_model=<path>``).
* ``"pretrain"`` – continued pretraining (all modules). Needs ``vq_model``.
Usage::
from load_stellar import load_stellar, list_models
print(list_models()) # available checkpoints
model = load_stellar("stellar-b16") # purpose="encode" (default)
import torch
img = torch.rand(1, 3, 224, 224) # values in [0, 1]
with torch.no_grad():
out = model.encode(img)
out["sparse"] # (1, K, D) sparse concept tokens
out["dense"] # (1, P, D) dense patch features
out["spatial"] # (1, P, K) per-token spatial maps
"""
import json
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from src.models.stellar_model import STELLARModel
REPO_ID = "microsoft/STELLAR"
PURPOSES = ("encode", "reconstruct", "pretrain")
def _load_index(repo_id=REPO_ID):
# Downloading config.json is what registers a download on the Hub, so always
# fetch it through hf_hub_download (do not read a local copy).
return json.load(open(hf_hub_download(repo_id, "config.json")))
def list_models(repo_id=REPO_ID):
"""Return the list of available model names."""
return list(_load_index(repo_id)["models"].keys())
def load_stellar(model="stellar-b16", purpose="encode", vq_model=None,
repo_id=REPO_ID, device="cpu"):
"""Download a STELLAR checkpoint from the Hub and build a ready-to-use model.
Args:
model: one of ``list_models()`` (e.g. ``"stellar-b16"``).
purpose: ``"encode"`` (features), ``"reconstruct"`` (decoder) or
``"pretrain"`` (all modules).
vq_model: path to the MaskGIT-VQGAN tokenizer
(``maskgit-vqgan-imagenet-f16-256.bin``). Required for
``reconstruct``/``pretrain`` when the model decodes into VQ tokens.
repo_id: Hugging Face repo id (default ``microsoft/STELLAR``).
device: device to move the model to.
"""
if purpose not in PURPOSES:
raise ValueError(f"purpose must be one of {PURPOSES}, got {purpose!r}")
index = _load_index(repo_id)
if model not in index["models"]:
raise ValueError(f"Unknown model '{model}'. Available: {list(index['models'])}")
cfg = index["models"][model]
weights_path = hf_hub_download(repo_id, cfg["weights"])
state_dict = load_file(weights_path)
needs_vq = cfg.get("recon_type") == "vq"
if purpose in ("reconstruct", "pretrain") and needs_vq and vq_model is None:
raise ValueError(
f"'{model}' reconstructs into VQ tokens; pass vq_model=<path to "
"maskgit-vqgan-imagenet-f16-256.bin> (download it from "
"https://huggingface.co/fun-research/TiTok).")
common = dict(
num_sparse_tokens=cfg["num_sparse_tokens"],
num_decoder_layers=cfg["num_decoder_layers"],
spatial_temp=cfg["spatial_temp"],
do_cls=cfg["do_cls"],
vit_pretrained=cfg["backbone"],
)
if purpose == "encode":
net = STELLARModel(**common, do_recon=False, do_clustering=False, vq_model=None)
elif purpose == "reconstruct":
net = STELLARModel(**common, do_recon=True, do_clustering=False, vq_model=vq_model)
else: # pretrain
net = STELLARModel(**common, do_recon=True, do_clustering=True, vq_model=vq_model)
missing, _ = net.load_state_dict(state_dict, strict=False)
# `tokenizer.*` come from `vq_model`, never from our weights. Anything else
# the built model still needs is a genuine error.
essential_missing = [k for k in missing if not k.startswith("tokenizer.")]
if essential_missing:
raise RuntimeError(f"Missing essential weights: {essential_missing[:8]}")
net.to(device)
if purpose == "encode":
net.eval()
return net
if __name__ == "__main__":
print("Available models:", list_models())
m = load_stellar("stellar-b16")
with torch.no_grad():
out = m.encode(torch.rand(1, 3, 224, 224))
print({k: tuple(v.shape) for k, v in out.items() if torch.is_tensor(v)})