Spaces:
Runtime error
Runtime error
| import argparse | |
| from pathlib import Path | |
| import torch | |
| from torch import nn | |
| from torch import optim | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from torchvision import datasets | |
| from utils import get_network, epoch | |
| torch.manual_seed(0) | |
| def train_nn_network(args): | |
| p = Path(__file__) | |
| weights_path = f"{p.parent}/weights" | |
| Path(weights_path).mkdir(parents=True, exist_ok=True) | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model = get_network(args.net) | |
| model.to(device) | |
| mnist_train = datasets.MNIST( | |
| ".", train=True, download=True, transform=transforms.ToTensor() | |
| ) | |
| mnist_test = datasets.MNIST( | |
| ".", train=False, download=True, transform=transforms.ToTensor() | |
| ) | |
| train_loader = DataLoader( | |
| mnist_train, batch_size=args.b, shuffle=True, num_workers=4, pin_memory=True | |
| ) | |
| test_loader = DataLoader( | |
| mnist_test, batch_size=args.b, shuffle=False, num_workers=4, pin_memory=True | |
| ) | |
| opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) | |
| criterion = nn.MSELoss() | |
| best_loss = None | |
| for i in range(1, args.epochs + 1): | |
| train_loss = epoch(train_loader, model, device, criterion, opt) | |
| test_loss = epoch(test_loader, model, device, criterion) | |
| if best_loss is None or best_loss > test_loss: | |
| best_loss = test_loss | |
| torch.save(model.state_dict(), f"{weights_path}/{args.net}.pth") | |
| print(f"Epoch: {i} | Train Loss: {train_loss:.4f} | Test Loss: {test_loss:.4f}") | |
| if __name__ == "__main__": | |
| train_nn_network() | |