File size: 10,174 Bytes
412f263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import argparse

import sys
sys.path.append(".")

from model.models import UNet, PatchGANDiscriminator
from model.losses import GeneratorLoss, DiscriminatorLoss
from utils.dataloader import CustomDataset, transform


def train_model(root_dir, start_epoch, num_epochs, load_model_g, load_model_d, num_workers,
                val_freq, batch_size, accum_iter, lr, lr_d, wandb_tracking, desc):
    if wandb_tracking:
        import wandb

        wandb.init(project="FRAN",
                   # track hyperparameters and run metadata
                   config={
                       "lr": lr,
                       "lr_d": lr_d,
                       "dataset": root_dir,
                       "epochs": num_epochs,
                       "batch_size": batch_size,
                       "description": desc
                   }
                   )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"device: {device}")
    if torch.cuda.device_count() > 0:
        print(f"{torch.cuda.device_count()} GPU(s)")
        if torch.cuda.device_count() > 1:
            print("multi-GPU training is currently not supported.")

    # Create instances of the dataset and split into scripts and validation sets
    dataset = CustomDataset(root_dir=root_dir, transform=transform)

    # Assuming you want to use 80% of the data for scripts and 20% for validation
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create data loaders for scripts and validation
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    # Create instances of the U-Net, discriminator, and loss models
    unet_model = UNet()
    discriminator_model = PatchGANDiscriminator(input_channels=4)

    if load_model_g:
        unet_model.load_state_dict(torch.load(load_model_g, map_location=device))
        print(f'loaded {load_model_g} for unet_model')
    if load_model_d:
        discriminator_model.load_state_dict(torch.load(load_model_d, map_location=device))
        print(f'loaded {load_model_d} for discriminator_model')

    unet_model = unet_model.to(device)
    discriminator_model = discriminator_model.to(device)

    # if multiGPU:
    #    unet_model = nn.DataParallel(unet_model)
    #    discriminator_model = nn.DataParallel(discriminator_model)

    # Create loss instances
    generator_loss_func = GeneratorLoss(discriminator_model, l1_weight=1.0, perceptual_weight=1.0,
                                        adversarial_weight=0.05, device=device)
    discriminator_loss_func = DiscriminatorLoss(discriminator_model)

    # Create instances of the Adam optimizer
    optimizer_g = optim.Adam(unet_model.parameters(), lr=lr)
    optimizer_d = optim.Adam(discriminator_model.parameters(), lr=lr_d)

    # Training and validation loop
    best_val_loss = float('inf')

    for epoch in range(start_epoch - 1, num_epochs):
        # Training
        unet_model.train()
        discriminator_model.train()
        batch_idx = 0
        for batch in train_dataloader:
            batch_idx += 1
            source_images, target_images = batch

            # if not multiGPU:
            # if multi GPU, nn.DataParallel will already put the batches on the right devices.
            # Otherwise, we do it manually
            source_images = source_images.to(device)
            target_images = target_images.to(device)

            # Zero gradients
            # optimizer_g.zero_grad()
            # optimizer_d.zero_grad()

            # Forward pass
            output_images = unet_model(source_images)
            # if multiGPU:
            #     output_device = output_images.get_device()
            #     source_images, target_images = source_images.to(output_device), target_images.to(output_device)
            output_images += source_images[:, :3, :, :]

            # Discriminator pass
            discriminator_loss = discriminator_loss_func(output_images.detach(), target_images, source_images)
            # discriminator_loss /= accum_iter
            discriminator_loss.backward()

            if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
                optimizer_d.step()
                optimizer_d.zero_grad()

            # Generator pass
            # Calculate the loss
            generator_loss, l1_loss, per_loss, adv_loss = generator_loss_func(output_images, target_images,
                                                                              source_images)
            generator_loss, l1_loss, per_loss, adv_loss = [i / accum_iter for i in
                                                           [generator_loss, l1_loss, per_loss, adv_loss]]
            generator_loss.backward()

            if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
                optimizer_g.step()
                optimizer_g.zero_grad()

            # Print scripts information (if needed)
            print(
                f'Training Epoch [{epoch + 1}/{num_epochs}], Gen Loss: {generator_loss.item()}, L1: {l1_loss.item()}, P: {per_loss.item()}, A: {adv_loss.item()}, Dis Loss: {discriminator_loss.item()}')
            if wandb_tracking:
                wandb.log({
                    'Training Epoch': epoch + 1,
                    'Gen Loss': generator_loss.item(),
                    'L1': l1_loss.item(),
                    'P': per_loss.item(),
                    'A': adv_loss.item(),
                    'Dis Loss': discriminator_loss.item()
                })

        torch.save(unet_model.state_dict(), 'recent_unet_model.pth')
        torch.save(discriminator_model.state_dict(), 'recent_discriminator_model.pth')

        # Validation
        if epoch % val_freq == 0:
            unet_model.eval()
            total_val_loss = 0.0
            with torch.no_grad():
                for val_batch in val_dataloader:
                    val_source_images, val_target_images = val_batch

                    # if not multiGPU:
                    # if multi GPU, nn.DataParallel will already put the batches on the right devices.
                    # Otherwise, we do it manually
                    val_source_images = val_source_images.to(device)
                    val_target_images = val_target_images.to(device)

                    # Forward pass
                    val_output_images = unet_model(val_source_images)

                    # if multiGPU:
                    #     output_device = val_output_images.get_device()
                    #     val_source_images, val_target_images = val_source_images.to(output_device), \
                    #         val_target_images.to(output_device)

                    # Calculate the loss
                    generator_loss, _, _, _ = generator_loss_func(val_output_images, val_target_images,
                                                                  val_source_images)
                    total_val_loss += generator_loss.item()

            average_val_loss = total_val_loss / len(val_dataloader)

            # Print validation information
            print(f'Validation Epoch [{epoch + 1}/{num_epochs}], Average Loss: {average_val_loss}')
            if wandb_tracking:
                wandb.log({
                    'Training Epoch': epoch + 1,
                    'Val Loss': average_val_loss,
                })

            # Save the model with the best validation loss
            if average_val_loss < best_val_loss:
                best_val_loss = average_val_loss
                torch.save(unet_model.state_dict(), 'best_unet_model.pth')
                torch.save(discriminator_model.state_dict(), 'best_discriminator_model.pth')

    if wandb_tracking:
        wandb.finish()


