cloud_resource_env / train.py
sunil18p31a0101's picture
FEAT: " Added for BOth GPU and CPU utilization with thermal control and best allocation."
fa65b6c
# pyright: reportMissingImports=false
"""Train a PPO agent on the Cloud GPU+CPU Resource Management Environment.
Supports all 3 tasks:
- gpu_cpu_allocation
- thermal_management
- heuristic_fragmentation
Usage:
python train.py # default task
python train.py --task thermal_management # specific task
python train.py --task all # train on all tasks
"""
import argparse
from pathlib import Path
from stable_baselines3 import PPO
from cloud_env import CloudResourceEnv
ALL_TASKS = ["gpu_cpu_allocation", "thermal_management", "heuristic_fragmentation"]
def train_task(task: str, timesteps: int = 2000, project_root: Path = None):
if project_root is None:
project_root = Path(__file__).resolve().parent
print(f"\n{'='*60}")
print(f"Training on task: {task}")
print(f"{'='*60}")
env = CloudResourceEnv(task=task)
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=timesteps)
model_path = project_root / f"cloud_rl_{task}"
model.save(model_path)
print(f"Model saved to {model_path}")
# Quick evaluation
obs, _ = env.reset()
total_reward = 0.0
steps = 0
for _ in range(env.max_steps):
action, _ = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(int(action))
total_reward += reward
steps += 1
env.render()
if terminated or truncated:
break
print(f"\nEvaluation — Task: {task} | Steps: {steps} | Total reward: {total_reward:.2f}")
print(f" Score: {info.get('score', 0.0):.4f}")
return total_reward
def main():
parser = argparse.ArgumentParser(description="Train PPO on Cloud GPU+CPU env")
parser.add_argument(
"--task",
default="gpu_cpu_allocation",
choices=ALL_TASKS + ["all"],
help="Task to train on (default: gpu_cpu_allocation)",
)
parser.add_argument(
"--timesteps",
type=int,
default=2000,
help="Total timesteps for training (default: 2000)",
)
args = parser.parse_args()
project_root = Path(__file__).resolve().parent
if args.task == "all":
results = {}
for task in ALL_TASKS:
reward = train_task(task, args.timesteps, project_root)
results[task] = reward
print(f"\n{'='*60}")
print("All tasks trained!")
for task, reward in results.items():
print(f" {task}: {reward:.2f}")
else:
train_task(args.task, args.timesteps, project_root)
if __name__ == "__main__":
main()