Spaces:
Sleeping
Sleeping
File size: 5,337 Bytes
4e8cb8d 78a51fb 4e8cb8d 78a51fb 4e8cb8d 78a51fb 4e8cb8d 78a51fb 4e8cb8d 78a51fb 4e8cb8d 78a51fb 70cf296 78a51fb 70cf296 78a51fb 70cf296 78a51fb 70cf296 78a51fb 70cf296 4e8cb8d 70cf296 4e8cb8d df5e01f 885a6d6 df5e01f 885a6d6 df5e01f 885a6d6 df5e01f 885a6d6 4e8cb8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | """
BitPixelLM - Hugging Face Gradio Space
Generates 32x32 pixel art from text prompts.
"""
import sys
import json
import traceback
import torch
import gradio as gr
from pathlib import Path
from PIL import Image
HERE = Path(__file__).parent
sys.path.insert(0, str(HERE))
from model.tokenizer import PaletteTokenizer
from model.text_encoder import TextTokenizer, TextEncoder
from model.bit_pixel_decoder import BitPixelLMDecoder, BitPixelLM
PALETTE_PATH = HERE / "palette_256.npy"
VOCAB_PATH = HERE / "vocab.json"
CHECKPOINT_PATH = HERE / "best.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
palette_tok = None
text_tok = None
VOCAB_DISPLAY = ""
def load_all():
global model, palette_tok, text_tok, VOCAB_DISPLAY
try:
palette_tok = PaletteTokenizer(palette_path=str(PALETTE_PATH))
with open(VOCAB_PATH) as f:
vocab = json.load(f)
text_tok = TextTokenizer(vocab)
VOCAB_DISPLAY = ", ".join(sorted(w for w in vocab if not w.startswith("<")))
ckpt = torch.load(str(CHECKPOINT_PATH), map_location=device, weights_only=False)
args = ckpt.get("args", {})
text_encoder = TextEncoder(
vocab_size=text_tok.vocab_size,
d_model=args.get("d_model", 256),
nhead=args.get("nhead", 8),
num_layers=args.get("text_layers", 3),
dim_feedforward=args.get("dim_ff", 512),
max_seq_len=args.get("max_text_len", 32),
dropout=args.get("dropout", 0.1),
)
pixel_decoder = BitPixelLMDecoder(
vocab_size=palette_tok.vocab_size,
d_model=args.get("d_model", 256),
nhead=args.get("nhead", 8),
num_layers=args.get("pixel_layers", 6),
dim_feedforward=args.get("dim_ff", 512),
img_size=32,
dropout=args.get("dropout", 0.1),
)
m = BitPixelLM(text_encoder, pixel_decoder).to(device)
m.load_state_dict(ckpt["model_state_dict"])
m.eval()
model = m
print(f"BitPixelLM loaded on {device} | vocab={text_tok.vocab_size} words")
except Exception:
traceback.print_exc()
raise
def generate(prompt, temperature, top_k, top_p, num_samples, scale):
if model is None:
raise gr.Error("Model failed to load. Check Space logs.")
if not prompt.strip():
raise gr.Error("Please enter a prompt.")
words = prompt.lower().split()
unknown = [w for w in words if w not in text_tok.word2idx
and w not in ("<pad>", "<sos>", "<eos>", "<unk>")]
text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)
images = []
try:
for _ in range(int(num_samples)):
with torch.no_grad():
toks = model.generate(
text_tokens,
sos_token=palette_tok.sos_token,
eos_token=palette_tok.eos_token,
temperature=float(temperature),
top_k=int(top_k),
top_p=float(top_p),
)
arr = palette_tok.decode_tokens(toks[0].cpu().tolist())
img = Image.fromarray(arr, "RGB")
s = int(scale)
if s > 1:
img = img.resize((32 * s, 32 * s), Image.NEAREST)
images.append(img)
except Exception as e:
raise gr.Error(f"Generation error: {e}")
if unknown:
gr.Warning("Unknown words (ignored): " + ", ".join(unknown))
return images
EXAMPLES = [
"a red pixel art sword",
"a blue pixel art knight",
"a green pixel art dragon",
"a purple pixel art wizard",
"a gold pixel art crown",
"a dark pixel art skeleton",
]
def generate_tiled(prompt, temperature, top_k, top_p, num_samples, scale):
"""Return all samples tiled into a single image."""
import numpy as np
imgs = generate(prompt, temperature, top_k, top_p, num_samples, scale)
if not imgs:
return None
w, h = imgs[0].size
n = len(imgs)
cols = min(n, 4)
rows = (n + cols - 1) // cols
canvas = Image.new("RGB", (cols * w, rows * h), (30, 30, 30))
for i, im in enumerate(imgs):
canvas.paste(im, ((i % cols) * w, (i // cols) * h))
return canvas
demo = gr.Interface(
fn=generate_tiled,
inputs=[
gr.Textbox(label="Prompt", placeholder="a red pixel art sword"),
gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature"),
gr.Slider(0, 256, value=40, step=1, label="Top-K (0=off)"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"),
gr.Slider(1, 8, value=4, step=1, label="Samples"),
gr.Slider(1, 16, value=8, step=1, label="Upscale (8=256px)"),
],
outputs=gr.Image(label="Generated Pixel Art", type="pil"),
title="BitPixelLM - Pixel Art Generator",
description="Generate 32x32 pixel art sprites from text prompts using BitPixelLM (BitNet b1.58, 7.4M params). Samples are tiled into one image.",
examples=[[ex, 0.8, 40, 0.9, 4, 8] for ex in EXAMPLES],
cache_examples=False,
)
load_all()
demo.launch(server_name="0.0.0.0", server_port=7860)
|