FaceGen-v1 / train.py
LH-Tech-AI's picture
Create train.py
2e4ca0d verified
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from torch.cuda.amp import autocast, GradScaler
import torchvision.utils as vutils
from IPython.display import display
# --- FaceGen v1 Config ---
BATCH_SIZE = 128
IMAGE_SIZE = 128
CHANNELS = 3
Z_DIM = 128
FEATURES_G = 256
FEATURES_D = 128
EPOCHS = 250
LR = 0.0002
BETA1 = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training will run on: {device}")
print("Loading face dataset...")
hf_dataset = load_dataset("SDbiaseval/faces", split="train")
transform = transforms.Compose([
transforms.Resize(IMAGE_SIZE),
transforms.CenterCrop(IMAGE_SIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
class FaceDataset(Dataset):
def __init__(self, hf_ds, transform):
self.hf_ds = hf_ds
self.transform = transform
def __len__(self):
return len(self.hf_ds)
def __getitem__(self, idx):
img = self.hf_ds[idx]['image'].convert("RGB")
return self.transform(img)
dataset = FaceDataset(hf_dataset, transform)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
num_workers=4,
pin_memory=True
)
print(f"Dataset ready with {len(dataset)} faces.")
class Generator(nn.Module):
def __init__(self, z_dim, channels, features_g):
super(Generator, self).__init__()
self.net = nn.Sequential(
# Input: Z_DIM x 1 x 1
nn.ConvTranspose2d(z_dim, features_g * 16, 4, 1, 0, bias=False),
nn.BatchNorm2d(features_g * 16),
nn.ReLU(True),
# 4x4 -> 8x8
nn.ConvTranspose2d(features_g * 16, features_g * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 8),
nn.ReLU(True),
# 8x8 -> 16x16
nn.ConvTranspose2d(features_g * 8, features_g * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 4),
nn.ReLU(True),
# 16x16 -> 32x32
nn.ConvTranspose2d(features_g * 4, features_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g * 2),
nn.ReLU(True),
# 32x32 -> 64x64
nn.ConvTranspose2d(features_g * 2, features_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_g),
nn.ReLU(True),
# 64x64 -> 128x128
nn.ConvTranspose2d(features_g, channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.net(x)
netG = Generator(Z_DIM, CHANNELS, FEATURES_G).to(device)
class Discriminator(nn.Module):
def __init__(self, channels, features_d):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# 128x128 -> 64x64
nn.Conv2d(channels, features_d, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 64x64 -> 32x32
nn.Conv2d(features_d, features_d * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 2),
nn.LeakyReLU(0.2, inplace=True),
# 32x32 -> 16x16
nn.Conv2d(features_d * 2, features_d * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 4),
nn.LeakyReLU(0.2, inplace=True),
# 16x16 -> 8x8
nn.Conv2d(features_d * 4, features_d * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 8),
nn.LeakyReLU(0.2, inplace=True),
# 8x8 -> 4x4
nn.Conv2d(features_d * 8, features_d * 16, 4, 2, 1, bias=False),
nn.BatchNorm2d(features_d * 16),
nn.LeakyReLU(0.2, inplace=True),
# 4x4 -> 1x1
nn.Conv2d(features_d * 16, 1, 4, 1, 0, bias=False),
)
def forward(self, x):
return self.net(x)
netD = Discriminator(CHANNELS, FEATURES_D).to(device)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
netG.apply(weights_init)
netD.apply(weights_init)
criterion = nn.BCEWithLogitsLoss()
optG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999))
optD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999))
fixed_noise = torch.randn(64, Z_DIM, 1, 1, device=device)
scaler = torch.amp.GradScaler('cuda')
print(f"Model size G: {sum(p.numel() for p in netG.parameters())/1e6:.2f}M parameters")
print(f"Model size D: {sum(p.numel() for p in netD.parameters())/1e6:.2f}M parameters")
real_label_val = 0.9
fake_label_val = 0.1
for epoch in range(EPOCHS):
for i, real_images in enumerate(dataloader):
real_images = real_images.to(device)
b_size = real_images.size(0)
# --- Discriminator Update ---
optD.zero_grad()
with torch.amp.autocast('cuda'):
output_real = netD(real_images).view(-1)
lossD_real = criterion(output_real, torch.full((b_size,), real_label_val, device=device))
noise = torch.randn(b_size, Z_DIM, 1, 1, device=device)
fake_images = netG(noise)
output_fake = netD(fake_images.detach()).view(-1)
lossD_fake = criterion(output_fake, torch.full((b_size,), fake_label_val, device=device))
lossD = lossD_real + lossD_fake
scaler.scale(lossD).backward()
scaler.step(optD)
# --- Generator Update ---
optG.zero_grad()
with torch.amp.autocast('cuda'):
output_fake_G = netD(fake_images).view(-1)
lossG = criterion(output_fake_G, torch.full((b_size,), real_label_val, device=device))
scaler.scale(lossG).backward()
scaler.step(optG)
scaler.update()
if i % 10 == 0:
print(f"E[{epoch}] I[{i}/{len(dataloader)}] Loss_D: {lossD.item():.4f} Loss_G: {lossG.item():.4f}")
if (epoch + 1) % 10 == 0 or epoch == 0:
netG.eval()
with torch.no_grad():
with torch.amp.autocast('cuda'):
sample = netG(fixed_noise[0:1]).detach().cpu().float()
vutils.save_image(sample, f"face_sample_epoch_{epoch}.png", normalize=True)
print(f"--> Sample saved: face_sample_epoch_{epoch}.png")
netG.train()
if (epoch + 1) % 50 == 0:
torch.save({
'epoch': epoch,
'model_state_dict': netG.state_dict(),
'optimizer_state_dict': optG.state_dict(),
'netD_state_dict': netD.state_dict(),
'optD_state_dict': optD.state_dict(),
'scaler_state_dict': scaler.state_dict(),
}, f'facegen_v1_checkpoint_epoch_{epoch+1}.ckpt')
print(f"--> Sicherheits-Checkpoint gespeichert: Epoche {epoch+1}")
torch.save({
'epoch': EPOCHS,
'model_state_dict': netG.state_dict(),
'optimizer_state_dict': optG.state_dict(),
'netD_state_dict': netD.state_dict(),
'optD_state_dict': optD.state_dict(),
'scaler_state_dict': scaler.state_dict(),
}, 'facegen_v1_full_checkpoint.ckpt')
torch.save(netG.state_dict(), 'facegen_v1_generator_only.pth')
print("Files saved: Training finished.")
print("Doing professionell gallery export...")
# --- FaceGen v2: Professional Gallery Export (Fix) ---
netG.eval()
with torch.no_grad():
with torch.amp.autocast('cuda'):
fake_faces = netG(fixed_noise).detach().cpu().float()
grid = vutils.make_grid(fake_faces, padding=4, normalize=True)
grid_np = grid.numpy().transpose((1, 2, 0))
plt.figure(figsize=(12, 12), facecolor='#111111')
plt.imshow(grid_np, interpolation='bilinear')
plt.axis("off")
plt.title(f"FaceGen v1 | Training Complete | {FEATURES_G}x{FEATURES_D} Filters",
color='white', fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig("facegen_v2_results.png", facecolor='#111111', bbox_inches='tight')