| """
|
| PixelArtGen β Gradio Web UI
|
|
|
| Interactive UI to generate pixel art from text prompts using
|
| BitPixelLM β a 1.58-bit ternary transformer (BitNet b1.58).
|
|
|
| Launch:
|
| python app.py
|
| Then open http://localhost:7860 in your browser.
|
| """
|
|
|
| import sys
|
| import json
|
| import torch
|
| import numpy as np
|
| import gradio as gr
|
| from pathlib import Path
|
| from PIL import Image
|
|
|
| sys.path.insert(0, str(Path(__file__).parent))
|
|
|
| from model.tokenizer import PaletteTokenizer
|
| from model.text_encoder import TextTokenizer, TextEncoder
|
| from model.bit_pixel_decoder import BitPixelLMDecoder, BitPixelLM
|
|
|
|
|
| DATA_DIR = Path(r"D:\PixelArtGen_Data\processed")
|
| CHECKPOINT_PATH = Path("checkpoints_bit/best.pt")
|
|
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| model = None
|
| palette_tok = None
|
| text_tok = None
|
|
|
|
|
| def load_tokenizers():
|
| """Load shared tokenizers."""
|
| global palette_tok, text_tok
|
| palette_tok = PaletteTokenizer(palette_path=str(DATA_DIR / "palette_256.npy"))
|
| with open(DATA_DIR / "vocab.json") as f:
|
| vocab = json.load(f)
|
| text_tok = TextTokenizer(vocab)
|
|
|
|
|
| def load_model():
|
| """Load the BitPixelLM model from checkpoint."""
|
| global model
|
| if model is not None:
|
| return model
|
|
|
| if not CHECKPOINT_PATH.exists():
|
| raise FileNotFoundError(
|
| f"Checkpoint not found: {CHECKPOINT_PATH}\n"
|
| "BitPixelLM is still training β check back once training completes."
|
| )
|
|
|
| checkpoint = torch.load(str(CHECKPOINT_PATH), map_location=device, weights_only=False)
|
| model_args = checkpoint.get("args", {})
|
|
|
| d_model = model_args.get("d_model", 256)
|
| nhead = model_args.get("nhead", 8)
|
| text_layers = model_args.get("text_layers", 3)
|
| pixel_layers = model_args.get("pixel_layers", 6)
|
| dim_ff = model_args.get("dim_ff", 512)
|
| dropout = model_args.get("dropout", 0.1)
|
| max_text_len = model_args.get("max_text_len", 32)
|
|
|
| text_encoder = TextEncoder(
|
| vocab_size=text_tok.vocab_size,
|
| d_model=d_model,
|
| nhead=nhead,
|
| num_layers=text_layers,
|
| dim_feedforward=dim_ff,
|
| max_seq_len=max_text_len,
|
| dropout=dropout,
|
| )
|
|
|
| pixel_decoder = BitPixelLMDecoder(
|
| vocab_size=palette_tok.vocab_size,
|
| d_model=d_model,
|
| nhead=nhead,
|
| num_layers=pixel_layers,
|
| dim_feedforward=dim_ff,
|
| img_size=32,
|
| dropout=dropout,
|
| )
|
| m = BitPixelLM(text_encoder, pixel_decoder).to(device)
|
|
|
| m.load_state_dict(checkpoint["model_state_dict"])
|
| m.eval()
|
| model = m
|
| return model
|
|
|
|
|
| def generate(
|
| prompt: str,
|
| temperature: float,
|
| top_k: int,
|
| top_p: float,
|
| num_samples: int,
|
| scale: int,
|
| ):
|
| """Generate pixel art from a text prompt."""
|
| if not prompt.strip():
|
| raise gr.Error("Please enter a prompt.")
|
|
|
| if model is None:
|
| raise gr.Error(
|
| "BitPixelLM is not loaded yet. "
|
| "It may still be training β check back once training completes."
|
| )
|
|
|
| text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)
|
|
|
|
|
| words = prompt.lower().strip().split()
|
| unknown = [w for w in words if w not in text_tok.word2idx and w not in ("<pad>", "<sos>", "<eos>", "<unk>")]
|
|
|
| images = []
|
| try:
|
| for _ in range(int(num_samples)):
|
| with torch.no_grad():
|
| generated_tokens = model.generate(
|
| text_tokens,
|
| sos_token=palette_tok.sos_token,
|
| eos_token=palette_tok.eos_token,
|
| temperature=temperature,
|
| top_k=top_k,
|
| top_p=top_p,
|
| )
|
|
|
| token_list = generated_tokens[0].cpu().tolist()
|
| img_array = palette_tok.decode_tokens(token_list)
|
| img = Image.fromarray(img_array, "RGB")
|
|
|
|
|
| s = int(scale)
|
| if s > 1:
|
| img = img.resize((32 * s, 32 * s), Image.NEAREST)
|
|
|
| images.append(img)
|
| except Exception as e:
|
| raise gr.Error(f"Generation failed: {e}")
|
|
|
| if unknown:
|
| gr.Warning(
|
| f"Unknown words treated as <unk>: {', '.join(unknown)}. "
|
| f"Try using words from the vocabulary list below."
|
| )
|
|
|
| return images
|
|
|
|
|
|
|
|
|
|
|
| def _load_vocab_words():
|
| try:
|
| with open(DATA_DIR / "vocab.json") as f:
|
| vocab = json.load(f)
|
| return sorted([w for w in vocab if not w.startswith("<")])
|
| except Exception:
|
| return ["pixel", "art", "sword", "red", "blue", "green"]
|
|
|
| VOCAB_WORDS = _load_vocab_words()
|
|
|
| EXAMPLE_PROMPTS = [
|
| "a red pixel art sword",
|
| "a green pixel art dragon",
|
| "a purple pixel art crystal",
|
| "a blue pixel art knight",
|
| "a gold pixel art castle",
|
| "a red pixel art phoenix",
|
| "a dark pixel art skeleton",
|
| "a teal pixel art wizard",
|
| "a silver pixel art robot",
|
| "a orange pixel art fox",
|
| ]
|
|
|
|
|
| def build_ui():
|
| with gr.Blocks(
|
| title="PixelArtGen",
|
| theme=gr.themes.Soft(primary_hue="purple"),
|
| css="""
|
| .gallery-item img { image-rendering: pixelated !important; }
|
| .output-gallery img { image-rendering: pixelated !important; }
|
| #gallery img { image-rendering: pixelated !important; }
|
| """,
|
| ) as app:
|
| gr.Markdown(
|
| """
|
| # PixelArtGen
|
| ### Generate 32x32 pixel art from text prompts
|
|
|
| Powered by **BitPixelLM** β a custom 1.58-bit ternary transformer built from scratch
|
| using BitNet b1.58 with RMSNorm, SwiGLU, and 2D positional encoding.
|
| 7.3M parameters (75% ternary weights at 1.58 bits per weight).
|
| """
|
| )
|
|
|
| with gr.Row():
|
| with gr.Column(scale=1):
|
| prompt = gr.Textbox(
|
| label="Prompt",
|
| placeholder="a red pixel art sword",
|
| lines=2,
|
| )
|
| with gr.Row():
|
| generate_btn = gr.Button("Generate", variant="primary", scale=2)
|
| num_samples = gr.Slider(1, 8, value=4, step=1, label="Samples")
|
|
|
| with gr.Accordion("Advanced Settings", open=False):
|
| temperature = gr.Slider(
|
| 0.1, 2.0, value=0.8, step=0.05,
|
| label="Temperature",
|
| info="Lower = more deterministic, higher = more creative"
|
| )
|
| top_k = gr.Slider(
|
| 0, 256, value=40, step=1,
|
| label="Top-K",
|
| info="0 = disabled. Limits sampling to top K tokens."
|
| )
|
| top_p = gr.Slider(
|
| 0.1, 1.0, value=0.9, step=0.05,
|
| label="Top-P (Nucleus)",
|
| info="Cumulative probability threshold for sampling."
|
| )
|
| scale = gr.Slider(
|
| 1, 16, value=8, step=1,
|
| label="Upscale Factor",
|
| info="8x = 256x256, 16x = 512x512"
|
| )
|
|
|
| gr.Markdown(
|
| f"**Known vocabulary:** {', '.join(VOCAB_WORDS)}"
|
| )
|
|
|
| with gr.Column(scale=2):
|
| gallery = gr.Gallery(
|
| label="Generated Pixel Art",
|
| columns=4,
|
| rows=2,
|
| height=520,
|
| object_fit="contain",
|
| elem_id="gallery",
|
| )
|
|
|
| gr.Markdown("### Examples")
|
| gr.Examples(
|
| examples=EXAMPLE_PROMPTS,
|
| inputs=[prompt],
|
| label="Click to try",
|
| )
|
|
|
| gr.Markdown(
|
| """
|
| ---
|
| **Architecture:**
|
| BitPixelLM treats pixel art generation as language modeling β each pixel is a token from a 256-color palette,
|
| generated left-to-right, top-to-bottom via a causal transformer with 2D positional encoding and cross-attention to text.
|
| Uses 1.58-bit ternary weights (BitNet b1.58) with RMSNorm and SwiGLU for extreme parameter efficiency.
|
| """
|
| )
|
|
|
|
|
| generate_btn.click(
|
| fn=generate,
|
| inputs=[prompt, temperature, top_k, top_p, num_samples, scale],
|
| outputs=gallery,
|
| )
|
|
|
|
|
| prompt.submit(
|
| fn=generate,
|
| inputs=[prompt, temperature, top_k, top_p, num_samples, scale],
|
| outputs=gallery,
|
| )
|
|
|
| return app
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| print("Loading tokenizers...")
|
| load_tokenizers()
|
| print(f" Palette: {palette_tok.vocab_size} tokens")
|
| print(f" Text: {text_tok.vocab_size} words")
|
| print(f" Device: {device}")
|
|
|
|
|
| print(f"Loading BitPixelLM from {CHECKPOINT_PATH}...")
|
| try:
|
| load_model()
|
| print(f" BitPixelLM loaded successfully.")
|
| except FileNotFoundError as e:
|
| print(f" {e}")
|
| print(f" UI will launch but generation will be unavailable until training completes.")
|
| except Exception as e:
|
| print(f" Failed to load BitPixelLM: {e}")
|
|
|
| print("\nLaunching UI...")
|
| app = build_ui()
|
| app.launch(
|
| server_name="0.0.0.0",
|
| server_port=7860,
|
| share=False,
|
| inbrowser=True,
|
| )
|
|
|