Flamekizer11's picture
Upload 27 files
64d0ccc verified
raw
history blame contribute delete
391 Bytes
# Utility functions for training machine learning models using PyTorch and calculating accuracy.
import torch
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def accuracy(outputs, labels):
preds = outputs.argmax(dim=1)
correct = (preds == labels).sum().item()
return correct / labels.size(0)