Update neochessppo.py
Browse files- neochessppo.py +9 -7
neochessppo.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torchrl
|
| 2 |
import torch
|
| 3 |
import chess
|
|
@@ -145,9 +147,9 @@ with set_gym_backend("gymnasium"):
|
|
| 145 |
|
| 146 |
policy = Policy().to(device)
|
| 147 |
value = Value().to(device)
|
| 148 |
-
valweight = torch.load("NeoChess/
|
| 149 |
value.load_state_dict(valweight)
|
| 150 |
-
polweight = torch.load("NeoChess/chessy_policy.pth")
|
| 151 |
policy.load_state_dict(polweight)
|
| 152 |
|
| 153 |
def sample_masked_action(logits, mask):
|
|
@@ -253,9 +255,10 @@ print(actor(obs))
|
|
| 253 |
rollout = env.rollout(3)
|
| 254 |
|
| 255 |
from torchrl.record.loggers import generate_exp_name, get_logger
|
| 256 |
-
def train_ppo_chess(chess_env, num_iterations=1, frames_per_batch=
|
| 257 |
-
num_epochs=
|
| 258 |
clip_epsilon=0.2, device="cpu"):
|
|
|
|
| 259 |
"""
|
| 260 |
Main PPO training loop for Chess
|
| 261 |
|
|
@@ -275,7 +278,6 @@ def train_ppo_chess(chess_env, num_iterations=1, frames_per_batch=100,
|
|
| 275 |
env = chess_env
|
| 276 |
# Create actor and value modules
|
| 277 |
actor_module = actor
|
| 278 |
-
global actor_module, value_module, loss_module
|
| 279 |
|
| 280 |
collector = SyncDataCollector(
|
| 281 |
env,
|
|
@@ -414,5 +416,5 @@ def train_ppo_chess(chess_env, num_iterations=1, frames_per_batch=100,
|
|
| 414 |
print("\nTraining completed!")
|
| 415 |
|
| 416 |
train_ppo_chess(env)
|
| 417 |
-
torch.save(value.state_dict(),"chessy_model.pth")
|
| 418 |
-
torch.save(policy.state_dict(),"chessy_policy.pth")
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.system("mv NeoChess/san_moves.txt /usr/local/python/3.12.1/lib/python3.12/site-packages/torchrl/envs/custom/")
|
| 3 |
import torchrl
|
| 4 |
import torch
|
| 5 |
import chess
|
|
|
|
| 147 |
|
| 148 |
policy = Policy().to(device)
|
| 149 |
value = Value().to(device)
|
| 150 |
+
valweight = torch.load("NeoChess-Community/chessy_modelt-1.pth",map_location=device,weights_only=False)
|
| 151 |
value.load_state_dict(valweight)
|
| 152 |
+
polweight = torch.load("NeoChess-Community/chessy_policy.pth",map_location=device,weights_only=False)
|
| 153 |
policy.load_state_dict(polweight)
|
| 154 |
|
| 155 |
def sample_masked_action(logits, mask):
|
|
|
|
| 255 |
rollout = env.rollout(3)
|
| 256 |
|
| 257 |
from torchrl.record.loggers import generate_exp_name, get_logger
|
| 258 |
+
def train_ppo_chess(chess_env, num_iterations=1, frames_per_batch=1000,
|
| 259 |
+
num_epochs=100, lr=3e-4, gamma=0.99, lmbda=0.95,
|
| 260 |
clip_epsilon=0.2, device="cpu"):
|
| 261 |
+
global actor_module, value_module, loss_module
|
| 262 |
"""
|
| 263 |
Main PPO training loop for Chess
|
| 264 |
|
|
|
|
| 278 |
env = chess_env
|
| 279 |
# Create actor and value modules
|
| 280 |
actor_module = actor
|
|
|
|
| 281 |
|
| 282 |
collector = SyncDataCollector(
|
| 283 |
env,
|
|
|
|
| 416 |
print("\nTraining completed!")
|
| 417 |
|
| 418 |
train_ppo_chess(env)
|
| 419 |
+
torch.save(value.state_dict(),"NeoChess-Community/chessy_model.pth")
|
| 420 |
+
torch.save(policy.state_dict(),"NeoChess-Community/chessy_policy.pth")
|