LH-Tech-AI commited on
Commit
2e4ca0d
·
verified ·
1 Parent(s): 2fb19e5

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +243 -0
train.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import torchvision.transforms as transforms
5
+ from torch.utils.data import DataLoader, Dataset
6
+ from datasets import load_dataset
7
+ import matplotlib
8
+ matplotlib.use('Agg')
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ from torch.cuda.amp import autocast, GradScaler
12
+ import torchvision.utils as vutils
13
+ from IPython.display import display
14
+
15
+ # --- FaceGen v1 Config ---
16
+ BATCH_SIZE = 128
17
+ IMAGE_SIZE = 128
18
+ CHANNELS = 3
19
+ Z_DIM = 128
20
+ FEATURES_G = 256
21
+ FEATURES_D = 128
22
+ EPOCHS = 250
23
+ LR = 0.0002
24
+ BETA1 = 0.5
25
+
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ print(f"Training will run on: {device}")
28
+
29
+ print("Loading face dataset...")
30
+ hf_dataset = load_dataset("SDbiaseval/faces", split="train")
31
+
32
+ transform = transforms.Compose([
33
+ transforms.Resize(IMAGE_SIZE),
34
+ transforms.CenterCrop(IMAGE_SIZE),
35
+ transforms.RandomHorizontalFlip(),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
38
+ ])
39
+
40
+ class FaceDataset(Dataset):
41
+ def __init__(self, hf_ds, transform):
42
+ self.hf_ds = hf_ds
43
+ self.transform = transform
44
+ def __len__(self):
45
+ return len(self.hf_ds)
46
+ def __getitem__(self, idx):
47
+ img = self.hf_ds[idx]['image'].convert("RGB")
48
+ return self.transform(img)
49
+
50
+ dataset = FaceDataset(hf_dataset, transform)
51
+
52
+ dataloader = DataLoader(
53
+ dataset,
54
+ batch_size=BATCH_SIZE,
55
+ shuffle=True,
56
+ drop_last=True,
57
+ num_workers=4,
58
+ pin_memory=True
59
+ )
60
+ print(f"Dataset ready with {len(dataset)} faces.")
61
+
62
+ class Generator(nn.Module):
63
+ def __init__(self, z_dim, channels, features_g):
64
+ super(Generator, self).__init__()
65
+ self.net = nn.Sequential(
66
+ # Input: Z_DIM x 1 x 1
67
+ nn.ConvTranspose2d(z_dim, features_g * 16, 4, 1, 0, bias=False),
68
+ nn.BatchNorm2d(features_g * 16),
69
+ nn.ReLU(True),
70
+ # 4x4 -> 8x8
71
+ nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1, bias=False),
72
+ nn.BatchNorm2d(features_g * 8),
73
+ nn.ReLU(True),
74
+ # 8x8 -> 16x16
75
+ nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
76
+ nn.BatchNorm2d(features_g * 4),
77
+ nn.ReLU(True),
78
+ # 16x16 -> 32x32
79
+ nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
80
+ nn.BatchNorm2d(features_g * 2),
81
+ nn.ReLU(True),
82
+ # 32x32 -> 64x64
83
+ nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
84
+ nn.BatchNorm2d(features_g),
85
+ nn.ReLU(True),
86
+ # 64x64 -> 128x128
87
+ nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False),
88
+ nn.Tanh()
89
+ )
90
+
91
+ def forward(self, x):
92
+ return self.net(x)
93
+
94
+ netG = Generator(Z_DIM, CHANNELS, FEATURES_G).to(device)
95
+
96
+ class Discriminator(nn.Module):
97
+ def __init__(self, channels, features_d):
98
+ super(Discriminator, self).__init__()
99
+ self.net = nn.Sequential(
100
+ # 128x128 -> 64x64
101
+ nn.Conv2d(channels, features_d, 4, 2, 1, bias=False),
102
+ nn.LeakyReLU(0.2, inplace=True),
103
+ # 64x64 -> 32x32
104
+ nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
105
+ nn.BatchNorm2d(features_d * 2),
106
+ nn.LeakyReLU(0.2, inplace=True),
107
+ # 32x32 -> 16x16
108
+ nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
109
+ nn.BatchNorm2d(features_d * 4),
110
+ nn.LeakyReLU(0.2, inplace=True),
111
+ # 16x16 -> 8x8
112
+ nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
113
+ nn.BatchNorm2d(features_d * 8),
114
+ nn.LeakyReLU(0.2, inplace=True),
115
+ # 8x8 -> 4x4
116
+ nn.Conv2d(features_d * 8, features_d * 16, 4, 2, 1, bias=False),
117
+ nn.BatchNorm2d(features_d * 16),
118
+ nn.LeakyReLU(0.2, inplace=True),
119
+ # 4x4 -> 1x1
120
+ nn.Conv2d(features_d * 16, 1, 4, 1, 0, bias=False),
121
+ )
122
+
123
+ def forward(self, x):
124
+ return self.net(x)
125
+
126
+ netD = Discriminator(CHANNELS, FEATURES_D).to(device)
127
+
128
+ def weights_init(m):
129
+ classname = m.__class__.__name__
130
+ if classname.find('Conv') != -1:
131
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
132
+ elif classname.find('BatchNorm') != -1:
133
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
134
+ nn.init.constant_(m.bias.data, 0)
135
+
136
+ netG.apply(weights_init)
137
+ netD.apply(weights_init)
138
+
139
+ criterion = nn.BCEWithLogitsLoss()
140
+
141
+ optG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999))
142
+ optD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999))
143
+
144
+ fixed_noise = torch.randn(64, Z_DIM, 1, 1, device=device)
145
+
146
+ scaler = torch.amp.GradScaler('cuda')
147
+
148
+ print(f"Model size G: {sum(p.numel() for p in netG.parameters())/1e6:.2f}M parameters")
149
+ print(f"Model size D: {sum(p.numel() for p in netD.parameters())/1e6:.2f}M parameters")
150
+
151
+ real_label_val = 0.9
152
+ fake_label_val = 0.1
153
+
154
+ for epoch in range(EPOCHS):
155
+ for i, real_images in enumerate(dataloader):
156
+ real_images = real_images.to(device)
157
+ b_size = real_images.size(0)
158
+
159
+ # --- Discriminator Update ---
160
+ optD.zero_grad()
161
+ with torch.amp.autocast('cuda'):
162
+ output_real = netD(real_images).view(-1)
163
+ lossD_real = criterion(output_real, torch.full((b_size,), real_label_val, device=device))
164
+
165
+ noise = torch.randn(b_size, Z_DIM, 1, 1, device=device)
166
+ fake_images = netG(noise)
167
+ output_fake = netD(fake_images.detach()).view(-1)
168
+ lossD_fake = criterion(output_fake, torch.full((b_size,), fake_label_val, device=device))
169
+ lossD = lossD_real + lossD_fake
170
+
171
+ scaler.scale(lossD).backward()
172
+ scaler.step(optD)
173
+
174
+ # --- Generator Update ---
175
+ optG.zero_grad()
176
+ with torch.amp.autocast('cuda'):
177
+ output_fake_G = netD(fake_images).view(-1)
178
+ lossG = criterion(output_fake_G, torch.full((b_size,), real_label_val, device=device))
179
+
180
+ scaler.scale(lossG).backward()
181
+ scaler.step(optG)
182
+ scaler.update()
183
+
184
+ if i % 10 == 0:
185
+ print(f"E[{epoch}] I[{i}/{len(dataloader)}] Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}")
186
+
187
+ if (epoch + 1) % 10 == 0 or epoch == 0:
188
+ netG.eval()
189
+ with torch.no_grad():
190
+ with torch.amp.autocast('cuda'):
191
+ sample = netG(fixed_noise[0:1]).detach().cpu().float()
192
+
193
+ vutils.save_image(sample, f"face_sample_epoch_{epoch}.png", normalize=True)
194
+ print(f"--> Sample saved: face_sample_epoch_{epoch}.png")
195
+
196
+ netG.train()
197
+
198
+ if (epoch + 1) % 50 == 0:
199
+ torch.save({
200
+ 'epoch': epoch,
201
+ 'model_state_dict': netG.state_dict(),
202
+ 'optimizer_state_dict': optG.state_dict(),
203
+ 'netD_state_dict': netD.state_dict(),
204
+ 'optD_state_dict': optD.state_dict(),
205
+ 'scaler_state_dict': scaler.state_dict(),
206
+ }, f'facegen_v1_checkpoint_epoch_{epoch+1}.ckpt')
207
+ print(f"--> Sicherheits-Checkpoint gespeichert: Epoche {epoch+1}")
208
+
209
+ torch.save({
210
+ 'epoch': EPOCHS,
211
+ 'model_state_dict': netG.state_dict(),
212
+ 'optimizer_state_dict': optG.state_dict(),
213
+ 'netD_state_dict': netD.state_dict(),
214
+ 'optD_state_dict': optD.state_dict(),
215
+ 'scaler_state_dict': scaler.state_dict(),
216
+ }, 'facegen_v1_full_checkpoint.ckpt')
217
+
218
+ torch.save(netG.state_dict(), 'facegen_v1_generator_only.pth')
219
+
220
+ print("Files saved: Training finished.")
221
+
222
+ print("Doing professionell gallery export...")
223
+
224
+ # --- FaceGen v2: Professional Gallery Export (Fix) ---
225
+ netG.eval()
226
+
227
+ with torch.no_grad():
228
+ with torch.amp.autocast('cuda'):
229
+ fake_faces = netG(fixed_noise).detach().cpu().float()
230
+
231
+ grid = vutils.make_grid(fake_faces, padding=4, normalize=True)
232
+ grid_np = grid.numpy().transpose((1, 2, 0))
233
+
234
+ plt.figure(figsize=(12, 12), facecolor='#111111')
235
+ plt.imshow(grid_np, interpolation='bilinear')
236
+ plt.axis("off")
237
+
238
+ plt.title(f"FaceGen v1 | Training Complete | {FEATURES_G}x{FEATURES_D} Filters",
239
+ color='white', fontsize=16, fontweight='bold', pad=20)
240
+
241
+ plt.tight_layout()
242
+
243
+ plt.savefig("facegen_v2_results.png", facecolor='#111111', bbox_inches='tight')