Piraloco commited on
Commit
54ed43c
·
1 Parent(s): 34b899d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -1,8 +1,24 @@
1
  # Import necessary libraries
2
  import torch
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # Load the model
5
- model = torch.load("model.pth")
 
6
 
7
  agents = [
8
  'Brimstone',
@@ -69,6 +85,7 @@ ranks = [
69
  'Radiant',
70
  ]
71
 
 
72
  def preprocess_data(data):
73
  # Preprocess the data (replace this with your specific preprocessing steps)
74
  data[0] = ranks.index(data[0])
@@ -87,10 +104,10 @@ def make_prediction(rank,map,agent_picks):
87
  processed_data = preprocess_data(data)
88
 
89
  # Feed the data to the model
90
- output = model(processed_data)
91
 
92
  # Post-process the output (replace this with your specific post-processing steps)
93
- prediction = model(data)
94
  prediction = prediction.item()
95
  prediction = 0 if prediction < 0 else prediction
96
 
 
1
  # Import necessary libraries
2
  import torch
3
 
4
+ class Net(neural_network_module.Module):
5
+ def __init__(self):
6
+ super(Net, self).__init__()
7
+ self.fc1 = neural_network_module.Linear(len(data[0][0]), 128)
8
+ self.fc2 = neural_network_module.Linear(128, 64)
9
+ self.fc3 = neural_network_module.Linear(64, 1)
10
+
11
+ def forward(self, x):
12
+ x = torch.relu(self.fc1(x))
13
+ x = torch.relu(self.fc2(x))
14
+ x = self.fc3(x)
15
+ return x
16
+
17
+ # Create an instance of the network
18
+
19
  # Load the model
20
+ model = Net()
21
+ model.save_state_dict(torch.load('model.pth'))
22
 
23
  agents = [
24
  'Brimstone',
 
85
  'Radiant',
86
  ]
87
 
88
+
89
  def preprocess_data(data):
90
  # Preprocess the data (replace this with your specific preprocessing steps)
91
  data[0] = ranks.index(data[0])
 
104
  processed_data = preprocess_data(data)
105
 
106
  # Feed the data to the model
107
+ prediction = model(processed_data)
108
 
109
  # Post-process the output (replace this with your specific post-processing steps)
110
+ #prediction = model(data)
111
  prediction = prediction.item()
112
  prediction = 0 if prediction < 0 else prediction
113