{ "cells": [ { "cell_type": "markdown", "id": "a0703884", "metadata": {}, "source": [ "# STELLAR image reconstruction\n", "\n", "This notebook runs the **full reconstruction pipeline** end to end:\n", "\n", "```\n", "image → encoder → sparse + spatial tokens → low-rank dense map\n", " → ViT decoder → VQGAN decoder → RGB pixels\n", "```\n", "\n", "The released STELLAR checkpoints predict [MaskGIT-VQGAN](https://huggingface.co/fun-research/TiTok)\n", "tokens, so you also need the VQGAN tokenizer to decode tokens back to pixels.\n", "\n", "**Requirements:** `torch`, `safetensors`, `huggingface_hub`, `pillow`, `matplotlib`, and the\n", "STELLAR model code on your `PYTHONPATH` (so `from src.models.stellar_model import STELLARModel` works)." ] }, { "cell_type": "code", "execution_count": null, "id": "3f852740", "metadata": {}, "outputs": [], "source": [ "import os, sys\n", "# load_stellar.py and the `src/` model code live in the repo root (one level up).\n", "sys.path.insert(0, os.path.abspath(\"..\"))\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from huggingface_hub import hf_hub_download\n", "from load_stellar import load_stellar, list_models\n", "\n", "print(\"Available models:\", list_models())\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "markdown", "id": "fa4219aa", "metadata": {}, "source": [ "## 1. Download the MaskGIT-VQGAN tokenizer\n", "\n", "STELLAR reuses the MaskGIT-VQGAN tokenizer from TiTok. We only need its decoder to turn\n", "predicted tokens back into pixels." ] }, { "cell_type": "code", "execution_count": null, "id": "8a6807dc", "metadata": {}, "outputs": [], "source": [ "vq_path = hf_hub_download(\n", " repo_id=\"fun-research/TiTok\",\n", " filename=\"maskgit-vqgan-imagenet-f16-256.bin\",\n", ")\n", "print(\"VQGAN tokenizer:\", vq_path)" ] }, { "cell_type": "markdown", "id": "d3ebc639", "metadata": {}, "source": [ "## 2. Load STELLAR for reconstruction\n", "\n", "`purpose=\"reconstruct\"` builds the encoder **and** decoder and attaches the VQGAN tokenizer." ] }, { "cell_type": "code", "execution_count": null, "id": "8b2e344f", "metadata": {}, "outputs": [], "source": [ "model = load_stellar(\"stellar-b16\", purpose=\"reconstruct\", vq_model=vq_path, device=device)\n", "model.eval();" ] }, { "cell_type": "markdown", "id": "0c143298", "metadata": {}, "source": [ "## 3. Load and preprocess an image\n", "\n", "Pass **raw** images in `[0, 1]`, resized to `224×224` with bicubic interpolation.\n", "ImageNet mean/std normalization is applied **inside** the model — do not normalize yourself." ] }, { "cell_type": "code", "execution_count": null, "id": "9deaef69", "metadata": {}, "outputs": [], "source": [ "from PIL import Image\n", "import numpy as np\n", "\n", "# Replace with your own image path.\n", "img_path = hf_hub_download(repo_id=\"huggingface/documentation-images\",\n", " filename=\"beignets-task-guide.png\", repo_type=\"dataset\")\n", "pil = Image.open(img_path).convert(\"RGB\")\n", "\n", "image = torch.from_numpy(np.asarray(pil)).permute(2, 0, 1).float() / 255.0 # (3, H, W) in [0,1]\n", "image = F.interpolate(image[None], size=(224, 224), mode=\"bicubic\", align_corners=False).clamp(0, 1)\n", "image = image.to(device)\n", "print(\"input:\", tuple(image.shape), float(image.min()), float(image.max()))" ] }, { "cell_type": "markdown", "id": "f845698f", "metadata": {}, "source": [ "## 4. Encode, then reconstruct\n", "\n", "First `encode` the image into its factorized features (sparse concept tokens + spatial\n", "maps). `reconstruct` is the decoder half of STELLAR: it takes those factorized features\n", "\n", "and decodes them all the way to pixels. It returns the RGB pixels (`reconstruction`), thepredicted VQGAN token ids (`tokens`), and the raw codebook logits (`logits`)." ] }, { "cell_type": "code", "execution_count": null, "id": "6e986c51", "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " features = model.encode(image) # factorized features\n", " out = model.reconstruct(features) # or model.reconstruct(features[\"sparse\"], features[\"spatial\"])\n", "\n", "print(\"factorized features:\")\n", "\n", "for k in (\"sparse\", \"spatial\"): print(f\" {k:14s}\", tuple(v.shape))\n", "\n", " print(f\" {k:8s}\", tuple(features[k].shape))for k, v in out.items():\n", "print(\"reconstruction outputs:\")" ] }, { "cell_type": "markdown", "id": "e74f4664", "metadata": {}, "source": [ "## 5. Display original vs. reconstruction" ] }, { "cell_type": "code", "execution_count": null, "id": "d552f2c2", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "orig = image[0].permute(1, 2, 0).cpu().numpy()\n", "recon = out[\"reconstruction\"][0].permute(1, 2, 0).cpu().numpy()\n", "\n", "fig, ax = plt.subplots(1, 2, figsize=(8, 4))\n", "ax[0].imshow(orig); ax[0].set_title(\"original\"); ax[0].axis(\"off\")\n", "ax[1].imshow(recon); ax[1].set_title(\"STELLAR reconstruction\"); ax[1].axis(\"off\")\n", "plt.tight_layout(); plt.show()" ] }, { "cell_type": "markdown", "id": "10c7f3d1", "metadata": {}, "source": [ "## Notes\n", "\n", "- Output resolution: **224×224** for the `/16` models (`stellar-b16`, `stellar-l16`, and the\n", " `b8`/`b24` ablations) and **256×256** for the `/14` `stellar-h16` model.\n", "- For larger / higher-fidelity reconstructions, swap in `stellar-l16` or `stellar-h16`.\n", "- `reconstruct` takes the factorized features (not the image), so you can edit / analyze the\n", " sparse tokens before decoding. It is gradient-transparent — wrap it in `torch.no_grad()`\n", " for inference, or keep gradients enabled (e.g. for feature-inversion experiments)." ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }