Spaces:
Sleeping
Sleeping
| """ | |
| BitPixelLM - Hugging Face Gradio Space | |
| Generates 32x32 pixel art from text prompts. | |
| """ | |
| import sys | |
| import json | |
| import traceback | |
| import torch | |
| import gradio as gr | |
| from pathlib import Path | |
| from PIL import Image | |
| HERE = Path(__file__).parent | |
| sys.path.insert(0, str(HERE)) | |
| from model.tokenizer import PaletteTokenizer | |
| from model.text_encoder import TextTokenizer, TextEncoder | |
| from model.bit_pixel_decoder import BitPixelLMDecoder, BitPixelLM | |
| PALETTE_PATH = HERE / "palette_256.npy" | |
| VOCAB_PATH = HERE / "vocab.json" | |
| CHECKPOINT_PATH = HERE / "best.pt" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = None | |
| palette_tok = None | |
| text_tok = None | |
| VOCAB_DISPLAY = "" | |
| def load_all(): | |
| global model, palette_tok, text_tok, VOCAB_DISPLAY | |
| try: | |
| palette_tok = PaletteTokenizer(palette_path=str(PALETTE_PATH)) | |
| with open(VOCAB_PATH) as f: | |
| vocab = json.load(f) | |
| text_tok = TextTokenizer(vocab) | |
| VOCAB_DISPLAY = ", ".join(sorted(w for w in vocab if not w.startswith("<"))) | |
| ckpt = torch.load(str(CHECKPOINT_PATH), map_location=device, weights_only=False) | |
| args = ckpt.get("args", {}) | |
| text_encoder = TextEncoder( | |
| vocab_size=text_tok.vocab_size, | |
| d_model=args.get("d_model", 256), | |
| nhead=args.get("nhead", 8), | |
| num_layers=args.get("text_layers", 3), | |
| dim_feedforward=args.get("dim_ff", 512), | |
| max_seq_len=args.get("max_text_len", 32), | |
| dropout=args.get("dropout", 0.1), | |
| ) | |
| pixel_decoder = BitPixelLMDecoder( | |
| vocab_size=palette_tok.vocab_size, | |
| d_model=args.get("d_model", 256), | |
| nhead=args.get("nhead", 8), | |
| num_layers=args.get("pixel_layers", 6), | |
| dim_feedforward=args.get("dim_ff", 512), | |
| img_size=32, | |
| dropout=args.get("dropout", 0.1), | |
| ) | |
| m = BitPixelLM(text_encoder, pixel_decoder).to(device) | |
| m.load_state_dict(ckpt["model_state_dict"]) | |
| m.eval() | |
| model = m | |
| print(f"BitPixelLM loaded on {device} | vocab={text_tok.vocab_size} words") | |
| except Exception: | |
| traceback.print_exc() | |
| raise | |
| def generate(prompt, temperature, top_k, top_p, num_samples, scale): | |
| if model is None: | |
| raise gr.Error("Model failed to load. Check Space logs.") | |
| if not prompt.strip(): | |
| raise gr.Error("Please enter a prompt.") | |
| words = prompt.lower().split() | |
| unknown = [w for w in words if w not in text_tok.word2idx | |
| and w not in ("<pad>", "<sos>", "<eos>", "<unk>")] | |
| text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device) | |
| images = [] | |
| try: | |
| for _ in range(int(num_samples)): | |
| with torch.no_grad(): | |
| toks = model.generate( | |
| text_tokens, | |
| sos_token=palette_tok.sos_token, | |
| eos_token=palette_tok.eos_token, | |
| temperature=float(temperature), | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| ) | |
| arr = palette_tok.decode_tokens(toks[0].cpu().tolist()) | |
| img = Image.fromarray(arr, "RGB") | |
| s = int(scale) | |
| if s > 1: | |
| img = img.resize((32 * s, 32 * s), Image.NEAREST) | |
| images.append(img) | |
| except Exception as e: | |
| raise gr.Error(f"Generation error: {e}") | |
| if unknown: | |
| gr.Warning("Unknown words (ignored): " + ", ".join(unknown)) | |
| return images | |
| EXAMPLES = [ | |
| "a red pixel art sword", | |
| "a blue pixel art knight", | |
| "a green pixel art dragon", | |
| "a purple pixel art wizard", | |
| "a gold pixel art crown", | |
| "a dark pixel art skeleton", | |
| ] | |
| def generate_tiled(prompt, temperature, top_k, top_p, num_samples, scale): | |
| """Return all samples tiled into a single image.""" | |
| import numpy as np | |
| imgs = generate(prompt, temperature, top_k, top_p, num_samples, scale) | |
| if not imgs: | |
| return None | |
| w, h = imgs[0].size | |
| n = len(imgs) | |
| cols = min(n, 4) | |
| rows = (n + cols - 1) // cols | |
| canvas = Image.new("RGB", (cols * w, rows * h), (30, 30, 30)) | |
| for i, im in enumerate(imgs): | |
| canvas.paste(im, ((i % cols) * w, (i // cols) * h)) | |
| return canvas | |
| demo = gr.Interface( | |
| fn=generate_tiled, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", placeholder="a red pixel art sword"), | |
| gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature"), | |
| gr.Slider(0, 256, value=40, step=1, label="Top-K (0=off)"), | |
| gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P"), | |
| gr.Slider(1, 8, value=4, step=1, label="Samples"), | |
| gr.Slider(1, 16, value=8, step=1, label="Upscale (8=256px)"), | |
| ], | |
| outputs=gr.Image(label="Generated Pixel Art", type="pil"), | |
| title="BitPixelLM - Pixel Art Generator", | |
| description="Generate 32x32 pixel art sprites from text prompts using BitPixelLM (BitNet b1.58, 7.4M params). Samples are tiled into one image.", | |
| examples=[[ex, 0.8, 40, 0.9, 4, 8] for ex in EXAMPLES], | |
| cache_examples=False, | |
| ) | |
| load_all() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |