File size: 8,305 Bytes
6794be9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
---
license: mit
library_name: stellar
pipeline_tag: image-feature-extraction
datasets:
  - imagenet-1k
tags:
  - vision
  - self-supervised-learning
  - representation-learning
  - sparse-tokens
  - vision-transformer
  - image-feature-extraction
---

# STELLAR — Sparse Visual Representations via Spatial–Semantic Factorization

**STELLAR** learns a **unified sparse visual representation** that supports both
**reconstruction** and **semantics** using as few as **16 tokens**. By factorizing
*"what"* (semantics) from *"where"* (spatial layout), each image is encoded as the
low-rank product of a **localization** matrix and a **semantics** matrix.

<p align="center">
  <img src="factorization.svg" alt="Spatial–semantic factorization" width="720">
</p>

- 📄 **Paper:** [arXiv:2602.01905](https://arxiv.org/abs/2602.01905) (ICML 2026)
- 💻 **Code:** [github.com/microsoft/STELLAR](https://github.com/microsoft/STELLAR)

These checkpoints contain the **full set of trained STELLAR modules** (encoder, sparse
tokens, projections, reconstruction decoder, and clustering heads), so a single file
supports feature extraction, image reconstruction, and continued pretraining. All
models are self-supervised on **ImageNet-1K** at 224×224.

## Highlights

- **Sparse & unified** — one small set of tokens serves both high-level semantics
  and pixel-level reconstruction.
- **Factorized latents** — each token captures a concept (*what*) together with a
  spatial map of *where* it appears.
- **Strong on both axes** — STELLAR-H reaches **2.60 FID** (reconstruction) and
  **79.1%** ImageNet linear-probing accuracy with just **16 tokens**.

## Available models

| Model | Backbone | Tokens | Params | Type | File |
| :--- | :--- | :---: | :---: | :--- | :--- |
| `stellar-b16` | ViT-B/16 | 16 | 88M | main | [`stellar-b16.safetensors`](stellar-b16.safetensors) |
| `stellar-l16` | ViT-L/16 | 16 | 307M | main | [`stellar-l16.safetensors`](stellar-l16.safetensors) |
| `stellar-h16` | ViT-H/14 | 16 | 636M | main | [`stellar-h16.safetensors`](stellar-h16.safetensors) |
| `stellar-b8`  | ViT-B/16 | 8  | 88M | ablation | [`stellar-b8.safetensors`](stellar-b8.safetensors) |
| `stellar-b24` | ViT-B/16 | 24 | 88M | ablation | [`stellar-b24.safetensors`](stellar-b24.safetensors) |

The main models (`b16`, `l16`, `h16`) are recommended for downstream use; the 8- and
24-token base models are ablations on the number of sparse tokens.

## Usage

Install the STELLAR code and the Hub helpers:

```bash
pip install huggingface_hub safetensors
git clone https://github.com/microsoft/STELLAR && cd STELLAR
pip install -r requirements.txt
```

### Quick start

From the STELLAR code directory, use the [`load_stellar.py`](load_stellar.py) helper
(it downloads the weights from the Hub for you):

```python
import torch
from load_stellar import load_stellar, list_models

print(list_models())                   # ['stellar-b16', 'stellar-l16', ...]
model = load_stellar("stellar-b16")     # purpose="encode" (default)

# RGB image in [0, 1], resized to 224×224 (ImageNet normalization is applied internally)
image = torch.rand(1, 3, 224, 224)
with torch.no_grad():
    out = model.encode(image)

out["sparse"]    # (1, K, D)   sparse concept tokens   ("what")
out["spatial"]   # (1, P, K)   per-token spatial maps  ("where")
out["dense"]     # (1, P, D)   dense per-patch features
out["cls"]       # (1, 1, D)   global image token
```

### Reconstruction & continued pretraining

The same checkpoint can be loaded for other purposes via the `purpose` argument. Image
reconstruction and continued pretraining use the decoder, which predicts
[MaskGIT-VQGAN](https://huggingface.co/fun-research/TiTok) tokens — pass the tokenizer
path as `vq_model`:

```python
# 1. encode -> factorized features (sparse concept tokens + spatial maps)
model = load_stellar("stellar-b16", purpose="reconstruct", vq_model=VQGAN_PATH)
features = model.encode(image)          # dict: sparse (B,K,D), spatial (B,P,K), ...

# 2. decode the factorized features -> VQGAN decoder -> pixels
out = model.reconstruct(features)       # or model.reconstruct(features["sparse"], features["spatial"])
pixels = out["reconstruction"]          # (B, 3, H, W) RGB in [0, 1]
#                                         224x224 for /16 models, 256x256 for the /14 H model
#   out["tokens"] : (B, P) predicted VQGAN token ids
#   out["logits"] : (B, P, 1024) raw codebook logits

# continued pretraining (all modules, gradients enabled)
model = load_stellar("stellar-b16", purpose="pretrain", vq_model=VQGAN_PATH)
losses = model({"image": image, "labels": labels, ...})["predictions"]
```

`reconstruct` is the **decoder half** of STELLAR: it takes the factorized features and
runs low-rank dense map → ViT decoder → VQGAN decoder to return RGB pixels. See
[`examples/reconstruction.ipynb`](examples/reconstruction.ipynb) for an end-to-end demo
that loads an image and displays the reconstruction.


### What the model returns

| Key | Shape | Description | Typical use |
| :--- | :--- | :--- | :--- |
| `sparse` | `(B, K, D)` | sparse concept tokens | classification, retrieval |
| `spatial` | `(B, P, K)` | spatial map of each token | segmentation, visualization |
| `dense` | `(B, P, D)` | dense per-patch features | segmentation |
| `lowrank` | `(B, P, D)` | reassembled dense map | reconstruction |
| `cls` | `(B, 1, D)` | global representation | classification |

`B` = batch, `K` = number of sparse tokens, `P` = number of patches (196 for /16 at
224², 256 for /14), `D` = embedding dim (768 / 1024 / 1280 for B / L / H).

### Loading the weights manually

```python
import json, torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from src.models.stellar_model import STELLARModel

repo = "microsoft/STELLAR"
cfg = json.load(open(hf_hub_download(repo, "config.json")))["models"]["stellar-b16"]
state = load_file(hf_hub_download(repo, cfg["weights"]))

model = STELLARModel(
    num_sparse_tokens=cfg["num_sparse_tokens"],
    num_decoder_layers=cfg["num_decoder_layers"],
    spatial_temp=cfg["spatial_temp"],
    vit_pretrained=cfg["backbone"],
    do_recon=False, do_clustering=False, vq_model=None,
)
model.load_state_dict(state, strict=False)   # encoder-only build ignores decoder/heads
model.eval()
features = model.encode(torch.rand(1, 3, 224, 224))
```

> **Tip:** download with `huggingface_hub` (as above) rather than `git clone` so that
> downloads are registered on the Hub — `git clone` is not counted in download stats.

## Model details

- **Architecture:** ViT encoder (MAE-initialized) + learned sparse latent queries with
  spatial–semantic factorization.
- **Pretraining data:** ImageNet-1K (self-supervised; labels not used).
- **Input:** RGB images in `[0, 1]`, resized to 224×224 (bicubic). ImageNet mean/std
  normalization is applied **inside** the model — pass raw `[0, 1]` images.
- **Weights:** the complete set of trained STELLAR modules (encoder, sparse tokens,
  projections, reconstruction decoder, and clustering heads), stored in `safetensors`.
  Only the third-party MaskGIT-VQGAN tokenizer is excluded — it is downloaded separately
  (from [TiTok](https://huggingface.co/fun-research/TiTok)) and passed via `vq_model`.
- **Framework:** PyTorch.

## Intended uses & limitations

- **Intended use:** extracting compact sparse/dense visual features for downstream
  recognition, segmentation, retrieval, reconstruction, and analysis.
- **Limitations:** pretrained on ImageNet-1K at 224×224, so features reflect that
  distribution; performance on very different domains (e.g. medical, satellite) may
  require fine-tuning. The models are research artifacts and are not safety-tested for
  production decision-making.

## Citation

```bibtex
@inproceedings{zhao2026stellar,
  title     = {Learning Sparse Visual Representations via Spatial-Semantic Factorization},
  author    = {Zhao, Theodore Zhengde and Kiblawi, Sid and Yang, Jianwei and Usuyama, Naoto and Tan, Reuben and Codella, Noel C and Naumann, Tristan and Poon, Hoifung and Wei, Mu},
  booktitle = {International Conference on Machine Learning (ICML)},
  year      = {2026},
  url       = {https://arxiv.org/abs/2602.01905},
}
```

## License

Released under the [MIT License](LICENSE).