File size: 401 Bytes
6733993 2eebda5 6733993 2eebda5 6733993 2eebda5 893cb78 6733993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch
from device import device
model = torch.load("model.pth", weights_only=False).to(device)
def run(test):
with torch.no_grad():
test_data = torch.tensor([test], dtype=torch.float).to(device)
predictions: torch.Tensor = model(test_data)
return predictions.squeeze().item()
if __name__ == '__main__':
x, y = map(int, input().split())
print(run([x, y]))
|