Spaces:
Running
Running
| import os | |
| import gdown | |
| import zipfile | |
| import shutil | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.datasets as datasets | |
| import torchvision.transforms as transforms | |
| from torch.utils.data import DataLoader | |
| import time | |
| import modules.model as model | |
| # Download model if not available | |
| if os.path.exists('celeba/') == False: | |
| url = 'https://drive.google.com/file/d/13vkq4tFCPE8O78KTj84HHM6kBnYkt8gP/view?usp=sharing' | |
| output = 'download.zip' | |
| gdown.download(url, output, fuzzy=True) | |
| with zipfile.ZipFile(output, 'r') as zip_ref: | |
| zip_ref.extractall() | |
| os.remove(output) | |
| shutil.rmtree('__MACOSX') | |
| # Set device | |
| if torch.backends.mps.is_available(): | |
| device = torch.device('mps') | |
| device_name = 'Apple Silicon GPU' | |
| elif torch.cuda.is_available(): | |
| device = torch.device('cuda') | |
| device_name = 'CUDA' | |
| else: | |
| device = torch.device('cpu') | |
| device_name = 'CPU' | |
| torch.set_default_device(device) | |
| print(f'\nDevice: {device_name}') | |
| # Define dataset, dataloader and transform | |
| imsize = int(128/0.8) | |
| batch_size = 10 | |
| fivecrop_transform = transforms.Compose([ | |
| transforms.Resize([imsize, imsize]), | |
| transforms.Grayscale(1), | |
| transforms.FiveCrop(int(imsize*0.8)), | |
| transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), | |
| transforms.Normalize(0, 1) | |
| ]) | |
| train_dataset = datasets.CelebA( | |
| root='', | |
| split='all', | |
| target_type='attr', | |
| transform=fivecrop_transform, | |
| download=True, | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| generator=torch.Generator(device=device) | |
| ) | |
| # Male index | |
| factor = 20 | |
| # Define model, optimiser and scheduler | |
| torch.manual_seed(2687) | |
| resnet = model.resnetModel_128() | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.SGD( | |
| resnet.parameters(), | |
| lr=0.01, | |
| momentum=0.9, | |
| weight_decay=0.001 | |
| ) | |
| scheduler = torch.optim.lr_scheduler.StepLR( | |
| optimizer=optimizer, | |
| step_size=1, | |
| gamma=0.1 | |
| ) | |
| def mins_to_hours(mins): | |
| hours = int(mins/60) | |
| rem_mins = mins % 60 | |
| return hours, rem_mins | |
| epochs = 2 | |
| train_losses = [] | |
| train_accuracy = [] | |
| for i in range(epochs): | |
| epoch_time = 0 | |
| for j, (X_train, y_train) in enumerate(train_loader): | |
| batch_start = time.time() | |
| X_train = X_train.to(device) | |
| y_train = y_train[:, factor] | |
| bs, ncrops, c, h, w = X_train.size() | |
| y_pred_crops = resnet.forward(X_train.view(-1, c, h, w)) | |
| y_pred = y_pred_crops.view(bs, ncrops, -1).mean(1) | |
| loss = criterion(y_pred, y_train) | |
| predicted = torch.max(y_pred.data, 1)[1] | |
| train_batch_accuracy = (predicted == y_train).sum()/len(X_train) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| train_losses.append(loss.item()) | |
| train_accuracy.append(train_batch_accuracy.item()) | |
| batch_end = time.time() | |
| batch_time = batch_end - batch_start | |
| epoch_time += batch_time | |
| avg_batch_time = epoch_time/(j+1) | |
| batches_remaining = len(train_loader)-(j+1) | |
| epoch_mins_remaining = round(batches_remaining*avg_batch_time/60) | |
| epoch_time_remaining = mins_to_hours(epoch_mins_remaining) | |
| full_epoch = avg_batch_time*len(train_loader) | |
| epochs_remaining = epochs-(i+1) | |
| rem_epoch_mins_remaining = epoch_mins_remaining+round(full_epoch*epochs_remaining/60) | |
| rem_epoch_time_remaining = mins_to_hours(rem_epoch_mins_remaining) | |
| if (j+1) % 10 == 0: | |
| print(f'\nEpoch: {i+1}/{epochs} | Train Batch: {j+1}/{len(train_loader)}') | |
| print(f'Current epoch: {epoch_time_remaining[0]} hours {epoch_time_remaining[1]} minutes') | |
| print(f'Remaining epochs: {rem_epoch_time_remaining[0]} hours {rem_epoch_time_remaining[1]} minutes') | |
| print(f'Train Loss: {loss}') | |
| print(f'Train Accuracy: {train_batch_accuracy}') | |
| scheduler.step() | |
| trained_model_name = resnet.model_name + '_epoch_' + str(i+1) + '.pt' | |
| torch.save( | |
| resnet.state_dict(), | |
| trained_model_name | |
| ) |