Denis Lebedev
Commit the app
87acbba unverified
raw
history blame contribute delete
496 Bytes
def top_k_accuracy(inp, targ, k=3):
# inp is logits of shape (batch_size, n_classes)
# targ is class indices of shape (batch_size,)
# Get top 3 predictions
_, pred = inp.topk(k=k, dim=1) # pred shape: (batch_size, k)
# Expand targ to compare with predictions
targ = targ.view(-1, 1).expand_as(pred) # shape: (batch_size, k)
# Check if correct class is in top 3 predictions
correct = (pred == targ).any(dim=1)
return correct.float().mean()