def top_k_accuracy(inp, targ, k=3): # inp is logits of shape (batch_size, n_classes) # targ is class indices of shape (batch_size,) # Get top 3 predictions _, pred = inp.topk(k=k, dim=1) # pred shape: (batch_size, k) # Expand targ to compare with predictions targ = targ.view(-1, 1).expand_as(pred) # shape: (batch_size, k) # Check if correct class is in top 3 predictions correct = (pred == targ).any(dim=1) return correct.float().mean()