File size: 6,490 Bytes
72e872c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""

PixelArtGen — Generate pixel art from text prompts.



Usage:

    python generate.py --prompt "a red pixel art sword" --output output.png

    python generate.py --prompt "a blue pixel art heart" --output heart.png --temperature 0.7

    python generate.py --batch-prompts prompts.txt --output-dir outputs/

"""

import os
import sys
import json
import argparse
import numpy as np
import torch
from pathlib import Path
from PIL import Image

sys.path.insert(0, str(Path(__file__).parent))

from model.tokenizer import PaletteTokenizer
from model.text_encoder import TextTokenizer, TextEncoder
from model.pixel_decoder import PixelLMDecoder, PixelLM


def load_model(checkpoint_path: str, data_dir: str, device: torch.device):
    """Load a trained PixelLM model from checkpoint."""
    data_dir = Path(data_dir)

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model_args = checkpoint.get("args", {})

    # Load tokenizers
    palette_tok = PaletteTokenizer(palette_path=str(data_dir / "palette_256.npy"))

    with open(data_dir / "vocab.json") as f:
        vocab = json.load(f)
    text_tok = TextTokenizer(vocab)

    # Rebuild model
    d_model = model_args.get("d_model", 256)
    nhead = model_args.get("nhead", 8)
    text_layers = model_args.get("text_layers", 3)
    pixel_layers = model_args.get("pixel_layers", 6)
    dim_ff = model_args.get("dim_ff", 512)
    dropout = model_args.get("dropout", 0.1)
    max_text_len = model_args.get("max_text_len", 32)

    text_encoder = TextEncoder(
        vocab_size=text_tok.vocab_size,
        d_model=d_model,
        nhead=nhead,
        num_layers=text_layers,
        dim_feedforward=dim_ff,
        max_seq_len=max_text_len,
        dropout=dropout,
    )

    pixel_decoder = PixelLMDecoder(
        vocab_size=palette_tok.vocab_size,
        d_model=d_model,
        nhead=nhead,
        num_layers=pixel_layers,
        dim_feedforward=dim_ff,
        img_size=32,
        dropout=dropout,
    )

    model = PixelLM(text_encoder, pixel_decoder).to(device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    return model, palette_tok, text_tok


def generate_pixel_art(

    model: PixelLM,

    palette_tok: PaletteTokenizer,

    text_tok: TextTokenizer,

    prompt: str,

    device: torch.device,

    temperature: float = 0.8,

    top_k: int = 40,

    top_p: float = 0.9,

    scale: int = 8,

) -> Image.Image:
    """

    Generate a 32×32 pixel art image from a text prompt.

    

    Args:

        model: Trained PixelLM model

        palette_tok: Color palette tokenizer

        text_tok: Text tokenizer

        prompt: Text description

        device: torch device

        temperature: Sampling temperature (lower = more deterministic)

        top_k: Top-k filtering

        top_p: Nucleus sampling threshold

        scale: Upscale factor for display (8 = 256×256 output)

    Returns:

        PIL Image (32*scale × 32*scale)

    """
    # Tokenize prompt
    text_tokens = text_tok.encode(prompt).unsqueeze(0).to(device)

    # Generate
    with torch.no_grad():
        generated_tokens = model.generate(
            text_tokens,
            sos_token=palette_tok.sos_token,
            eos_token=palette_tok.eos_token,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
        )

    # Decode to image
    token_list = generated_tokens[0].cpu().tolist()
    img_array = palette_tok.decode_tokens(token_list)
    img = Image.fromarray(img_array, "RGB")

    # Upscale with nearest-neighbor (pixel art style)
    if scale > 1:
        img = img.resize((32 * scale, 32 * scale), Image.NEAREST)

    return img


def main():
    parser = argparse.ArgumentParser(description="Generate pixel art from text")
    parser.add_argument("--prompt", type=str, help="Text prompt")
    parser.add_argument("--output", type=str, default="output.png", help="Output file")
    parser.add_argument("--checkpoint", type=str, default="checkpoints/best.pt")
    parser.add_argument("--data-dir", type=str, default=r"D:\PixelArtGen_Data\processed")
    parser.add_argument("--temperature", type=float, default=0.8)
    parser.add_argument("--top-k", type=int, default=40)
    parser.add_argument("--top-p", type=float, default=0.9)
    parser.add_argument("--scale", type=int, default=8, help="Upscale factor")
    parser.add_argument("--num-samples", type=int, default=1, help="Number of images to generate")
    parser.add_argument("--batch-prompts", type=str, help="File with prompts (one per line)")
    parser.add_argument("--output-dir", type=str, default="outputs")

    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # Load model
    print(f"Loading model from {args.checkpoint}...")
    model, palette_tok, text_tok = load_model(args.checkpoint, args.data_dir, device)
    print(f"  Model: {model.count_parameters():,} parameters")

    # Collect prompts
    if args.batch_prompts:
        with open(args.batch_prompts) as f:
            prompts = [line.strip() for line in f if line.strip()]
    elif args.prompt:
        prompts = [args.prompt]
    else:
        prompts = [
            "a red pixel art sword",
            "a blue pixel art heart",
            "a green pixel art tree",
            "a purple pixel art gem",
        ]

    # Generate
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    for i, prompt in enumerate(prompts):
        print(f"\nGenerating: \"{prompt}\"")
        for j in range(args.num_samples):
            img = generate_pixel_art(
                model, palette_tok, text_tok, prompt, device,
                temperature=args.temperature,
                top_k=args.top_k,
                top_p=args.top_p,
                scale=args.scale,
            )

            if len(prompts) == 1 and args.num_samples == 1:
                out_path = args.output
            else:
                safe_name = prompt.replace(" ", "_")[:30]
                out_path = output_dir / f"{safe_name}_{j}.png"

            img.save(str(out_path))
            print(f"  Saved: {out_path}")

    print("\nDone!")


if __name__ == "__main__":
    main()