trioskosmos commited on
Commit
fa71ce1
·
verified ·
1 Parent(s): 5dfc8e2

Upload ai/training/train_ppo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/training/train_ppo.py +53 -0
ai/training/train_ppo.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import our environment
2
+ from ai.gym_env import LoveLiveCardGameEnv
3
+ from sb3_contrib import MaskablePPO
4
+ from sb3_contrib.common.maskable.utils import get_action_masks
5
+ from sb3_contrib.common.wrappers import ActionMasker
6
+
7
+
8
+ def make_env():
9
+ env = LoveLiveCardGameEnv()
10
+ # Wrap with ActionMasker for MaskablePPO logic
11
+ env = ActionMasker(env, lambda env: env.action_masks())
12
+ return env
13
+
14
+
15
+ def main():
16
+ # Create Environment
17
+ env = make_env()
18
+
19
+ # Define Model (MaskablePPO)
20
+ model = MaskablePPO(
21
+ "MlpPolicy", env, verbose=1, gamma=0.99, learning_rate=3e-4, tensorboard_log="./logs/ppo_tensorboard/"
22
+ )
23
+
24
+ print("Starting Training...")
25
+ # Train for 100k steps
26
+ model.learn(total_timesteps=100_000, progress_bar=True)
27
+
28
+ # Save Model
29
+ model.save("checkpoints/lovelive_ppo_agent")
30
+ print("Training Complete. Model Saved.")
31
+
32
+ # Test Run
33
+ obs, _ = env.reset()
34
+ done = False
35
+ total_reward = 0
36
+ while not done:
37
+ # Predict using masks
38
+ action_masks = get_action_masks(env)
39
+ action, _states = model.predict(obs, action_masks=action_masks, deterministic=True)
40
+
41
+ obs, reward, done, truncated, info = env.step(action)
42
+ total_reward += reward
43
+ env.render()
44
+
45
+ print(f"Test Run Reward: {total_reward}")
46
+
47
+
48
+ if __name__ == "__main__":
49
+ try:
50
+ main()
51
+ except ImportError as e:
52
+ print(f"Import Error: {e}")
53
+ print("Please install: pip install -r requirements_rl.txt")