chatpig commited on
Commit
409857b
·
verified ·
1 Parent(s): 02bacb8

Upload 2 files

Browse files
Files changed (2) hide show
  1. generator.py +360 -0
  2. trainer.py +244 -0
generator.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import matplotlib.pyplot as plt
7
+ from pathlib import Path
8
+ from tkinter import *
9
+ from PIL import ImageTk, Image
10
+ import random
11
+ from safetensors.torch import load_file
12
+
13
+ # Generator model definition (must match the training architecture)
14
+ class Generator(nn.Module):
15
+ def __init__(self, codings_size, image_size, image_channels):
16
+ super(Generator, self).__init__()
17
+
18
+ self.fc = nn.Linear(codings_size, 6 * 6 * 256, bias=False)
19
+ self.bn1 = nn.BatchNorm1d(6 * 6 * 256)
20
+ self.leaky_relu = nn.LeakyReLU(0.2)
21
+
22
+ self.conv_transpose1 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=1, padding=2, bias=False)
23
+ self.bn2 = nn.BatchNorm2d(128)
24
+
25
+ self.conv_transpose2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(64)
27
+
28
+ self.conv_transpose3 = nn.ConvTranspose2d(64, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
29
+ self.tanh = nn.Tanh()
30
+
31
+ def forward(self, x):
32
+ x = self.fc(x)
33
+ x = self.bn1(x)
34
+ x = self.leaky_relu(x)
35
+ x = x.view(-1, 256, 6, 6)
36
+
37
+ x = self.conv_transpose1(x)
38
+ x = self.bn2(x)
39
+ x = self.leaky_relu(x)
40
+
41
+ x = self.conv_transpose2(x)
42
+ x = self.bn3(x)
43
+ x = self.leaky_relu(x)
44
+
45
+ x = self.conv_transpose3(x)
46
+ x = self.tanh(x)
47
+
48
+ return x
49
+
50
+ def load_model(model_path, device='cpu'):
51
+ """
52
+ Load the trained generator model from safetensors format.
53
+
54
+ Args:
55
+ model_path: Path to the .safetensors model file
56
+ device: Device to load the model on ('cpu' or 'cuda')
57
+
58
+ Returns:
59
+ Loaded generator model and configuration parameters
60
+ """
61
+ # Load state dict and metadata from safetensors
62
+ state_dict = load_file(model_path)
63
+
64
+ # Load metadata from safetensors file
65
+ from safetensors import safe_open
66
+ with safe_open(model_path, framework="pt", device=str(device)) as f:
67
+ metadata = f.metadata()
68
+
69
+ # Extract model configuration from metadata
70
+ codings_size = int(metadata['codings_size'])
71
+ image_size = int(metadata['image_size'])
72
+ image_channels = int(metadata['image_channels'])
73
+
74
+ # Create generator model
75
+ model = Generator(codings_size, image_size, image_channels)
76
+ model.load_state_dict(state_dict)
77
+ model.to(device)
78
+ model.eval()
79
+
80
+ print(f"Model configuration: codings_size={codings_size}, image_size={image_size}, image_channels={image_channels}")
81
+
82
+ return model, codings_size, image_size, image_channels
83
+
84
+ def generate_images(model, num_images, codings_size=100, seed=None, device='cpu'):
85
+ """
86
+ Generate images using the trained GAN generator model.
87
+
88
+ Args:
89
+ model: Loaded PyTorch generator model
90
+ num_images: Number of images to generate
91
+ codings_size: Size of the latent vector (default: 100)
92
+ seed: Random seed for reproducibility
93
+ device: Device to run generation on
94
+
95
+ Returns:
96
+ Generated images as numpy array (scaled to [0, 1])
97
+ """
98
+ if seed is not None:
99
+ torch.manual_seed(seed)
100
+ np.random.seed(seed)
101
+
102
+ # Generate random noise as input
103
+ noise = torch.randn(num_images, codings_size, device=device)
104
+
105
+ # Generate images
106
+ with torch.no_grad():
107
+ generated_images = model(noise)
108
+
109
+ # Convert from CHW to HWC format and scale from [-1, 1] to [0, 1]
110
+ generated_images = generated_images.permute(0, 2, 3, 1).cpu().numpy()
111
+ generated_images = (generated_images + 1) / 2 # Scale to [0, 1]
112
+
113
+ return generated_images
114
+
115
+ def save_image_grid(images, output_path, grid_size=None):
116
+ """
117
+ Save generated images as a grid visualization.
118
+
119
+ Args:
120
+ images: Array of generated images
121
+ output_path: Path to save the grid image
122
+ grid_size: Optional grid size (rows, cols). If None, auto-calculate square grid
123
+ """
124
+ num_images = images.shape[0]
125
+
126
+ if grid_size is None:
127
+ # Auto-calculate square grid
128
+ grid_rows = int(np.sqrt(num_images))
129
+ grid_cols = int(np.ceil(num_images / grid_rows))
130
+ else:
131
+ grid_rows, grid_cols = grid_size
132
+
133
+ fig = plt.figure(figsize=(grid_cols * 2, grid_rows * 2))
134
+
135
+ for i in range(min(num_images, grid_rows * grid_cols)):
136
+ plt.subplot(grid_rows, grid_cols, i + 1)
137
+
138
+ # Handle different image formats
139
+ if images.shape[-1] == 1:
140
+ # Grayscale
141
+ plt.imshow(images[i, :, :, 0], cmap='gray')
142
+ else:
143
+ # RGB or RGBA
144
+ plt.imshow(images[i])
145
+
146
+ plt.axis('off')
147
+
148
+ plt.tight_layout()
149
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
150
+ plt.close()
151
+
152
+ def save_individual_images(images, output_dir, prefix="generated"):
153
+ """
154
+ Save each generated image as a separate file.
155
+
156
+ Args:
157
+ images: Array of generated images
158
+ output_dir: Directory to save individual images
159
+ prefix: Prefix for image filenames
160
+ """
161
+ output_dir = Path(output_dir)
162
+ output_dir.mkdir(parents=True, exist_ok=True)
163
+
164
+ for i, img in enumerate(images):
165
+ # Convert to uint8 format (0-255)
166
+ img_uint8 = (img * 255).astype(np.uint8)
167
+
168
+ # Save using matplotlib to handle RGBA correctly
169
+ output_path = output_dir / f"{prefix}_{i:04d}.png"
170
+ plt.imsave(output_path, img_uint8)
171
+
172
+ print(f"Saved {len(images)} individual images to: {output_dir}")
173
+
174
+
175
+ # ============ TKINTER UI MODE ============
176
+
177
+ def run_gui(model_path, output_path):
178
+ """
179
+ Run Tkinter GUI for interactive image generation.
180
+ """
181
+ # Set device
182
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
183
+ print(f"Using device: {device}")
184
+
185
+ # Load model once at startup
186
+ print(f"Loading model from: {model_path}")
187
+ try:
188
+ model, codings_size, image_size, image_channels = load_model(model_path, device)
189
+ print("Model loaded successfully!")
190
+ except Exception as e:
191
+ print(f"Error loading model: {e}")
192
+ import traceback
193
+ traceback.print_exc()
194
+ return
195
+
196
+ # Create output directory
197
+ output_dir = Path(output_path).parent
198
+ output_dir.mkdir(parents=True, exist_ok=True)
199
+
200
+ # Initialize Tkinter window
201
+ root = Tk()
202
+ root.title("CryptoPunk Generator")
203
+ root.columnconfigure([0, 1, 2, 3], minsize=200)
204
+
205
+ # Create a placeholder image if output doesn't exist
206
+ if not os.path.exists(output_path):
207
+ fig = plt.figure(figsize=(4, 4))
208
+ plt.text(0.5, 0.5, 'Click a button to generate!',
209
+ ha='center', va='center', fontsize=16)
210
+ plt.axis('off')
211
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
212
+ plt.close()
213
+
214
+ # Load and display initial image
215
+ img = ImageTk.PhotoImage(Image.open(output_path))
216
+ panel = Label(root, image=img)
217
+ panel.grid(row=1, columnspan=4, sticky="nsew")
218
+
219
+ def update_img():
220
+ """Update the displayed image"""
221
+ new_img = ImageTk.PhotoImage(Image.open(output_path))
222
+ panel.configure(image=new_img)
223
+ panel.image = new_img
224
+
225
+ def generate(grid_size):
226
+ """Generate images in a grid"""
227
+ print(f"Generating {grid_size}x{grid_size} grid...")
228
+ n_img = grid_size * grid_size
229
+ seed = random.getrandbits(32)
230
+
231
+ # Generate images
232
+ images = generate_images(model, n_img, codings_size, seed, device)
233
+
234
+ # Create grid visualization
235
+ fig = plt.figure(figsize=(8, 8))
236
+ for i in range(n_img):
237
+ plt.subplot(grid_size, grid_size, i + 1)
238
+ plt.imshow(images[i, :, :, :])
239
+ plt.axis('off')
240
+ plt.tight_layout()
241
+ plt.savefig(output_path, dpi=150, bbox_inches='tight')
242
+ plt.close()
243
+
244
+ print(f"Generated with seed: {seed}")
245
+ update_img()
246
+
247
+ # Create buttons
248
+ btn_1 = Button(root, text="Generate 1 cryptopunk", command=lambda: generate(1))
249
+ btn_3 = Button(root, text="Generate 3x3 cryptopunks", command=lambda: generate(3))
250
+ btn_5 = Button(root, text="Generate 5x5 cryptopunks", command=lambda: generate(5))
251
+ btn_q = Button(root, text="Terminate", command=root.quit)
252
+
253
+ btn_1.grid(row=0, column=0, sticky="nsew")
254
+ btn_3.grid(row=0, column=1, sticky="nsew")
255
+ btn_5.grid(row=0, column=2, sticky="nsew")
256
+ btn_q.grid(row=0, column=3, sticky="nsew")
257
+
258
+ print("\nGUI started! Click buttons to generate images.")
259
+ root.mainloop()
260
+
261
+
262
+ # ============ CLI MODE ============
263
+
264
+ def run_cli(args):
265
+ """
266
+ Run command-line interface for batch image generation.
267
+ """
268
+ # Check if model exists
269
+ if not os.path.exists(args.model_path):
270
+ print(f"Error: Model not found at {args.model_path}")
271
+ print("Please train the model first using trainer.py")
272
+ return
273
+
274
+ # Set device
275
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
276
+ print(f"Using device: {device}")
277
+
278
+ # Load the trained model
279
+ print(f"Loading model from: {args.model_path}")
280
+ try:
281
+ model, codings_size, image_size, image_channels = load_model(args.model_path, device)
282
+ print("Model loaded successfully!")
283
+ except Exception as e:
284
+ print(f"Error loading model: {e}")
285
+ import traceback
286
+ traceback.print_exc()
287
+ return
288
+
289
+ # Calculate actual number of images for grid
290
+ if args.grid_size is not None:
291
+ num_images = args.grid_size * args.grid_size
292
+ grid_size = (args.grid_size, args.grid_size)
293
+ print(f"Generating {num_images} images in a {args.grid_size}x{args.grid_size} grid")
294
+ else:
295
+ num_images = args.num_images
296
+ grid_size = None
297
+ print(f"Generating {num_images} images")
298
+
299
+ # Generate images
300
+ print("Generating images...")
301
+ images = generate_images(model, num_images, codings_size, args.seed, device)
302
+ print(f"Generated images shape: {images.shape}")
303
+ print(f"Image value range: [{images.min():.3f}, {images.max():.3f}]")
304
+
305
+ # Create output directory if needed
306
+ output_dir = Path(args.output_path).parent
307
+ output_dir.mkdir(parents=True, exist_ok=True)
308
+
309
+ # Save grid visualization
310
+ save_image_grid(images, args.output_path, grid_size)
311
+ print(f"Grid image saved to: {args.output_path}")
312
+
313
+ # Optionally save individual images
314
+ if args.save_individual:
315
+ save_individual_images(images, args.individual_output_dir)
316
+
317
+ print("\nGeneration complete!")
318
+ if args.seed is not None:
319
+ print(f"Seed used: {args.seed} (use same seed to reproduce these images)")
320
+
321
+
322
+ # ============ MAIN ============
323
+
324
+ def main():
325
+ parser = argparse.ArgumentParser(description="Generate images using trained GAN model")
326
+ parser.add_argument("--gui", action="store_true",
327
+ help="Launch Tkinter GUI interface (default if no other args)")
328
+ parser.add_argument("--model_path", type=str, default="./models/generator_model.safetensors",
329
+ help="Path to the trained generator model (.safetensors file)")
330
+ parser.add_argument("--output_path", type=str, default="./generated/output.png",
331
+ help="Path to save the generated image grid")
332
+ parser.add_argument("--num_images", type=int, default=16,
333
+ help="Number of images to generate (CLI mode, default: 16)")
334
+ parser.add_argument("--grid_size", type=int, default=None,
335
+ help="Grid size N for NxN layout (CLI mode)")
336
+ parser.add_argument("--seed", type=int, default=None,
337
+ help="Random seed for reproducibility (CLI mode only)")
338
+ parser.add_argument("--save_individual", action="store_true",
339
+ help="Save each generated image as a separate file (CLI mode)")
340
+ parser.add_argument("--individual_output_dir", type=str, default="./generated/individual/",
341
+ help="Directory to save individual images (CLI mode)")
342
+
343
+ args = parser.parse_args()
344
+
345
+ # Determine mode: GUI if --gui flag or if no CLI-specific args provided
346
+ cli_args_provided = (args.grid_size is not None or
347
+ args.num_images != 16 or
348
+ args.seed is not None or
349
+ args.save_individual)
350
+
351
+ if args.gui or not cli_args_provided:
352
+ # GUI mode
353
+ run_gui(args.model_path, args.output_path)
354
+ else:
355
+ # CLI mode
356
+ run_cli(args)
357
+
358
+
359
+ if __name__ == "__main__":
360
+ main()
trainer.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import argparse
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ from torch.utils.data import Dataset, DataLoader
10
+ from torchvision import transforms
11
+ from PIL import Image
12
+ import matplotlib.pyplot as plt
13
+ from safetensors.torch import save_file
14
+
15
+ def plot_multiple_images(images, n_cols, epoch):
16
+ n_cols = n_cols or len(images)
17
+ n_rows = (len(images) - 1) // n_cols + 1
18
+ # Convert from CHW to HWC format for plotting
19
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
20
+ if images.shape[-1] == 1:
21
+ images = np.squeeze(images, axis=-1)
22
+ plt.figure(figsize=(n_cols, n_rows))
23
+ for index, image in enumerate(images):
24
+ image = ((image + 1) / 2) # scale back
25
+ plt.subplot(n_rows, n_cols, index + 1)
26
+ plt.imshow(image, cmap="binary")
27
+ plt.axis("off")
28
+ plt.savefig(f'{args.images_output_path}epoch_{epoch}.png')
29
+ plt.close() # Close the figure to free memory
30
+
31
+ class ImageDataset(Dataset):
32
+ def __init__(self, file_paths, image_size, image_channels):
33
+ self.file_paths = file_paths
34
+ self.image_size = image_size
35
+ self.image_channels = image_channels
36
+ self.transform = transforms.Compose([
37
+ transforms.Resize((image_size, image_size)),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize([0.5] * image_channels, [0.5] * image_channels) # Scale to [-1, 1]
40
+ ])
41
+
42
+ def __len__(self):
43
+ return len(self.file_paths)
44
+
45
+ def __getitem__(self, idx):
46
+ img_path = self.file_paths[idx]
47
+ image = Image.open(img_path).convert('RGBA' if self.image_channels == 4 else 'RGB')
48
+ image = self.transform(image)
49
+ return image
50
+
51
+ def get_dataloader(inputs, batch_size, image_size, image_channels):
52
+ if type(inputs) == dict:
53
+ file_paths = inputs["paths"].tolist()
54
+ else:
55
+ file_paths = glob.glob(f"{inputs}/*")
56
+
57
+ dataset = ImageDataset(file_paths, image_size, image_channels)
58
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2)
59
+ return dataloader
60
+
61
+ def discriminator_loss(real_output, fake_output, criterion):
62
+ real_loss = criterion(real_output, torch.ones_like(real_output))
63
+ fake_loss = criterion(fake_output, torch.zeros_like(fake_output))
64
+ total_loss = real_loss + fake_loss
65
+ return total_loss
66
+
67
+ def generator_loss(fake_output, criterion):
68
+ return criterion(fake_output, torch.ones_like(fake_output))
69
+
70
+ def train_step(images, batch_size, codings_size, generator, discriminator, gen_optimizer, disc_optimizer, criterion, device):
71
+ noise = torch.randn(batch_size, codings_size, device=device)
72
+
73
+ # Train Discriminator
74
+ disc_optimizer.zero_grad()
75
+ generated_images = generator(noise)
76
+ real_output = discriminator(images)
77
+ fake_output = discriminator(generated_images.detach())
78
+ disc_loss = discriminator_loss(real_output, fake_output, criterion)
79
+ disc_loss.backward()
80
+ disc_optimizer.step()
81
+
82
+ # Train Generator
83
+ gen_optimizer.zero_grad()
84
+ fake_output = discriminator(generated_images)
85
+ gen_loss = generator_loss(fake_output, criterion)
86
+ gen_loss.backward()
87
+ gen_optimizer.step()
88
+
89
+ return gen_loss.item(), disc_loss.item()
90
+
91
+ def train(dataloader, epochs, batch_size, codings_size, generator, discriminator, gen_optimizer, disc_optimizer, criterion, device):
92
+ generator.train()
93
+ discriminator.train()
94
+
95
+ for epoch in range(epochs):
96
+ for image_batch in dataloader:
97
+ image_batch = image_batch.to(device)
98
+ gen_loss, disc_loss = train_step(image_batch, batch_size, codings_size, generator, discriminator,
99
+ gen_optimizer, disc_optimizer, criterion, device)
100
+
101
+ print(f"Epoch {epoch+1}/{epochs} - Gen Loss: {gen_loss:.4f}, Disc Loss: {disc_loss:.4f}")
102
+ if args.images_output_path:
103
+ generator.eval()
104
+ with torch.no_grad():
105
+ noise = torch.randn(batch_size, codings_size, device=device)
106
+ display_images = generator(noise)
107
+ plot_multiple_images(display_images, 8, epoch)
108
+ generator.train()
109
+
110
+ class Generator(nn.Module):
111
+ def __init__(self, codings_size, image_size, image_channels):
112
+ super(Generator, self).__init__()
113
+
114
+ self.fc = nn.Linear(codings_size, 6 * 6 * 256, bias=False)
115
+ self.bn1 = nn.BatchNorm1d(6 * 6 * 256)
116
+ self.leaky_relu = nn.LeakyReLU(0.2)
117
+
118
+ self.conv_transpose1 = nn.ConvTranspose2d(256, 128, kernel_size=5, stride=1, padding=2, bias=False)
119
+ self.bn2 = nn.BatchNorm2d(128)
120
+
121
+ self.conv_transpose2 = nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
122
+ self.bn3 = nn.BatchNorm2d(64)
123
+
124
+ self.conv_transpose3 = nn.ConvTranspose2d(64, image_channels, kernel_size=5, stride=2, padding=2, output_padding=1, bias=False)
125
+ self.tanh = nn.Tanh()
126
+
127
+ def forward(self, x):
128
+ x = self.fc(x)
129
+ x = self.bn1(x)
130
+ x = self.leaky_relu(x)
131
+ x = x.view(-1, 256, 6, 6)
132
+
133
+ x = self.conv_transpose1(x)
134
+ x = self.bn2(x)
135
+ x = self.leaky_relu(x)
136
+
137
+ x = self.conv_transpose2(x)
138
+ x = self.bn3(x)
139
+ x = self.leaky_relu(x)
140
+
141
+ x = self.conv_transpose3(x)
142
+ x = self.tanh(x)
143
+
144
+ return x
145
+
146
+ class Discriminator(nn.Module):
147
+ def __init__(self, image_size, image_channels):
148
+ super(Discriminator, self).__init__()
149
+
150
+ self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=4, stride=2, padding=1)
151
+ self.leaky_relu1 = nn.LeakyReLU(0.2)
152
+ self.dropout1 = nn.Dropout(0.4)
153
+
154
+ self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
155
+ self.leaky_relu2 = nn.LeakyReLU(0.2)
156
+ self.dropout2 = nn.Dropout(0.4)
157
+
158
+ self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
159
+ self.leaky_relu3 = nn.LeakyReLU(0.2)
160
+ self.dropout3 = nn.Dropout(0.4)
161
+
162
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
163
+ self.fc = nn.Linear(256, 1)
164
+ self.sigmoid = nn.Sigmoid()
165
+
166
+ def forward(self, x):
167
+ x = self.conv1(x)
168
+ x = self.leaky_relu1(x)
169
+ x = self.dropout1(x)
170
+
171
+ x = self.conv2(x)
172
+ x = self.leaky_relu2(x)
173
+ x = self.dropout2(x)
174
+
175
+ x = self.conv3(x)
176
+ x = self.leaky_relu3(x)
177
+ x = self.dropout3(x)
178
+
179
+ x = self.global_avg_pool(x)
180
+ x = x.view(x.size(0), -1)
181
+ x = self.fc(x)
182
+ x = self.sigmoid(x)
183
+
184
+ return x
185
+
186
+
187
+ if __name__ == "__main__":
188
+ parser = argparse.ArgumentParser()
189
+ parser.add_argument("--data_path", default="./data/attributes.csv", help="Path to dataset (attributes.csv)")
190
+ parser.add_argument("--images_path", default="./data/images/", help="Path to images")
191
+ parser.add_argument("--model_output_path", default="./models/", help="Path to output the generator model")
192
+ parser.add_argument("--images_output_path", default="./gen_images/", help="Path to output generated images during training")
193
+ parser.add_argument("--codings_size", type=int, default=100, help="Size of the latent z vector")
194
+ parser.add_argument("--image_size", type=int, default=24, help="Images size")
195
+ parser.add_argument("--image_channels", type=int, default=4, help="Images channels")
196
+ parser.add_argument("--batch_size", type=int, default=16, help="Input batch size")
197
+ parser.add_argument("--epochs", type=int, default=50, help="Number of epochs")
198
+ args = parser.parse_args()
199
+ print(args)
200
+
201
+ # Set device
202
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
203
+ print(f"Using device: {device}")
204
+
205
+ if args.images_output_path and (os.path.exists(args.images_output_path) == False):
206
+ print(f"Saving generated images during training at: {args.images_output_path}")
207
+ os.mkdir(args.images_output_path)
208
+
209
+ print("Loading the dataset...")
210
+ df = pd.read_csv(args.data_path)
211
+ df.id = df.id.apply(lambda x: f"{args.images_path}punk{x:03d}.png")
212
+
213
+ print("Creating PyTorch DataLoader...")
214
+ dataloader = get_dataloader({"paths": df.id}, args.batch_size, args.image_size, args.image_channels)
215
+
216
+ generator = Generator(args.codings_size, args.image_size, args.image_channels).to(device)
217
+ print("Generator architecture:")
218
+ print(generator)
219
+
220
+ discriminator = Discriminator(args.image_size, args.image_channels).to(device)
221
+ print("Discriminator architecture:")
222
+ print(discriminator)
223
+
224
+ gen_optimizer = optim.RMSprop(generator.parameters(), lr=0.001)
225
+ disc_optimizer = optim.RMSprop(discriminator.parameters(), lr=0.001)
226
+ criterion = nn.BCELoss()
227
+
228
+ print("Training model...")
229
+ train(dataloader, args.epochs, args.batch_size, args.codings_size, generator, discriminator,
230
+ gen_optimizer, disc_optimizer, criterion, device)
231
+
232
+ print(f"Saving model at: {args.model_output_path}...")
233
+ os.makedirs(args.model_output_path, exist_ok=True)
234
+ model_path = args.model_output_path if args.model_output_path.endswith('.safetensors') else os.path.join(args.model_output_path, 'generator_model.safetensors')
235
+
236
+ # Save the generator model in safetensors format
237
+ # Metadata is stored as strings in safetensors
238
+ metadata = {
239
+ 'codings_size': str(args.codings_size),
240
+ 'image_size': str(args.image_size),
241
+ 'image_channels': str(args.image_channels)
242
+ }
243
+ save_file(generator.state_dict(), model_path, metadata=metadata)
244
+ print(f"Model saved to: {model_path}")