from __future__ import annotations import json import sys from argparse import ArgumentParser from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from env.environment import GPUInferenceSchedulingEnv from env.models import GPUState, JobState, SchedulingAction, SchedulingObservation def choose_action(observation: SchedulingObservation) -> SchedulingAction: """Minimal example policy that users can replace with their own agent logic.""" idle_gpus = [gpu for gpu in observation.gpus if not gpu.busy] pending_jobs = sorted( observation.pending_jobs, key=lambda job: ( 0 if job.priority.value == "high" else 1 if job.priority.value == "medium" else 2, job.deadline, job.job_id, ), ) best_pair: tuple[JobState, GPUState] | None = None for job in pending_jobs: fitting = [gpu for gpu in idle_gpus if gpu.vram_capacity >= job.vram_requirement] if not fitting: continue best_gpu = sorted( fitting, key=lambda gpu: ( gpu.vram_capacity - job.vram_requirement, gpu.cost_per_step, -gpu.speed_multiplier, gpu.gpu_id, ), )[0] best_pair = (job, best_gpu) break if best_pair is None: return SchedulingAction( action="defer", rationale="No pending job currently fits an available GPU.", ) job, gpu = best_pair return SchedulingAction( action="place", job_id=job.job_id, gpu_id=gpu.gpu_id, rationale=( f"Place {job.job_id} on {gpu.gpu_id} because it fits, " f"priority={job.priority.value}, deadline={job.deadline}." )[:240], ) def run_scenario(scenario_id: str, seed: int | None = None) -> dict[str, object]: env = GPUInferenceSchedulingEnv(scenario_id=scenario_id, seed=seed) observation = env.reset(seed=seed, scenario_id=scenario_id if seed is None else None) done = False while not done: action = choose_action(observation) step_result = env.step(action) observation = step_result.observation done = step_result.done return { "scenario_id": observation.scenario_id, "metrics": env.metrics().model_dump(mode="json"), "action_log": [entry.model_dump(mode="json") for entry in observation.action_log], } def main() -> None: parser = ArgumentParser(description="Run a minimal custom agent against the GPU scheduling environment.") parser.add_argument("--scenario-id", default="deadline_crunch") parser.add_argument("--seed", type=int, default=None) args = parser.parse_args() report = run_scenario(scenario_id=args.scenario_id, seed=args.seed) print(json.dumps(report, indent=2)) if __name__ == "__main__": main()