Spaces:
Runtime error
Runtime error
| import torch | |
| def train( | |
| model, | |
| device, | |
| train_loader, | |
| criterion, | |
| optimizer, | |
| epoch, | |
| train_loss, | |
| train_acc, | |
| mse=None, | |
| ): | |
| model.train() | |
| curr_loss = 0 | |
| t_pred = 0 | |
| for batch_idx, (images, targets) in enumerate(train_loader): | |
| images, targets = images.to(device), targets.to(device) | |
| optimizer.zero_grad() | |
| output = model(images).squeeze() | |
| loss = criterion(output, targets) | |
| loss.backward() | |
| optimizer.step() | |
| curr_loss += loss.sum().item() | |
| _, preds = torch.max(output, 1) | |
| t_pred += torch.sum(preds == targets.data).item() | |
| if batch_idx % 10 == 0: | |
| print( | |
| "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( | |
| epoch, | |
| batch_idx * len(images), | |
| len(train_loader.dataset), | |
| 100.0 * batch_idx / len(train_loader), | |
| loss.item(), | |
| ) | |
| ) | |
| train_loss.append(loss.sum().item() / len(images)) | |
| train_acc.append(preds.sum().item() / len(images)) | |
| epoch_loss = curr_loss / len(train_loader.dataset) | |
| epoch_acc = t_pred / len(train_loader.dataset) | |
| train_loss.append(epoch_loss) | |
| train_acc.append(epoch_acc) | |
| print( | |
| "\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( | |
| epoch_loss, | |
| t_pred, | |
| len(train_loader.dataset), | |
| 100.0 * t_pred / len(train_loader.dataset), | |
| ) | |
| ) | |
| return train_loss, train_acc, epoch_loss | |
| def valid( | |
| model, device, test_loader, criterion, epoch, valid_loss, valid_acc, mse=None | |
| ): | |
| model.eval() | |
| test_loss = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for batch_idx, (images, targets) in enumerate(test_loader): | |
| images, targets = images.to(device), targets.to(device) | |
| output = model(images).squeeze() | |
| loss = criterion(output, targets) | |
| test_loss += loss.sum().item() | |
| _, preds = torch.max(output, 1) | |
| correct += torch.sum(preds == targets.data) | |
| if batch_idx % 10 == 0: | |
| print( | |
| "Valid Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( | |
| epoch, | |
| batch_idx * len(images), | |
| len(test_loader.dataset), | |
| 100.0 * batch_idx / len(test_loader), | |
| loss.item(), | |
| ) | |
| ) | |
| valid_loss.append(loss.sum().item() / len(images)) | |
| valid_acc.append(preds.sum().item() / len(images)) | |
| epoch_loss = test_loss / len(test_loader.dataset) | |
| epoch_acc = correct / len(test_loader.dataset) | |
| valid_loss.append(epoch_loss) | |
| valid_acc.append(epoch_acc.item()) | |
| print( | |
| "Valid Set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( | |
| epoch_loss, | |
| correct, | |
| len(test_loader.dataset), | |
| 100.0 * correct / len(test_loader.dataset), | |
| ) | |
| ) | |
| return valid_loss, valid_acc | |