gpu-scheduler-openenv / examples /custom_agent.py
SulmanK's picture
Restore full project documentation and add Hugging Face Space integration
09f204f
Raw
History Blame Contribute Delete
2.98 kB
from __future__ import annotations
import json
import sys
from argparse import ArgumentParser
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from env.environment import GPUInferenceSchedulingEnv
from env.models import GPUState, JobState, SchedulingAction, SchedulingObservation
def choose_action(observation: SchedulingObservation) -> SchedulingAction:
"""Minimal example policy that users can replace with their own agent logic."""
idle_gpus = [gpu for gpu in observation.gpus if not gpu.busy]
pending_jobs = sorted(
observation.pending_jobs,
key=lambda job: (
0 if job.priority.value == "high" else 1 if job.priority.value == "medium" else 2,
job.deadline,
job.job_id,
),
)
best_pair: tuple[JobState, GPUState] | None = None
for job in pending_jobs:
fitting = [gpu for gpu in idle_gpus if gpu.vram_capacity >= job.vram_requirement]
if not fitting:
continue
best_gpu = sorted(
fitting,
key=lambda gpu: (
gpu.vram_capacity - job.vram_requirement,
gpu.cost_per_step,
-gpu.speed_multiplier,
gpu.gpu_id,
),
)[0]
best_pair = (job, best_gpu)
break
if best_pair is None:
return SchedulingAction(
action="defer",
rationale="No pending job currently fits an available GPU.",
)
job, gpu = best_pair
return SchedulingAction(
action="place",
job_id=job.job_id,
gpu_id=gpu.gpu_id,
rationale=(
f"Place {job.job_id} on {gpu.gpu_id} because it fits, "
f"priority={job.priority.value}, deadline={job.deadline}."
)[:240],
)
def run_scenario(scenario_id: str, seed: int | None = None) -> dict[str, object]:
env = GPUInferenceSchedulingEnv(scenario_id=scenario_id, seed=seed)
observation = env.reset(seed=seed, scenario_id=scenario_id if seed is None else None)
done = False
while not done:
action = choose_action(observation)
step_result = env.step(action)
observation = step_result.observation
done = step_result.done
return {
"scenario_id": observation.scenario_id,
"metrics": env.metrics().model_dump(mode="json"),
"action_log": [entry.model_dump(mode="json") for entry in observation.action_log],
}
def main() -> None:
parser = ArgumentParser(description="Run a minimal custom agent against the GPU scheduling environment.")
parser.add_argument("--scenario-id", default="deadline_crunch")
parser.add_argument("--seed", type=int, default=None)
args = parser.parse_args()
report = run_scenario(scenario_id=args.scenario_id, seed=args.seed)
print(json.dumps(report, indent=2))
if __name__ == "__main__":
main()