Spaces:
Sleeping
Sleeping
File size: 796 Bytes
21ff05c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from torch import nn
import torch
input_size = 4
hidden_size = 64
output_size = 5
class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
def predict(model, input):
model.eval()
input_tensors = torch.tensor(input, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
output = model(input_tensors)
probabilities = torch.softmax(output, dim=1)
predicted_class_index = torch.argmax(probabilities, dim=1).item()
return predicted_class_index
|