| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torchvision.transforms as transforms |
| from torch.utils.data import DataLoader, Dataset |
| from datasets import load_dataset |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from torch.cuda.amp import autocast, GradScaler |
| import torchvision.utils as vutils |
| from IPython.display import display |
|
|
| |
| BATCH_SIZE = 128 |
| IMAGE_SIZE = 128 |
| CHANNELS = 3 |
| Z_DIM = 128 |
| FEATURES_G = 256 |
| FEATURES_D = 128 |
| EPOCHS = 250 |
| LR = 0.0002 |
| BETA1 = 0.5 |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Training will run on: {device}") |
|
|
| print("Loading face dataset...") |
| hf_dataset = load_dataset("SDbiaseval/faces", split="train") |
|
|
| transform = transforms.Compose([ |
| transforms.Resize(IMAGE_SIZE), |
| transforms.CenterCrop(IMAGE_SIZE), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| ]) |
|
|
| class FaceDataset(Dataset): |
| def __init__(self, hf_ds, transform): |
| self.hf_ds = hf_ds |
| self.transform = transform |
| def __len__(self): |
| return len(self.hf_ds) |
| def __getitem__(self, idx): |
| img = self.hf_ds[idx]['image'].convert("RGB") |
| return self.transform(img) |
|
|
| dataset = FaceDataset(hf_dataset, transform) |
|
|
| dataloader = DataLoader( |
| dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| drop_last=True, |
| num_workers=4, |
| pin_memory=True |
| ) |
| print(f"Dataset ready with {len(dataset)} faces.") |
|
|
| class Generator(nn.Module): |
| def __init__(self, z_dim, channels, features_g): |
| super(Generator, self).__init__() |
| self.net = nn.Sequential( |
| |
| nn.ConvTranspose2d(z_dim, features_g * 16, 4, 1, 0, bias=False), |
| nn.BatchNorm2d(features_g * 16), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 8), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 4), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g * 2), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_g), |
| nn.ReLU(True), |
| |
| nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False), |
| nn.Tanh() |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| netG = Generator(Z_DIM, CHANNELS, FEATURES_G).to(device) |
|
|
| class Discriminator(nn.Module): |
| def __init__(self, channels, features_d): |
| super(Discriminator, self).__init__() |
| self.net = nn.Sequential( |
| |
| nn.Conv2d(channels, features_d, 4, 2, 1, bias=False), |
| nn.LeakyReLU(0.2, inplace=True), |
| |
| nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_d * 2), |
| nn.LeakyReLU(0.2, inplace=True), |
| |
| nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_d * 4), |
| nn.LeakyReLU(0.2, inplace=True), |
| |
| nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_d * 8), |
| nn.LeakyReLU(0.2, inplace=True), |
| |
| nn.Conv2d(features_d * 8, features_d * 16, 4, 2, 1, bias=False), |
| nn.BatchNorm2d(features_d * 16), |
| nn.LeakyReLU(0.2, inplace=True), |
| |
| nn.Conv2d(features_d * 16, 1, 4, 1, 0, bias=False), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| netD = Discriminator(CHANNELS, FEATURES_D).to(device) |
|
|
| def weights_init(m): |
| classname = m.__class__.__name__ |
| if classname.find('Conv') != -1: |
| nn.init.normal_(m.weight.data, 0.0, 0.02) |
| elif classname.find('BatchNorm') != -1: |
| nn.init.normal_(m.weight.data, 1.0, 0.02) |
| nn.init.constant_(m.bias.data, 0) |
|
|
| netG.apply(weights_init) |
| netD.apply(weights_init) |
|
|
| criterion = nn.BCEWithLogitsLoss() |
|
|
| optG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999)) |
| optD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999)) |
|
|
| fixed_noise = torch.randn(64, Z_DIM, 1, 1, device=device) |
|
|
| scaler = torch.amp.GradScaler('cuda') |
|
|
| print(f"Model size G: {sum(p.numel() for p in netG.parameters())/1e6:.2f}M parameters") |
| print(f"Model size D: {sum(p.numel() for p in netD.parameters())/1e6:.2f}M parameters") |
|
|
| real_label_val = 0.9 |
| fake_label_val = 0.1 |
|
|
| for epoch in range(EPOCHS): |
| for i, real_images in enumerate(dataloader): |
| real_images = real_images.to(device) |
| b_size = real_images.size(0) |
| |
| |
| optD.zero_grad() |
| with torch.amp.autocast('cuda'): |
| output_real = netD(real_images).view(-1) |
| lossD_real = criterion(output_real, torch.full((b_size,), real_label_val, device=device)) |
| |
| noise = torch.randn(b_size, Z_DIM, 1, 1, device=device) |
| fake_images = netG(noise) |
| output_fake = netD(fake_images.detach()).view(-1) |
| lossD_fake = criterion(output_fake, torch.full((b_size,), fake_label_val, device=device)) |
| lossD = lossD_real + lossD_fake |
|
|
| scaler.scale(lossD).backward() |
| scaler.step(optD) |
| |
| |
| optG.zero_grad() |
| with torch.amp.autocast('cuda'): |
| output_fake_G = netD(fake_images).view(-1) |
| lossG = criterion(output_fake_G, torch.full((b_size,), real_label_val, device=device)) |
|
|
| scaler.scale(lossG).backward() |
| scaler.step(optG) |
| scaler.update() |
|
|
| if i % 10 == 0: |
| print(f"E[{epoch}] I[{i}/{len(dataloader)}] Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}") |
|
|
| if (epoch + 1) % 10 == 0 or epoch == 0: |
| netG.eval() |
| with torch.no_grad(): |
| with torch.amp.autocast('cuda'): |
| sample = netG(fixed_noise[0:1]).detach().cpu().float() |
| |
| vutils.save_image(sample, f"face_sample_epoch_{epoch}.png", normalize=True) |
| print(f"--> Sample saved: face_sample_epoch_{epoch}.png") |
| |
| netG.train() |
|
|
| if (epoch + 1) % 50 == 0: |
| torch.save({ |
| 'epoch': epoch, |
| 'model_state_dict': netG.state_dict(), |
| 'optimizer_state_dict': optG.state_dict(), |
| 'netD_state_dict': netD.state_dict(), |
| 'optD_state_dict': optD.state_dict(), |
| 'scaler_state_dict': scaler.state_dict(), |
| }, f'facegen_v1_checkpoint_epoch_{epoch+1}.ckpt') |
| print(f"--> Sicherheits-Checkpoint gespeichert: Epoche {epoch+1}") |
|
|
| torch.save({ |
| 'epoch': EPOCHS, |
| 'model_state_dict': netG.state_dict(), |
| 'optimizer_state_dict': optG.state_dict(), |
| 'netD_state_dict': netD.state_dict(), |
| 'optD_state_dict': optD.state_dict(), |
| 'scaler_state_dict': scaler.state_dict(), |
| }, 'facegen_v1_full_checkpoint.ckpt') |
|
|
| torch.save(netG.state_dict(), 'facegen_v1_generator_only.pth') |
|
|
| print("Files saved: Training finished.") |
|
|
| print("Doing professionell gallery export...") |
|
|
| |
| netG.eval() |
|
|
| with torch.no_grad(): |
| with torch.amp.autocast('cuda'): |
| fake_faces = netG(fixed_noise).detach().cpu().float() |
|
|
| grid = vutils.make_grid(fake_faces, padding=4, normalize=True) |
| grid_np = grid.numpy().transpose((1, 2, 0)) |
|
|
| plt.figure(figsize=(12, 12), facecolor='#111111') |
| plt.imshow(grid_np, interpolation='bilinear') |
| plt.axis("off") |
|
|
| plt.title(f"FaceGen v1 | Training Complete | {FEATURES_G}x{FEATURES_D} Filters", |
| color='white', fontsize=16, fontweight='bold', pad=20) |
|
|
| plt.tight_layout() |
|
|
| plt.savefig("facegen_v2_results.png", facecolor='#111111', bbox_inches='tight') |