File size: 160 Bytes
284ba3b
 
 
 
 
 
1
2
3
4
5
6
7

import torch

def get_label_from_output(output_tensor):
    pred = torch.argmax(output_tensor, dim=1).item()
    return "บวก" if pred == 1 else "ลบ"