File size: 2,644 Bytes
5f69b60
 
fa65b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
5f69b60
 
 
 
 
 
 
fa65b6c
 
 
 
 
 
5f69b60
fa65b6c
 
 
5f69b60
fa65b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f69b60
fa65b6c
5f69b60
fa65b6c
 
 
 
 
 
 
 
 
 
 
5f69b60
 
 
fa65b6c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# pyright: reportMissingImports=false

"""Train a PPO agent on the Cloud GPU+CPU Resource Management Environment.

Supports all 3 tasks:
  - gpu_cpu_allocation
  - thermal_management
  - heuristic_fragmentation

Usage:
    python train.py                              # default task
    python train.py --task thermal_management    # specific task
    python train.py --task all                   # train on all tasks
"""

import argparse
from pathlib import Path

from stable_baselines3 import PPO

from cloud_env import CloudResourceEnv


ALL_TASKS = ["gpu_cpu_allocation", "thermal_management", "heuristic_fragmentation"]


def train_task(task: str, timesteps: int = 2000, project_root: Path = None):
    if project_root is None:
        project_root = Path(__file__).resolve().parent

    print(f"\n{'='*60}")
    print(f"Training on task: {task}")
    print(f"{'='*60}")

    env = CloudResourceEnv(task=task)
    model = PPO("MlpPolicy", env, verbose=1)

    model.learn(total_timesteps=timesteps)
    model_path = project_root / f"cloud_rl_{task}"
    model.save(model_path)
    print(f"Model saved to {model_path}")

    # Quick evaluation
    obs, _ = env.reset()
    total_reward = 0.0
    steps = 0
    for _ in range(env.max_steps):
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(int(action))
        total_reward += reward
        steps += 1
        env.render()
        if terminated or truncated:
            break

    print(f"\nEvaluation — Task: {task} | Steps: {steps} | Total reward: {total_reward:.2f}")
    print(f"  Score: {info.get('score', 0.0):.4f}")
    return total_reward


def main():
    parser = argparse.ArgumentParser(description="Train PPO on Cloud GPU+CPU env")
    parser.add_argument(
        "--task",
        default="gpu_cpu_allocation",
        choices=ALL_TASKS + ["all"],
        help="Task to train on (default: gpu_cpu_allocation)",
    )
    parser.add_argument(
        "--timesteps",
        type=int,
        default=2000,
        help="Total timesteps for training (default: 2000)",
    )
    args = parser.parse_args()

    project_root = Path(__file__).resolve().parent

    if args.task == "all":
        results = {}
        for task in ALL_TASKS:
            reward = train_task(task, args.timesteps, project_root)
            results[task] = reward
        print(f"\n{'='*60}")
        print("All tasks trained!")
        for task, reward in results.items():
            print(f"  {task}: {reward:.2f}")
    else:
        train_task(args.task, args.timesteps, project_root)


if __name__ == "__main__":
    main()