Spaces:
Sleeping
Sleeping
| def top_k_accuracy(inp, targ, k=3): | |
| # inp is logits of shape (batch_size, n_classes) | |
| # targ is class indices of shape (batch_size,) | |
| # Get top 3 predictions | |
| _, pred = inp.topk(k=k, dim=1) # pred shape: (batch_size, k) | |
| # Expand targ to compare with predictions | |
| targ = targ.view(-1, 1).expand_as(pred) # shape: (batch_size, k) | |
| # Check if correct class is in top 3 predictions | |
| correct = (pred == targ).any(dim=1) | |
| return correct.float().mean() |