| """Convenience loader for the STELLAR models hosted on the Hugging Face Hub. |
| |
| Requires the STELLAR code (https://github.com/microsoft/STELLAR) to be installed |
| and importable, plus ``huggingface_hub`` and ``safetensors``:: |
| |
| pip install huggingface_hub safetensors |
| # and install the STELLAR repo so that `src` is importable |
| |
| The released checkpoints contain every trained STELLAR module, so a single file |
| serves three purposes via the ``purpose`` argument: |
| |
| * ``"encode"`` – feature extraction (encoder only). No extra files needed. |
| * ``"reconstruct"`` – image reconstruction (encoder + decoder). Needs the |
| MaskGIT-VQGAN tokenizer (``vq_model=<path>``). |
| * ``"pretrain"`` – continued pretraining (all modules). Needs ``vq_model``. |
| |
| Usage:: |
| |
| from load_stellar import load_stellar, list_models |
| |
| print(list_models()) # available checkpoints |
| model = load_stellar("stellar-b16") # purpose="encode" (default) |
| |
| import torch |
| img = torch.rand(1, 3, 224, 224) # values in [0, 1] |
| with torch.no_grad(): |
| out = model.encode(img) |
| out["sparse"] # (1, K, D) sparse concept tokens |
| out["dense"] # (1, P, D) dense patch features |
| out["spatial"] # (1, P, K) per-token spatial maps |
| """ |
| import json |
|
|
| import torch |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
|
|
| from src.models.stellar_model import STELLARModel |
|
|
| REPO_ID = "microsoft/STELLAR" |
| PURPOSES = ("encode", "reconstruct", "pretrain") |
|
|
|
|
| def _load_index(repo_id=REPO_ID): |
| |
| |
| return json.load(open(hf_hub_download(repo_id, "config.json"))) |
|
|
|
|
| def list_models(repo_id=REPO_ID): |
| """Return the list of available model names.""" |
| return list(_load_index(repo_id)["models"].keys()) |
|
|
|
|
| def load_stellar(model="stellar-b16", purpose="encode", vq_model=None, |
| repo_id=REPO_ID, device="cpu"): |
| """Download a STELLAR checkpoint from the Hub and build a ready-to-use model. |
| |
| Args: |
| model: one of ``list_models()`` (e.g. ``"stellar-b16"``). |
| purpose: ``"encode"`` (features), ``"reconstruct"`` (decoder) or |
| ``"pretrain"`` (all modules). |
| vq_model: path to the MaskGIT-VQGAN tokenizer |
| (``maskgit-vqgan-imagenet-f16-256.bin``). Required for |
| ``reconstruct``/``pretrain`` when the model decodes into VQ tokens. |
| repo_id: Hugging Face repo id (default ``microsoft/STELLAR``). |
| device: device to move the model to. |
| """ |
| if purpose not in PURPOSES: |
| raise ValueError(f"purpose must be one of {PURPOSES}, got {purpose!r}") |
|
|
| index = _load_index(repo_id) |
| if model not in index["models"]: |
| raise ValueError(f"Unknown model '{model}'. Available: {list(index['models'])}") |
| cfg = index["models"][model] |
|
|
| weights_path = hf_hub_download(repo_id, cfg["weights"]) |
| state_dict = load_file(weights_path) |
|
|
| needs_vq = cfg.get("recon_type") == "vq" |
| if purpose in ("reconstruct", "pretrain") and needs_vq and vq_model is None: |
| raise ValueError( |
| f"'{model}' reconstructs into VQ tokens; pass vq_model=<path to " |
| "maskgit-vqgan-imagenet-f16-256.bin> (download it from " |
| "https://huggingface.co/fun-research/TiTok).") |
|
|
| common = dict( |
| num_sparse_tokens=cfg["num_sparse_tokens"], |
| num_decoder_layers=cfg["num_decoder_layers"], |
| spatial_temp=cfg["spatial_temp"], |
| do_cls=cfg["do_cls"], |
| vit_pretrained=cfg["backbone"], |
| ) |
| if purpose == "encode": |
| net = STELLARModel(**common, do_recon=False, do_clustering=False, vq_model=None) |
| elif purpose == "reconstruct": |
| net = STELLARModel(**common, do_recon=True, do_clustering=False, vq_model=vq_model) |
| else: |
| net = STELLARModel(**common, do_recon=True, do_clustering=True, vq_model=vq_model) |
|
|
| missing, _ = net.load_state_dict(state_dict, strict=False) |
| |
| |
| essential_missing = [k for k in missing if not k.startswith("tokenizer.")] |
| if essential_missing: |
| raise RuntimeError(f"Missing essential weights: {essential_missing[:8]}") |
|
|
| net.to(device) |
| if purpose == "encode": |
| net.eval() |
| return net |
|
|
|
|
| if __name__ == "__main__": |
| print("Available models:", list_models()) |
| m = load_stellar("stellar-b16") |
| with torch.no_grad(): |
| out = m.encode(torch.rand(1, 3, 224, 224)) |
| print({k: tuple(v.shape) for k, v in out.items() if torch.is_tensor(v)}) |
|
|