WRU-Net / model /metric.py
HirraA's picture
Upload 19 files
006869b verified
raw
history blame contribute delete
745 Bytes
import torch
from sklearn.metrics import f1_score as f1
from sklearn.metrics import precision_score, recall_score
def precision(output, target):
with torch.no_grad():
pred = torch.argmax(output, dim=1)
assert pred.shape[0] == len(target)
return precision_score(target.view(-1).cpu(), pred.view(-1).cpu())
def recall(output, target):
with torch.no_grad():
pred = torch.argmax(output, dim=1)
assert pred.shape[0] == len(target)
return recall_score(target.view(-1).cpu(), pred.view(-1).cpu())
def f1_score(output, target):
with torch.no_grad():
pred = torch.argmax(output, dim=1)
assert pred.shape[0] == len(target)
return f1(target.view(-1).cpu(), pred.view(-1).cpu())