import torch def get_label_from_output(output_tensor): pred = torch.argmax(output_tensor, dim=1).item() return "บวก" if pred == 1 else "ลบ"