""" PixelArtGen — Generate pixel art from text prompts. Usage: python generate.py --prompt "a red pixel art sword" --output output.png python generate.py --prompt "a blue pixel art heart" --output heart.png --temperature 0.7 python generate.py --batch-prompts prompts.txt --output-dir outputs/ """ import os import sys import json import argparse import numpy as np import torch from pathlib import Path from PIL import Image sys.path.insert(0, str(Path(__file__).parent)) from model.tokenizer import PaletteTokenizer from model.text_encoder import TextTokenizer, TextEncoder from model.pixel_decoder import PixelLMDecoder, PixelLM def load_model(checkpoint_path: str, data_dir: str, device: torch.device): """Load a trained PixelLM model from checkpoint.""" data_dir = Path(data_dir) # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) model_args = checkpoint.get("args", {}) # Load tokenizers palette_tok = PaletteTokenizer(palette_path=str(data_dir / "palette_256.npy")) with open(data_dir / "vocab.json") as f: vocab = json.load(f) text_tok = TextTokenizer(vocab) # Rebuild model d_model = model_args.get("d_model", 256) nhead = model_args.get("nhead", 8) text_layers = model_args.get("text_layers", 3) pixel_layers = model_args.get("pixel_layers", 6) dim_ff = model_args.get("dim_ff", 512) dropout = model_args.get("dropout", 0.1) max_text_len = model_args.get("max_text_len", 32) text_encoder = TextEncoder( vocab_size=text_tok.vocab_size, d_model=d_model, nhead=nhead, num_layers=text_layers, dim_feedforward=dim_ff, max_seq_len=max_text_len, dropout=dropout, ) pixel_decoder = PixelLMDecoder( vocab_size=palette_tok.vocab_size, d_model=d_model, nhead=nhead, num_layers=pixel_layers, dim_feedforward=dim_ff, img_size=32, dropout=dropout, ) model = PixelLM(text_encoder, pixel_decoder).to(device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, palette_tok, text_tok def generate_pixel_art( model: PixelLM, palette_tok: PaletteTokenizer, text_tok: TextTokenizer, prompt: str, device: torch.device, temperature: float = 0.8, top_k: int = 40, top_p: float = 0.9, scale: int = 8, ) -> Image.Image: """ Generate a 32×32 pixel art image from a text prompt. Args: model: Trained PixelLM model palette_tok: Color palette tokenizer text_tok: Text tokenizer prompt: Text description device: torch device temperature: Sampling temperature (lower = more deterministic) top_k: Top-k filtering top_p: Nucleus sampling threshold scale: Upscale factor for display (8 = 256×256 output) Returns: PIL Image (32*scale × 32*scale) """ # Tokenize prompt text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device) # Generate with torch.no_grad(): generated_tokens = model.generate( text_tokens, sos_token=palette_tok.sos_token, eos_token=palette_tok.eos_token, temperature=temperature, top_k=top_k, top_p=top_p, ) # Decode to image token_list = generated_tokens[0].cpu().tolist() img_array = palette_tok.decode_tokens(token_list) img = Image.fromarray(img_array, "RGB") # Upscale with nearest-neighbor (pixel art style) if scale > 1: img = img.resize((32 * scale, 32 * scale), Image.NEAREST) return img def main(): parser = argparse.ArgumentParser(description="Generate pixel art from text") parser.add_argument("--prompt", type=str, help="Text prompt") parser.add_argument("--output", type=str, default="output.png", help="Output file") parser.add_argument("--checkpoint", type=str, default="checkpoints/best.pt") parser.add_argument("--data-dir", type=str, default=r"D:\PixelArtGen_Data\processed") parser.add_argument("--temperature", type=float, default=0.8) parser.add_argument("--top-k", type=int, default=40) parser.add_argument("--top-p", type=float, default=0.9) parser.add_argument("--scale", type=int, default=8, help="Upscale factor") parser.add_argument("--num-samples", type=int, default=1, help="Number of images to generate") parser.add_argument("--batch-prompts", type=str, help="File with prompts (one per line)") parser.add_argument("--output-dir", type=str, default="outputs") args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # Load model print(f"Loading model from {args.checkpoint}...") model, palette_tok, text_tok = load_model(args.checkpoint, args.data_dir, device) print(f" Model: {model.count_parameters():,} parameters") # Collect prompts if args.batch_prompts: with open(args.batch_prompts) as f: prompts = [line.strip() for line in f if line.strip()] elif args.prompt: prompts = [args.prompt] else: prompts = [ "a red pixel art sword", "a blue pixel art heart", "a green pixel art tree", "a purple pixel art gem", ] # Generate output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) for i, prompt in enumerate(prompts): print(f"\nGenerating: \"{prompt}\"") for j in range(args.num_samples): img = generate_pixel_art( model, palette_tok, text_tok, prompt, device, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, scale=args.scale, ) if len(prompts) == 1 and args.num_samples == 1: out_path = args.output else: safe_name = prompt.replace(" ", "_")[:30] out_path = output_dir / f"{safe_name}_{j}.png" img.save(str(out_path)) print(f" Saved: {out_path}") print("\nDone!") if __name__ == "__main__": main()