File size: 12,702 Bytes
409857b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from pathlib import Path
from tkinter import *
from PIL import ImageTk, Image
import random
from safetensors.torch import load_file

# Generator model definition (must match the training architecture)
class Generator(nn.Module):
    def __init__(self, codings_size, image_size, image_channels):
        super(Generator, self).__init__()
        
        self.fc = nn.Linear(codings_size, 6 * 6 * 256, bias=False)
        self.bn1 = nn.BatchNorm1d(6 * 6 * 256)
        self.leaky_relu = nn.LeakyReLU(0.2)
        
        self.conv_transpose1 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=1, padding=2, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv_transpose2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(64)
        
        self.conv_transpose3 = nn.ConvTranspose2d(64, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
        self.tanh = nn.Tanh()
    
    def forward(self, x):
        x = self.fc(x)
        x = self.bn1(x)
        x = self.leaky_relu(x)
        x = x.view(-1, 256, 6, 6)
        
        x = self.conv_transpose1(x)
        x = self.bn2(x)
        x = self.leaky_relu(x)
        
        x = self.conv_transpose2(x)
        x = self.bn3(x)
        x = self.leaky_relu(x)
        
        x = self.conv_transpose3(x)
        x = self.tanh(x)
        
        return x

def load_model(model_path, device='cpu'):
    """
    Load the trained generator model from safetensors format.
    
    Args:
        model_path: Path to the .safetensors model file
        device: Device to load the model on ('cpu' or 'cuda')
    
    Returns:
        Loaded generator model and configuration parameters
    """
    # Load state dict and metadata from safetensors
    state_dict = load_file(model_path)
    
    # Load metadata from safetensors file
    from safetensors import safe_open
    with safe_open(model_path, framework="pt", device=str(device)) as f:
        metadata = f.metadata()
    
    # Extract model configuration from metadata
    codings_size = int(metadata['codings_size'])
    image_size = int(metadata['image_size'])
    image_channels = int(metadata['image_channels'])
    
    # Create generator model
    model = Generator(codings_size, image_size, image_channels)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    
    print(f"Model configuration: codings_size={codings_size}, image_size={image_size}, image_channels={image_channels}")
    
    return model, codings_size, image_size, image_channels

def generate_images(model, num_images, codings_size=100, seed=None, device='cpu'):
    """
    Generate images using the trained GAN generator model.
    
    Args:
        model: Loaded PyTorch generator model
        num_images: Number of images to generate
        codings_size: Size of the latent vector (default: 100)
        seed: Random seed for reproducibility
        device: Device to run generation on
    
    Returns:
        Generated images as numpy array (scaled to [0, 1])
    """
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
    
    # Generate random noise as input
    noise = torch.randn(num_images, codings_size, device=device)
    
    # Generate images
    with torch.no_grad():
        generated_images = model(noise)
    
    # Convert from CHW to HWC format and scale from [-1, 1] to [0, 1]
    generated_images = generated_images.permute(0, 2, 3, 1).cpu().numpy()
    generated_images = (generated_images + 1) / 2  # Scale to [0, 1]
    
    return generated_images

def save_image_grid(images, output_path, grid_size=None):
    """
    Save generated images as a grid visualization.
    
    Args:
        images: Array of generated images
        output_path: Path to save the grid image
        grid_size: Optional grid size (rows, cols). If None, auto-calculate square grid
    """
    num_images = images.shape[0]
    
    if grid_size is None:
        # Auto-calculate square grid
        grid_rows = int(np.sqrt(num_images))
        grid_cols = int(np.ceil(num_images / grid_rows))
    else:
        grid_rows, grid_cols = grid_size
    
    fig = plt.figure(figsize=(grid_cols * 2, grid_rows * 2))
    
    for i in range(min(num_images, grid_rows * grid_cols)):
        plt.subplot(grid_rows, grid_cols, i + 1)
        
        # Handle different image formats
        if images.shape[-1] == 1:
            # Grayscale
            plt.imshow(images[i, :, :, 0], cmap='gray')
        else:
            # RGB or RGBA
            plt.imshow(images[i])
        
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()

def save_individual_images(images, output_dir, prefix="generated"):
    """
    Save each generated image as a separate file.
    
    Args:
        images: Array of generated images
        output_dir: Directory to save individual images
        prefix: Prefix for image filenames
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    for i, img in enumerate(images):
        # Convert to uint8 format (0-255)
        img_uint8 = (img * 255).astype(np.uint8)
        
        # Save using matplotlib to handle RGBA correctly
        output_path = output_dir / f"{prefix}_{i:04d}.png"
        plt.imsave(output_path, img_uint8)
    
    print(f"Saved {len(images)} individual images to: {output_dir}")


# ============ TKINTER UI MODE ============

def run_gui(model_path, output_path):
    """
    Run Tkinter GUI for interactive image generation.
    """
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load model once at startup
    print(f"Loading model from: {model_path}")
    try:
        model, codings_size, image_size, image_channels = load_model(model_path, device)
        print("Model loaded successfully!")
    except Exception as e:
        print(f"Error loading model: {e}")
        import traceback
        traceback.print_exc()
        return
    
    # Create output directory
    output_dir = Path(output_path).parent
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize Tkinter window
    root = Tk()
    root.title("CryptoPunk Generator")
    root.columnconfigure([0, 1, 2, 3], minsize=200)
    
    # Create a placeholder image if output doesn't exist
    if not os.path.exists(output_path):
        fig = plt.figure(figsize=(4, 4))
        plt.text(0.5, 0.5, 'Click a button to generate!', 
                ha='center', va='center', fontsize=16)
        plt.axis('off')
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close()
    
    # Load and display initial image
    img = ImageTk.PhotoImage(Image.open(output_path))
    panel = Label(root, image=img)
    panel.grid(row=1, columnspan=4, sticky="nsew")
    
    def update_img():
        """Update the displayed image"""
        new_img = ImageTk.PhotoImage(Image.open(output_path))
        panel.configure(image=new_img)
        panel.image = new_img
    
    def generate(grid_size):
        """Generate images in a grid"""
        print(f"Generating {grid_size}x{grid_size} grid...")
        n_img = grid_size * grid_size
        seed = random.getrandbits(32)
        
        # Generate images
        images = generate_images(model, n_img, codings_size, seed, device)
        
        # Create grid visualization
        fig = plt.figure(figsize=(8, 8))
        for i in range(n_img):
            plt.subplot(grid_size, grid_size, i + 1)
            plt.imshow(images[i, :, :, :])
            plt.axis('off')
        plt.tight_layout()
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close()
        
        print(f"Generated with seed: {seed}")
        update_img()
    
    # Create buttons
    btn_1 = Button(root, text="Generate 1 cryptopunk", command=lambda: generate(1))
    btn_3 = Button(root, text="Generate 3x3 cryptopunks", command=lambda: generate(3))
    btn_5 = Button(root, text="Generate 5x5 cryptopunks", command=lambda: generate(5))
    btn_q = Button(root, text="Terminate", command=root.quit)
    
    btn_1.grid(row=0, column=0, sticky="nsew")
    btn_3.grid(row=0, column=1, sticky="nsew")
    btn_5.grid(row=0, column=2, sticky="nsew")
    btn_q.grid(row=0, column=3, sticky="nsew")
    
    print("\nGUI started! Click buttons to generate images.")
    root.mainloop()


# ============ CLI MODE ============

def run_cli(args):
    """
    Run command-line interface for batch image generation.
    """
    # Check if model exists
    if not os.path.exists(args.model_path):
        print(f"Error: Model not found at {args.model_path}")
        print("Please train the model first using trainer.py")
        return
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load the trained model
    print(f"Loading model from: {args.model_path}")
    try:
        model, codings_size, image_size, image_channels = load_model(args.model_path, device)
        print("Model loaded successfully!")
    except Exception as e:
        print(f"Error loading model: {e}")
        import traceback
        traceback.print_exc()
        return
    
    # Calculate actual number of images for grid
    if args.grid_size is not None:
        num_images = args.grid_size * args.grid_size
        grid_size = (args.grid_size, args.grid_size)
        print(f"Generating {num_images} images in a {args.grid_size}x{args.grid_size} grid")
    else:
        num_images = args.num_images
        grid_size = None
        print(f"Generating {num_images} images")
    
    # Generate images
    print("Generating images...")
    images = generate_images(model, num_images, codings_size, args.seed, device)
    print(f"Generated images shape: {images.shape}")
    print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
    
    # Create output directory if needed
    output_dir = Path(args.output_path).parent
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Save grid visualization
    save_image_grid(images, args.output_path, grid_size)
    print(f"Grid image saved to: {args.output_path}")
    
    # Optionally save individual images
    if args.save_individual:
        save_individual_images(images, args.individual_output_dir)
    
    print("\nGeneration complete!")
    if args.seed is not None:
        print(f"Seed used: {args.seed} (use same seed to reproduce these images)")


# ============ MAIN ============

def main():
    parser = argparse.ArgumentParser(description="Generate images using trained GAN model")
    parser.add_argument("--gui", action="store_true",
                        help="Launch Tkinter GUI interface (default if no other args)")
    parser.add_argument("--model_path", type=str, default="./models/generator_model.safetensors",
                        help="Path to the trained generator model (.safetensors file)")
    parser.add_argument("--output_path", type=str, default="./generated/output.png",
                        help="Path to save the generated image grid")
    parser.add_argument("--num_images", type=int, default=16,
                        help="Number of images to generate (CLI mode, default: 16)")
    parser.add_argument("--grid_size", type=int, default=None,
                        help="Grid size N for NxN layout (CLI mode)")
    parser.add_argument("--seed", type=int, default=None,
                        help="Random seed for reproducibility (CLI mode only)")
    parser.add_argument("--save_individual", action="store_true",
                        help="Save each generated image as a separate file (CLI mode)")
    parser.add_argument("--individual_output_dir", type=str, default="./generated/individual/",
                        help="Directory to save individual images (CLI mode)")
    
    args = parser.parse_args()
    
    # Determine mode: GUI if --gui flag or if no CLI-specific args provided
    cli_args_provided = (args.grid_size is not None or 
                         args.num_images != 16 or 
                         args.seed is not None or 
                         args.save_individual)
    
    if args.gui or not cli_args_provided:
        # GUI mode
        run_gui(args.model_path, args.output_path)
    else:
        # CLI mode
        run_cli(args)


if __name__ == "__main__":
    main()