File size: 401 Bytes
6733993
2eebda5
6733993
2eebda5
6733993
 
 
2eebda5
893cb78
 
6733993
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from device import device

model = torch.load("model.pth", weights_only=False).to(device)

def run(test):
    with torch.no_grad():
        test_data = torch.tensor([test], dtype=torch.float).to(device)
        predictions: torch.Tensor = model(test_data)
        return predictions.squeeze().item()

if __name__ == '__main__':
    x, y = map(int, input().split())
    print(run([x, y]))