BitPixelLM-demo / app.py
BlakePeavy's picture
Deploy BitPixelLM Gradio Space
df5e01f verified
"""
BitPixelLM - Hugging Face Gradio Space
Generates 32x32 pixel art from text prompts.
"""
import sys
import json
import traceback
import torch
import gradio as gr
from pathlib import Path
from PIL import Image
HERE = Path(__file__).parent
sys.path.insert(0, str(HERE))
from model.tokenizer import PaletteTokenizer
from model.text_encoder import TextTokenizer, TextEncoder
from model.bit_pixel_decoder import BitPixelLMDecoder, BitPixelLM
PALETTE_PATH = HERE / "palette_256.npy"
VOCAB_PATH = HERE / "vocab.json"
CHECKPOINT_PATH = HERE / "best.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
palette_tok = None
text_tok = None
VOCAB_DISPLAY = ""
def load_all():
global model, palette_tok, text_tok, VOCAB_DISPLAY
try:
palette_tok = PaletteTokenizer(palette_path=str(PALETTE_PATH))
with open(VOCAB_PATH) as f:
vocab = json.load(f)
text_tok = TextTokenizer(vocab)
VOCAB_DISPLAY = ", ".join(sorted(w for w in vocab if not w.startswith("<")))
ckpt = torch.load(str(CHECKPOINT_PATH), map_location=device, weights_only=False)
args = ckpt.get("args", {})
text_encoder = TextEncoder(
vocab_size=text_tok.vocab_size,
d_model=args.get("d_model", 256),
nhead=args.get("nhead", 8),
num_layers=args.get("text_layers", 3),
dim_feedforward=args.get("dim_ff", 512),
max_seq_len=args.get("max_text_len", 32),
dropout=args.get("dropout", 0.1),
)
pixel_decoder = BitPixelLMDecoder(
vocab_size=palette_tok.vocab_size,
d_model=args.get("d_model", 256),
nhead=args.get("nhead", 8),
num_layers=args.get("pixel_layers", 6),
dim_feedforward=args.get("dim_ff", 512),
img_size=32,
dropout=args.get("dropout", 0.1),
)
m = BitPixelLM(text_encoder, pixel_decoder).to(device)
m.load_state_dict(ckpt["model_state_dict"])
m.eval()
model = m
print(f"BitPixelLM loaded on {device} | vocab={text_tok.vocab_size} words")
except Exception:
traceback.print_exc()
raise
def generate(prompt, temperature, top_k, top_p, num_samples, scale):
if model is None:
raise gr.Error("Model failed to load. Check Space logs.")
if not prompt.strip():
raise gr.Error("Please enter a prompt.")
words = prompt.lower().split()
unknown = [w for w in words if w not in text_tok.word2idx
and w not in ("<pad>", "<sos>", "<eos>", "<unk>")]
text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)
images = []
try:
for _ in range(int(num_samples)):
with torch.no_grad():
toks = model.generate(
text_tokens,
sos_token=palette_tok.sos_token,
eos_token=palette_tok.eos_token,
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
)
arr = palette_tok.decode_tokens(toks[0].cpu().tolist())
img = Image.fromarray(arr, "RGB")
s = int(scale)
if s > 1:
img = img.resize((32 * s, 32 * s), Image.NEAREST)
images.append(img)
except Exception as e:
raise gr.Error(f"Generation error: {e}")
if unknown:
gr.Warning("Unknown words (ignored): " + ", ".join(unknown))
return images
EXAMPLES = [
"a red pixel art sword",
"a blue pixel art knight",
"a green pixel art dragon",
"a purple pixel art wizard",
"a gold pixel art crown",
"a dark pixel art skeleton",
]
def generate_tiled(prompt, temperature, top_k, top_p, num_samples, scale):
"""Return all samples tiled into a single image."""
import numpy as np
imgs = generate(prompt, temperature, top_k, top_p, num_samples, scale)
if not imgs:
return None
w, h = imgs[0].size
n = len(imgs)
cols = min(n, 4)
rows = (n + cols - 1) // cols
canvas = Image.new("RGB", (cols * w, rows * h), (30, 30, 30))
for i, im in enumerate(imgs):
canvas.paste(im, ((i % cols) * w, (i // cols) * h))
return canvas
demo = gr.Interface(
fn=generate_tiled,
inputs=[
gr.Textbox(label="Prompt", placeholder="a red pixel art sword"),
gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature"),
gr.Slider(0, 256, value=40, step=1, label="Top-K (0=off)"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"),
gr.Slider(1, 8, value=4, step=1, label="Samples"),
gr.Slider(1, 16, value=8, step=1, label="Upscale (8=256px)"),
],
outputs=gr.Image(label="Generated Pixel Art", type="pil"),
title="BitPixelLM - Pixel Art Generator",
description="Generate 32x32 pixel art sprites from text prompts using BitPixelLM (BitNet b1.58, 7.4M params). Samples are tiled into one image.",
examples=[[ex, 0.8, 40, 0.9, 4, 8] for ex in EXAMPLES],
cache_examples=False,
)
load_all()
demo.launch(server_name="0.0.0.0", server_port=7860)