Spaces:
Build error
Build error
File size: 10,174 Bytes
412f263 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import argparse
import sys
sys.path.append(".")
from model.models import UNet, PatchGANDiscriminator
from model.losses import GeneratorLoss, DiscriminatorLoss
from utils.dataloader import CustomDataset, transform
def train_model(root_dir, start_epoch, num_epochs, load_model_g, load_model_d, num_workers,
val_freq, batch_size, accum_iter, lr, lr_d, wandb_tracking, desc):
if wandb_tracking:
import wandb
wandb.init(project="FRAN",
# track hyperparameters and run metadata
config={
"lr": lr,
"lr_d": lr_d,
"dataset": root_dir,
"epochs": num_epochs,
"batch_size": batch_size,
"description": desc
}
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
if torch.cuda.device_count() > 0:
print(f"{torch.cuda.device_count()} GPU(s)")
if torch.cuda.device_count() > 1:
print("multi-GPU training is currently not supported.")
# Create instances of the dataset and split into scripts and validation sets
dataset = CustomDataset(root_dir=root_dir, transform=transform)
# Assuming you want to use 80% of the data for scripts and 20% for validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# Create data loaders for scripts and validation
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
# Create instances of the U-Net, discriminator, and loss models
unet_model = UNet()
discriminator_model = PatchGANDiscriminator(input_channels=4)
if load_model_g:
unet_model.load_state_dict(torch.load(load_model_g, map_location=device))
print(f'loaded {load_model_g} for unet_model')
if load_model_d:
discriminator_model.load_state_dict(torch.load(load_model_d, map_location=device))
print(f'loaded {load_model_d} for discriminator_model')
unet_model = unet_model.to(device)
discriminator_model = discriminator_model.to(device)
# if multiGPU:
# unet_model = nn.DataParallel(unet_model)
# discriminator_model = nn.DataParallel(discriminator_model)
# Create loss instances
generator_loss_func = GeneratorLoss(discriminator_model, l1_weight=1.0, perceptual_weight=1.0,
adversarial_weight=0.05, device=device)
discriminator_loss_func = DiscriminatorLoss(discriminator_model)
# Create instances of the Adam optimizer
optimizer_g = optim.Adam(unet_model.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator_model.parameters(), lr=lr_d)
# Training and validation loop
best_val_loss = float('inf')
for epoch in range(start_epoch - 1, num_epochs):
# Training
unet_model.train()
discriminator_model.train()
batch_idx = 0
for batch in train_dataloader:
batch_idx += 1
source_images, target_images = batch
# if not multiGPU:
# if multi GPU, nn.DataParallel will already put the batches on the right devices.
# Otherwise, we do it manually
source_images = source_images.to(device)
target_images = target_images.to(device)
# Zero gradients
# optimizer_g.zero_grad()
# optimizer_d.zero_grad()
# Forward pass
output_images = unet_model(source_images)
# if multiGPU:
# output_device = output_images.get_device()
# source_images, target_images = source_images.to(output_device), target_images.to(output_device)
output_images += source_images[:, :3, :, :]
# Discriminator pass
discriminator_loss = discriminator_loss_func(output_images.detach(), target_images, source_images)
# discriminator_loss /= accum_iter
discriminator_loss.backward()
if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
optimizer_d.step()
optimizer_d.zero_grad()
# Generator pass
# Calculate the loss
generator_loss, l1_loss, per_loss, adv_loss = generator_loss_func(output_images, target_images,
source_images)
generator_loss, l1_loss, per_loss, adv_loss = [i / accum_iter for i in
[generator_loss, l1_loss, per_loss, adv_loss]]
generator_loss.backward()
if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
optimizer_g.step()
optimizer_g.zero_grad()
# Print scripts information (if needed)
print(
f'Training Epoch [{epoch + 1}/{num_epochs}], Gen Loss: {generator_loss.item()}, L1: {l1_loss.item()}, P: {per_loss.item()}, A: {adv_loss.item()}, Dis Loss: {discriminator_loss.item()}')
if wandb_tracking:
wandb.log({
'Training Epoch': epoch + 1,
'Gen Loss': generator_loss.item(),
'L1': l1_loss.item(),
'P': per_loss.item(),
'A': adv_loss.item(),
'Dis Loss': discriminator_loss.item()
})
torch.save(unet_model.state_dict(), 'recent_unet_model.pth')
torch.save(discriminator_model.state_dict(), 'recent_discriminator_model.pth')
# Validation
if epoch % val_freq == 0:
unet_model.eval()
total_val_loss = 0.0
with torch.no_grad():
for val_batch in val_dataloader:
val_source_images, val_target_images = val_batch
# if not multiGPU:
# if multi GPU, nn.DataParallel will already put the batches on the right devices.
# Otherwise, we do it manually
val_source_images = val_source_images.to(device)
val_target_images = val_target_images.to(device)
# Forward pass
val_output_images = unet_model(val_source_images)
# if multiGPU:
# output_device = val_output_images.get_device()
# val_source_images, val_target_images = val_source_images.to(output_device), \
# val_target_images.to(output_device)
# Calculate the loss
generator_loss, _, _, _ = generator_loss_func(val_output_images, val_target_images,
val_source_images)
total_val_loss += generator_loss.item()
average_val_loss = total_val_loss / len(val_dataloader)
# Print validation information
print(f'Validation Epoch [{epoch + 1}/{num_epochs}], Average Loss: {average_val_loss}')
if wandb_tracking:
wandb.log({
'Training Epoch': epoch + 1,
'Val Loss': average_val_loss,
})
# Save the model with the best validation loss
if average_val_loss < best_val_loss:
best_val_loss = average_val_loss
torch.save(unet_model.state_dict(), 'best_unet_model.pth')
torch.save(discriminator_model.state_dict(), 'best_discriminator_model.pth')
if wandb_tracking:
wandb.finish()
if __name__ == "__main__":
# Define command-line arguments
parser = argparse.ArgumentParser(description="Training Script")
parser.add_argument("--root_dir", type=str, default='data/processed/train',
help="Path to the training data. Note the format: To use the dataloader, the directory should be filled with folders containing image files of various ages, where the file name is the age.")
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch, if scripts is resumed")
parser.add_argument("--num_epochs", type=int, default=2000, help="End epoch")
parser.add_argument("--load_model_g", type=str, default='',
help="Path to pretrained generator model. Leave blank to train from scratch")
parser.add_argument("--load_model_d", type=str, default='',
help="Path to pretrained discriminator model. Leave blank to train from scratch")
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
parser.add_argument("--batch_size", type=int, default=3, help="Batch size")
parser.add_argument("--accum_iter", type=int, default=3, help="Number of batches after which weights are updated")
parser.add_argument("--val_freq", type=int, default=1, help="Validation frequency (epochs)")
parser.add_argument("--lr", type=float, default=0.00001, help="Learning rate for generator")
parser.add_argument("--lr_d", type=float, default=0.00001, help="Learning rate for discriminator")
parser.add_argument("--wandb_tracking", help="A binary (True/False) argument for using WandB tracking or not")
parser.add_argument("--desc", type=str, default='', help="Description for WandB")
# Parse command-line arguments
args = parser.parse_args()
# Call the scripts function with parsed arguments
train_model(args.root_dir, args.start_epoch, args.num_epochs, args.load_model_g, args.load_model_d,
args.num_workers, args.val_freq, args.batch_size, args.accum_iter, args.lr, args.lr_d,
args.wandb_tracking, args.desc)
|