Spaces:
Runtime error
Runtime error
update
Browse files- connectfour/app.py +40 -19
- connectfour/checkpoint/policies/always_same/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/always_same/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/policies/beat_last/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/beat_last/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/policies/learned/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/learned/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/policies/learned_v1/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/learned_v1/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/policies/learned_v2/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/learned_v2/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/policies/learned_v3/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/learned_v3/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/policies/learned_v4/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/learned_v4/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/policies/learned_v5/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/learned_v5/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/policies/linear/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/linear/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/policies/random/policy_state.pkl +0 -3
- connectfour/checkpoint/policies/random/rllib_checkpoint.json +0 -1
- connectfour/checkpoint/rllib_checkpoint.json +0 -1
- connectfour/training/__pycache__/callbacks.cpython-38.pyc +0 -0
- connectfour/training/__pycache__/dummy_policies.cpython-38.pyc +0 -0
- connectfour/training/__pycache__/wrappers.cpython-38.pyc +0 -0
- connectfour/training/callbacks.py +93 -51
- connectfour/training/dummy_policies.py +4 -0
- connectfour/training/train.py +117 -45
- connectfour/training/wrappers.py +7 -96
- models/__init__.py +3 -0
- connectfour/checkpoint/algorithm_state.pkl → models/model.onnx +2 -2
- poetry.lock +416 -33
- pyproject.toml +4 -1
connectfour/app.py
CHANGED
|
@@ -2,25 +2,25 @@ import time
|
|
| 2 |
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
| 5 |
-
import ray
|
| 6 |
-
import ray.rllib.algorithms.ppo as ppo
|
| 7 |
from pettingzoo.classic import connect_four_v3
|
|
|
|
| 8 |
from ray.tune import register_env
|
| 9 |
|
| 10 |
-
from connectfour.checkpoint import CHECKPOINT
|
| 11 |
-
from connectfour.training.models import Connect4MaskModel
|
| 12 |
from connectfour.training.wrappers import Connect4Env
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
# poetry export -f requirements.txt --output requirements.txt --without-hashes
|
|
|
|
| 17 |
# gradio connectfour/app.py
|
| 18 |
|
| 19 |
|
| 20 |
class Connect4:
|
| 21 |
def __init__(self, who_plays_first) -> None:
|
| 22 |
-
ray.init(include_dashboard=False, ignore_reinit_error=True)
|
| 23 |
-
|
| 24 |
# define how to make the environment
|
| 25 |
env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
|
| 26 |
|
|
@@ -44,25 +44,46 @@ class Connect4:
|
|
| 44 |
|
| 45 |
return self.render_and_state
|
| 46 |
|
| 47 |
-
def get_algo(self
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
)
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
self.algo.restore(checkpoint)
|
| 57 |
|
| 58 |
def play(self, action=None):
|
| 59 |
if self.has_erroneous_state():
|
| 60 |
return self.blue_screen()
|
| 61 |
|
| 62 |
if self.human != self.player_id:
|
| 63 |
-
action = self.algo.compute_single_action(
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
)
|
|
|
|
| 66 |
|
| 67 |
if action not in self.legal_moves:
|
| 68 |
action = np.random.choice(self.legal_moves)
|
|
@@ -114,7 +135,7 @@ demo = gr.Blocks()
|
|
| 114 |
|
| 115 |
with demo:
|
| 116 |
connect4 = Connect4("You")
|
| 117 |
-
connect4.get_algo(
|
| 118 |
|
| 119 |
with gr.Row():
|
| 120 |
with gr.Column(scale=1):
|
|
|
|
| 2 |
|
| 3 |
import gradio as gr
|
| 4 |
import numpy as np
|
|
|
|
|
|
|
| 5 |
from pettingzoo.classic import connect_four_v3
|
| 6 |
+
from ray.rllib.utils.framework import try_import_torch
|
| 7 |
from ray.tune import register_env
|
| 8 |
|
|
|
|
|
|
|
| 9 |
from connectfour.training.wrappers import Connect4Env
|
| 10 |
+
from models import MODEL_PATH
|
| 11 |
+
import onnxruntime as ort
|
| 12 |
+
from ray.rllib.algorithms.algorithm import Algorithm
|
| 13 |
+
from connectfour.checkpoint import CHECKPOINT
|
| 14 |
|
| 15 |
+
torch, nn = try_import_torch()
|
| 16 |
|
| 17 |
# poetry export -f requirements.txt --output requirements.txt --without-hashes
|
| 18 |
+
# tensorboard --logdir ~/ray_results/
|
| 19 |
# gradio connectfour/app.py
|
| 20 |
|
| 21 |
|
| 22 |
class Connect4:
|
| 23 |
def __init__(self, who_plays_first) -> None:
|
|
|
|
|
|
|
| 24 |
# define how to make the environment
|
| 25 |
env_creator = lambda config: connect_four_v3.env(render_mode="rgb_array")
|
| 26 |
|
|
|
|
| 44 |
|
| 45 |
return self.render_and_state
|
| 46 |
|
| 47 |
+
def get_algo(self):
|
| 48 |
+
# self.pytorch_model = torch.load(MODEL_PATH / "model.pt")
|
| 49 |
+
# self.algo = Algorithm.from_checkpoint(checkpoint=CHECKPOINT)
|
| 50 |
+
self.session = ort.InferenceSession(str(MODEL_PATH / "model.onnx"), None)
|
| 51 |
+
|
| 52 |
+
def compute_action(self, obs):
|
| 53 |
+
return self.pytorch_model(
|
| 54 |
+
input_dict={"obs": self.flatten_obs(obs)},
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def flatten_obs(self, obs):
|
| 58 |
+
flatten_action_mask = torch.from_numpy(obs["action_mask"])
|
| 59 |
+
flatten_observation = torch.flatten(
|
| 60 |
+
torch.from_numpy(obs["observation"]), end_dim=2
|
| 61 |
)
|
| 62 |
+
flatten_obs = torch.concat([flatten_action_mask, flatten_observation])
|
| 63 |
+
return flatten_obs[None, :]
|
|
|
|
| 64 |
|
| 65 |
def play(self, action=None):
|
| 66 |
if self.has_erroneous_state():
|
| 67 |
return self.blue_screen()
|
| 68 |
|
| 69 |
if self.human != self.player_id:
|
| 70 |
+
# action = self.algo.compute_single_action(
|
| 71 |
+
# self.obs[self.player_id], policy_id="learned_v9"
|
| 72 |
+
# )
|
| 73 |
+
# Torch
|
| 74 |
+
# action = self.compute_action(self.obs[self.player_id])
|
| 75 |
+
# action = int(torch.argmax(action[0]))
|
| 76 |
+
# ONNX
|
| 77 |
+
action = self.session.run(
|
| 78 |
+
["output"],
|
| 79 |
+
{
|
| 80 |
+
"obs": self.flatten_obs(self.obs[self.player_id])
|
| 81 |
+
.numpy()
|
| 82 |
+
.astype(np.float32),
|
| 83 |
+
"state_ins": [],
|
| 84 |
+
},
|
| 85 |
)
|
| 86 |
+
action = int(np.argmax(action[0]))
|
| 87 |
|
| 88 |
if action not in self.legal_moves:
|
| 89 |
action = np.random.choice(self.legal_moves)
|
|
|
|
| 135 |
|
| 136 |
with demo:
|
| 137 |
connect4 = Connect4("You")
|
| 138 |
+
connect4.get_algo()
|
| 139 |
|
| 140 |
with gr.Row():
|
| 141 |
with gr.Column(scale=1):
|
connectfour/checkpoint/policies/always_same/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d278413093ad1bc4f227279e3dab7be04ebd70ca1ed156a1363515c69d0a858e
|
| 3 |
-
size 10992
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/always_same/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/policies/beat_last/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:cd422258c16de0866599730a5a5b2b48e2ee81cbae69f9d5471deeae76c42b47
|
| 3 |
-
size 10992
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/beat_last/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/policies/learned/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:2a517583e5fcad7e483bca619723583cc6928499390c1fcfc25d907e109cd4b4
|
| 3 |
-
size 2139442
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/learned/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v1/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:276c26007c2419a688c27f9dfa70c20fecb468a0aa07d28d6a9e8099bbc849be
|
| 3 |
-
size 2139439
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v1/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v2/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:e37a485d3a54f7a8b194693e7a61f790e67071358130178fa01cdbd840c4a4da
|
| 3 |
-
size 2139439
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v2/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v3/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:9f90899ae98a387e312333b234041c68b9c50da4af92ee5250686087a39eebb3
|
| 3 |
-
size 2139439
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v3/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v4/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:3af3b3fe41bac489cb693af387b1ccc4437a532a78d539b3abb4cc5f77929592
|
| 3 |
-
size 2139439
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v4/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v5/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f2b28b979e2f4411d196e03ca75ea7f25f7601bb997aa8bcdcf1d49c9ea30754
|
| 3 |
-
size 2139439
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/learned_v5/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/policies/linear/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:4f70d44ac661632dc0557204abe34308dfb25b800a668b49c2efd9a2a73a7bc0
|
| 3 |
-
size 10992
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/linear/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/policies/random/policy_state.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f3b1ab86bada035779feedb2b92ae0a64f6d9474bb4f0ae44324e17d65659764
|
| 3 |
-
size 10992
|
|
|
|
|
|
|
|
|
|
|
|
connectfour/checkpoint/policies/random/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Policy", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/checkpoint/rllib_checkpoint.json
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
{"type": "Algorithm", "checkpoint_version": "1.0", "ray_version": "2.3.1", "ray_commit": "5f14cee8dfc6d61ec4fd3bc2c440f9944e92b33a"}
|
|
|
|
|
|
connectfour/training/__pycache__/callbacks.cpython-38.pyc
CHANGED
|
Binary files a/connectfour/training/__pycache__/callbacks.cpython-38.pyc and b/connectfour/training/__pycache__/callbacks.cpython-38.pyc differ
|
|
|
connectfour/training/__pycache__/dummy_policies.cpython-38.pyc
CHANGED
|
Binary files a/connectfour/training/__pycache__/dummy_policies.cpython-38.pyc and b/connectfour/training/__pycache__/dummy_policies.cpython-38.pyc differ
|
|
|
connectfour/training/__pycache__/wrappers.cpython-38.pyc
CHANGED
|
Binary files a/connectfour/training/__pycache__/wrappers.cpython-38.pyc and b/connectfour/training/__pycache__/wrappers.cpython-38.pyc differ
|
|
|
connectfour/training/callbacks.py
CHANGED
|
@@ -1,33 +1,44 @@
|
|
| 1 |
-
from
|
|
|
|
| 2 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
-
def create_self_play_callback(win_rate_thr, opponent_policies):
|
| 6 |
class SelfPlayCallback(DefaultCallbacks):
|
| 7 |
win_rate_threshold = win_rate_thr
|
| 8 |
|
| 9 |
def __init__(self):
|
| 10 |
super().__init__()
|
| 11 |
self.current_opponent = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def on_train_result(self, *, algorithm, result, **kwargs):
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
main_rew = result["hist_stats"].pop("policy_learned_reward")
|
| 19 |
opponent_rew = result["hist_stats"].pop("episode_reward")
|
| 20 |
|
| 21 |
-
if len(main_rew) != len(opponent_rew):
|
| 22 |
-
raise Exception(
|
| 23 |
-
"len(main_rew) != len(opponent_rew)",
|
| 24 |
-
len(main_rew),
|
| 25 |
-
len(opponent_rew),
|
| 26 |
-
result["hist_stats"].keys(),
|
| 27 |
-
"episode len",
|
| 28 |
-
len(opponent_rew),
|
| 29 |
-
)
|
| 30 |
-
|
| 31 |
won = 0
|
| 32 |
for r_main, r_opponent in zip(main_rew, opponent_rew):
|
| 33 |
if r_main > r_opponent:
|
|
@@ -35,54 +46,85 @@ def create_self_play_callback(win_rate_thr, opponent_policies):
|
|
| 35 |
win_rate = won / len(main_rew)
|
| 36 |
|
| 37 |
result["win_rate"] = win_rate
|
| 38 |
-
print(f"Iter={algorithm.iteration} win-rate={win_rate}
|
| 39 |
|
| 40 |
-
# If win rate is good -> Snapshot current policy and play against
|
| 41 |
-
# it next, keeping the snapshot fixed and only improving the "learned"
|
| 42 |
-
# policy.
|
| 43 |
if win_rate > self.win_rate_threshold:
|
| 44 |
-
self.
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
# Re-define the mapping function, such that "learned" is forced
|
| 51 |
-
# to play against any of the previously played policies
|
| 52 |
-
# (excluding "random").
|
| 53 |
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
| 54 |
-
# agent_id = [0|1] -> policy depends on episode ID
|
| 55 |
-
# This way, we make sure that both policies sometimes play
|
| 56 |
-
# (start player) and sometimes agent1 (player to move 2nd).
|
| 57 |
return (
|
| 58 |
"learned"
|
| 59 |
if episode.episode_id % 2 == int(agent_id[-1:])
|
| 60 |
-
else np.random.choice(
|
| 61 |
-
opponent_policies
|
| 62 |
-
+ [
|
| 63 |
-
f"learned_v{i}"
|
| 64 |
-
for i in range(1, self.current_opponent + 1)
|
| 65 |
-
]
|
| 66 |
-
)
|
| 67 |
)
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
policy_cls=type(algorithm.get_policy("learned")),
|
| 72 |
-
policy_mapping_fn=policy_mapping_fn,
|
| 73 |
)
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
algorithm.workers.sync_weights()
|
|
|
|
| 83 |
else:
|
| 84 |
-
print("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
return SelfPlayCallback
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
|
| 3 |
import numpy as np
|
| 4 |
+
from ray.rllib.algorithms.callbacks import DefaultCallbacks
|
| 5 |
+
|
| 6 |
+
from connectfour.training.dummy_policies import (
|
| 7 |
+
AlwaysSameHeuristic,
|
| 8 |
+
BeatLastHeuristic,
|
| 9 |
+
LinearHeuristic,
|
| 10 |
+
RandomHeuristic,
|
| 11 |
+
)
|
| 12 |
|
| 13 |
|
| 14 |
+
def create_self_play_callback(win_rate_thr, opponent_policies, opponent_count=10):
|
| 15 |
class SelfPlayCallback(DefaultCallbacks):
|
| 16 |
win_rate_threshold = win_rate_thr
|
| 17 |
|
| 18 |
def __init__(self):
|
| 19 |
super().__init__()
|
| 20 |
self.current_opponent = 0
|
| 21 |
+
self.opponent_policies = deque(opponent_policies, maxlen=opponent_count)
|
| 22 |
+
self.policy_to_remove = None
|
| 23 |
+
self.frozen_policies = {
|
| 24 |
+
"always_same": AlwaysSameHeuristic,
|
| 25 |
+
"linear": LinearHeuristic,
|
| 26 |
+
"beat_last": BeatLastHeuristic,
|
| 27 |
+
"random": RandomHeuristic,
|
| 28 |
+
}
|
| 29 |
|
| 30 |
def on_train_result(self, *, algorithm, result, **kwargs):
|
| 31 |
+
"""Called at the end of Algorithm.train().
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
algorithm: Current Algorithm instance.
|
| 35 |
+
result: Dict of results returned from Algorithm.train() call.
|
| 36 |
+
You can mutate this object to add additional metrics.
|
| 37 |
+
kwargs: Forward compatibility placeholder.
|
| 38 |
+
"""
|
| 39 |
main_rew = result["hist_stats"].pop("policy_learned_reward")
|
| 40 |
opponent_rew = result["hist_stats"].pop("episode_reward")
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
won = 0
|
| 43 |
for r_main, r_opponent in zip(main_rew, opponent_rew):
|
| 44 |
if r_main > r_opponent:
|
|
|
|
| 46 |
win_rate = won / len(main_rew)
|
| 47 |
|
| 48 |
result["win_rate"] = win_rate
|
| 49 |
+
print(f"Iter={algorithm.iteration} win-rate={win_rate}")
|
| 50 |
|
|
|
|
|
|
|
|
|
|
| 51 |
if win_rate > self.win_rate_threshold:
|
| 52 |
+
if len(self.opponent_policies) == self.opponent_policies.maxlen:
|
| 53 |
+
self.policy_to_remove = self.opponent_policies[0]
|
| 54 |
+
|
| 55 |
+
new_pol_id = None
|
| 56 |
+
while new_pol_id is None:
|
| 57 |
+
if np.random.choice(range(6)) == 0:
|
| 58 |
+
new_pol_id = np.random.choice(list(self.frozen_policies.keys()))
|
| 59 |
+
else:
|
| 60 |
+
self.current_opponent += 1
|
| 61 |
+
new_pol_id = f"learned_v{self.current_opponent}"
|
| 62 |
+
|
| 63 |
+
if new_pol_id in self.opponent_policies:
|
| 64 |
+
new_pol_id = None
|
| 65 |
+
else:
|
| 66 |
+
self.opponent_policies.append(new_pol_id)
|
| 67 |
+
|
| 68 |
+
print("Non trainable policies", list(self.opponent_policies))
|
| 69 |
|
|
|
|
|
|
|
|
|
|
| 70 |
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
|
|
|
|
|
|
|
|
|
| 71 |
return (
|
| 72 |
"learned"
|
| 73 |
if episode.episode_id % 2 == int(agent_id[-1:])
|
| 74 |
+
else np.random.choice(list(self.opponent_policies))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
)
|
| 76 |
|
| 77 |
+
print(
|
| 78 |
+
f"Iter={algorithm.iteration} Adding new opponent to the mix ({new_pol_id}). League size {len(self.opponent_policies) + 1}"
|
|
|
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
+
if new_pol_id in list(self.frozen_policies.keys()):
|
| 82 |
+
new_policy = algorithm.add_policy(
|
| 83 |
+
policy_id=new_pol_id,
|
| 84 |
+
policy_cls=self.frozen_policies[new_pol_id],
|
| 85 |
+
policy_mapping_fn=policy_mapping_fn,
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
new_policy = algorithm.add_policy(
|
| 89 |
+
policy_id=new_pol_id,
|
| 90 |
+
policy_cls=type(algorithm.get_policy("learned")),
|
| 91 |
+
policy_mapping_fn=policy_mapping_fn,
|
| 92 |
+
)
|
| 93 |
+
learned_state = algorithm.get_policy("learned").get_state()
|
| 94 |
+
new_policy.set_state(learned_state)
|
| 95 |
algorithm.workers.sync_weights()
|
| 96 |
+
|
| 97 |
else:
|
| 98 |
+
print("Not good enough... Keep learning ...")
|
| 99 |
+
|
| 100 |
+
result["league_size"] = len(self.opponent_policies) + 1
|
| 101 |
+
|
| 102 |
+
def on_evaluate_end(self, *, algorithm, evaluation_metrics, **kwargs):
|
| 103 |
+
"""Runs when the evaluation is done.
|
| 104 |
+
|
| 105 |
+
Runs at the end of Algorithm.evaluate().
|
| 106 |
|
| 107 |
+
Args:
|
| 108 |
+
algorithm: Reference to the algorithm instance.
|
| 109 |
+
evaluation_metrics: Results dict to be returned from algorithm.evaluate().
|
| 110 |
+
You can mutate this object to add additional metrics.
|
| 111 |
+
kwargs: Forward compatibility placeholder.
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
|
| 115 |
+
return (
|
| 116 |
+
"learned"
|
| 117 |
+
if episode.episode_id % 2 == int(agent_id[-1:])
|
| 118 |
+
else np.random.choice(list(self.opponent_policies))
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if self.policy_to_remove is not None:
|
| 122 |
+
print("Remove ", self.policy_to_remove, "from opponent policies")
|
| 123 |
+
algorithm.remove_policy(
|
| 124 |
+
self.policy_to_remove,
|
| 125 |
+
policy_mapping_fn=policy_mapping_fn,
|
| 126 |
+
)
|
| 127 |
+
self.policy_to_remove = None
|
| 128 |
+
algorithm.workers.sync_weights()
|
| 129 |
|
| 130 |
return SelfPlayCallback
|
connectfour/training/dummy_policies.py
CHANGED
|
@@ -23,6 +23,10 @@ class HeuristicBase(Policy):
|
|
| 23 |
"""No weights to set."""
|
| 24 |
pass
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
@override(Policy)
|
| 27 |
def compute_actions(
|
| 28 |
self,
|
|
|
|
| 23 |
"""No weights to set."""
|
| 24 |
pass
|
| 25 |
|
| 26 |
+
@override(Policy)
|
| 27 |
+
def export_model(self, export_dir: str, onnx=None) -> None:
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
@override(Policy)
|
| 31 |
def compute_actions(
|
| 32 |
self,
|
connectfour/training/train.py
CHANGED
|
@@ -8,6 +8,7 @@ from ray import air, tune
|
|
| 8 |
from ray.rllib.policy.policy import PolicySpec
|
| 9 |
from ray.rllib.utils.framework import try_import_torch
|
| 10 |
from ray.tune import CLIReporter, register_env
|
|
|
|
| 11 |
|
| 12 |
from connectfour.training.callbacks import create_self_play_callback
|
| 13 |
from connectfour.training.dummy_policies import (
|
|
@@ -29,12 +30,21 @@ def get_cli_args():
|
|
| 29 |
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
|
| 30 |
python connectfour/training/train.py --num-gpus 1 --stop-iters 1 --win-rate-threshold 0.50
|
| 31 |
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
|
|
|
|
|
|
|
|
|
|
| 32 |
"""
|
| 33 |
parser = argparse.ArgumentParser()
|
| 34 |
parser.add_argument("--num-cpus", type=int, default=0)
|
| 35 |
parser.add_argument("--num-gpus", type=int, default=0)
|
| 36 |
parser.add_argument("--num-workers", type=int, default=2)
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
parser.add_argument(
|
| 39 |
"--stop-iters", type=int, default=200, help="Number of iterations to train."
|
| 40 |
)
|
|
@@ -57,13 +67,6 @@ def get_cli_args():
|
|
| 57 |
return args
|
| 58 |
|
| 59 |
|
| 60 |
-
def select_policy(agent_id, episode, **kwargs):
|
| 61 |
-
if episode.episode_id % 2 == int(agent_id[-1:]):
|
| 62 |
-
return "learned"
|
| 63 |
-
else:
|
| 64 |
-
return random.choice(["always_same", "beat_last", "random", "linear"])
|
| 65 |
-
|
| 66 |
-
|
| 67 |
if __name__ == "__main__":
|
| 68 |
args = get_cli_args()
|
| 69 |
|
|
@@ -80,20 +83,23 @@ if __name__ == "__main__":
|
|
| 80 |
# register that way to make the environment under an rllib name
|
| 81 |
register_env("connect4", lambda config: Connect4Env(env_creator(config)))
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
config = (
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
)
|
| 93 |
-
|
| 94 |
-
.rollouts(
|
| 95 |
-
num_rollout_workers=args.num_workers,
|
| 96 |
-
num_envs_per_worker=5,
|
| 97 |
)
|
| 98 |
.multi_agent(
|
| 99 |
policies={
|
|
@@ -106,19 +112,88 @@ if __name__ == "__main__":
|
|
| 106 |
policy_mapping_fn=select_policy,
|
| 107 |
policies_to_train=["learned"],
|
| 108 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
)
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
progress_reporter=CLIReporter(
|
| 123 |
metric_columns={
|
| 124 |
"training_iteration": "iter",
|
|
@@ -128,21 +203,18 @@ if __name__ == "__main__":
|
|
| 128 |
"policy_reward_mean/learned": "reward",
|
| 129 |
"win_rate": "win_rate",
|
| 130 |
"league_size": "league_size",
|
| 131 |
-
}
|
| 132 |
-
mode="max",
|
| 133 |
-
metric="win_rate",
|
| 134 |
-
sort_by_metric=True,
|
| 135 |
-
),
|
| 136 |
-
checkpoint_config=air.CheckpointConfig(
|
| 137 |
-
checkpoint_at_end=True,
|
| 138 |
-
checkpoint_frequency=10,
|
| 139 |
),
|
| 140 |
-
)
|
| 141 |
-
).fit()
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
ray.shutdown()
|
|
|
|
| 8 |
from ray.rllib.policy.policy import PolicySpec
|
| 9 |
from ray.rllib.utils.framework import try_import_torch
|
| 10 |
from ray.tune import CLIReporter, register_env
|
| 11 |
+
from ray.rllib.algorithms.algorithm import Algorithm
|
| 12 |
|
| 13 |
from connectfour.training.callbacks import create_self_play_callback
|
| 14 |
from connectfour.training.dummy_policies import (
|
|
|
|
| 30 |
python connectfour/training/train.py --num-cpus 4 --num-gpus 1 --stop-iters 10 --win-rate-threshold 0.50
|
| 31 |
python connectfour/training/train.py --num-gpus 1 --stop-iters 1 --win-rate-threshold 0.50
|
| 32 |
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --stop-iters 200
|
| 33 |
+
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --win-rate-threshold 0.95 --stop-iters 2000 > training.log 2>&1
|
| 34 |
+
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --win-rate-threshold 0.96 --stop-iters 10000 > training.log 2>&1
|
| 35 |
+
python connectfour/training/train.py --num-cpus 5 --num-gpus 1 --win-rate-threshold 0.99 --stop-iters 5000 --from-checkpoint ~/ray_results/PPO/PPO_connect4_8414a_00000_0_2023-04-03_12-44-31/checkpoint_004000 > training.log 2>&1
|
| 36 |
"""
|
| 37 |
parser = argparse.ArgumentParser()
|
| 38 |
parser.add_argument("--num-cpus", type=int, default=0)
|
| 39 |
parser.add_argument("--num-gpus", type=int, default=0)
|
| 40 |
parser.add_argument("--num-workers", type=int, default=2)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--from-checkpoint",
|
| 43 |
+
type=str,
|
| 44 |
+
default=None,
|
| 45 |
+
help="Full path to a experiment directory to resume tuning from "
|
| 46 |
+
"a previously saved Algorithm state.",
|
| 47 |
+
)
|
| 48 |
parser.add_argument(
|
| 49 |
"--stop-iters", type=int, default=200, help="Number of iterations to train."
|
| 50 |
)
|
|
|
|
| 67 |
return args
|
| 68 |
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
if __name__ == "__main__":
|
| 71 |
args = get_cli_args()
|
| 72 |
|
|
|
|
| 83 |
# register that way to make the environment under an rllib name
|
| 84 |
register_env("connect4", lambda config: Connect4Env(env_creator(config)))
|
| 85 |
|
| 86 |
+
def select_policy(agent_id, episode, **kwargs):
|
| 87 |
+
if episode.episode_id % 2 == int(agent_id[-1:]):
|
| 88 |
+
return "learned"
|
| 89 |
+
else:
|
| 90 |
+
return random.choice(["always_same", "beat_last", "random", "linear"])
|
| 91 |
+
|
| 92 |
config = (
|
| 93 |
+
(
|
| 94 |
+
ppo.PPOConfig()
|
| 95 |
+
.environment("connect4")
|
| 96 |
+
.framework("torch")
|
| 97 |
+
.training(model={"custom_model": Connect4MaskModel})
|
| 98 |
+
.rollouts(
|
| 99 |
+
num_rollout_workers=args.num_workers,
|
| 100 |
+
num_envs_per_worker=5,
|
| 101 |
)
|
| 102 |
+
# .checkpointing(checkpoint_trainable_policies_only=True)
|
|
|
|
|
|
|
|
|
|
| 103 |
)
|
| 104 |
.multi_agent(
|
| 105 |
policies={
|
|
|
|
| 112 |
policy_mapping_fn=select_policy,
|
| 113 |
policies_to_train=["learned"],
|
| 114 |
)
|
| 115 |
+
.callbacks(
|
| 116 |
+
create_self_play_callback(
|
| 117 |
+
win_rate_thr=args.win_rate_threshold,
|
| 118 |
+
opponent_policies=["always_same", "beat_last", "random", "linear"],
|
| 119 |
+
opponent_count=15,
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
.evaluation(evaluation_interval=1)
|
| 123 |
)
|
| 124 |
|
| 125 |
+
if args.from_checkpoint is None:
|
| 126 |
+
stop = {
|
| 127 |
+
"timesteps_total": args.stop_timesteps,
|
| 128 |
+
"training_iteration": args.stop_iters,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
results = tune.Tuner(
|
| 132 |
+
"PPO",
|
| 133 |
+
param_space=config.to_dict(),
|
| 134 |
+
run_config=air.RunConfig(
|
| 135 |
+
stop=stop,
|
| 136 |
+
verbose=2,
|
| 137 |
+
progress_reporter=CLIReporter(
|
| 138 |
+
metric_columns={
|
| 139 |
+
"training_iteration": "iter",
|
| 140 |
+
"time_total_s": "time_total_s",
|
| 141 |
+
"timesteps_total": "ts",
|
| 142 |
+
"episodes_this_iter": "train_episodes",
|
| 143 |
+
"policy_reward_mean/learned": "reward",
|
| 144 |
+
"win_rate": "win_rate",
|
| 145 |
+
"league_size": "league_size",
|
| 146 |
+
},
|
| 147 |
+
mode="max",
|
| 148 |
+
metric="win_rate",
|
| 149 |
+
sort_by_metric=True,
|
| 150 |
+
),
|
| 151 |
+
checkpoint_config=air.CheckpointConfig(
|
| 152 |
+
num_to_keep=10,
|
| 153 |
+
checkpoint_at_end=True,
|
| 154 |
+
checkpoint_frequency=10,
|
| 155 |
+
checkpoint_score_order="max",
|
| 156 |
+
),
|
| 157 |
+
),
|
| 158 |
+
).fit()
|
| 159 |
+
|
| 160 |
+
best_checkpoint = results.get_best_result(
|
| 161 |
+
metric="win_rate", mode="max"
|
| 162 |
+
).checkpoint
|
| 163 |
+
print("Best checkpoint", best_checkpoint)
|
| 164 |
+
|
| 165 |
+
else:
|
| 166 |
+
algo = Algorithm.from_checkpoint(checkpoint=args.from_checkpoint)
|
| 167 |
+
|
| 168 |
+
config = algo.config.copy(False)
|
| 169 |
+
config.checkpointing(export_native_model_files=True)
|
| 170 |
+
|
| 171 |
+
opponent_policies = list(algo.workers.local_worker().policy_map.keys())
|
| 172 |
+
opponent_policies.remove("learned")
|
| 173 |
+
opponent_policies.sort()
|
| 174 |
+
|
| 175 |
+
config.callbacks(
|
| 176 |
+
create_self_play_callback(
|
| 177 |
+
win_rate_thr=args.win_rate_threshold,
|
| 178 |
+
opponent_policies=opponent_policies,
|
| 179 |
+
opponent_count=len(opponent_policies),
|
| 180 |
+
)
|
| 181 |
+
)
|
| 182 |
+
config.evaluation(evaluation_interval=None)
|
| 183 |
+
|
| 184 |
+
analysis = tune.run(
|
| 185 |
+
"PPO",
|
| 186 |
+
config=config.to_dict(),
|
| 187 |
+
restore=args.from_checkpoint,
|
| 188 |
+
checkpoint_freq=10,
|
| 189 |
+
checkpoint_at_end=True,
|
| 190 |
+
keep_checkpoints_num=10,
|
| 191 |
+
mode="max",
|
| 192 |
+
metric="win_rate",
|
| 193 |
+
stop={
|
| 194 |
+
"win_rate": args.win_rate_threshold,
|
| 195 |
+
"training_iteration": args.stop_iters,
|
| 196 |
+
},
|
| 197 |
progress_reporter=CLIReporter(
|
| 198 |
metric_columns={
|
| 199 |
"training_iteration": "iter",
|
|
|
|
| 203 |
"policy_reward_mean/learned": "reward",
|
| 204 |
"win_rate": "win_rate",
|
| 205 |
"league_size": "league_size",
|
| 206 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
),
|
| 208 |
+
)
|
|
|
|
| 209 |
|
| 210 |
+
algo = Algorithm.from_checkpoint(analysis.best_checkpoint)
|
| 211 |
+
ppo_policy = algo.get_policy("learned")
|
| 212 |
+
|
| 213 |
+
# Save as torch model
|
| 214 |
+
ppo_policy.export_model("models")
|
| 215 |
+
# Save as ONNX model
|
| 216 |
+
ppo_policy.export_model("models", onnx=11)
|
| 217 |
+
|
| 218 |
+
print("Best checkpoint", analysis.best_checkpoint)
|
| 219 |
|
| 220 |
ray.shutdown()
|
connectfour/training/wrappers.py
CHANGED
|
@@ -1,112 +1,23 @@
|
|
| 1 |
from typing import Optional
|
| 2 |
|
| 3 |
-
from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
| 4 |
from ray.rllib.utils.annotations import PublicAPI
|
| 5 |
-
from ray.rllib.
|
| 6 |
|
| 7 |
|
| 8 |
@PublicAPI
|
| 9 |
-
class Connect4Env(
|
| 10 |
-
"""An interface to the PettingZoo MARL environment library
|
| 11 |
-
|
| 12 |
-
See: https://github.com/Farama-Foundation/PettingZoo
|
| 13 |
-
|
| 14 |
-
Inherits from MultiAgentEnv and exposes a given AEC
|
| 15 |
-
(actor-environment-cycle) game from the PettingZoo project via the
|
| 16 |
-
MultiAgentEnv public API.
|
| 17 |
-
|
| 18 |
-
Note that the wrapper has some important limitations:
|
| 19 |
-
|
| 20 |
-
1. All agents have the same action_spaces and observation_spaces.
|
| 21 |
-
Note: If, within your aec game, agents do not have homogeneous action /
|
| 22 |
-
observation spaces, apply SuperSuit wrappers
|
| 23 |
-
to apply padding functionality: https://github.com/Farama-Foundation/
|
| 24 |
-
SuperSuit#built-in-multi-agent-only-functions
|
| 25 |
-
2. Environments are positive sum games (-> Agents are expected to cooperate
|
| 26 |
-
to maximize reward). This isn't a hard restriction, it just that
|
| 27 |
-
standard algorithms aren't expected to work well in highly competitive
|
| 28 |
-
games."""
|
| 29 |
-
|
| 30 |
-
def __init__(self, env):
|
| 31 |
-
super().__init__()
|
| 32 |
-
self.env = env
|
| 33 |
-
env.reset()
|
| 34 |
-
|
| 35 |
-
# Since all agents have the same spaces, do not provide full observation-
|
| 36 |
-
# and action-spaces as Dicts, mapping agent IDs to the individual
|
| 37 |
-
# agents' spaces. Instead, `self.[action|observation]_space` are the single
|
| 38 |
-
# agent spaces.
|
| 39 |
-
self._obs_space_in_preferred_format = False
|
| 40 |
-
self._action_space_in_preferred_format = False
|
| 41 |
-
|
| 42 |
-
# Collect the individual agents' spaces (they should all be the same):
|
| 43 |
-
first_obs_space = self.env.observation_space(self.env.agents[0])
|
| 44 |
-
first_action_space = self.env.action_space(self.env.agents[0])
|
| 45 |
-
|
| 46 |
-
for agent in self.env.agents:
|
| 47 |
-
if self.env.observation_space(agent) != first_obs_space:
|
| 48 |
-
raise ValueError(
|
| 49 |
-
"Observation spaces for all agents must be identical. Perhaps "
|
| 50 |
-
"SuperSuit's pad_observations wrapper can help (useage: "
|
| 51 |
-
"`supersuit.aec_wrappers.pad_observations(env)`"
|
| 52 |
-
)
|
| 53 |
-
if self.env.action_space(agent) != first_action_space:
|
| 54 |
-
raise ValueError(
|
| 55 |
-
"Action spaces for all agents must be identical. Perhaps "
|
| 56 |
-
"SuperSuit's pad_action_space wrapper can help (usage: "
|
| 57 |
-
"`supersuit.aec_wrappers.pad_action_space(env)`"
|
| 58 |
-
)
|
| 59 |
-
|
| 60 |
-
# Convert from gym to gymnasium, if necessary.
|
| 61 |
-
self.observation_space = convert_old_gym_space_to_gymnasium_space(
|
| 62 |
-
first_obs_space
|
| 63 |
-
)
|
| 64 |
-
self.action_space = convert_old_gym_space_to_gymnasium_space(first_action_space)
|
| 65 |
-
|
| 66 |
-
self._agent_ids = set(self.env.agents)
|
| 67 |
|
| 68 |
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
|
|
|
|
|
|
| 69 |
info = self.env.reset(seed=seed, options=options)
|
| 70 |
return (
|
| 71 |
{self.env.agent_selection: self.env.observe(self.env.agent_selection)},
|
| 72 |
info or {},
|
| 73 |
)
|
| 74 |
|
| 75 |
-
def step(self, action):
|
| 76 |
-
self.env.step(action[self.env.agent_selection])
|
| 77 |
-
obs_d = {}
|
| 78 |
-
rew_d = {}
|
| 79 |
-
terminated_d = {}
|
| 80 |
-
truncated_d = {}
|
| 81 |
-
info_d = {}
|
| 82 |
-
while self.env.agents:
|
| 83 |
-
obs, rew, terminated, truncated, info = self.env.last()
|
| 84 |
-
agent_id = self.env.agent_selection
|
| 85 |
-
obs_d[agent_id] = obs
|
| 86 |
-
rew_d[agent_id] = rew
|
| 87 |
-
terminated_d[agent_id] = terminated
|
| 88 |
-
truncated_d[agent_id] = truncated
|
| 89 |
-
info_d[agent_id] = info
|
| 90 |
-
if (
|
| 91 |
-
self.env.terminations[self.env.agent_selection]
|
| 92 |
-
or self.env.truncations[self.env.agent_selection]
|
| 93 |
-
):
|
| 94 |
-
self.env.step(None)
|
| 95 |
-
else:
|
| 96 |
-
break
|
| 97 |
-
|
| 98 |
-
all_gone = not self.env.agents
|
| 99 |
-
terminated_d["__all__"] = all_gone and all(terminated_d.values())
|
| 100 |
-
truncated_d["__all__"] = all_gone and all(truncated_d.values())
|
| 101 |
-
|
| 102 |
-
return obs_d, rew_d, terminated_d, truncated_d, info_d
|
| 103 |
-
|
| 104 |
-
def close(self):
|
| 105 |
-
self.env.close()
|
| 106 |
-
|
| 107 |
def render(self):
|
|
|
|
|
|
|
| 108 |
return self.env.render()
|
| 109 |
-
|
| 110 |
-
@property
|
| 111 |
-
def get_sub_environments(self):
|
| 112 |
-
return self.env.unwrapped
|
|
|
|
| 1 |
from typing import Optional
|
| 2 |
|
|
|
|
| 3 |
from ray.rllib.utils.annotations import PublicAPI
|
| 4 |
+
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
|
| 5 |
|
| 6 |
|
| 7 |
@PublicAPI
|
| 8 |
+
class Connect4Env(PettingZooEnv):
|
| 9 |
+
"""An interface to the PettingZoo MARL environment library"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
| 12 |
+
# In base class =>
|
| 13 |
+
# info = self.env.reset(seed=seed, return_info=True, options=options)
|
| 14 |
info = self.env.reset(seed=seed, options=options)
|
| 15 |
return (
|
| 16 |
{self.env.agent_selection: self.env.observe(self.env.agent_selection)},
|
| 17 |
info or {},
|
| 18 |
)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def render(self):
|
| 21 |
+
# In base class =>
|
| 22 |
+
# return self.env.render(self.render_mode)
|
| 23 |
return self.env.render()
|
|
|
|
|
|
|
|
|
|
|
|
models/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
MODEL_PATH = Path(__file__).parent.absolute()
|
connectfour/checkpoint/algorithm_state.pkl → models/model.onnx
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae8bf3e7080eca6ba6c8e68dbceffa03ca8d05f65249c474544a66e039352d2a
|
| 3 |
+
size 361882
|
poetry.lock
CHANGED
|
@@ -290,6 +290,18 @@ d = ["aiohttp (>=3.7.4)"]
|
|
| 290 |
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
| 291 |
uvloop = ["uvloop (>=0.15.2)"]
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
[[package]]
|
| 294 |
name = "certifi"
|
| 295 |
version = "2022.12.7"
|
|
@@ -456,6 +468,24 @@ files = [
|
|
| 456 |
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
| 457 |
]
|
| 458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
[[package]]
|
| 460 |
name = "contourpy"
|
| 461 |
version = "1.0.7"
|
|
@@ -692,6 +722,18 @@ files = [
|
|
| 692 |
docs = ["furo (>=2022.12.7)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
|
| 693 |
testing = ["covdefaults (>=2.3)", "coverage (>=7.2.2)", "diff-cover (>=7.5)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"]
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
[[package]]
|
| 696 |
name = "fonttools"
|
| 697 |
version = "4.39.3"
|
|
@@ -848,16 +890,60 @@ files = [
|
|
| 848 |
{file = "gast-0.5.3.tar.gz", hash = "sha256:cfbea25820e653af9c7d1807f659ce0a0a9c64f2439421a7bba4f0983f532dea"},
|
| 849 |
]
|
| 850 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 851 |
[[package]]
|
| 852 |
name = "gradio"
|
| 853 |
-
version = "3.
|
| 854 |
description = "Python library for easily interacting with trained machine learning models"
|
| 855 |
category = "main"
|
| 856 |
optional = false
|
| 857 |
python-versions = ">=3.7"
|
| 858 |
files = [
|
| 859 |
-
{file = "gradio-3.
|
| 860 |
-
{file = "gradio-3.
|
| 861 |
]
|
| 862 |
|
| 863 |
[package.dependencies]
|
|
@@ -866,7 +952,7 @@ aiohttp = "*"
|
|
| 866 |
altair = ">=4.2.0"
|
| 867 |
fastapi = "*"
|
| 868 |
ffmpy = "*"
|
| 869 |
-
|
| 870 |
httpx = "*"
|
| 871 |
huggingface-hub = ">=0.13.0"
|
| 872 |
jinja2 = "*"
|
|
@@ -888,6 +974,25 @@ typing-extensions = "*"
|
|
| 888 |
uvicorn = "*"
|
| 889 |
websockets = ">=10.0"
|
| 890 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 891 |
[[package]]
|
| 892 |
name = "grpcio"
|
| 893 |
version = "1.49.1"
|
|
@@ -1138,6 +1243,21 @@ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "jedi", "pytest", "pytest
|
|
| 1138 |
torch = ["torch"]
|
| 1139 |
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
| 1140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1141 |
[[package]]
|
| 1142 |
name = "idna"
|
| 1143 |
version = "3.4"
|
|
@@ -1531,6 +1651,24 @@ docs = ["sphinx (>=1.6.0)", "sphinx-bootstrap-theme"]
|
|
| 1531 |
flake8 = ["flake8"]
|
| 1532 |
tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"]
|
| 1533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1534 |
[[package]]
|
| 1535 |
name = "markdown-it-py"
|
| 1536 |
version = "2.2.0"
|
|
@@ -2184,6 +2322,111 @@ files = [
|
|
| 2184 |
setuptools = "*"
|
| 2185 |
wheel = "*"
|
| 2186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2187 |
[[package]]
|
| 2188 |
name = "orjson"
|
| 2189 |
version = "3.8.8"
|
|
@@ -2471,39 +2714,65 @@ testing = ["pytest", "pytest-benchmark"]
|
|
| 2471 |
|
| 2472 |
[[package]]
|
| 2473 |
name = "protobuf"
|
| 2474 |
-
version = "3.
|
| 2475 |
description = "Protocol Buffers"
|
| 2476 |
category = "main"
|
| 2477 |
optional = false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2478 |
python-versions = "*"
|
| 2479 |
files = [
|
| 2480 |
-
{file = "
|
| 2481 |
-
{file = "
|
| 2482 |
-
|
| 2483 |
-
|
| 2484 |
-
|
| 2485 |
-
|
| 2486 |
-
{file = "protobuf-3.17.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:850f429bd2399525d339d05bc809f090f16d3d88737bed637d355a5ee8d3b81a"},
|
| 2487 |
-
{file = "protobuf-3.17.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:809a96d5a1a74538728710f9104f43ae77f5e48bde274ee321b10a324ba52e4f"},
|
| 2488 |
-
{file = "protobuf-3.17.0-cp36-cp36m-win32.whl", hash = "sha256:8a3ac375539055164f31a330770f137875307e6f04c21e2647f2e7139c501295"},
|
| 2489 |
-
{file = "protobuf-3.17.0-cp36-cp36m-win_amd64.whl", hash = "sha256:3d338910b10b88b18581cf6877b3938b2e262e8fdc2c1057f5a291787de63183"},
|
| 2490 |
-
{file = "protobuf-3.17.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:1488f786bd1912f97796cf5def8cacf433735616896cf7ed9dc786cee693dfc8"},
|
| 2491 |
-
{file = "protobuf-3.17.0-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:bcaff977db178f0bfde10bab0d23a5f5adf5964adba70c315e45922a1c55eb90"},
|
| 2492 |
-
{file = "protobuf-3.17.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:939ce06846ddfec99c0bff510510b3ee45778e7a3aec6544d1f36526e5fecb67"},
|
| 2493 |
-
{file = "protobuf-3.17.0-cp37-cp37m-win32.whl", hash = "sha256:3237acce5b666c7b0f45785cc2d0809796d4df3593bd68338aebf25408139188"},
|
| 2494 |
-
{file = "protobuf-3.17.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2f77afe33bb86c7d34221a86193256d69aa10818620fe4a7513d98211d67d672"},
|
| 2495 |
-
{file = "protobuf-3.17.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:acc9f2091ace3de429eee424ab7ba0bc52a6aa9ffc9909e5c4de259a3f71db46"},
|
| 2496 |
-
{file = "protobuf-3.17.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a29631f4f8bcf79b12a59e83d238d888de5034871461d788c74c68218ad75049"},
|
| 2497 |
-
{file = "protobuf-3.17.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:05c304396e309661c45e3a97bd2d8da1fc2bab743ed2ca880bcb757271c40c0e"},
|
| 2498 |
-
{file = "protobuf-3.17.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:baea44967071e6a51e705e4e88aebf35f530a14004cc69f60a185e5d7e13de7e"},
|
| 2499 |
-
{file = "protobuf-3.17.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:3b5c461af5a3cebd796c73370db929b7e24cbaba655eefdc044226bc8a843d6b"},
|
| 2500 |
-
{file = "protobuf-3.17.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:44399393c3a8cc04a4cfbdc721dd7f2114497efda582e946a91b8c4290ae5ff5"},
|
| 2501 |
-
{file = "protobuf-3.17.0-py2.py3-none-any.whl", hash = "sha256:e32ef0c9f4b548c80d94dfff8b4130ca2ff3d50caaf2455889e3f5b8a01e8038"},
|
| 2502 |
-
{file = "protobuf-3.17.0.tar.gz", hash = "sha256:05dfe9319939a8473c21b469f34f6486646e54fb8542637cf7ed8e2fbfe21538"},
|
| 2503 |
-
]
|
| 2504 |
-
|
| 2505 |
-
[package.dependencies]
|
| 2506 |
-
six = ">=1.9"
|
| 2507 |
|
| 2508 |
[[package]]
|
| 2509 |
name = "pydantic"
|
|
@@ -2705,6 +2974,18 @@ files = [
|
|
| 2705 |
[package.extras]
|
| 2706 |
diagrams = ["jinja2", "railroad-diagrams"]
|
| 2707 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2708 |
[[package]]
|
| 2709 |
name = "pyrsistent"
|
| 2710 |
version = "0.19.3"
|
|
@@ -2997,6 +3278,25 @@ urllib3 = ">=1.21.1,<1.27"
|
|
| 2997 |
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
| 2998 |
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
| 2999 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3000 |
[[package]]
|
| 3001 |
name = "rfc3986"
|
| 3002 |
version = "1.5.0"
|
|
@@ -3035,6 +3335,21 @@ typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9
|
|
| 3035 |
[package.extras]
|
| 3036 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
| 3037 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3038 |
[[package]]
|
| 3039 |
name = "scikit-image"
|
| 3040 |
version = "0.20.0"
|
|
@@ -3231,6 +3546,56 @@ files = [
|
|
| 3231 |
[package.extras]
|
| 3232 |
widechars = ["wcwidth"]
|
| 3233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3234 |
[[package]]
|
| 3235 |
name = "tensorboardx"
|
| 3236 |
version = "2.6"
|
|
@@ -3625,6 +3990,24 @@ files = [
|
|
| 3625 |
{file = "websockets-10.4.tar.gz", hash = "sha256:eef610b23933c54d5d921c92578ae5f89813438fded840c2e9809d378dc765d3"},
|
| 3626 |
]
|
| 3627 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3628 |
[[package]]
|
| 3629 |
name = "wheel"
|
| 3630 |
version = "0.40.0"
|
|
@@ -3832,4 +4215,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
|
| 3832 |
[metadata]
|
| 3833 |
lock-version = "2.0"
|
| 3834 |
python-versions = ">=3.8,<3.11"
|
| 3835 |
-
content-hash = "
|
|
|
|
| 290 |
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
| 291 |
uvloop = ["uvloop (>=0.15.2)"]
|
| 292 |
|
| 293 |
+
[[package]]
|
| 294 |
+
name = "cachetools"
|
| 295 |
+
version = "5.3.0"
|
| 296 |
+
description = "Extensible memoizing collections and decorators"
|
| 297 |
+
category = "main"
|
| 298 |
+
optional = false
|
| 299 |
+
python-versions = "~=3.7"
|
| 300 |
+
files = [
|
| 301 |
+
{file = "cachetools-5.3.0-py3-none-any.whl", hash = "sha256:429e1a1e845c008ea6c85aa35d4b98b65d6a9763eeef3e37e92728a12d1de9d4"},
|
| 302 |
+
{file = "cachetools-5.3.0.tar.gz", hash = "sha256:13dfddc7b8df938c21a940dfa6557ce6e94a2f1cdfa58eb90c805721d58f2c14"},
|
| 303 |
+
]
|
| 304 |
+
|
| 305 |
[[package]]
|
| 306 |
name = "certifi"
|
| 307 |
version = "2022.12.7"
|
|
|
|
| 468 |
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
| 469 |
]
|
| 470 |
|
| 471 |
+
[[package]]
|
| 472 |
+
name = "coloredlogs"
|
| 473 |
+
version = "15.0.1"
|
| 474 |
+
description = "Colored terminal output for Python's logging module"
|
| 475 |
+
category = "main"
|
| 476 |
+
optional = false
|
| 477 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
| 478 |
+
files = [
|
| 479 |
+
{file = "coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934"},
|
| 480 |
+
{file = "coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0"},
|
| 481 |
+
]
|
| 482 |
+
|
| 483 |
+
[package.dependencies]
|
| 484 |
+
humanfriendly = ">=9.1"
|
| 485 |
+
|
| 486 |
+
[package.extras]
|
| 487 |
+
cron = ["capturer (>=2.4)"]
|
| 488 |
+
|
| 489 |
[[package]]
|
| 490 |
name = "contourpy"
|
| 491 |
version = "1.0.7"
|
|
|
|
| 722 |
docs = ["furo (>=2022.12.7)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"]
|
| 723 |
testing = ["covdefaults (>=2.3)", "coverage (>=7.2.2)", "diff-cover (>=7.5)", "pytest (>=7.2.2)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"]
|
| 724 |
|
| 725 |
+
[[package]]
|
| 726 |
+
name = "flatbuffers"
|
| 727 |
+
version = "23.3.3"
|
| 728 |
+
description = "The FlatBuffers serialization format for Python"
|
| 729 |
+
category = "main"
|
| 730 |
+
optional = false
|
| 731 |
+
python-versions = "*"
|
| 732 |
+
files = [
|
| 733 |
+
{file = "flatbuffers-23.3.3-py2.py3-none-any.whl", hash = "sha256:5ad36d376240090757e8f0a2cfaf6abcc81c6536c0dc988060375fd0899121f8"},
|
| 734 |
+
{file = "flatbuffers-23.3.3.tar.gz", hash = "sha256:cabd87c4882f37840f6081f094b2c5bc28cefc2a6357732746936d055ab45c3d"},
|
| 735 |
+
]
|
| 736 |
+
|
| 737 |
[[package]]
|
| 738 |
name = "fonttools"
|
| 739 |
version = "4.39.3"
|
|
|
|
| 890 |
{file = "gast-0.5.3.tar.gz", hash = "sha256:cfbea25820e653af9c7d1807f659ce0a0a9c64f2439421a7bba4f0983f532dea"},
|
| 891 |
]
|
| 892 |
|
| 893 |
+
[[package]]
|
| 894 |
+
name = "google-auth"
|
| 895 |
+
version = "2.17.1"
|
| 896 |
+
description = "Google Authentication Library"
|
| 897 |
+
category = "main"
|
| 898 |
+
optional = false
|
| 899 |
+
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
|
| 900 |
+
files = [
|
| 901 |
+
{file = "google-auth-2.17.1.tar.gz", hash = "sha256:8f379b46bad381ad2a0b989dfb0c13ad28d3c2a79f27348213f8946a1d15d55a"},
|
| 902 |
+
{file = "google_auth-2.17.1-py2.py3-none-any.whl", hash = "sha256:357ff22a75b4c0f6093470f21816a825d2adee398177569824e37b6c10069e19"},
|
| 903 |
+
]
|
| 904 |
+
|
| 905 |
+
[package.dependencies]
|
| 906 |
+
cachetools = ">=2.0.0,<6.0"
|
| 907 |
+
pyasn1-modules = ">=0.2.1"
|
| 908 |
+
rsa = {version = ">=3.1.4,<5", markers = "python_version >= \"3.6\""}
|
| 909 |
+
six = ">=1.9.0"
|
| 910 |
+
|
| 911 |
+
[package.extras]
|
| 912 |
+
aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)", "requests (>=2.20.0,<3.0.0dev)"]
|
| 913 |
+
enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
|
| 914 |
+
pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"]
|
| 915 |
+
reauth = ["pyu2f (>=0.1.5)"]
|
| 916 |
+
requests = ["requests (>=2.20.0,<3.0.0dev)"]
|
| 917 |
+
|
| 918 |
+
[[package]]
|
| 919 |
+
name = "google-auth-oauthlib"
|
| 920 |
+
version = "0.4.6"
|
| 921 |
+
description = "Google Authentication Library"
|
| 922 |
+
category = "main"
|
| 923 |
+
optional = false
|
| 924 |
+
python-versions = ">=3.6"
|
| 925 |
+
files = [
|
| 926 |
+
{file = "google-auth-oauthlib-0.4.6.tar.gz", hash = "sha256:a90a072f6993f2c327067bf65270046384cda5a8ecb20b94ea9a687f1f233a7a"},
|
| 927 |
+
{file = "google_auth_oauthlib-0.4.6-py2.py3-none-any.whl", hash = "sha256:3f2a6e802eebbb6fb736a370fbf3b055edcb6b52878bf2f26330b5e041316c73"},
|
| 928 |
+
]
|
| 929 |
+
|
| 930 |
+
[package.dependencies]
|
| 931 |
+
google-auth = ">=1.0.0"
|
| 932 |
+
requests-oauthlib = ">=0.7.0"
|
| 933 |
+
|
| 934 |
+
[package.extras]
|
| 935 |
+
tool = ["click (>=6.0.0)"]
|
| 936 |
+
|
| 937 |
[[package]]
|
| 938 |
name = "gradio"
|
| 939 |
+
version = "3.24.0"
|
| 940 |
description = "Python library for easily interacting with trained machine learning models"
|
| 941 |
category = "main"
|
| 942 |
optional = false
|
| 943 |
python-versions = ">=3.7"
|
| 944 |
files = [
|
| 945 |
+
{file = "gradio-3.24.0-py3-none-any.whl", hash = "sha256:cedd67f7cbd17764b3613fb4df274a7c450c74e31a2e3229097d43cb4ffa50c7"},
|
| 946 |
+
{file = "gradio-3.24.0.tar.gz", hash = "sha256:4ac2bf531b3c0ff5ec9e93959f2d1dbc49eac1767bafa2d80f8950a3bc40c4ed"},
|
| 947 |
]
|
| 948 |
|
| 949 |
[package.dependencies]
|
|
|
|
| 952 |
altair = ">=4.2.0"
|
| 953 |
fastapi = "*"
|
| 954 |
ffmpy = "*"
|
| 955 |
+
gradio-client = ">=0.0.5"
|
| 956 |
httpx = "*"
|
| 957 |
huggingface-hub = ">=0.13.0"
|
| 958 |
jinja2 = "*"
|
|
|
|
| 974 |
uvicorn = "*"
|
| 975 |
websockets = ">=10.0"
|
| 976 |
|
| 977 |
+
[[package]]
|
| 978 |
+
name = "gradio-client"
|
| 979 |
+
version = "0.0.5"
|
| 980 |
+
description = "Python library for easily interacting with trained machine learning models"
|
| 981 |
+
category = "main"
|
| 982 |
+
optional = false
|
| 983 |
+
python-versions = ">=3.7"
|
| 984 |
+
files = [
|
| 985 |
+
{file = "gradio_client-0.0.5-py3-none-any.whl", hash = "sha256:ca4167ebae72d920ebec2be47010cf60e31e0296ad9baac771befb17b87f0eef"},
|
| 986 |
+
{file = "gradio_client-0.0.5.tar.gz", hash = "sha256:dc6479a119314aac0bbf6821da6e946df17f048cc571559379a89590618f7b5d"},
|
| 987 |
+
]
|
| 988 |
+
|
| 989 |
+
[package.dependencies]
|
| 990 |
+
fsspec = "*"
|
| 991 |
+
huggingface-hub = ">=0.13.0"
|
| 992 |
+
packaging = "*"
|
| 993 |
+
requests = "*"
|
| 994 |
+
websockets = "*"
|
| 995 |
+
|
| 996 |
[[package]]
|
| 997 |
name = "grpcio"
|
| 998 |
version = "1.49.1"
|
|
|
|
| 1243 |
torch = ["torch"]
|
| 1244 |
typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
| 1245 |
|
| 1246 |
+
[[package]]
|
| 1247 |
+
name = "humanfriendly"
|
| 1248 |
+
version = "10.0"
|
| 1249 |
+
description = "Human friendly output for text interfaces using Python"
|
| 1250 |
+
category = "main"
|
| 1251 |
+
optional = false
|
| 1252 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
| 1253 |
+
files = [
|
| 1254 |
+
{file = "humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477"},
|
| 1255 |
+
{file = "humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc"},
|
| 1256 |
+
]
|
| 1257 |
+
|
| 1258 |
+
[package.dependencies]
|
| 1259 |
+
pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""}
|
| 1260 |
+
|
| 1261 |
[[package]]
|
| 1262 |
name = "idna"
|
| 1263 |
version = "3.4"
|
|
|
|
| 1651 |
flake8 = ["flake8"]
|
| 1652 |
tests = ["psutil", "pytest (!=3.3.0)", "pytest-cov"]
|
| 1653 |
|
| 1654 |
+
[[package]]
|
| 1655 |
+
name = "markdown"
|
| 1656 |
+
version = "3.4.3"
|
| 1657 |
+
description = "Python implementation of John Gruber's Markdown."
|
| 1658 |
+
category = "main"
|
| 1659 |
+
optional = false
|
| 1660 |
+
python-versions = ">=3.7"
|
| 1661 |
+
files = [
|
| 1662 |
+
{file = "Markdown-3.4.3-py3-none-any.whl", hash = "sha256:065fd4df22da73a625f14890dd77eb8040edcbd68794bcd35943be14490608b2"},
|
| 1663 |
+
{file = "Markdown-3.4.3.tar.gz", hash = "sha256:8bf101198e004dc93e84a12a7395e31aac6a9c9942848ae1d99b9d72cf9b3520"},
|
| 1664 |
+
]
|
| 1665 |
+
|
| 1666 |
+
[package.dependencies]
|
| 1667 |
+
importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""}
|
| 1668 |
+
|
| 1669 |
+
[package.extras]
|
| 1670 |
+
testing = ["coverage", "pyyaml"]
|
| 1671 |
+
|
| 1672 |
[[package]]
|
| 1673 |
name = "markdown-it-py"
|
| 1674 |
version = "2.2.0"
|
|
|
|
| 2322 |
setuptools = "*"
|
| 2323 |
wheel = "*"
|
| 2324 |
|
| 2325 |
+
[[package]]
|
| 2326 |
+
name = "oauthlib"
|
| 2327 |
+
version = "3.2.2"
|
| 2328 |
+
description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
|
| 2329 |
+
category = "main"
|
| 2330 |
+
optional = false
|
| 2331 |
+
python-versions = ">=3.6"
|
| 2332 |
+
files = [
|
| 2333 |
+
{file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"},
|
| 2334 |
+
{file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"},
|
| 2335 |
+
]
|
| 2336 |
+
|
| 2337 |
+
[package.extras]
|
| 2338 |
+
rsa = ["cryptography (>=3.0.0)"]
|
| 2339 |
+
signals = ["blinker (>=1.4.0)"]
|
| 2340 |
+
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
| 2341 |
+
|
| 2342 |
+
[[package]]
|
| 2343 |
+
name = "onnx"
|
| 2344 |
+
version = "1.12.0"
|
| 2345 |
+
description = "Open Neural Network Exchange"
|
| 2346 |
+
category = "main"
|
| 2347 |
+
optional = false
|
| 2348 |
+
python-versions = "*"
|
| 2349 |
+
files = [
|
| 2350 |
+
{file = "onnx-1.12.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:bdbd2578424c70836f4d0f9dda16c21868ddb07cc8192f9e8a176908b43d694b"},
|
| 2351 |
+
{file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213e73610173f6b2e99f99a4b0636f80b379c417312079d603806e48ada4ca8b"},
|
| 2352 |
+
{file = "onnx-1.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fd2f4e23078df197bb76a59b9cd8f5a43a6ad2edc035edb3ecfb9042093e05a"},
|
| 2353 |
+
{file = "onnx-1.12.0-cp310-cp310-win32.whl", hash = "sha256:23781594bb8b7ee985de1005b3c601648d5b0568a81e01365c48f91d1f5648e4"},
|
| 2354 |
+
{file = "onnx-1.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:81a3555fd67be2518bf86096299b48fb9154652596219890abfe90bd43a9ec13"},
|
| 2355 |
+
{file = "onnx-1.12.0-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:5578b93dc6c918cec4dee7fb7d9dd3b09d338301ee64ca8b4f28bc217ed42dca"},
|
| 2356 |
+
{file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c11162ffc487167da140f1112f49c4f82d815824f06e58bc3095407699f05863"},
|
| 2357 |
+
{file = "onnx-1.12.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341c7016e23273e9ffa9b6e301eee95b8c37d0f04df7cedbdb169d2c39524c96"},
|
| 2358 |
+
{file = "onnx-1.12.0-cp37-cp37m-win32.whl", hash = "sha256:3c6e6bcffc3f5c1e148df3837dc667fa4c51999788c1b76b0b8fbba607e02da8"},
|
| 2359 |
+
{file = "onnx-1.12.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8a7aa61aea339bd28f310f4af4f52ce6c4b876386228760b16308efd58f95059"},
|
| 2360 |
+
{file = "onnx-1.12.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:56ceb7e094c43882b723cfaa107d85ad673cfdf91faeb28d7dcadacca4f43a07"},
|
| 2361 |
+
{file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b3629e8258db15d4e2c9b7f1be91a3186719dd94661c218c6f5fde3cc7de3d4d"},
|
| 2362 |
+
{file = "onnx-1.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d9a7db54e75529160337232282a4816cc50667dc7dc34be178fd6f6b79d4705"},
|
| 2363 |
+
{file = "onnx-1.12.0-cp38-cp38-win32.whl", hash = "sha256:fea5156a03398fe0e23248042d8651c1eaac5f6637d4dd683b4c1f1320b9f7b4"},
|
| 2364 |
+
{file = "onnx-1.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:f66d2996e65f490a57b3ae952e4e9189b53cc9fe3f75e601d50d4db2dc1b1cd9"},
|
| 2365 |
+
{file = "onnx-1.12.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:c39a7a0352c856f1df30dccf527eb6cb4909052e5eaf6fa2772a637324c526aa"},
|
| 2366 |
+
{file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fab13feb4d94342aae6d357d480f2e47d41b9f4e584367542b21ca6defda9e0a"},
|
| 2367 |
+
{file = "onnx-1.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7a9b3ea02c30efc1d2662337e280266aca491a8e86be0d8a657f874b7cccd1e"},
|
| 2368 |
+
{file = "onnx-1.12.0-cp39-cp39-win32.whl", hash = "sha256:f8800f28c746ab06e51ef8449fd1215621f4ddba91be3ffc264658937d38a2af"},
|
| 2369 |
+
{file = "onnx-1.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:af90427ca04c6b7b8107c2021e1273227a3ef1a7a01f3073039cae7855a59833"},
|
| 2370 |
+
{file = "onnx-1.12.0.tar.gz", hash = "sha256:13b3e77d27523b9dbf4f30dfc9c959455859d5e34e921c44f712d69b8369eff9"},
|
| 2371 |
+
]
|
| 2372 |
+
|
| 2373 |
+
[package.dependencies]
|
| 2374 |
+
numpy = ">=1.16.6"
|
| 2375 |
+
protobuf = ">=3.12.2,<=3.20.1"
|
| 2376 |
+
typing-extensions = ">=3.6.2.1"
|
| 2377 |
+
|
| 2378 |
+
[package.extras]
|
| 2379 |
+
lint = ["clang-format (==13.0.0)", "flake8", "mypy (==0.782)", "types-protobuf (==3.18.4)"]
|
| 2380 |
+
|
| 2381 |
+
[[package]]
|
| 2382 |
+
name = "onnxruntime"
|
| 2383 |
+
version = "1.14.1"
|
| 2384 |
+
description = "ONNX Runtime is a runtime accelerator for Machine Learning models"
|
| 2385 |
+
category = "main"
|
| 2386 |
+
optional = false
|
| 2387 |
+
python-versions = "*"
|
| 2388 |
+
files = [
|
| 2389 |
+
{file = "onnxruntime-1.14.1-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:193ef1ac512e530c6e6e259c26e67212e2cd3f2bfaad6ff935ed3f4281053056"},
|
| 2390 |
+
{file = "onnxruntime-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d2853bbb36cb272d99f6c225e5040eb0ddb37a667fce20d186ecdf0a6fac8af8"},
|
| 2391 |
+
{file = "onnxruntime-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e1b173365c6894616b8207e23cbb891da9638c5373668d6653e4081ef5f04d0"},
|
| 2392 |
+
{file = "onnxruntime-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24bf0401c5f92be7230ac660ff07ba06f7c175e99e225d5d48ff09062a3b76e9"},
|
| 2393 |
+
{file = "onnxruntime-1.14.1-cp310-cp310-manylinux_2_27_aarch64.whl", hash = "sha256:0a2d09260bbdbe1df678e0a237a5f7b1a44fd11a2f52688d8b6a53a9d03a26db"},
|
| 2394 |
+
{file = "onnxruntime-1.14.1-cp310-cp310-manylinux_2_27_x86_64.whl", hash = "sha256:d99d35b9d5c3f46cad1673a39cc753fb57d60784369b59e6f8cd3dfb77df1885"},
|
| 2395 |
+
{file = "onnxruntime-1.14.1-cp310-cp310-win32.whl", hash = "sha256:f400356df1b27d9adc5513319e8a89753e48ef0d6c5084caf5db8e132f46e7e8"},
|
| 2396 |
+
{file = "onnxruntime-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:96a4059dbab162fe5cdb6750f8c70b2106ef2de5d49a7f72085171937d0e36d3"},
|
| 2397 |
+
{file = "onnxruntime-1.14.1-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:fa23df6a349218636290f9fe56d7baaceb1a50cf92255234d495198b47d92327"},
|
| 2398 |
+
{file = "onnxruntime-1.14.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bc70e44d9e123d126648da24ffb39e56464272a1660a3eb91f4f5b74263be3ba"},
|
| 2399 |
+
{file = "onnxruntime-1.14.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:deff8138045a3affb6be064b598e3ec69a88e4d445359c50464ee5379b8eaf19"},
|
| 2400 |
+
{file = "onnxruntime-1.14.1-cp37-cp37m-manylinux_2_27_aarch64.whl", hash = "sha256:7c02acdc1107cbf698dcbf6dadc6f5b6aa179e7fa9a026251e99cf8613bd3129"},
|
| 2401 |
+
{file = "onnxruntime-1.14.1-cp37-cp37m-manylinux_2_27_x86_64.whl", hash = "sha256:6efa3b2f4b1eaa6c714c07861993bfd9bb33bd73cdbcaf5b4aadcf1ec13fcaf7"},
|
| 2402 |
+
{file = "onnxruntime-1.14.1-cp37-cp37m-win32.whl", hash = "sha256:72fc0acc82c54bf03eba065ad9025baa438c00c54a2ee0beb8ae4b6085cd3a0d"},
|
| 2403 |
+
{file = "onnxruntime-1.14.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4d6f08ea40d63ccf90f203f4a2a498f4e590737dcaf16867075cc8e0a86c5554"},
|
| 2404 |
+
{file = "onnxruntime-1.14.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:c2d9e8f1bc6037f14d8aaa480492792c262fc914936153e40b06b3667bb25549"},
|
| 2405 |
+
{file = "onnxruntime-1.14.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e7424d3befdd95b537c90787bbfaa053b2bb19eb60135abb898cb0e099d7d7ad"},
|
| 2406 |
+
{file = "onnxruntime-1.14.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9066d275e6e41d0597e234d2d88c074d4325e650c74a9527a52cadbcf42a0fe2"},
|
| 2407 |
+
{file = "onnxruntime-1.14.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8224d3c1f2cd0b899cea7b5a39f28b971debe0da30fcbc61382801d97d6f5740"},
|
| 2408 |
+
{file = "onnxruntime-1.14.1-cp38-cp38-manylinux_2_27_aarch64.whl", hash = "sha256:f4ac52ff4ac793683ebd1fbd1ee24197e3b4ca825ee68ff739296a820867debe"},
|
| 2409 |
+
{file = "onnxruntime-1.14.1-cp38-cp38-manylinux_2_27_x86_64.whl", hash = "sha256:b1dd8cdd3be36c32ddd8f5763841ed571c3e81da59439a622947bd97efee6e77"},
|
| 2410 |
+
{file = "onnxruntime-1.14.1-cp38-cp38-win32.whl", hash = "sha256:95d0f0cd95360c07f1c3ba20962b9bb813627df4bfc1b4b274e1d40044df5ad1"},
|
| 2411 |
+
{file = "onnxruntime-1.14.1-cp38-cp38-win_amd64.whl", hash = "sha256:de40a558e00fc00f92e298d5be99eb8075dba51368dabcb259670a00f4670e56"},
|
| 2412 |
+
{file = "onnxruntime-1.14.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:c65b587a42a89fceceaad367bd69d071ee5c9c7010b76e2adac5e9efd9356fb5"},
|
| 2413 |
+
{file = "onnxruntime-1.14.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6e47ef6a2c6e6dd6ff48bc13f2331d124dff00e1d76627624bb3268c8058f19c"},
|
| 2414 |
+
{file = "onnxruntime-1.14.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0afd0f671d068dd99b9d071d88e93a9a57a5ed59af440c0f4d65319ee791603f"},
|
| 2415 |
+
{file = "onnxruntime-1.14.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc65e9061349cdf98ce16b37722b557109f16076632fbfed9a3151895cfd3bb7"},
|
| 2416 |
+
{file = "onnxruntime-1.14.1-cp39-cp39-manylinux_2_27_aarch64.whl", hash = "sha256:2ff17c71187391a71e6ccc78ca89aed83bcaed1c085c95267ab1a70897868bdd"},
|
| 2417 |
+
{file = "onnxruntime-1.14.1-cp39-cp39-manylinux_2_27_x86_64.whl", hash = "sha256:9b795189916942ce848192200dde5b1f32799ee6c84fc600969a44d88e8a5404"},
|
| 2418 |
+
{file = "onnxruntime-1.14.1-cp39-cp39-win32.whl", hash = "sha256:17ca3100112af045118750d24643a01ed4e6d86071a8efaef75cc1d434ea64aa"},
|
| 2419 |
+
{file = "onnxruntime-1.14.1-cp39-cp39-win_amd64.whl", hash = "sha256:b5e8c489329ba0fa0639dfd7ec02d6b07cece1bab52ef83884b537247efbda74"},
|
| 2420 |
+
]
|
| 2421 |
+
|
| 2422 |
+
[package.dependencies]
|
| 2423 |
+
coloredlogs = "*"
|
| 2424 |
+
flatbuffers = "*"
|
| 2425 |
+
numpy = ">=1.21.6"
|
| 2426 |
+
packaging = "*"
|
| 2427 |
+
protobuf = "*"
|
| 2428 |
+
sympy = "*"
|
| 2429 |
+
|
| 2430 |
[[package]]
|
| 2431 |
name = "orjson"
|
| 2432 |
version = "3.8.8"
|
|
|
|
| 2714 |
|
| 2715 |
[[package]]
|
| 2716 |
name = "protobuf"
|
| 2717 |
+
version = "3.19.6"
|
| 2718 |
description = "Protocol Buffers"
|
| 2719 |
category = "main"
|
| 2720 |
optional = false
|
| 2721 |
+
python-versions = ">=3.5"
|
| 2722 |
+
files = [
|
| 2723 |
+
{file = "protobuf-3.19.6-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:010be24d5a44be7b0613750ab40bc8b8cedc796db468eae6c779b395f50d1fa1"},
|
| 2724 |
+
{file = "protobuf-3.19.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11478547958c2dfea921920617eb457bc26867b0d1aa065ab05f35080c5d9eb6"},
|
| 2725 |
+
{file = "protobuf-3.19.6-cp310-cp310-win32.whl", hash = "sha256:559670e006e3173308c9254d63facb2c03865818f22204037ab76f7a0ff70b5f"},
|
| 2726 |
+
{file = "protobuf-3.19.6-cp310-cp310-win_amd64.whl", hash = "sha256:347b393d4dd06fb93a77620781e11c058b3b0a5289262f094379ada2920a3730"},
|
| 2727 |
+
{file = "protobuf-3.19.6-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a8ce5ae0de28b51dff886fb922012dad885e66176663950cb2344c0439ecb473"},
|
| 2728 |
+
{file = "protobuf-3.19.6-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90b0d02163c4e67279ddb6dc25e063db0130fc299aefabb5d481053509fae5c8"},
|
| 2729 |
+
{file = "protobuf-3.19.6-cp36-cp36m-win32.whl", hash = "sha256:30f5370d50295b246eaa0296533403961f7e64b03ea12265d6dfce3a391d8992"},
|
| 2730 |
+
{file = "protobuf-3.19.6-cp36-cp36m-win_amd64.whl", hash = "sha256:0c0714b025ec057b5a7600cb66ce7c693815f897cfda6d6efb58201c472e3437"},
|
| 2731 |
+
{file = "protobuf-3.19.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:5057c64052a1f1dd7d4450e9aac25af6bf36cfbfb3a1cd89d16393a036c49157"},
|
| 2732 |
+
{file = "protobuf-3.19.6-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:bb6776bd18f01ffe9920e78e03a8676530a5d6c5911934c6a1ac6eb78973ecb6"},
|
| 2733 |
+
{file = "protobuf-3.19.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84a04134866861b11556a82dd91ea6daf1f4925746b992f277b84013a7cc1229"},
|
| 2734 |
+
{file = "protobuf-3.19.6-cp37-cp37m-win32.whl", hash = "sha256:4bc98de3cdccfb5cd769620d5785b92c662b6bfad03a202b83799b6ed3fa1fa7"},
|
| 2735 |
+
{file = "protobuf-3.19.6-cp37-cp37m-win_amd64.whl", hash = "sha256:aa3b82ca1f24ab5326dcf4ea00fcbda703e986b22f3d27541654f749564d778b"},
|
| 2736 |
+
{file = "protobuf-3.19.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2b2d2913bcda0e0ec9a784d194bc490f5dc3d9d71d322d070b11a0ade32ff6ba"},
|
| 2737 |
+
{file = "protobuf-3.19.6-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:d0b635cefebd7a8a0f92020562dead912f81f401af7e71f16bf9506ff3bdbb38"},
|
| 2738 |
+
{file = "protobuf-3.19.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a552af4dc34793803f4e735aabe97ffc45962dfd3a237bdde242bff5a3de684"},
|
| 2739 |
+
{file = "protobuf-3.19.6-cp38-cp38-win32.whl", hash = "sha256:0469bc66160180165e4e29de7f445e57a34ab68f49357392c5b2f54c656ab25e"},
|
| 2740 |
+
{file = "protobuf-3.19.6-cp38-cp38-win_amd64.whl", hash = "sha256:91d5f1e139ff92c37e0ff07f391101df77e55ebb97f46bbc1535298d72019462"},
|
| 2741 |
+
{file = "protobuf-3.19.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c0ccd3f940fe7f3b35a261b1dd1b4fc850c8fde9f74207015431f174be5976b3"},
|
| 2742 |
+
{file = "protobuf-3.19.6-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:30a15015d86b9c3b8d6bf78d5b8c7749f2512c29f168ca259c9d7727604d0e39"},
|
| 2743 |
+
{file = "protobuf-3.19.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:878b4cd080a21ddda6ac6d1e163403ec6eea2e206cf225982ae04567d39be7b0"},
|
| 2744 |
+
{file = "protobuf-3.19.6-cp39-cp39-win32.whl", hash = "sha256:5a0d7539a1b1fb7e76bf5faa0b44b30f812758e989e59c40f77a7dab320e79b9"},
|
| 2745 |
+
{file = "protobuf-3.19.6-cp39-cp39-win_amd64.whl", hash = "sha256:bbf5cea5048272e1c60d235c7bd12ce1b14b8a16e76917f371c718bd3005f045"},
|
| 2746 |
+
{file = "protobuf-3.19.6-py2.py3-none-any.whl", hash = "sha256:14082457dc02be946f60b15aad35e9f5c69e738f80ebbc0900a19bc83734a5a4"},
|
| 2747 |
+
{file = "protobuf-3.19.6.tar.gz", hash = "sha256:5f5540d57a43042389e87661c6eaa50f47c19c6176e8cf1c4f287aeefeccb5c4"},
|
| 2748 |
+
]
|
| 2749 |
+
|
| 2750 |
+
[[package]]
|
| 2751 |
+
name = "pyasn1"
|
| 2752 |
+
version = "0.4.8"
|
| 2753 |
+
description = "ASN.1 types and codecs"
|
| 2754 |
+
category = "main"
|
| 2755 |
+
optional = false
|
| 2756 |
+
python-versions = "*"
|
| 2757 |
+
files = [
|
| 2758 |
+
{file = "pyasn1-0.4.8-py2.py3-none-any.whl", hash = "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d"},
|
| 2759 |
+
{file = "pyasn1-0.4.8.tar.gz", hash = "sha256:aef77c9fb94a3ac588e87841208bdec464471d9871bd5050a287cc9a475cd0ba"},
|
| 2760 |
+
]
|
| 2761 |
+
|
| 2762 |
+
[[package]]
|
| 2763 |
+
name = "pyasn1-modules"
|
| 2764 |
+
version = "0.2.8"
|
| 2765 |
+
description = "A collection of ASN.1-based protocols modules."
|
| 2766 |
+
category = "main"
|
| 2767 |
+
optional = false
|
| 2768 |
python-versions = "*"
|
| 2769 |
files = [
|
| 2770 |
+
{file = "pyasn1-modules-0.2.8.tar.gz", hash = "sha256:905f84c712230b2c592c19470d3ca8d552de726050d1d1716282a1f6146be65e"},
|
| 2771 |
+
{file = "pyasn1_modules-0.2.8-py2.py3-none-any.whl", hash = "sha256:a50b808ffeb97cb3601dd25981f6b016cbb3d31fbf57a8b8a87428e6158d0c74"},
|
| 2772 |
+
]
|
| 2773 |
+
|
| 2774 |
+
[package.dependencies]
|
| 2775 |
+
pyasn1 = ">=0.4.6,<0.5.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2776 |
|
| 2777 |
[[package]]
|
| 2778 |
name = "pydantic"
|
|
|
|
| 2974 |
[package.extras]
|
| 2975 |
diagrams = ["jinja2", "railroad-diagrams"]
|
| 2976 |
|
| 2977 |
+
[[package]]
|
| 2978 |
+
name = "pyreadline3"
|
| 2979 |
+
version = "3.4.1"
|
| 2980 |
+
description = "A python implementation of GNU readline."
|
| 2981 |
+
category = "main"
|
| 2982 |
+
optional = false
|
| 2983 |
+
python-versions = "*"
|
| 2984 |
+
files = [
|
| 2985 |
+
{file = "pyreadline3-3.4.1-py3-none-any.whl", hash = "sha256:b0efb6516fd4fb07b45949053826a62fa4cb353db5be2bbb4a7aa1fdd1e345fb"},
|
| 2986 |
+
{file = "pyreadline3-3.4.1.tar.gz", hash = "sha256:6f3d1f7b8a31ba32b73917cefc1f28cc660562f39aea8646d30bd6eff21f7bae"},
|
| 2987 |
+
]
|
| 2988 |
+
|
| 2989 |
[[package]]
|
| 2990 |
name = "pyrsistent"
|
| 2991 |
version = "0.19.3"
|
|
|
|
| 3278 |
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
| 3279 |
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
| 3280 |
|
| 3281 |
+
[[package]]
|
| 3282 |
+
name = "requests-oauthlib"
|
| 3283 |
+
version = "1.3.1"
|
| 3284 |
+
description = "OAuthlib authentication support for Requests."
|
| 3285 |
+
category = "main"
|
| 3286 |
+
optional = false
|
| 3287 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
| 3288 |
+
files = [
|
| 3289 |
+
{file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"},
|
| 3290 |
+
{file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"},
|
| 3291 |
+
]
|
| 3292 |
+
|
| 3293 |
+
[package.dependencies]
|
| 3294 |
+
oauthlib = ">=3.0.0"
|
| 3295 |
+
requests = ">=2.0.0"
|
| 3296 |
+
|
| 3297 |
+
[package.extras]
|
| 3298 |
+
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
|
| 3299 |
+
|
| 3300 |
[[package]]
|
| 3301 |
name = "rfc3986"
|
| 3302 |
version = "1.5.0"
|
|
|
|
| 3335 |
[package.extras]
|
| 3336 |
jupyter = ["ipywidgets (>=7.5.1,<9)"]
|
| 3337 |
|
| 3338 |
+
[[package]]
|
| 3339 |
+
name = "rsa"
|
| 3340 |
+
version = "4.9"
|
| 3341 |
+
description = "Pure-Python RSA implementation"
|
| 3342 |
+
category = "main"
|
| 3343 |
+
optional = false
|
| 3344 |
+
python-versions = ">=3.6,<4"
|
| 3345 |
+
files = [
|
| 3346 |
+
{file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"},
|
| 3347 |
+
{file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"},
|
| 3348 |
+
]
|
| 3349 |
+
|
| 3350 |
+
[package.dependencies]
|
| 3351 |
+
pyasn1 = ">=0.1.3"
|
| 3352 |
+
|
| 3353 |
[[package]]
|
| 3354 |
name = "scikit-image"
|
| 3355 |
version = "0.20.0"
|
|
|
|
| 3546 |
[package.extras]
|
| 3547 |
widechars = ["wcwidth"]
|
| 3548 |
|
| 3549 |
+
[[package]]
|
| 3550 |
+
name = "tensorboard"
|
| 3551 |
+
version = "2.12.0"
|
| 3552 |
+
description = "TensorBoard lets you watch Tensors Flow"
|
| 3553 |
+
category = "main"
|
| 3554 |
+
optional = false
|
| 3555 |
+
python-versions = ">=3.8"
|
| 3556 |
+
files = [
|
| 3557 |
+
{file = "tensorboard-2.12.0-py3-none-any.whl", hash = "sha256:3cbdc32448d7a28dc1bf0b1754760c08b8e0e2e37c451027ebd5ff4896613012"},
|
| 3558 |
+
]
|
| 3559 |
+
|
| 3560 |
+
[package.dependencies]
|
| 3561 |
+
absl-py = ">=0.4"
|
| 3562 |
+
google-auth = ">=1.6.3,<3"
|
| 3563 |
+
google-auth-oauthlib = ">=0.4.1,<0.5"
|
| 3564 |
+
grpcio = ">=1.48.2"
|
| 3565 |
+
markdown = ">=2.6.8"
|
| 3566 |
+
numpy = ">=1.12.0"
|
| 3567 |
+
protobuf = ">=3.19.6"
|
| 3568 |
+
requests = ">=2.21.0,<3"
|
| 3569 |
+
setuptools = ">=41.0.0"
|
| 3570 |
+
tensorboard-data-server = ">=0.7.0,<0.8.0"
|
| 3571 |
+
tensorboard-plugin-wit = ">=1.6.0"
|
| 3572 |
+
werkzeug = ">=1.0.1"
|
| 3573 |
+
wheel = ">=0.26"
|
| 3574 |
+
|
| 3575 |
+
[[package]]
|
| 3576 |
+
name = "tensorboard-data-server"
|
| 3577 |
+
version = "0.7.0"
|
| 3578 |
+
description = "Fast data loading for TensorBoard"
|
| 3579 |
+
category = "main"
|
| 3580 |
+
optional = false
|
| 3581 |
+
python-versions = ">=3.7"
|
| 3582 |
+
files = [
|
| 3583 |
+
{file = "tensorboard_data_server-0.7.0-py3-none-any.whl", hash = "sha256:753d4214799b31da7b6d93837959abebbc6afa86e69eacf1e9a317a48daa31eb"},
|
| 3584 |
+
{file = "tensorboard_data_server-0.7.0-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:eb7fa518737944dbf4f0cf83c2e40a7ac346bf91be2e6a0215de98be74e85454"},
|
| 3585 |
+
{file = "tensorboard_data_server-0.7.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64aa1be7c23e80b1a42c13b686eb0875bb70f5e755f4d2b8de5c1d880cf2267f"},
|
| 3586 |
+
]
|
| 3587 |
+
|
| 3588 |
+
[[package]]
|
| 3589 |
+
name = "tensorboard-plugin-wit"
|
| 3590 |
+
version = "1.8.1"
|
| 3591 |
+
description = "What-If Tool TensorBoard plugin."
|
| 3592 |
+
category = "main"
|
| 3593 |
+
optional = false
|
| 3594 |
+
python-versions = "*"
|
| 3595 |
+
files = [
|
| 3596 |
+
{file = "tensorboard_plugin_wit-1.8.1-py3-none-any.whl", hash = "sha256:ff26bdd583d155aa951ee3b152b3d0cffae8005dc697f72b44a8e8c2a77a8cbe"},
|
| 3597 |
+
]
|
| 3598 |
+
|
| 3599 |
[[package]]
|
| 3600 |
name = "tensorboardx"
|
| 3601 |
version = "2.6"
|
|
|
|
| 3990 |
{file = "websockets-10.4.tar.gz", hash = "sha256:eef610b23933c54d5d921c92578ae5f89813438fded840c2e9809d378dc765d3"},
|
| 3991 |
]
|
| 3992 |
|
| 3993 |
+
[[package]]
|
| 3994 |
+
name = "werkzeug"
|
| 3995 |
+
version = "2.2.3"
|
| 3996 |
+
description = "The comprehensive WSGI web application library."
|
| 3997 |
+
category = "main"
|
| 3998 |
+
optional = false
|
| 3999 |
+
python-versions = ">=3.7"
|
| 4000 |
+
files = [
|
| 4001 |
+
{file = "Werkzeug-2.2.3-py3-none-any.whl", hash = "sha256:56433961bc1f12533306c624f3be5e744389ac61d722175d543e1751285da612"},
|
| 4002 |
+
{file = "Werkzeug-2.2.3.tar.gz", hash = "sha256:2e1ccc9417d4da358b9de6f174e3ac094391ea1d4fbef2d667865d819dfd0afe"},
|
| 4003 |
+
]
|
| 4004 |
+
|
| 4005 |
+
[package.dependencies]
|
| 4006 |
+
MarkupSafe = ">=2.1.1"
|
| 4007 |
+
|
| 4008 |
+
[package.extras]
|
| 4009 |
+
watchdog = ["watchdog"]
|
| 4010 |
+
|
| 4011 |
[[package]]
|
| 4012 |
name = "wheel"
|
| 4013 |
version = "0.40.0"
|
|
|
|
| 4215 |
[metadata]
|
| 4216 |
lock-version = "2.0"
|
| 4217 |
python-versions = ">=3.8,<3.11"
|
| 4218 |
+
content-hash = "81eac0c68b289dd9d22d17dd34ae6bcd31f53c8201f229d31ff96d7094ad2392"
|
pyproject.toml
CHANGED
|
@@ -23,8 +23,11 @@ pygame = "^2.3.0"
|
|
| 23 |
torch = "^2.0.0"
|
| 24 |
libclang = "15.0.6.1"
|
| 25 |
tensorflow-probability = "^0.19.0"
|
| 26 |
-
protobuf = "3.
|
| 27 |
scipy = ">=1.8,<1.9.2"
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
[tool.poetry.dev-dependencies]
|
| 30 |
pylint = "*"
|
|
|
|
| 23 |
torch = "^2.0.0"
|
| 24 |
libclang = "15.0.6.1"
|
| 25 |
tensorflow-probability = "^0.19.0"
|
| 26 |
+
protobuf = "3.19.6"
|
| 27 |
scipy = ">=1.8,<1.9.2"
|
| 28 |
+
onnx = "1.12.0"
|
| 29 |
+
tensorboard = "^2.12.0"
|
| 30 |
+
onnxruntime = "^1.14.1"
|
| 31 |
|
| 32 |
[tool.poetry.dev-dependencies]
|
| 33 |
pylint = "*"
|