File size: 4,731 Bytes
6794be9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Convenience loader for the STELLAR models hosted on the Hugging Face Hub.

Requires the STELLAR code (https://github.com/microsoft/STELLAR) to be installed
and importable, plus ``huggingface_hub`` and ``safetensors``::

    pip install huggingface_hub safetensors
    # and install the STELLAR repo so that `src` is importable

The released checkpoints contain every trained STELLAR module, so a single file
serves three purposes via the ``purpose`` argument:

* ``"encode"``      – feature extraction (encoder only). No extra files needed.
* ``"reconstruct"`` – image reconstruction (encoder + decoder). Needs the
  MaskGIT-VQGAN tokenizer (``vq_model=<path>``).
* ``"pretrain"``    – continued pretraining (all modules). Needs ``vq_model``.

Usage::

    from load_stellar import load_stellar, list_models

    print(list_models())                       # available checkpoints
    model = load_stellar("stellar-b16")         # purpose="encode" (default)

    import torch
    img = torch.rand(1, 3, 224, 224)            # values in [0, 1]
    with torch.no_grad():
        out = model.encode(img)
    out["sparse"]   # (1, K, D) sparse concept tokens
    out["dense"]    # (1, P, D) dense patch features
    out["spatial"]  # (1, P, K) per-token spatial maps
"""
import json

import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

from src.models.stellar_model import STELLARModel

REPO_ID = "microsoft/STELLAR"
PURPOSES = ("encode", "reconstruct", "pretrain")


def _load_index(repo_id=REPO_ID):
    # Downloading config.json is what registers a download on the Hub, so always
    # fetch it through hf_hub_download (do not read a local copy).
    return json.load(open(hf_hub_download(repo_id, "config.json")))


def list_models(repo_id=REPO_ID):
    """Return the list of available model names."""
    return list(_load_index(repo_id)["models"].keys())


def load_stellar(model="stellar-b16", purpose="encode", vq_model=None,
                 repo_id=REPO_ID, device="cpu"):
    """Download a STELLAR checkpoint from the Hub and build a ready-to-use model.

    Args:
        model: one of ``list_models()`` (e.g. ``"stellar-b16"``).
        purpose: ``"encode"`` (features), ``"reconstruct"`` (decoder) or
            ``"pretrain"`` (all modules).
        vq_model: path to the MaskGIT-VQGAN tokenizer
            (``maskgit-vqgan-imagenet-f16-256.bin``). Required for
            ``reconstruct``/``pretrain`` when the model decodes into VQ tokens.
        repo_id: Hugging Face repo id (default ``microsoft/STELLAR``).
        device: device to move the model to.
    """
    if purpose not in PURPOSES:
        raise ValueError(f"purpose must be one of {PURPOSES}, got {purpose!r}")

    index = _load_index(repo_id)
    if model not in index["models"]:
        raise ValueError(f"Unknown model '{model}'. Available: {list(index['models'])}")
    cfg = index["models"][model]

    weights_path = hf_hub_download(repo_id, cfg["weights"])
    state_dict = load_file(weights_path)

    needs_vq = cfg.get("recon_type") == "vq"
    if purpose in ("reconstruct", "pretrain") and needs_vq and vq_model is None:
        raise ValueError(
            f"'{model}' reconstructs into VQ tokens; pass vq_model=<path to "
            "maskgit-vqgan-imagenet-f16-256.bin> (download it from "
            "https://huggingface.co/fun-research/TiTok).")

    common = dict(
        num_sparse_tokens=cfg["num_sparse_tokens"],
        num_decoder_layers=cfg["num_decoder_layers"],
        spatial_temp=cfg["spatial_temp"],
        do_cls=cfg["do_cls"],
        vit_pretrained=cfg["backbone"],
    )
    if purpose == "encode":
        net = STELLARModel(**common, do_recon=False, do_clustering=False, vq_model=None)
    elif purpose == "reconstruct":
        net = STELLARModel(**common, do_recon=True, do_clustering=False, vq_model=vq_model)
    else:  # pretrain
        net = STELLARModel(**common, do_recon=True, do_clustering=True, vq_model=vq_model)

    missing, _ = net.load_state_dict(state_dict, strict=False)
    # `tokenizer.*` come from `vq_model`, never from our weights. Anything else
    # the built model still needs is a genuine error.
    essential_missing = [k for k in missing if not k.startswith("tokenizer.")]
    if essential_missing:
        raise RuntimeError(f"Missing essential weights: {essential_missing[:8]}")

    net.to(device)
    if purpose == "encode":
        net.eval()
    return net


if __name__ == "__main__":
    print("Available models:", list_models())
    m = load_stellar("stellar-b16")
    with torch.no_grad():
        out = m.encode(torch.rand(1, 3, 224, 224))
    print({k: tuple(v.shape) for k, v in out.items() if torch.is_tensor(v)})