JAMM032's picture
Upload github repo files
97fcc90 verified
raw
history blame contribute delete
252 Bytes
import torch
def accuracy(logits, y):
return (logits.argmax(1) == y).float().mean().item()
def topk_accuracy(logits, y, k=5):
topk = logits.topk(k, dim=1).indices
return (topk == y.unsqueeze(1)).any(dim=1).float().mean().item()