piralocoplasticone commited on
Commit
9db4f63
·
1 Parent(s): 1795837

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class NeuralNetwork(nn.Module):
7
+ def __init__(self, input_size, hidden_size, output_size):
8
+ """
9
+ Initializes a neural network model.
10
+
11
+ Args:
12
+ input_size (int): The size of the input layer.
13
+ hidden_size (int): The size of the hidden layer.
14
+ output_size (int): The size of the output layer.
15
+ """
16
+ super(NeuralNetwork, self).__init__()
17
+ self.fc1 = nn.Linear(input_size, hidden_size)
18
+ self.relu = nn.ReLU()
19
+ self.fc2 = nn.Linear(hidden_size, output_size)
20
+
21
+ def forward(self, input_image):
22
+ """
23
+ Performs a forward pass through the neural network.
24
+
25
+ Args:
26
+ input_image (torch.Tensor): The input image tensor.
27
+
28
+ Returns:
29
+ torch.Tensor: The output tensor of the neural network.
30
+ """
31
+ input_image = self.relu(self.fc1(input_image))
32
+ input_image = self.fc2(input_image)
33
+ return input_image
34
+
35
+ # Load the pre-trained model
36
+ model = NeuralNetwork(14, 64, 2)
37
+ model.load_state_dict(torch.load("model.pth"))
38
+
39
+ # List of all Valorant agents
40
+ maps = [
41
+ 'Ascent',
42
+ 'Bind',
43
+ 'Breeze',
44
+ 'Fracture',
45
+ 'Haven',
46
+ 'Icebox',
47
+ 'Lotus',
48
+ 'Pearl',
49
+ 'Split',
50
+ 'Sunset',
51
+ ]
52
+
53
+ agents = [
54
+ 'Brimstone',
55
+ 'Viper',
56
+ 'Omen',
57
+ 'Killjoy',
58
+ 'Cypher',
59
+ 'Sova',
60
+ 'Sage',
61
+ 'Phoenix',
62
+ 'Jett',
63
+ 'Reyna',
64
+ 'Raze',
65
+ 'Breach',
66
+ 'Skye',
67
+ 'Yoru',
68
+ 'Astra',
69
+ 'Kayo',
70
+ 'Chamber',
71
+ 'Neon',
72
+ 'Fade',
73
+ 'Harbor',
74
+ 'Gekko',
75
+ 'Deadlock',
76
+ 'Iso',
77
+ ]
78
+
79
+
80
+
81
+ # Define the prediction function
82
+ def predict(*args):
83
+ def test_convert(test):
84
+ test[3] = maps.index(test[3])
85
+ test[4:9] = [agents.index(index) for index in test[4:9]]
86
+ test[9:14] = [agents.index(index) for index in test[9:14]]
87
+
88
+ return test
89
+
90
+ data = list(args)
91
+ data = test_convert(data)
92
+ data = torch.tensor(data, dtype=torch.float32)
93
+
94
+ outputs = model(data)
95
+ highest_score = (torch.max(outputs), torch.argmax(outputs).item())
96
+
97
+ if highest_score[0] < 13:
98
+ outputs[highest_score[1]] = 13
99
+ else:
100
+ if outputs[1-highest_score[1]] < highest_score[0] - 2:
101
+ outputs[1-highest_score[1]] = highest_score[0] - 2
102
+
103
+ score_a = round(outputs[0].item())
104
+ score_b = round(outputs[1].item())
105
+
106
+ return f'Predicted score: {score_a} - {score_b}'
107
+
108
+
109
+
110
+ # Define the output component
111
+ with gr.Blocks() as demo:
112
+ # Frame for date and map
113
+
114
+ with gr.Row():
115
+ with gr.Column(min_width="0px", scale=1):
116
+ year_input = gr.Number(label="Year", value=23)
117
+ with gr.Column(min_width="0px", scale=1):
118
+ month_input = gr.Number(label="Month", value=2)
119
+ with gr.Column(min_width="0px", scale=1):
120
+ day_input = gr.Number(label="Day", value=23)
121
+ with gr.Column(scale=3):
122
+ map_input = gr.Dropdown(maps, label="Map", value='Ascent')
123
+ with gr.Column(scale=3):
124
+ pass
125
+
126
+ # Frames for agents' dropdowns
127
+ with gr.Row():
128
+ with gr.Column():
129
+ # Team 1 agent dropdowns
130
+ team1_agent1_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 1", value='Brimstone')
131
+ team1_agent2_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 2", value='Viper')
132
+ team1_agent3_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 3", value='Omen')
133
+ team1_agent4_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 4", value='Killjoy')
134
+ team1_agent5_input = gr.Dropdown(choices=agents, label="Team 1 - Agent 5", value='Cypher')
135
+
136
+ with gr.Column():
137
+ # Team 2 agent dropdowns
138
+ team2_agent1_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 1", value='Sova')
139
+ team2_agent2_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 2", value='Sage')
140
+ team2_agent3_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 3", value='Phoenix')
141
+ team2_agent4_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 4", value='Jett')
142
+ team2_agent5_input = gr.Dropdown(choices=agents, label="Team 2 - Agent 5", value='Reyna')
143
+ # ... add all dropdowns for Team 2
144
+
145
+ with gr.Column():
146
+
147
+ translate_btn = gr.Button(value="Translate")
148
+ # Add any outputs you have
149
+ score_difference_output = gr.Textbox(label="Score Difference")
150
+ translate_btn.click(fn=predict, inputs=[year_input, month_input, day_input, map_input, team1_agent1_input, team1_agent2_input, team1_agent3_input, team1_agent4_input, team1_agent5_input, team2_agent1_input, team2_agent2_input, team2_agent3_input, team2_agent4_input, team2_agent5_input], outputs=score_difference_output)
151
+
152
+ print('Lauching interface!')
153
+ demo.launch()