Spaces:
Sleeping
Sleeping
File size: 1,533 Bytes
fa71ce1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# Import our environment
from ai.gym_env import LoveLiveCardGameEnv
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.wrappers import ActionMasker
def make_env():
env = LoveLiveCardGameEnv()
# Wrap with ActionMasker for MaskablePPO logic
env = ActionMasker(env, lambda env: env.action_masks())
return env
def main():
# Create Environment
env = make_env()
# Define Model (MaskablePPO)
model = MaskablePPO(
"MlpPolicy", env, verbose=1, gamma=0.99, learning_rate=3e-4, tensorboard_log="./logs/ppo_tensorboard/"
)
print("Starting Training...")
# Train for 100k steps
model.learn(total_timesteps=100_000, progress_bar=True)
# Save Model
model.save("checkpoints/lovelive_ppo_agent")
print("Training Complete. Model Saved.")
# Test Run
obs, _ = env.reset()
done = False
total_reward = 0
while not done:
# Predict using masks
action_masks = get_action_masks(env)
action, _states = model.predict(obs, action_masks=action_masks, deterministic=True)
obs, reward, done, truncated, info = env.step(action)
total_reward += reward
env.render()
print(f"Test Run Reward: {total_reward}")
if __name__ == "__main__":
try:
main()
except ImportError as e:
print(f"Import Error: {e}")
print("Please install: pip install -r requirements_rl.txt")
|