Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| class MyConvBlock(nn.Module): | |
| def __init__(self, in_ch, out_ch, dropout_p): | |
| kernel_size = 3 | |
| super().__init__() | |
| self.model = nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(), | |
| nn.Dropout(dropout_p), | |
| nn.MaxPool2d(2, stride=2) | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| def get_batch_accuracy(output, y, N): | |
| pred = output.argmax(dim=1, keepdim=True) | |
| correct = pred.eq(y.view_as(pred)).sum().item() | |
| return correct / N | |
| def train(model, train_loader, train_N, random_trans, optimizer, loss_function): | |
| loss = 0 | |
| accuracy = 0 | |
| model.train() | |
| for x, y in train_loader: | |
| output = model(random_trans(x)) | |
| optimizer.zero_grad() | |
| batch_loss = loss_function(output, y) | |
| batch_loss.backward() | |
| optimizer.step() | |
| loss += batch_loss.item() | |
| accuracy += get_batch_accuracy(output, y, train_N) | |
| print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy)) | |
| def validate(model, valid_loader, valid_N, loss_function): | |
| loss = 0 | |
| accuracy = 0 | |
| model.eval() | |
| with torch.no_grad(): | |
| for x, y in valid_loader: | |
| output = model(x) | |
| loss += loss_function(output, y).item() | |
| accuracy += get_batch_accuracy(output, y, valid_N) | |
| print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy)) |