File size: 496 Bytes
87acbba
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def top_k_accuracy(inp, targ, k=3):
    # inp is logits of shape (batch_size, n_classes)
    # targ is class indices of shape (batch_size,)
    
    # Get top 3 predictions
    _, pred = inp.topk(k=k, dim=1)  # pred shape: (batch_size, k)
    
    # Expand targ to compare with predictions
    targ = targ.view(-1, 1).expand_as(pred)  # shape: (batch_size, k)
    
    # Check if correct class is in top 3 predictions
    correct = (pred == targ).any(dim=1)
    
    return correct.float().mean()