""" BitPixelLM - Hugging Face Gradio Space Generates 32x32 pixel art from text prompts. """ import sys import json import traceback import torch import gradio as gr from pathlib import Path from PIL import Image HERE = Path(__file__).parent sys.path.insert(0, str(HERE)) from model.tokenizer import PaletteTokenizer from model.text_encoder import TextTokenizer, TextEncoder from model.bit_pixel_decoder import BitPixelLMDecoder, BitPixelLM PALETTE_PATH = HERE / "palette_256.npy" VOCAB_PATH = HERE / "vocab.json" CHECKPOINT_PATH = HERE / "best.pt" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = None palette_tok = None text_tok = None VOCAB_DISPLAY = "" def load_all(): global model, palette_tok, text_tok, VOCAB_DISPLAY try: palette_tok = PaletteTokenizer(palette_path=str(PALETTE_PATH)) with open(VOCAB_PATH) as f: vocab = json.load(f) text_tok = TextTokenizer(vocab) VOCAB_DISPLAY = ", ".join(sorted(w for w in vocab if not w.startswith("<"))) ckpt = torch.load(str(CHECKPOINT_PATH), map_location=device, weights_only=False) args = ckpt.get("args", {}) text_encoder = TextEncoder( vocab_size=text_tok.vocab_size, d_model=args.get("d_model", 256), nhead=args.get("nhead", 8), num_layers=args.get("text_layers", 3), dim_feedforward=args.get("dim_ff", 512), max_seq_len=args.get("max_text_len", 32), dropout=args.get("dropout", 0.1), ) pixel_decoder = BitPixelLMDecoder( vocab_size=palette_tok.vocab_size, d_model=args.get("d_model", 256), nhead=args.get("nhead", 8), num_layers=args.get("pixel_layers", 6), dim_feedforward=args.get("dim_ff", 512), img_size=32, dropout=args.get("dropout", 0.1), ) m = BitPixelLM(text_encoder, pixel_decoder).to(device) m.load_state_dict(ckpt["model_state_dict"]) m.eval() model = m print(f"BitPixelLM loaded on {device} | vocab={text_tok.vocab_size} words") except Exception: traceback.print_exc() raise def generate(prompt, temperature, top_k, top_p, num_samples, scale): if model is None: raise gr.Error("Model failed to load. Check Space logs.") if not prompt.strip(): raise gr.Error("Please enter a prompt.") words = prompt.lower().split() unknown = [w for w in words if w not in text_tok.word2idx and w not in ("", "", "", "")] text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device) images = [] try: for _ in range(int(num_samples)): with torch.no_grad(): toks = model.generate( text_tokens, sos_token=palette_tok.sos_token, eos_token=palette_tok.eos_token, temperature=float(temperature), top_k=int(top_k), top_p=float(top_p), ) arr = palette_tok.decode_tokens(toks[0].cpu().tolist()) img = Image.fromarray(arr, "RGB") s = int(scale) if s > 1: img = img.resize((32 * s, 32 * s), Image.NEAREST) images.append(img) except Exception as e: raise gr.Error(f"Generation error: {e}") if unknown: gr.Warning("Unknown words (ignored): " + ", ".join(unknown)) return images EXAMPLES = [ "a red pixel art sword", "a blue pixel art knight", "a green pixel art dragon", "a purple pixel art wizard", "a gold pixel art crown", "a dark pixel art skeleton", ] def generate_tiled(prompt, temperature, top_k, top_p, num_samples, scale): """Return all samples tiled into a single image.""" import numpy as np imgs = generate(prompt, temperature, top_k, top_p, num_samples, scale) if not imgs: return None w, h = imgs[0].size n = len(imgs) cols = min(n, 4) rows = (n + cols - 1) // cols canvas = Image.new("RGB", (cols * w, rows * h), (30, 30, 30)) for i, im in enumerate(imgs): canvas.paste(im, ((i % cols) * w, (i // cols) * h)) return canvas demo = gr.Interface( fn=generate_tiled, inputs=[ gr.Textbox(label="Prompt", placeholder="a red pixel art sword"), gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature"), gr.Slider(0, 256, value=40, step=1, label="Top-K (0=off)"), gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"), gr.Slider(1, 8, value=4, step=1, label="Samples"), gr.Slider(1, 16, value=8, step=1, label="Upscale (8=256px)"), ], outputs=gr.Image(label="Generated Pixel Art", type="pil"), title="BitPixelLM - Pixel Art Generator", description="Generate 32x32 pixel art sprites from text prompts using BitPixelLM (BitNet b1.58, 7.4M params). Samples are tiled into one image.", examples=[[ex, 0.8, 40, 0.9, 4, 8] for ex in EXAMPLES], cache_examples=False, ) load_all() demo.launch(server_name="0.0.0.0", server_port=7860)