Spaces:
Runtime error
Runtime error
add error screen
Browse files- connectfour/app.py +19 -1
- connectfour/training/train.py +11 -3
- error-screen.npy +3 -0
connectfour/app.py
CHANGED
|
@@ -14,6 +14,7 @@ from connectfour.training.wrappers import Connect4Env
|
|
| 14 |
POLICY_ID = "learned_v5"
|
| 15 |
|
| 16 |
# poetry export -f requirements.txt --output requirements.txt --without-hashes
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
class Connect4:
|
|
@@ -55,6 +56,9 @@ class Connect4:
|
|
| 55 |
self.algo.restore(checkpoint)
|
| 56 |
|
| 57 |
def play(self, action=None):
|
|
|
|
|
|
|
|
|
|
| 58 |
if self.human != self.player_id:
|
| 59 |
action = self.algo.compute_single_action(
|
| 60 |
self.obs[self.player_id], policy_id=POLICY_ID
|
|
@@ -73,8 +77,11 @@ class Connect4:
|
|
| 73 |
|
| 74 |
@property
|
| 75 |
def render_and_state(self):
|
| 76 |
-
end_message = "End of the game"
|
| 77 |
if self.done:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
if self.reward[self.human] > 0:
|
| 79 |
end_message += ": You WIN !!"
|
| 80 |
elif self.reward[self.human] < 0:
|
|
@@ -83,6 +90,12 @@ class Connect4:
|
|
| 83 |
|
| 84 |
return self.env.render(), "Game On"
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
@property
|
| 87 |
def player_id(self):
|
| 88 |
return list(self.obs.keys())[0]
|
|
@@ -91,6 +104,11 @@ class Connect4:
|
|
| 91 |
def legal_moves(self):
|
| 92 |
return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1]
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
demo = gr.Blocks()
|
| 96 |
|
|
|
|
| 14 |
POLICY_ID = "learned_v5"
|
| 15 |
|
| 16 |
# poetry export -f requirements.txt --output requirements.txt --without-hashes
|
| 17 |
+
# gradio connectfour/app.py
|
| 18 |
|
| 19 |
|
| 20 |
class Connect4:
|
|
|
|
| 56 |
self.algo.restore(checkpoint)
|
| 57 |
|
| 58 |
def play(self, action=None):
|
| 59 |
+
if self.has_erroneous_state():
|
| 60 |
+
return self.blue_screen()
|
| 61 |
+
|
| 62 |
if self.human != self.player_id:
|
| 63 |
action = self.algo.compute_single_action(
|
| 64 |
self.obs[self.player_id], policy_id=POLICY_ID
|
|
|
|
| 77 |
|
| 78 |
@property
|
| 79 |
def render_and_state(self):
|
|
|
|
| 80 |
if self.done:
|
| 81 |
+
if hasattr(self, "reward") and self.human not in self.reward:
|
| 82 |
+
return self.blue_screen()
|
| 83 |
+
|
| 84 |
+
end_message = "End of the game"
|
| 85 |
if self.reward[self.human] > 0:
|
| 86 |
end_message += ": You WIN !!"
|
| 87 |
elif self.reward[self.human] < 0:
|
|
|
|
| 90 |
|
| 91 |
return self.env.render(), "Game On"
|
| 92 |
|
| 93 |
+
def blue_screen(self):
|
| 94 |
+
with open("error-screen.npy", "rb") as f:
|
| 95 |
+
error_screen = np.load(f)
|
| 96 |
+
|
| 97 |
+
return (error_screen, "Restart the Game")
|
| 98 |
+
|
| 99 |
@property
|
| 100 |
def player_id(self):
|
| 101 |
return list(self.obs.keys())[0]
|
|
|
|
| 104 |
def legal_moves(self):
|
| 105 |
return np.arange(7)[self.obs[self.player_id]["action_mask"] == 1]
|
| 106 |
|
| 107 |
+
def has_erroneous_state(self):
|
| 108 |
+
if len(list(self.obs.keys())) == 0:
|
| 109 |
+
return True
|
| 110 |
+
return False
|
| 111 |
+
|
| 112 |
|
| 113 |
demo = gr.Blocks()
|
| 114 |
|
connectfour/training/train.py
CHANGED
|
@@ -27,7 +27,7 @@ def get_cli_args():
|
|
| 27 |
Create CLI parser and return parsed arguments
|
| 28 |
|
| 29 |
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
|
| 30 |
-
python connectfour/training/train.py --num-gpus 1 --stop-iters
|
| 31 |
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
|
| 32 |
"""
|
| 33 |
parser = argparse.ArgumentParser()
|
|
@@ -68,7 +68,10 @@ if __name__ == "__main__":
|
|
| 68 |
args = get_cli_args()
|
| 69 |
|
| 70 |
ray.init(
|
| 71 |
-
num_cpus=args.num_cpus or None,
|
|
|
|
|
|
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
# define how to make the environment
|
|
@@ -126,6 +129,8 @@ if __name__ == "__main__":
|
|
| 126 |
"win_rate": "win_rate",
|
| 127 |
"league_size": "league_size",
|
| 128 |
},
|
|
|
|
|
|
|
| 129 |
sort_by_metric=True,
|
| 130 |
),
|
| 131 |
checkpoint_config=air.CheckpointConfig(
|
|
@@ -135,6 +140,9 @@ if __name__ == "__main__":
|
|
| 135 |
),
|
| 136 |
).fit()
|
| 137 |
|
| 138 |
-
print(
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
ray.shutdown()
|
|
|
|
| 27 |
Create CLI parser and return parsed arguments
|
| 28 |
|
| 29 |
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
|
| 30 |
+
python connectfour/training/train.py --num-gpus 1 --stop-iters 1 --win-rate-threshold 0.50
|
| 31 |
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
|
| 32 |
"""
|
| 33 |
parser = argparse.ArgumentParser()
|
|
|
|
| 68 |
args = get_cli_args()
|
| 69 |
|
| 70 |
ray.init(
|
| 71 |
+
num_cpus=args.num_cpus or None,
|
| 72 |
+
num_gpus=args.num_gpus,
|
| 73 |
+
include_dashboard=False,
|
| 74 |
+
resources={"accelerator_type:RTX": 1},
|
| 75 |
)
|
| 76 |
|
| 77 |
# define how to make the environment
|
|
|
|
| 129 |
"win_rate": "win_rate",
|
| 130 |
"league_size": "league_size",
|
| 131 |
},
|
| 132 |
+
mode="max",
|
| 133 |
+
metric="win_rate",
|
| 134 |
sort_by_metric=True,
|
| 135 |
),
|
| 136 |
checkpoint_config=air.CheckpointConfig(
|
|
|
|
| 140 |
),
|
| 141 |
).fit()
|
| 142 |
|
| 143 |
+
print(
|
| 144 |
+
"Best checkpoint",
|
| 145 |
+
results.get_best_result(metric="win_rate", mode="max").checkpoint,
|
| 146 |
+
)
|
| 147 |
|
| 148 |
ray.shutdown()
|
error-screen.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6b72c5a148f41583927cd127d1d2b51073adec2ebd33ace7d4c074142d16d992
|
| 3 |
+
size 4316726
|