| """ |
| Inference Script for Thermal Grid RL Agent Environment |
| ======================================================== |
| MANDATORY |
| - API_BASE_URL, MODEL_NAME, HF_TOKEN must be set in environment / .env |
| - Use OpenAI client for all LLM calls |
| - Emit [START], [STEP], [END] to stdout exactly as specified |
| |
| Environment Variables: |
| HF_TOKEN - Hugging Face / API key (checked first) |
| API_KEY - Alternative API key (fallback) |
| API_BASE_URL - The API endpoint for the LLM |
| MODEL_NAME - The model identifier to use for inference |
| ENV_URL - URL of the thermal grid environment server |
| |
| STDOUT FORMAT |
| [START] task=<task_name> env=<benchmark> model=<model_name> |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> |
| [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn> |
| """ |
|
|
| import asyncio |
| import os |
| import re |
| import json |
| import logging |
| from typing import List, Optional |
|
|
| from openai import OpenAI |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| from client import ThermalGridRlAgentEnv |
| from models import ThermalGridRlAgentAction, ThermalGridRlAgentObservation |
| from server.thermal_grid_rl_agent_environment import ThermalGridTaskID |
| from server.grader import ThermalGridGrader |
|
|
| logging.basicConfig(level=logging.ERROR) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| API_KEY = os.getenv("HF_TOKEN") |
| API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" |
| DEFAULT_MODEL = os.getenv("MODEL_NAME") or "meta-llama/Llama-3.2-1B-Instruct" |
| ENV_URL = os.getenv("ENV_URL", "http://localhost:8000") |
|
|
| print("\n--- LOADED ENVIRONMENT VARIABLES ---") |
| print(f"API_BASE_URL : {API_BASE_URL}") |
| print(f"MODEL_NAME : {DEFAULT_MODEL}") |
| print(f"API_KEY : {API_KEY[:4]}...{API_KEY[-4:] if len(API_KEY)>8 else ''}") |
| print(f"ENV_URL : {ENV_URL}") |
| print("------------------------------------\n") |
|
|
| BENCHMARK = "thermal_grid_rl_multi_agent" |
| MAX_STEPS = 30 |
| SUCCESS_SCORE_THRESHOLD = 0.1 |
| EARLY_STOP_REWARD = 0.9 |
| EARLY_STOP_CONSEC = 5 |
|
|
| TASKS = [ |
| ThermalGridTaskID.BASELINE, |
| ThermalGridTaskID.LOAD_SHIFT, |
| ThermalGridTaskID.GRID_STRESS, |
| ] |
|
|
|
|
| def log_start(task: str, env: str, model: str) -> None: |
| print(f"[START] task={task} env={env} model={model}", flush=True) |
|
|
|
|
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: |
| print( |
| f"[STEP] step={step} action={action} reward={reward:.2f} " |
| f"done={str(done).lower()} error={error if error else 'null'}", |
| flush=True, |
| ) |
|
|
|
|
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) |
| print( |
| f"[END] success={str(success).lower()} steps={steps} " |
| f"score={score:.3f} rewards={rewards_str}", |
| flush=True, |
| ) |
|
|
|
|
| class CoolingAgent: |
| """Focuses on thermal safety and equipment longevity.""" |
| @staticmethod |
| def get_recommendation(obs: ThermalGridRlAgentObservation) -> str: |
| max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0 |
| ambient = obs.ambient_temp_c |
| |
| if max_cpu > 75.0 or ambient > 38.0: |
| return ( |
| "CRITICAL: Thermal emergency. Suggest CRAC at 12°C, 100% fans, " |
| "and all 4 chillers. Priorities: Safety over cost." |
| ) |
| elif max_cpu > 65.0: |
| return ( |
| "WARNING: High temperatures. Recommend CRAC at 15°C and 85% fans. " |
| "Ensure at least 3 chillers are active." |
| ) |
| elif obs.thermal_mass_lag_c_per_min > 0.3: |
| return ( |
| "PREDICTIVE: Room heating up. Proactively lower CRAC setpoint " |
| "and increase fans by 10%." |
| ) |
| return "STATUS: Thermal state stable. Maintain current cooling setpoints." |
|
|
| class EnergyAgent: |
| """Focuses on PUE, electricity cost, and grid signals.""" |
| @staticmethod |
| def get_recommendation(obs: ThermalGridRlAgentObservation) -> str: |
| price = obs.energy_price_per_kwh |
| dr_signal = obs.demand_response_signal |
| pue = obs.pue |
| |
| if dr_signal == 1 or price > 0.15: |
| return ( |
| "CRITICAL: DR signal active or high pricing. Suggest raising CRAC " |
| "to 25°C and reducing fans to 40% to shed load." |
| ) |
| elif pue > 1.30: |
| return ( |
| "INEFFICIENCY: High PUE. Recommend optimizing chiller count " |
| "and raising CRAC setpoint to improve efficiency." |
| ) |
| return "STATUS: Energy usage within bounds. Optimize for efficiency if safety allows." |
|
|
| class WorkloadAgent: |
| """Focuses on throughput and job scheduling.""" |
| @staticmethod |
| def get_recommendation(obs: ThermalGridRlAgentObservation) -> str: |
| pending = obs.pending_batch_jobs |
| off_peak = obs.off_peak_window |
| |
| if pending > 100 and off_peak == 1: |
| return ( |
| "OPPORTUNITY: Off-peak window and high backlog. Suggest running " |
| "all pending batch jobs now." |
| ) |
| elif pending > 50: |
| return "URGENT: Batch backlog growing. Suggest increasing throughput." |
| return "STATUS: Workload manageable. Schedule batch jobs normally." |
|
|
| class OversightAgent: |
| """Deterministic safety monitor that can override the Coordinator.""" |
| def __init__(self, thermal_limit_c: float = 80.0): |
| self.thermal_limit_c = thermal_limit_c |
|
|
| def enforce(self, obs: ThermalGridRlAgentObservation, action: ThermalGridRlAgentAction) -> ThermalGridRlAgentAction: |
| """Deterministic safety overrides.""" |
| max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0 |
| |
| |
| action.metadata["oversight_triggered"] = False |
| action.metadata["oversight_reason"] = None |
|
|
| if max_cpu > self.thermal_limit_c: |
| |
| action.crac_setpoint_c = 12.0 |
| action.fan_speeds_pct = [100.0] * len(action.fan_speeds_pct) |
| action.num_active_chillers = 4 |
| |
| action.metadata["oversight_triggered"] = True |
| action.metadata["oversight_reason"] = f"CPU Temp CRITICAL ({max_cpu:.1f}°C). Emergency cooling forced." |
| print(f"[OVERSIGHT] OVERRIDE: {action.metadata['oversight_reason']}") |
| |
| return action |
|
|
| def simulate_negotiation(obs: ThermalGridRlAgentObservation, step: int) -> str: |
| """ |
| Simulate a structured one-round debate between the three specialized agents. |
| Each agent gives an initial position, then replies to the others. |
| Returns a formatted transcript string for the coordinator to resolve. |
| """ |
| cooling_pos = CoolingAgent.get_recommendation(obs) |
| energy_pos = EnergyAgent.get_recommendation(obs) |
| workload_pos = WorkloadAgent.get_recommendation(obs) |
|
|
| max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0 |
| price = obs.energy_price_per_kwh |
| pue = obs.pue |
| pending = obs.pending_batch_jobs |
| dr = obs.demand_response_signal |
| off_peak = obs.off_peak_window |
|
|
| |
| if dr == 1 or price > 0.15: |
| cooling_reply = ( |
| "[COOLING→ENERGY] I understand the DR signal, but thermal safety " |
| "cannot be compromised. If we raise the setpoint above 22°C now, " |
| "CPU temps will spike within 5 steps. Propose: CRAC at 18°C max." |
| ) |
| elif max_cpu > 65.0: |
| cooling_reply = ( |
| "[COOLING→ENERGY] Current CPU temps are dangerously high. " |
| "Any further energy saving must wait. Safety override required." |
| ) |
| else: |
| cooling_reply = ( |
| "[COOLING→ENERGY] Thermal state is manageable. I can accept " |
| "a moderate setpoint increase if PUE improvement is significant." |
| ) |
|
|
| |
| if pue > 1.30: |
| energy_reply = ( |
| "[ENERGY→COOLING] PUE is above 1.30 — we're burning money. " |
| "Raising CRAC by 2°C and reducing fans by 15% will save ~8% energy " |
| "with minimal thermal impact at current load levels." |
| ) |
| elif dr == 1: |
| energy_reply = ( |
| "[ENERGY→COOLING] Grid is requesting load shed. We MUST reduce " |
| "facility power by at least 10%. I support minimal cooling for now " |
| "if you can keep CPUs below 70°C." |
| ) |
| else: |
| energy_reply = ( |
| "[ENERGY→COOLING] Energy costs are within acceptable bounds. " |
| "No immediate conflict — support your cooling recommendation." |
| ) |
|
|
| |
| if pending > 50 and off_peak == 1 and dr == 0: |
| workload_reply = ( |
| "[WORKLOAD→BOTH] Off-peak window is active and backlog is growing. " |
| "If either of you can spare 10% headroom, I recommend running " |
| "batch jobs now to clear the queue before peak pricing resumes." |
| ) |
| elif pending > 100: |
| workload_reply = ( |
| "[WORKLOAD→BOTH] Batch backlog is critical. Throughput must improve " |
| "or SLAs will be breached. Request at least 2 chillers remain active." |
| ) |
| else: |
| workload_reply = ( |
| "[WORKLOAD→BOTH] Workload is stable. No conflict from my side — " |
| "please optimize for energy efficiency and safety as you see fit." |
| ) |
|
|
| transcript = f""" |
| --- ROUND 1: AGENT POSITIONS --- |
| [COOLING AGENT] : {cooling_pos} |
| [ENERGY AGENT] : {energy_pos} |
| [WORKLOAD AGENT]: {workload_pos} |
| |
| --- ROUND 2: AGENT REPLIES --- |
| {cooling_reply} |
| {energy_reply} |
| {workload_reply} |
| """.strip() |
|
|
| return transcript |
|
|
|
|
| SYSTEM_PROMPT = """You are the Facility Coordinator for a datacenter. |
| Analyze the state and agent recommendations, then output a JSON object with your final control actions. |
| |
| Required JSON format: |
| ```json |
| { |
| "reasoning": "Step-by-step logic resolving conflicts", |
| "crac_setpoint_c": 16.0, |
| "fan_speeds_pct": [75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0], |
| "num_active_chillers": 3 |
| } |
| ``` |
| |
| Constraints: |
| - crac_setpoint_c: Float between 12.0 and 27.0 |
| - fan_speeds_pct: List of exactly 10 floats between 20.0 and 100.0 |
| - num_active_chillers: Integer between 1 and 4 |
| |
| Priority: 1. Safety, 2. Grid, 3. Throughput, 4. PUE. |
| Output ONLY the JSON object. Do not add conversational text.""" |
|
|
|
|
| def build_user_message(obs: ThermalGridRlAgentObservation, step: int, task: str) -> str: |
| max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0 |
| avg_cpu = sum(obs.mean_cpu_temps_c) / len(obs.mean_cpu_temps_c) if obs.mean_cpu_temps_c else 0.0 |
|
|
| |
| negotiation_transcript = simulate_negotiation(obs, step) |
|
|
| return f"""STEP {step}/{MAX_STEPS} | Task: {task} |
| |
| {negotiation_transcript} |
| |
| --- CURRENT ENVIRONMENT STATE --- |
| - Thermal : max_cpu={max_cpu:.1f}°C avg_cpu={avg_cpu:.1f}°C |
| - Cooling : PUE={obs.pue:.3f} setpoint={obs.crac_supply_temp_c:.1f}°C fans={obs.avg_fan_speed_pct:.0f}% chillers={obs.num_active_chillers} |
| - Grid : price=${obs.energy_price_per_kwh:.3f}/kWh DR={obs.demand_response_signal} ambient={obs.ambient_temp_c:.1f}°C |
| - Batch : {obs.pending_batch_jobs} jobs pending off_peak={obs.off_peak_window} |
| |
| You have read both rounds of debate above. Resolve the conflict and respond with JSON only.""" |
|
|
|
|
| def _extract_json(raw: str) -> dict: |
| """Robustly parse JSON from LLM, handling markdown and truncation.""" |
| if not raw: return {} |
| |
| |
| cleaned = re.sub(r"```(?:json)?\s*(.*?)\s*```", r"\1", raw, flags=re.DOTALL) |
| cleaned = cleaned.strip() |
| |
| |
| try: |
| return json.loads(cleaned) |
| except json.JSONDecodeError: |
| pass |
| |
| |
| |
| for _ in range(5): |
| cleaned += "\n}" |
| try: |
| return json.loads(cleaned) |
| except: |
| continue |
| |
| |
| m = re.search(r'\{.*\}', raw, re.DOTALL) |
| if m: |
| try: |
| return json.loads(m.group()) |
| except: |
| pass |
| |
| return {} |
|
|
|
|
| def get_llm_action( |
| client: OpenAI, |
| obs: ThermalGridRlAgentObservation, |
| step: int, |
| task_id: ThermalGridTaskID, |
| model: str |
| ) -> tuple: |
|
|
| user_prompt = build_user_message(obs, step, task_id.value) |
|
|
| raw_attempts = [] |
| parsed_data = None |
| final_raw = "" |
|
|
| for attempt in range(3): |
| try: |
| response = client.chat.completions.create( |
| model=model, |
| messages=[ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_prompt}, |
| ], |
| max_tokens=1024, |
| temperature=0.3 if attempt == 0 else 0.7, |
| ) |
|
|
| raw = response.choices[0].message.content or "{}" |
| raw_attempts.append(raw) |
|
|
| data = _extract_json(raw) |
|
|
| if data and "crac_setpoint_c" in data: |
| parsed_data = data |
| final_raw = raw |
|
|
| crac = max(12.0, min(27.0, float(data.get("crac_setpoint_c", 18.0)))) |
| fans = [max(20.0, min(100.0, float(f))) for f in data.get("fan_speeds_pct", [70.0] * 10)] |
| chillers = max(1, min(4, int(data.get("num_active_chillers", 2)))) |
|
|
| if len(fans) != 10: |
| fans = [fans[0] if fans else 70.0] * 10 |
|
|
| action = ThermalGridRlAgentAction( |
| crac_setpoint_c=crac, |
| fan_speeds_pct=fans, |
| num_active_chillers=chillers, |
| ) |
|
|
| oversight = OversightAgent() |
| final_action = oversight.enforce(obs, action) |
|
|
| return final_action, user_prompt, final_raw, raw_attempts, parsed_data |
|
|
| except Exception as e: |
| print(f"[ERROR] LLM Attempt {attempt+1} failed: {e}") |
| logger.warning(f"Step {step} Attempt {attempt+1} failed: {e}") |
|
|
| raise ValueError(f"Unparseable LLM response after 3 attempts: {raw_attempts}") |
|
|
| async def run_inference(task_id: ThermalGridTaskID, model: str, train_mode: bool = False) -> None: |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) |
|
|
| if train_mode: |
| print(f"[MODE] TRAIN — collecting trajectories") |
| else: |
| print(f"[MODE] INFERENCE — early stopping enabled") |
|
|
| log_start(task=task_id.value, env=BENCHMARK, model=model) |
|
|
| env = ThermalGridRlAgentEnv(base_url=ENV_URL) |
| grader = ThermalGridGrader(task_id=task_id.value) |
|
|
| rewards = [] |
| steps_taken = 0 |
| success = False |
|
|
| episode_buffer = [] |
|
|
| try: |
| |
| |
| |
| try: |
| reset_result = await env.reset(task_id=task_id.value) |
| except Exception as e: |
| raise RuntimeError(f"env.reset() failed: {e}") |
|
|
| observation = reset_result.observation |
|
|
| |
| |
| |
| for step in range(1, MAX_STEPS + 1): |
| try: |
| |
| action, prompt, raw, raw_attempts, parsed_data = await asyncio.to_thread( |
| get_llm_action, client, observation, step, task_id, model |
| ) |
|
|
| |
| action_dict = { |
| "crac_setpoint_c": action.crac_setpoint_c, |
| "fan_speeds_pct": action.fan_speeds_pct, |
| "num_active_chillers": action.num_active_chillers, |
| } |
|
|
| action_str = json.dumps(action_dict, separators=(",", ":")) |
|
|
|
|
| |
| |
| |
| step_result = await env.step(action) |
|
|
| reward = float(step_result.reward or 0.0) |
| done = step_result.done |
|
|
| rewards.append(reward) |
| steps_taken = step |
|
|
| |
| |
| |
| episode_buffer.append({ |
| "step": step, |
|
|
| |
| "prompt": prompt, |
|
|
| |
| "response": action_dict, |
|
|
| |
| "reward": reward, |
|
|
| |
| "oversight_triggered": action.metadata.get("oversight_triggered", False), |
|
|
| |
| "raw_response": raw, |
| "raw_attempts": raw_attempts, |
| "parsed_action": parsed_data, |
|
|
| |
| "state": { |
| "max_cpu": max(observation.max_cpu_temps_c) if observation.max_cpu_temps_c else 0.0, |
| "avg_cpu": sum(observation.mean_cpu_temps_c)/len(observation.mean_cpu_temps_c) if observation.mean_cpu_temps_c else 0.0, |
| "pue": observation.pue, |
| "energy_price": observation.energy_price_per_kwh, |
| "ambient": observation.ambient_temp_c, |
| "pending_jobs": observation.pending_batch_jobs, |
| "off_peak": observation.off_peak_window |
| } |
| }) |
|
|
| |
| log_step(step=step, action=action_str, reward=reward, done=done, error=None) |
|
|
| |
| observation = step_result.observation |
|
|
| |
| |
| |
| if not train_mode: |
| recent = rewards[-EARLY_STOP_CONSEC:] if len(rewards) >= EARLY_STOP_CONSEC else [] |
| if recent and all(r >= EARLY_STOP_REWARD for r in recent): |
| print("[EARLY STOP] Stable high reward achieved") |
| break |
|
|
| if done: |
| break |
|
|
| except Exception as e: |
| import traceback |
| print("\n[ERROR] Exception in loop:") |
| traceback.print_exc() |
| log_step(step=step, action="{}", reward=0.0, done=False, error=str(e)) |
| break |
|
|
| |
| |
| |
| score = min(max(grader.get_thermal_grid_score(), 0.0), 1.0) |
| success = score >= SUCCESS_SCORE_THRESHOLD |
|
|
| |
| |
| |
| if train_mode: |
| |
| with open("all_prompts.jsonl", "a") as f: |
| for step_data in episode_buffer: |
| f.write(json.dumps({"prompt": step_data["prompt"]}) + "\n") |
|
|
| |
| with open("expert_trajectories.jsonl", "a") as f: |
| for step_data in episode_buffer: |
| f.write(json.dumps({ |
| "step": step_data["step"], |
| "prompt": step_data["prompt"], |
| "response": step_data["response"], |
| "reward": step_data["reward"], |
| "oversight_triggered": step_data["oversight_triggered"] |
| }) + "\n") |
|
|
| finally: |
| try: |
| await env.close() |
| except: |
| pass |
|
|
| log_end(success=success, steps=steps_taken, score=score, rewards=rewards) |
| async def main() -> None: |
| import argparse |
|
|
| parser = argparse.ArgumentParser(description="Thermal Grid RL Agent Inference") |
|
|
| parser.add_argument("--model", type=str, default=DEFAULT_MODEL) |
| parser.add_argument("--train", action="store_true", default=False) |
|
|
| args = parser.parse_args() |
|
|
| current_model = args.model |
| train_mode = args.train |
|
|
| |
| if train_mode: |
| for f_path in ["all_prompts.jsonl", "expert_trajectories.jsonl"]: |
| if os.path.exists(f_path): |
| os.remove(f_path) |
|
|
| for task_id in TASKS: |
| await run_inference(task_id, current_model, train_mode=train_mode) |
|
|
|
|
| if __name__ == "__main__": |
| asyncio.run(main()) |