Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -58,18 +58,15 @@ def infer(model, replay_file,
|
|
| 58 |
ot_timer = ot_time_remaining[0] - ot_time_remaining
|
| 59 |
timer[is_ot] = -ot_timer # Negate to indicate overtime
|
| 60 |
|
| 61 |
-
goal_diff =
|
| 62 |
goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0]))
|
| 63 |
|
| 64 |
-
# if nullify_goal_difference:
|
| 65 |
-
# features.scoreboard[..., 2] = 0
|
| 66 |
-
|
| 67 |
bs = 512
|
| 68 |
predictions = []
|
| 69 |
it = trange(len(serialized_states), desc="Running model")
|
| 70 |
for i in range(0, len(serialized_states), bs):
|
| 71 |
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
|
| 72 |
-
|
| 73 |
if nullify_goal_difference:
|
| 74 |
batch[:, SB_BLUE_SCORE] = 0
|
| 75 |
batch[:, SB_ORANGE_SCORE] = 0
|
|
|
|
| 58 |
ot_timer = ot_time_remaining[0] - ot_time_remaining
|
| 59 |
timer[is_ot] = -ot_timer # Negate to indicate overtime
|
| 60 |
|
| 61 |
+
goal_diff = serialized_scoreboards[:, SB_BLUE_SCORE] - serialized_scoreboards[:, SB_ORANGE_SCORE]
|
| 62 |
goal_diff_diff = goal_diff.diff(prepend=torch.Tensor([0]))
|
| 63 |
|
|
|
|
|
|
|
|
|
|
| 64 |
bs = 512
|
| 65 |
predictions = []
|
| 66 |
it = trange(len(serialized_states), desc="Running model")
|
| 67 |
for i in range(0, len(serialized_states), bs):
|
| 68 |
batch = (serialized_states[i:i + bs].clone().to(DEVICE),
|
| 69 |
+
serialized_scoreboards[i:i + bs].clone().to(DEVICE))
|
| 70 |
if nullify_goal_difference:
|
| 71 |
batch[:, SB_BLUE_SCORE] = 0
|
| 72 |
batch[:, SB_ORANGE_SCORE] = 0
|