Spaces:
Sleeping
Sleeping
| # pyright: reportMissingImports=false | |
| """Train a PPO agent on the Cloud GPU+CPU Resource Management Environment. | |
| Supports all 3 tasks: | |
| - gpu_cpu_allocation | |
| - thermal_management | |
| - heuristic_fragmentation | |
| Usage: | |
| python train.py # default task | |
| python train.py --task thermal_management # specific task | |
| python train.py --task all # train on all tasks | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| from stable_baselines3 import PPO | |
| from cloud_env import CloudResourceEnv | |
| ALL_TASKS = ["gpu_cpu_allocation", "thermal_management", "heuristic_fragmentation"] | |
| def train_task(task: str, timesteps: int = 2000, project_root: Path = None): | |
| if project_root is None: | |
| project_root = Path(__file__).resolve().parent | |
| print(f"\n{'='*60}") | |
| print(f"Training on task: {task}") | |
| print(f"{'='*60}") | |
| env = CloudResourceEnv(task=task) | |
| model = PPO("MlpPolicy", env, verbose=1) | |
| model.learn(total_timesteps=timesteps) | |
| model_path = project_root / f"cloud_rl_{task}" | |
| model.save(model_path) | |
| print(f"Model saved to {model_path}") | |
| # Quick evaluation | |
| obs, _ = env.reset() | |
| total_reward = 0.0 | |
| steps = 0 | |
| for _ in range(env.max_steps): | |
| action, _ = model.predict(obs, deterministic=True) | |
| obs, reward, terminated, truncated, info = env.step(int(action)) | |
| total_reward += reward | |
| steps += 1 | |
| env.render() | |
| if terminated or truncated: | |
| break | |
| print(f"\nEvaluation — Task: {task} | Steps: {steps} | Total reward: {total_reward:.2f}") | |
| print(f" Score: {info.get('score', 0.0):.4f}") | |
| return total_reward | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train PPO on Cloud GPU+CPU env") | |
| parser.add_argument( | |
| "--task", | |
| default="gpu_cpu_allocation", | |
| choices=ALL_TASKS + ["all"], | |
| help="Task to train on (default: gpu_cpu_allocation)", | |
| ) | |
| parser.add_argument( | |
| "--timesteps", | |
| type=int, | |
| default=2000, | |
| help="Total timesteps for training (default: 2000)", | |
| ) | |
| args = parser.parse_args() | |
| project_root = Path(__file__).resolve().parent | |
| if args.task == "all": | |
| results = {} | |
| for task in ALL_TASKS: | |
| reward = train_task(task, args.timesteps, project_root) | |
| results[task] = reward | |
| print(f"\n{'='*60}") | |
| print("All tasks trained!") | |
| for task, reward in results.items(): | |
| print(f" {task}: {reward:.2f}") | |
| else: | |
| train_task(args.task, args.timesteps, project_root) | |
| if __name__ == "__main__": | |
| main() |