if __name__ == "__main__":
    # Define command-line arguments
    parser = argparse.ArgumentParser(description="Training Script")
    parser.add_argument("--root_dir", type=str, default='data/processed/train',
                        help="Path to the training data. Note the format: To use the dataloader, the directory should be filled with folders containing image files of various ages, where the file name is the age.")
    parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch, if scripts is resumed")
    parser.add_argument("--num_epochs", type=int, default=2000, help="End epoch")
    parser.add_argument("--load_model_g", type=str, default='',
                        help="Path to pretrained generator model. Leave blank to train from scratch")
    parser.add_argument("--load_model_d", type=str, default='',
                        help="Path to pretrained discriminator model. Leave blank to train from scratch")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
    parser.add_argument("--batch_size", type=int, default=3, help="Batch size")
    parser.add_argument("--accum_iter", type=int, default=3, help="Number of batches after which weights are updated")
    parser.add_argument("--val_freq", type=int, default=1, help="Validation frequency (epochs)")
    parser.add_argument("--lr", type=float, default=0.00001, help="Learning rate for generator")
    parser.add_argument("--lr_d", type=float, default=0.00001, help="Learning rate for discriminator")
    parser.add_argument("--wandb_tracking", help="A binary (True/False) argument for using WandB tracking or not")
    parser.add_argument("--desc", type=str, default='', help="Description for WandB")

    # Parse command-line arguments
    args = parser.parse_args()

    # Call the scripts function with parsed arguments
    train_model(args.root_dir, args.start_epoch, args.num_epochs, args.load_model_g, args.load_model_d,
                args.num_workers, args.val_freq, args.batch_size, args.accum_iter, args.lr, args.lr_d,
                args.wandb_tracking, args.desc)