Spaces:
Build error
Build error
| import torch | |
| from tqdm import tqdm | |
| from torch.amp import autocast | |
| def train(model, device, train_loader, optimizer, criterion, epoch, accumulation_steps=4): | |
| model.train() | |
| running_loss = 0.0 | |
| correct1 = 0 | |
| correct5 = 0 | |
| total = 0 | |
| pbar = tqdm(train_loader) | |
| for batch_idx, (inputs, targets) in enumerate(pbar): | |
| inputs, targets = inputs.to(device), targets.to(device) | |
| with autocast(device_type='cuda'): | |
| outputs = model(inputs) | |
| loss = criterion(outputs, targets) / accumulation_steps | |
| loss.backward() | |
| if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader): | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| running_loss += loss.item() * accumulation_steps | |
| _, predicted = outputs.topk(5, 1, True, True) | |
| total += targets.size(0) | |
| correct1 += predicted[:, :1].eq(targets.view(-1, 1).expand_as(predicted[:, :1])).sum().item() | |
| correct5 += predicted.eq(targets.view(-1, 1).expand_as(predicted)).sum().item() | |
| pbar.set_description(desc=f'Epoch {epoch} | Loss: {running_loss / (batch_idx + 1):.4f} | Top-1 Acc: {100. * correct1 / total:.2f} | Top-5 Acc: {100. * correct5 / total:.2f}') | |
| if (batch_idx + 1) % 50 == 0: | |
| torch.cuda.empty_cache() | |
| return 100. * correct1 / total, 100. * correct5 / total, running_loss / len(train_loader) | |
| def test(model, device, test_loader, criterion): | |
| model.eval() | |
| test_loss = 0 | |
| correct1 = 0 | |
| correct5 = 0 | |
| total = 0 | |
| misclassified_images = [] | |
| misclassified_labels = [] | |
| misclassified_preds = [] | |
| with torch.no_grad(): | |
| for inputs, targets in test_loader: | |
| inputs, targets = inputs.to(device), targets.to(device) | |
| outputs = model(inputs) | |
| loss = criterion(outputs, targets) | |
| test_loss += loss.item() | |
| _, predicted = outputs.topk(5, 1, True, True) | |
| total += targets.size(0) | |
| correct1 += predicted[:, :1].eq(targets.view(-1, 1).expand_as(predicted[:, :1])).sum().item() | |
| correct5 += predicted.eq(targets.view(-1, 1).expand_as(predicted)).sum().item() | |
| # Collect misclassified samples | |
| ''' | |
| for i in range(inputs.size(0)): | |
| if targets[i] not in predicted[i, :1]: | |
| misclassified_images.append(inputs[i].cpu()) | |
| misclassified_labels.append(targets[i].cpu()) | |
| misclassified_preds.append(predicted[i, :1].cpu()) | |
| ''' | |
| test_accuracy1 = 100. * correct1 / total | |
| test_accuracy5 = 100. * correct5 / total | |
| print(f'Test Loss: {test_loss/len(test_loader):.4f}, Top-1 Accuracy: {test_accuracy1:.2f}, Top-5 Accuracy: {test_accuracy5:.2f}') | |
| return test_accuracy1, test_accuracy5, test_loss / len(test_loader), misclassified_images, misclassified_labels, misclassified_preds | |