BitPixelLM / generate.py
BlakePeavy's picture
Upload BitPixelLM model artifacts
72e872c verified
"""
PixelArtGen — Generate pixel art from text prompts.
Usage:
python generate.py --prompt "a red pixel art sword" --output output.png
python generate.py --prompt "a blue pixel art heart" --output heart.png --temperature 0.7
python generate.py --batch-prompts prompts.txt --output-dir outputs/
"""
import os
import sys
import json
import argparse
import numpy as np
import torch
from pathlib import Path
from PIL import Image
sys.path.insert(0, str(Path(__file__).parent))
from model.tokenizer import PaletteTokenizer
from model.text_encoder import TextTokenizer, TextEncoder
from model.pixel_decoder import PixelLMDecoder, PixelLM
def load_model(checkpoint_path: str, data_dir: str, device: torch.device):
"""Load a trained PixelLM model from checkpoint."""
data_dir = Path(data_dir)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model_args = checkpoint.get("args", {})
# Load tokenizers
palette_tok = PaletteTokenizer(palette_path=str(data_dir / "palette_256.npy"))
with open(data_dir / "vocab.json") as f:
vocab = json.load(f)
text_tok = TextTokenizer(vocab)
# Rebuild model
d_model = model_args.get("d_model", 256)
nhead = model_args.get("nhead", 8)
text_layers = model_args.get("text_layers", 3)
pixel_layers = model_args.get("pixel_layers", 6)
dim_ff = model_args.get("dim_ff", 512)
dropout = model_args.get("dropout", 0.1)
max_text_len = model_args.get("max_text_len", 32)
text_encoder = TextEncoder(
vocab_size=text_tok.vocab_size,
d_model=d_model,
nhead=nhead,
num_layers=text_layers,
dim_feedforward=dim_ff,
max_seq_len=max_text_len,
dropout=dropout,
)
pixel_decoder = PixelLMDecoder(
vocab_size=palette_tok.vocab_size,
d_model=d_model,
nhead=nhead,
num_layers=pixel_layers,
dim_feedforward=dim_ff,
img_size=32,
dropout=dropout,
)
model = PixelLM(text_encoder, pixel_decoder).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, palette_tok, text_tok
def generate_pixel_art(
model: PixelLM,
palette_tok: PaletteTokenizer,
text_tok: TextTokenizer,
prompt: str,
device: torch.device,
temperature: float = 0.8,
top_k: int = 40,
top_p: float = 0.9,
scale: int = 8,
) -> Image.Image:
"""
Generate a 32×32 pixel art image from a text prompt.
Args:
model: Trained PixelLM model
palette_tok: Color palette tokenizer
text_tok: Text tokenizer
prompt: Text description
device: torch device
temperature: Sampling temperature (lower = more deterministic)
top_k: Top-k filtering
top_p: Nucleus sampling threshold
scale: Upscale factor for display (8 = 256×256 output)
Returns:
PIL Image (32*scale × 32*scale)
"""
# Tokenize prompt
text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)
# Generate
with torch.no_grad():
generated_tokens = model.generate(
text_tokens,
sos_token=palette_tok.sos_token,
eos_token=palette_tok.eos_token,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
# Decode to image
token_list = generated_tokens[0].cpu().tolist()
img_array = palette_tok.decode_tokens(token_list)
img = Image.fromarray(img_array, "RGB")
# Upscale with nearest-neighbor (pixel art style)
if scale > 1:
img = img.resize((32 * scale, 32 * scale), Image.NEAREST)
return img
def main():
parser = argparse.ArgumentParser(description="Generate pixel art from text")
parser.add_argument("--prompt", type=str, help="Text prompt")
parser.add_argument("--output", type=str, default="output.png", help="Output file")
parser.add_argument("--checkpoint", type=str, default="checkpoints/best.pt")
parser.add_argument("--data-dir", type=str, default=r"D:\PixelArtGen_Data\processed")
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-k", type=int, default=40)
parser.add_argument("--top-p", type=float, default=0.9)
parser.add_argument("--scale", type=int, default=8, help="Upscale factor")
parser.add_argument("--num-samples", type=int, default=1, help="Number of images to generate")
parser.add_argument("--batch-prompts", type=str, help="File with prompts (one per line)")
parser.add_argument("--output-dir", type=str, default="outputs")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Load model
print(f"Loading model from {args.checkpoint}...")
model, palette_tok, text_tok = load_model(args.checkpoint, args.data_dir, device)
print(f" Model: {model.count_parameters():,} parameters")
# Collect prompts
if args.batch_prompts:
with open(args.batch_prompts) as f:
prompts = [line.strip() for line in f if line.strip()]
elif args.prompt:
prompts = [args.prompt]
else:
prompts = [
"a red pixel art sword",
"a blue pixel art heart",
"a green pixel art tree",
"a purple pixel art gem",
]
# Generate
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
for i, prompt in enumerate(prompts):
print(f"\nGenerating: \"{prompt}\"")
for j in range(args.num_samples):
img = generate_pixel_art(
model, palette_tok, text_tok, prompt, device,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
scale=args.scale,
)
if len(prompts) == 1 and args.num_samples == 1:
out_path = args.output
else:
safe_name = prompt.replace(" ", "_")[:30]
out_path = output_dir / f"{safe_name}_{j}.png"
img.save(str(out_path))
print(f" Saved: {out_path}")
print("\nDone!")
if __name__ == "__main__":
main()