| import os |
| import argparse |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import matplotlib.pyplot as plt |
| from pathlib import Path |
| from tkinter import * |
| from PIL import ImageTk, Image |
| import random |
| from safetensors.torch import load_file |
|
|
| |
| class Generator(nn.Module): |
| def __init__(self, codings_size, image_size, image_channels): |
| super(Generator, self).__init__() |
| |
| self.fc = nn.Linear(codings_size, 6 * 6 * 256, bias=False) |
| self.bn1 = nn.BatchNorm1d(6 * 6 * 256) |
| self.leaky_relu = nn.LeakyReLU(0.2) |
| |
| self.conv_transpose1 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=1, padding=2, bias=False) |
| self.bn2 = nn.BatchNorm2d(128) |
| |
| self.conv_transpose2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False) |
| self.bn3 = nn.BatchNorm2d(64) |
| |
| self.conv_transpose3 = nn.ConvTranspose2d(64, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False) |
| self.tanh = nn.Tanh() |
| |
| def forward(self, x): |
| x = self.fc(x) |
| x = self.bn1(x) |
| x = self.leaky_relu(x) |
| x = x.view(-1, 256, 6, 6) |
| |
| x = self.conv_transpose1(x) |
| x = self.bn2(x) |
| x = self.leaky_relu(x) |
| |
| x = self.conv_transpose2(x) |
| x = self.bn3(x) |
| x = self.leaky_relu(x) |
| |
| x = self.conv_transpose3(x) |
| x = self.tanh(x) |
| |
| return x |
|
|
| def load_model(model_path, device='cpu'): |
| """ |
| Load the trained generator model from safetensors format. |
| |
| Args: |
| model_path: Path to the .safetensors model file |
| device: Device to load the model on ('cpu' or 'cuda') |
| |
| Returns: |
| Loaded generator model and configuration parameters |
| """ |
| |
| state_dict = load_file(model_path) |
| |
| |
| from safetensors import safe_open |
| with safe_open(model_path, framework="pt", device=str(device)) as f: |
| metadata = f.metadata() |
| |
| |
| codings_size = int(metadata['codings_size']) |
| image_size = int(metadata['image_size']) |
| image_channels = int(metadata['image_channels']) |
| |
| |
| model = Generator(codings_size, image_size, image_channels) |
| model.load_state_dict(state_dict) |
| model.to(device) |
| model.eval() |
| |
| print(f"Model configuration: codings_size={codings_size}, image_size={image_size}, image_channels={image_channels}") |
| |
| return model, codings_size, image_size, image_channels |
|
|
| def generate_images(model, num_images, codings_size=100, seed=None, device='cpu'): |
| """ |
| Generate images using the trained GAN generator model. |
| |
| Args: |
| model: Loaded PyTorch generator model |
| num_images: Number of images to generate |
| codings_size: Size of the latent vector (default: 100) |
| seed: Random seed for reproducibility |
| device: Device to run generation on |
| |
| Returns: |
| Generated images as numpy array (scaled to [0, 1]) |
| """ |
| if seed is not None: |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| |
| |
| noise = torch.randn(num_images, codings_size, device=device) |
| |
| |
| with torch.no_grad(): |
| generated_images = model(noise) |
| |
| |
| generated_images = generated_images.permute(0, 2, 3, 1).cpu().numpy() |
| generated_images = (generated_images + 1) / 2 |
| |
| return generated_images |
|
|
| def save_image_grid(images, output_path, grid_size=None): |
| """ |
| Save generated images as a grid visualization. |
| |
| Args: |
| images: Array of generated images |
| output_path: Path to save the grid image |
| grid_size: Optional grid size (rows, cols). If None, auto-calculate square grid |
| """ |
| num_images = images.shape[0] |
| |
| if grid_size is None: |
| |
| grid_rows = int(np.sqrt(num_images)) |
| grid_cols = int(np.ceil(num_images / grid_rows)) |
| else: |
| grid_rows, grid_cols = grid_size |
| |
| fig = plt.figure(figsize=(grid_cols * 2, grid_rows * 2)) |
| |
| for i in range(min(num_images, grid_rows * grid_cols)): |
| plt.subplot(grid_rows, grid_cols, i + 1) |
| |
| |
| if images.shape[-1] == 1: |
| |
| plt.imshow(images[i, :, :, 0], cmap='gray') |
| else: |
| |
| plt.imshow(images[i]) |
| |
| plt.axis('off') |
| |
| plt.tight_layout() |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| plt.close() |
|
|
| def save_individual_images(images, output_dir, prefix="generated"): |
| """ |
| Save each generated image as a separate file. |
| |
| Args: |
| images: Array of generated images |
| output_dir: Directory to save individual images |
| prefix: Prefix for image filenames |
| """ |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| for i, img in enumerate(images): |
| |
| img_uint8 = (img * 255).astype(np.uint8) |
| |
| |
| output_path = output_dir / f"{prefix}_{i:04d}.png" |
| plt.imsave(output_path, img_uint8) |
| |
| print(f"Saved {len(images)} individual images to: {output_dir}") |
|
|
|
|
| |
|
|
| def run_gui(model_path, output_path): |
| """ |
| Run Tkinter GUI for interactive image generation. |
| """ |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| |
| print(f"Loading model from: {model_path}") |
| try: |
| model, codings_size, image_size, image_channels = load_model(model_path, device) |
| print("Model loaded successfully!") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| import traceback |
| traceback.print_exc() |
| return |
| |
| |
| output_dir = Path(output_path).parent |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| root = Tk() |
| root.title("CryptoPunk Generator") |
| root.columnconfigure([0, 1, 2, 3], minsize=200) |
| |
| |
| if not os.path.exists(output_path): |
| fig = plt.figure(figsize=(4, 4)) |
| plt.text(0.5, 0.5, 'Click a button to generate!', |
| ha='center', va='center', fontsize=16) |
| plt.axis('off') |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| |
| |
| img = ImageTk.PhotoImage(Image.open(output_path)) |
| panel = Label(root, image=img) |
| panel.grid(row=1, columnspan=4, sticky="nsew") |
| |
| def update_img(): |
| """Update the displayed image""" |
| new_img = ImageTk.PhotoImage(Image.open(output_path)) |
| panel.configure(image=new_img) |
| panel.image = new_img |
| |
| def generate(grid_size): |
| """Generate images in a grid""" |
| print(f"Generating {grid_size}x{grid_size} grid...") |
| n_img = grid_size * grid_size |
| seed = random.getrandbits(32) |
| |
| |
| images = generate_images(model, n_img, codings_size, seed, device) |
| |
| |
| fig = plt.figure(figsize=(8, 8)) |
| for i in range(n_img): |
| plt.subplot(grid_size, grid_size, i + 1) |
| plt.imshow(images[i, :, :, :]) |
| plt.axis('off') |
| plt.tight_layout() |
| plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| plt.close() |
| |
| print(f"Generated with seed: {seed}") |
| update_img() |
| |
| |
| btn_1 = Button(root, text="Generate 1 cryptopunk", command=lambda: generate(1)) |
| btn_3 = Button(root, text="Generate 3x3 cryptopunks", command=lambda: generate(3)) |
| btn_5 = Button(root, text="Generate 5x5 cryptopunks", command=lambda: generate(5)) |
| btn_q = Button(root, text="Terminate", command=root.quit) |
| |
| btn_1.grid(row=0, column=0, sticky="nsew") |
| btn_3.grid(row=0, column=1, sticky="nsew") |
| btn_5.grid(row=0, column=2, sticky="nsew") |
| btn_q.grid(row=0, column=3, sticky="nsew") |
| |
| print("\nGUI started! Click buttons to generate images.") |
| root.mainloop() |
|
|
|
|
| |
|
|
| def run_cli(args): |
| """ |
| Run command-line interface for batch image generation. |
| """ |
| |
| if not os.path.exists(args.model_path): |
| print(f"Error: Model not found at {args.model_path}") |
| print("Please train the model first using trainer.py") |
| return |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| |
| print(f"Loading model from: {args.model_path}") |
| try: |
| model, codings_size, image_size, image_channels = load_model(args.model_path, device) |
| print("Model loaded successfully!") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| import traceback |
| traceback.print_exc() |
| return |
| |
| |
| if args.grid_size is not None: |
| num_images = args.grid_size * args.grid_size |
| grid_size = (args.grid_size, args.grid_size) |
| print(f"Generating {num_images} images in a {args.grid_size}x{args.grid_size} grid") |
| else: |
| num_images = args.num_images |
| grid_size = None |
| print(f"Generating {num_images} images") |
| |
| |
| print("Generating images...") |
| images = generate_images(model, num_images, codings_size, args.seed, device) |
| print(f"Generated images shape: {images.shape}") |
| print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]") |
| |
| |
| output_dir = Path(args.output_path).parent |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| save_image_grid(images, args.output_path, grid_size) |
| print(f"Grid image saved to: {args.output_path}") |
| |
| |
| if args.save_individual: |
| save_individual_images(images, args.individual_output_dir) |
| |
| print("\nGeneration complete!") |
| if args.seed is not None: |
| print(f"Seed used: {args.seed} (use same seed to reproduce these images)") |
|
|
|
|
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate images using trained GAN model") |
| parser.add_argument("--gui", action="store_true", |
| help="Launch Tkinter GUI interface (default if no other args)") |
| parser.add_argument("--model_path", type=str, default="./models/generator_model.safetensors", |
| help="Path to the trained generator model (.safetensors file)") |
| parser.add_argument("--output_path", type=str, default="./generated/output.png", |
| help="Path to save the generated image grid") |
| parser.add_argument("--num_images", type=int, default=16, |
| help="Number of images to generate (CLI mode, default: 16)") |
| parser.add_argument("--grid_size", type=int, default=None, |
| help="Grid size N for NxN layout (CLI mode)") |
| parser.add_argument("--seed", type=int, default=None, |
| help="Random seed for reproducibility (CLI mode only)") |
| parser.add_argument("--save_individual", action="store_true", |
| help="Save each generated image as a separate file (CLI mode)") |
| parser.add_argument("--individual_output_dir", type=str, default="./generated/individual/", |
| help="Directory to save individual images (CLI mode)") |
| |
| args = parser.parse_args() |
| |
| |
| cli_args_provided = (args.grid_size is not None or |
| args.num_images != 16 or |
| args.seed is not None or |
| args.save_individual) |
| |
| if args.gui or not cli_args_provided: |
| |
| run_gui(args.model_path, args.output_path) |
| else: |
| |
| run_cli(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|