"""Convenience loader for the STELLAR models hosted on the Hugging Face Hub. Requires the STELLAR code (https://github.com/microsoft/STELLAR) to be installed and importable, plus ``huggingface_hub`` and ``safetensors``:: pip install huggingface_hub safetensors # and install the STELLAR repo so that `src` is importable The released checkpoints contain every trained STELLAR module, so a single file serves three purposes via the ``purpose`` argument: * ``"encode"`` – feature extraction (encoder only). No extra files needed. * ``"reconstruct"`` – image reconstruction (encoder + decoder). Needs the MaskGIT-VQGAN tokenizer (``vq_model=``). * ``"pretrain"`` – continued pretraining (all modules). Needs ``vq_model``. Usage:: from load_stellar import load_stellar, list_models print(list_models()) # available checkpoints model = load_stellar("stellar-b16") # purpose="encode" (default) import torch img = torch.rand(1, 3, 224, 224) # values in [0, 1] with torch.no_grad(): out = model.encode(img) out["sparse"] # (1, K, D) sparse concept tokens out["dense"] # (1, P, D) dense patch features out["spatial"] # (1, P, K) per-token spatial maps """ import json import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file from src.models.stellar_model import STELLARModel REPO_ID = "microsoft/STELLAR" PURPOSES = ("encode", "reconstruct", "pretrain") def _load_index(repo_id=REPO_ID): # Downloading config.json is what registers a download on the Hub, so always # fetch it through hf_hub_download (do not read a local copy). return json.load(open(hf_hub_download(repo_id, "config.json"))) def list_models(repo_id=REPO_ID): """Return the list of available model names.""" return list(_load_index(repo_id)["models"].keys()) def load_stellar(model="stellar-b16", purpose="encode", vq_model=None, repo_id=REPO_ID, device="cpu"): """Download a STELLAR checkpoint from the Hub and build a ready-to-use model. Args: model: one of ``list_models()`` (e.g. ``"stellar-b16"``). purpose: ``"encode"`` (features), ``"reconstruct"`` (decoder) or ``"pretrain"`` (all modules). vq_model: path to the MaskGIT-VQGAN tokenizer (``maskgit-vqgan-imagenet-f16-256.bin``). Required for ``reconstruct``/``pretrain`` when the model decodes into VQ tokens. repo_id: Hugging Face repo id (default ``microsoft/STELLAR``). device: device to move the model to. """ if purpose not in PURPOSES: raise ValueError(f"purpose must be one of {PURPOSES}, got {purpose!r}") index = _load_index(repo_id) if model not in index["models"]: raise ValueError(f"Unknown model '{model}'. Available: {list(index['models'])}") cfg = index["models"][model] weights_path = hf_hub_download(repo_id, cfg["weights"]) state_dict = load_file(weights_path) needs_vq = cfg.get("recon_type") == "vq" if purpose in ("reconstruct", "pretrain") and needs_vq and vq_model is None: raise ValueError( f"'{model}' reconstructs into VQ tokens; pass vq_model= (download it from " "https://huggingface.co/fun-research/TiTok).") common = dict( num_sparse_tokens=cfg["num_sparse_tokens"], num_decoder_layers=cfg["num_decoder_layers"], spatial_temp=cfg["spatial_temp"], do_cls=cfg["do_cls"], vit_pretrained=cfg["backbone"], ) if purpose == "encode": net = STELLARModel(**common, do_recon=False, do_clustering=False, vq_model=None) elif purpose == "reconstruct": net = STELLARModel(**common, do_recon=True, do_clustering=False, vq_model=vq_model) else: # pretrain net = STELLARModel(**common, do_recon=True, do_clustering=True, vq_model=vq_model) missing, _ = net.load_state_dict(state_dict, strict=False) # `tokenizer.*` come from `vq_model`, never from our weights. Anything else # the built model still needs is a genuine error. essential_missing = [k for k in missing if not k.startswith("tokenizer.")] if essential_missing: raise RuntimeError(f"Missing essential weights: {essential_missing[:8]}") net.to(device) if purpose == "encode": net.eval() return net if __name__ == "__main__": print("Available models:", list_models()) m = load_stellar("stellar-b16") with torch.no_grad(): out = m.encode(torch.rand(1, 3, 224, 224)) print({k: tuple(v.shape) for k, v in out.items() if torch.is_tensor(v)})