| | """ |
| | VibeToken-Gen Gradio Demo |
| | Class-conditional ImageNet generation with dynamic resolution support. |
| | """ |
| | import spaces |
| |
|
| | import os |
| | import random |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | import torch |
| |
|
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | torch.set_float32_matmul_precision("high") |
| | torch.set_grad_enabled(False) |
| | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
| | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
| |
|
| | from huggingface_hub import hf_hub_download |
| | from PIL import Image |
| |
|
| | from vibetokengen.generate import generate |
| | from vibetokengen.model import GPT_models |
| | from vibetoken import VibeTokenTokenizer |
| |
|
| | |
| | |
| | |
| |
|
| | HF_REPO = "mpatel57/VibeToken" |
| | USE_XXL = 1 |
| |
|
| | if USE_XXL: |
| | GPT_MODEL_NAME = "GPT-XXL" |
| | GPT_CKPT_FILENAME = "VibeTokenGen-xxl-dynamic-65_750k.pt" |
| | NUM_OUTPUT_LAYER = 4 |
| | EXTRA_LAYERS = "QKV" |
| | else: |
| | GPT_MODEL_NAME = "GPT-B" |
| | GPT_CKPT_FILENAME = "VibeTokenGen-b-fixed65_dynamic_1500k.pt" |
| | NUM_OUTPUT_LAYER = 4 |
| | EXTRA_LAYERS = "QKV" |
| |
|
| | VQ_CKPT_FILENAME = "VibeToken_LL.bin" |
| | CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "vibetoken_ll.yaml") |
| |
|
| | CODEBOOK_SIZE = 32768 |
| | NUM_CODEBOOKS = 8 |
| | LATENT_SIZE = 65 |
| | NUM_CLASSES = 1000 |
| | CLS_TOKEN_NUM = 1 |
| | CLASS_DROPOUT_PROB = 0.1 |
| | CAPPING = 50.0 |
| |
|
| | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| | DTYPE = torch.float32 |
| | COMPILE = 0 |
| |
|
| | |
| | |
| | |
| |
|
| | IMAGENET_CLASSES = { |
| | "Golden Retriever": 207, |
| | "Labrador Retriever": 208, |
| | "German Shepherd": 235, |
| | "Siberian Husky": 250, |
| | "Pembroke Corgi": 263, |
| | "Tabby Cat": 281, |
| | "Persian Cat": 283, |
| | "Siamese Cat": 284, |
| | "Tiger": 292, |
| | "Lion": 291, |
| | "Cheetah": 293, |
| | "Brown Bear": 294, |
| | "Giant Panda": 388, |
| | "Red Fox": 277, |
| | "Arctic Fox": 279, |
| | "Timber Wolf": 269, |
| | "Bald Eagle": 22, |
| | "Macaw": 88, |
| | "Flamingo": 130, |
| | "Peacock": 84, |
| | "Goldfish": 1, |
| | "Great White Shark": 2, |
| | "Jellyfish": 107, |
| | "Monarch Butterfly": 323, |
| | "Ladybug": 301, |
| | "Snail": 113, |
| | "Red Sports Car": 817, |
| | "School Bus": 779, |
| | "Steam Locomotive": 820, |
| | "Sailboat": 914, |
| | "Space Shuttle": 812, |
| | "Castle": 483, |
| | "Church": 497, |
| | "Lighthouse": 437, |
| | "Volcano": 980, |
| | "Lakeside": 975, |
| | "Cliff": 972, |
| | "Coral Reef": 973, |
| | "Valley": 979, |
| | "Seashore": 978, |
| | "Mushroom": 947, |
| | "Broccoli": 937, |
| | "Pizza": 963, |
| | "Ice Cream": 928, |
| | "Cheeseburger": 933, |
| | "Espresso": 967, |
| | "Acoustic Guitar": 402, |
| | "Grand Piano": 579, |
| | "Violin": 889, |
| | "Balloon": 417, |
| | } |
| |
|
| | GENERATOR_RESOLUTION_PRESETS = { |
| | "256 × 256": (256, 256), |
| | "384 × 256": (384, 256), |
| | "256 × 384": (256, 384), |
| | "384 × 384": (384, 384), |
| | "512 × 256": (512, 256), |
| | "256 × 512": (256, 512), |
| | "512 × 512": (512, 512), |
| | } |
| |
|
| | OUTPUT_RESOLUTION_PRESETS = { |
| | "Same as generator": None, |
| | "256 × 256": (256, 256), |
| | "384 × 384": (384, 384), |
| | "512 × 512": (512, 512), |
| | "768 × 768": (768, 768), |
| | "1024 × 1024": (1024, 1024), |
| | "512 × 256 (2:1)": (512, 256), |
| | "256 × 512 (1:2)": (256, 512), |
| | "768 × 512 (3:2)": (768, 512), |
| | "512 × 768 (2:3)": (512, 768), |
| | "1024 × 512 (2:1)": (1024, 512), |
| | "512 × 1024 (1:2)": (512, 1024), |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | vq_model = None |
| | gpt_model = None |
| |
|
| |
|
| | def download_checkpoint(filename: str) -> str: |
| | return hf_hub_download(repo_id=HF_REPO, filename=filename) |
| |
|
| |
|
| | def _make_res_tensors(gen_h: int, gen_w: int, multiplier: int): |
| | """Create normalized resolution tensors for the GPT generator.""" |
| | th = torch.tensor(gen_h / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1) |
| | tw = torch.tensor(gen_w / 1536, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(multiplier, 1) |
| | return th, tw |
| |
|
| |
|
| | def _warmup(model): |
| | """Run a throwaway generation to trigger torch.compile and warm CUDA caches.""" |
| | print("Warming up (first call triggers compilation, may take ~30-60s)...") |
| | dummy_cond = torch.tensor([0], device=DEVICE) |
| | th, tw = _make_res_tensors(256, 256, multiplier=2) |
| | with torch.inference_mode(): |
| | generate( |
| | model, dummy_cond, LATENT_SIZE, NUM_CODEBOOKS, |
| | cfg_scale=4.0, cfg_interval=-1, |
| | target_h=th, target_w=tw, |
| | temperature=1.0, top_k=500, top_p=1.0, sample_logits=True, |
| | ) |
| | if DEVICE == "cuda": |
| | torch.cuda.synchronize() |
| | print("Warmup complete — subsequent generations will be fast.") |
| |
|
| |
|
| | def load_models(): |
| | global vq_model, gpt_model |
| |
|
| | print("Downloading checkpoints (if needed)...") |
| | vq_path = download_checkpoint(VQ_CKPT_FILENAME) |
| | gpt_path = download_checkpoint(GPT_CKPT_FILENAME) |
| |
|
| | print(f"Loading VibeToken tokenizer from {vq_path}...") |
| | vq_model = VibeTokenTokenizer.from_config( |
| | CONFIG_PATH, vq_path, device=DEVICE, dtype=DTYPE, |
| | ) |
| | print("VibeToken tokenizer loaded.") |
| |
|
| | print(f"Loading {GPT_MODEL_NAME} from {gpt_path}...") |
| | gpt_model = GPT_models[GPT_MODEL_NAME]( |
| | vocab_size=CODEBOOK_SIZE, |
| | block_size=LATENT_SIZE, |
| | num_classes=NUM_CLASSES, |
| | cls_token_num=CLS_TOKEN_NUM, |
| | model_type="c2i", |
| | num_codebooks=NUM_CODEBOOKS, |
| | n_output_layer=NUM_OUTPUT_LAYER, |
| | class_dropout_prob=CLASS_DROPOUT_PROB, |
| | extra_layers=EXTRA_LAYERS, |
| | capping=CAPPING, |
| | ).to(device=DEVICE, dtype=DTYPE) |
| |
|
| | checkpoint = torch.load(gpt_path, map_location="cpu", weights_only=False) |
| | if "model" in checkpoint: |
| | weights = checkpoint["model"] |
| | elif "module" in checkpoint: |
| | weights = checkpoint["module"] |
| | elif "state_dict" in checkpoint: |
| | weights = checkpoint["state_dict"] |
| | else: |
| | weights = checkpoint |
| | gpt_model.load_state_dict(weights, strict=True) |
| | gpt_model.eval() |
| | del checkpoint |
| | print(f"{GPT_MODEL_NAME} loaded.") |
| |
|
| | if COMPILE: |
| | print("Compiling GPT model with torch.compile (max-autotune)...") |
| | gpt_model = torch.compile(gpt_model, mode="max-autotune", fullgraph=True) |
| | _warmup(gpt_model) |
| | else: |
| | print("Skipping torch.compile (set VIBETOKEN_NO_COMPILE=0 to enable).") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def auto_decoder_patch_size(h: int, w: int) -> tuple[int, int]: |
| | max_dim = max(h, w) |
| | if max_dim <= 256: |
| | ps = 8 |
| | elif max_dim <= 512: |
| | ps = 16 |
| | else: |
| | ps = 32 |
| | return (ps, ps) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @torch.inference_mode() |
| | @spaces.GPU(duration=90) |
| | def generate_image( |
| | class_name: str, |
| | class_id: int, |
| | gen_resolution_preset: str, |
| | out_resolution_preset: str, |
| | decoder_ps_choice: str, |
| | cfg_scale: float, |
| | temperature: float, |
| | top_k: int, |
| | top_p: float, |
| | seed: int, |
| | randomize_seed: bool, |
| | ): |
| | if vq_model is None or gpt_model is None: |
| | raise gr.Error("Models are still loading. Please wait a moment and try again.") |
| |
|
| | if randomize_seed: |
| | seed = random.randint(0, 2**31 - 1) |
| |
|
| | torch.manual_seed(seed) |
| | np.random.seed(seed) |
| | if DEVICE == "cuda": |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| | if class_name and class_name != "Custom (enter ID below)": |
| | cid = IMAGENET_CLASSES[class_name] |
| | else: |
| | cid = int(class_id) |
| | cid = max(0, min(cid, NUM_CLASSES - 1)) |
| |
|
| | gen_h, gen_w = GENERATOR_RESOLUTION_PRESETS[gen_resolution_preset] |
| |
|
| | out_res = OUTPUT_RESOLUTION_PRESETS[out_resolution_preset] |
| | if out_res is None: |
| | out_h, out_w = gen_h, gen_w |
| | else: |
| | out_h, out_w = out_res |
| |
|
| | if decoder_ps_choice == "Auto": |
| | dec_ps = auto_decoder_patch_size(out_h, out_w) |
| | else: |
| | ps = int(decoder_ps_choice) |
| | dec_ps = (ps, ps) |
| |
|
| | multiplier = 2 if cfg_scale > 1.0 else 1 |
| |
|
| | c_indices = torch.tensor([cid], device=DEVICE) |
| | th, tw = _make_res_tensors(gen_h, gen_w, multiplier) |
| |
|
| | index_sample = generate( |
| | gpt_model, |
| | c_indices, |
| | LATENT_SIZE, |
| | NUM_CODEBOOKS, |
| | cfg_scale=cfg_scale, |
| | cfg_interval=-1, |
| | target_h=th, |
| | target_w=tw, |
| | temperature=temperature, |
| | top_k=top_k, |
| | top_p=top_p, |
| | sample_logits=True, |
| | ) |
| |
|
| | index_sample = index_sample.unsqueeze(2) |
| | samples = vq_model.decode( |
| | index_sample, |
| | height=out_h, |
| | width=out_w, |
| | patch_size=dec_ps, |
| | ) |
| | samples = torch.clamp(samples, 0, 1) |
| |
|
| | img_np = (samples[0].permute(1, 2, 0).float().cpu().numpy() * 255).astype("uint8") |
| | pil_img = Image.fromarray(img_np) |
| |
|
| | return pil_img, seed |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | HEADER_MD = """ |
| | # VibeToken-Gen: Dynamic Resolution Image Generation |
| | |
| | <p style="margin-top:4px;"> |
| | <b>Maitreya Patel, Jingtao Li, Weiming Zhuang, Yezhou Yang, Lingjuan Lyu</b> |
| | | |
| | </p> |
| | <h3>CVPR 2026 (Main Conference)</h3> |
| | |
| | <p> |
| | <a href="https://huggingface.co/mpatel57/VibeToken" target="_blank">🤗 Model</a> | |
| | <a href="https://github.com/patel-maitreya/VibeToken" target="_blank">💻 GitHub</a> |
| | </p> |
| | |
| | Generate ImageNet class-conditional images at **arbitrary resolutions** using only **65 tokens**. |
| | VibeToken-Gen maintains a constant **179G FLOPs** regardless of output resolution. |
| | """ |
| |
|
| | CITATION_MD = """ |
| | ### Citation |
| | ```bibtex |
| | @inproceedings{vibetoken2026, |
| | title = {VibeToken: Scaling 1D Image Tokenizers and Autoregressive Models for Dynamic Resolution Generations}, |
| | author = {Patel, Maitreya and Li, Jingtao and Zhuang, Weiming and Yang, Yezhou and Lyu, Lingjuan}, |
| | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, |
| | year = {2026} |
| | } |
| | ``` |
| | """ |
| |
|
| | class_choices = ["Custom (enter ID below)"] + sorted(IMAGENET_CLASSES.keys()) |
| |
|
| | with gr.Blocks( |
| | title="VibeToken-Gen Demo", |
| | theme=gr.themes.Soft(), |
| | ) as demo: |
| | gr.Markdown(HEADER_MD) |
| |
|
| | with gr.Row(): |
| | |
| | with gr.Column(scale=1): |
| | class_dropdown = gr.Dropdown( |
| | label="ImageNet Class", |
| | choices=class_choices, |
| | value="Golden Retriever", |
| | info="Pick a class or choose 'Custom' to enter an ID manually.", |
| | ) |
| | class_id_input = gr.Number( |
| | label="Custom Class ID (0–999)", |
| | value=207, |
| | minimum=0, |
| | maximum=999, |
| | step=1, |
| | visible=False, |
| | ) |
| | gen_resolution_dropdown = gr.Dropdown( |
| | label="Generator Resolution", |
| | choices=list(GENERATOR_RESOLUTION_PRESETS.keys()), |
| | value="256 × 256", |
| | info="Internal resolution for the AR generator (max 512×512).", |
| | ) |
| | out_resolution_dropdown = gr.Dropdown( |
| | label="Output Resolution (Decoder)", |
| | choices=list(OUTPUT_RESOLUTION_PRESETS.keys()), |
| | value="Same as generator", |
| | info="Final image resolution. Set higher for super-resolution (e.g. generate at 256, decode at 1024).", |
| | ) |
| | decoder_ps_dropdown = gr.Dropdown( |
| | label="Decoder Patch Size", |
| | choices=["Auto", "8", "16", "32"], |
| | value="Auto", |
| | info="'Auto' selects based on output resolution. Larger = faster but coarser.", |
| | ) |
| |
|
| | with gr.Accordion("Advanced Sampling Parameters", open=False): |
| | cfg_slider = gr.Slider( |
| | label="CFG Scale", |
| | minimum=1.0, maximum=20.0, value=4.0, step=0.5, |
| | info="Classifier-free guidance strength.", |
| | ) |
| | temp_slider = gr.Slider( |
| | label="Temperature", |
| | minimum=0.1, maximum=2.0, value=1.0, step=0.05, |
| | ) |
| | topk_slider = gr.Slider( |
| | label="Top-k", |
| | minimum=0, maximum=2000, value=500, step=10, |
| | info="0 disables top-k filtering.", |
| | ) |
| | topp_slider = gr.Slider( |
| | label="Top-p", |
| | minimum=0.0, maximum=1.0, value=1.0, step=0.05, |
| | info="1.0 disables nucleus sampling.", |
| | ) |
| | seed_input = gr.Number( |
| | label="Seed", value=0, minimum=0, maximum=2**31 - 1, step=1, |
| | ) |
| | randomize_cb = gr.Checkbox(label="Randomize seed", value=True) |
| |
|
| | generate_btn = gr.Button("Generate", variant="primary", size="lg") |
| |
|
| | |
| | with gr.Column(scale=2): |
| | output_image = gr.Image(label="Generated Image", type="pil", height=512) |
| | used_seed = gr.Number(label="Seed used", interactive=False) |
| |
|
| | |
| | def toggle_custom_id(choice): |
| | return gr.update(visible=(choice == "Custom (enter ID below)")) |
| |
|
| | class_dropdown.change( |
| | fn=toggle_custom_id, |
| | inputs=[class_dropdown], |
| | outputs=[class_id_input], |
| | ) |
| |
|
| | generate_btn.click( |
| | fn=generate_image, |
| | inputs=[ |
| | class_dropdown, |
| | class_id_input, |
| | gen_resolution_dropdown, |
| | out_resolution_dropdown, |
| | decoder_ps_dropdown, |
| | cfg_slider, |
| | temp_slider, |
| | topk_slider, |
| | topp_slider, |
| | seed_input, |
| | randomize_cb, |
| | ], |
| | outputs=[output_image, used_seed], |
| | ) |
| |
|
| | gr.Markdown(CITATION_MD) |
| |
|
| |
|
| | load_models() |
| | demo.queue().launch() |
| |
|