import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import DataLoader, Dataset from datasets import load_dataset import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np from torch.cuda.amp import autocast, GradScaler import torchvision.utils as vutils from IPython.display import display # --- FaceGen v1 Config --- BATCH_SIZE = 128 IMAGE_SIZE = 128 CHANNELS = 3 Z_DIM = 128 FEATURES_G = 256 FEATURES_D = 128 EPOCHS = 250 LR = 0.0002 BETA1 = 0.5 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Training will run on: {device}") print("Loading face dataset...") hf_dataset = load_dataset("SDbiaseval/faces", split="train") transform = transforms.Compose([ transforms.Resize(IMAGE_SIZE), transforms.CenterCrop(IMAGE_SIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) class FaceDataset(Dataset): def __init__(self, hf_ds, transform): self.hf_ds = hf_ds self.transform = transform def __len__(self): return len(self.hf_ds) def __getitem__(self, idx): img = self.hf_ds[idx]['image'].convert("RGB") return self.transform(img) dataset = FaceDataset(hf_dataset, transform) dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4, pin_memory=True ) print(f"Dataset ready with {len(dataset)} faces.") class Generator(nn.Module): def __init__(self, z_dim, channels, features_g): super(Generator, self).__init__() self.net = nn.Sequential( # Input: Z_DIM x 1 x 1 nn.ConvTranspose2d(z_dim, features_g * 16, 4, 1, 0, bias=False), nn.BatchNorm2d(features_g * 16), nn.ReLU(True), # 4x4 -> 8x8 nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(features_g * 8), nn.ReLU(True), # 8x8 -> 16x16 nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(features_g * 4), nn.ReLU(True), # 16x16 -> 32x32 nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(features_g * 2), nn.ReLU(True), # 32x32 -> 64x64 nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False), nn.BatchNorm2d(features_g), nn.ReLU(True), # 64x64 -> 128x128 nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, x): return self.net(x) netG = Generator(Z_DIM, CHANNELS, FEATURES_G).to(device) class Discriminator(nn.Module): def __init__(self, channels, features_d): super(Discriminator, self).__init__() self.net = nn.Sequential( # 128x128 -> 64x64 nn.Conv2d(channels, features_d, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # 64x64 -> 32x32 nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(features_d * 2), nn.LeakyReLU(0.2, inplace=True), # 32x32 -> 16x16 nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(features_d * 4), nn.LeakyReLU(0.2, inplace=True), # 16x16 -> 8x8 nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(features_d * 8), nn.LeakyReLU(0.2, inplace=True), # 8x8 -> 4x4 nn.Conv2d(features_d * 8, features_d * 16, 4, 2, 1, bias=False), nn.BatchNorm2d(features_d * 16), nn.LeakyReLU(0.2, inplace=True), # 4x4 -> 1x1 nn.Conv2d(features_d * 16, 1, 4, 1, 0, bias=False), ) def forward(self, x): return self.net(x) netD = Discriminator(CHANNELS, FEATURES_D).to(device) def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) netG.apply(weights_init) netD.apply(weights_init) criterion = nn.BCEWithLogitsLoss() optG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999)) optD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999)) fixed_noise = torch.randn(64, Z_DIM, 1, 1, device=device) scaler = torch.amp.GradScaler('cuda') print(f"Model size G: {sum(p.numel() for p in netG.parameters())/1e6:.2f}M parameters") print(f"Model size D: {sum(p.numel() for p in netD.parameters())/1e6:.2f}M parameters") real_label_val = 0.9 fake_label_val = 0.1 for epoch in range(EPOCHS): for i, real_images in enumerate(dataloader): real_images = real_images.to(device) b_size = real_images.size(0) # --- Discriminator Update --- optD.zero_grad() with torch.amp.autocast('cuda'): output_real = netD(real_images).view(-1) lossD_real = criterion(output_real, torch.full((b_size,), real_label_val, device=device)) noise = torch.randn(b_size, Z_DIM, 1, 1, device=device) fake_images = netG(noise) output_fake = netD(fake_images.detach()).view(-1) lossD_fake = criterion(output_fake, torch.full((b_size,), fake_label_val, device=device)) lossD = lossD_real + lossD_fake scaler.scale(lossD).backward() scaler.step(optD) # --- Generator Update --- optG.zero_grad() with torch.amp.autocast('cuda'): output_fake_G = netD(fake_images).view(-1) lossG = criterion(output_fake_G, torch.full((b_size,), real_label_val, device=device)) scaler.scale(lossG).backward() scaler.step(optG) scaler.update() if i % 10 == 0: print(f"E[{epoch}] I[{i}/{len(dataloader)}] Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}") if (epoch + 1) % 10 == 0 or epoch == 0: netG.eval() with torch.no_grad(): with torch.amp.autocast('cuda'): sample = netG(fixed_noise[0:1]).detach().cpu().float() vutils.save_image(sample, f"face_sample_epoch_{epoch}.png", normalize=True) print(f"--> Sample saved: face_sample_epoch_{epoch}.png") netG.train() if (epoch + 1) % 50 == 0: torch.save({ 'epoch': epoch, 'model_state_dict': netG.state_dict(), 'optimizer_state_dict': optG.state_dict(), 'netD_state_dict': netD.state_dict(), 'optD_state_dict': optD.state_dict(), 'scaler_state_dict': scaler.state_dict(), }, f'facegen_v1_checkpoint_epoch_{epoch+1}.ckpt') print(f"--> Sicherheits-Checkpoint gespeichert: Epoche {epoch+1}") torch.save({ 'epoch': EPOCHS, 'model_state_dict': netG.state_dict(), 'optimizer_state_dict': optG.state_dict(), 'netD_state_dict': netD.state_dict(), 'optD_state_dict': optD.state_dict(), 'scaler_state_dict': scaler.state_dict(), }, 'facegen_v1_full_checkpoint.ckpt') torch.save(netG.state_dict(), 'facegen_v1_generator_only.pth') print("Files saved: Training finished.") print("Doing professionell gallery export...") # --- FaceGen v2: Professional Gallery Export (Fix) --- netG.eval() with torch.no_grad(): with torch.amp.autocast('cuda'): fake_faces = netG(fixed_noise).detach().cpu().float() grid = vutils.make_grid(fake_faces, padding=4, normalize=True) grid_np = grid.numpy().transpose((1, 2, 0)) plt.figure(figsize=(12, 12), facecolor='#111111') plt.imshow(grid_np, interpolation='bilinear') plt.axis("off") plt.title(f"FaceGen v1 | Training Complete | {FEATURES_G}x{FEATURES_D} Filters", color='white', fontsize=16, fontweight='bold', pad=20) plt.tight_layout() plt.savefig("facegen_v2_results.png", facecolor='#111111', bbox_inches='tight')