""" Inference Script for Thermal Grid RL Agent Environment ======================================================== MANDATORY - API_BASE_URL, MODEL_NAME, HF_TOKEN must be set in environment / .env - Use OpenAI client for all LLM calls - Emit [START], [STEP], [END] to stdout exactly as specified Environment Variables: HF_TOKEN - Hugging Face / API key (checked first) API_KEY - Alternative API key (fallback) API_BASE_URL - The API endpoint for the LLM MODEL_NAME - The model identifier to use for inference ENV_URL - URL of the thermal grid environment server STDOUT FORMAT [START] task= env= model= [STEP] step= action= reward=<0.00> done= error= [END] success= steps= score= rewards= """ import asyncio import os import re import json import logging from typing import List, Optional from openai import OpenAI from dotenv import load_dotenv load_dotenv() from client import ThermalGridRlAgentEnv from models import ThermalGridRlAgentAction, ThermalGridRlAgentObservation from server.thermal_grid_rl_agent_environment import ThermalGridTaskID from server.grader import ThermalGridGrader logging.basicConfig(level=logging.ERROR) logger = logging.getLogger(__name__) API_KEY = os.getenv("HF_TOKEN") API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" DEFAULT_MODEL = os.getenv("MODEL_NAME") or "meta-llama/Llama-3.2-1B-Instruct" ENV_URL = os.getenv("ENV_URL", "http://localhost:8000") print("\n--- LOADED ENVIRONMENT VARIABLES ---") print(f"API_BASE_URL : {API_BASE_URL}") print(f"MODEL_NAME : {DEFAULT_MODEL}") print(f"API_KEY : {API_KEY[:4]}...{API_KEY[-4:] if len(API_KEY)>8 else ''}") print(f"ENV_URL : {ENV_URL}") print("------------------------------------\n") BENCHMARK = "thermal_grid_rl_multi_agent" MAX_STEPS = 30 SUCCESS_SCORE_THRESHOLD = 0.1 EARLY_STOP_REWARD = 0.9 # not 1.0 (unless rewards are definitely capped at 1.0) EARLY_STOP_CONSEC = 5 # 5 steps for more stability TASKS = [ ThermalGridTaskID.BASELINE, ThermalGridTaskID.LOAD_SHIFT, ThermalGridTaskID.GRID_STRESS, ] def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: print( f"[STEP] step={step} action={action} reward={reward:.2f} " f"done={str(done).lower()} error={error if error else 'null'}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(success).lower()} steps={steps} " f"score={score:.3f} rewards={rewards_str}", flush=True, ) class CoolingAgent: """Focuses on thermal safety and equipment longevity.""" @staticmethod def get_recommendation(obs: ThermalGridRlAgentObservation) -> str: max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0 ambient = obs.ambient_temp_c if max_cpu > 75.0 or ambient > 38.0: return ( "CRITICAL: Thermal emergency. Suggest CRAC at 12°C, 100% fans, " "and all 4 chillers. Priorities: Safety over cost." ) elif max_cpu > 65.0: return ( "WARNING: High temperatures. Recommend CRAC at 15°C and 85% fans. " "Ensure at least 3 chillers are active." ) elif obs.thermal_mass_lag_c_per_min > 0.3: return ( "PREDICTIVE: Room heating up. Proactively lower CRAC setpoint " "and increase fans by 10%." ) return "STATUS: Thermal state stable. Maintain current cooling setpoints." class EnergyAgent: """Focuses on PUE, electricity cost, and grid signals.""" @staticmethod def get_recommendation(obs: ThermalGridRlAgentObservation) -> str: price = obs.energy_price_per_kwh dr_signal = obs.demand_response_signal pue = obs.pue if dr_signal == 1 or price > 0.15: return ( "CRITICAL: DR signal active or high pricing. Suggest raising CRAC " "to 25°C and reducing fans to 40% to shed load." ) elif pue > 1.30: return ( "INEFFICIENCY: High PUE. Recommend optimizing chiller count " "and raising CRAC setpoint to improve efficiency." ) return "STATUS: Energy usage within bounds. Optimize for efficiency if safety allows." class WorkloadAgent: """Focuses on throughput and job scheduling.""" @staticmethod def get_recommendation(obs: ThermalGridRlAgentObservation) -> str: pending = obs.pending_batch_jobs off_peak = obs.off_peak_window if pending > 100 and off_peak == 1: return ( "OPPORTUNITY: Off-peak window and high backlog. Suggest running " "all pending batch jobs now." ) elif pending > 50: return "URGENT: Batch backlog growing. Suggest increasing throughput." return "STATUS: Workload manageable. Schedule batch jobs normally." class OversightAgent: """Deterministic safety monitor that can override the Coordinator.""" def __init__(self, thermal_limit_c: float = 80.0): self.thermal_limit_c = thermal_limit_c def enforce(self, obs: ThermalGridRlAgentObservation, action: ThermalGridRlAgentAction) -> ThermalGridRlAgentAction: """Deterministic safety overrides.""" max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0 # Priority 4: Metadata fix action.metadata["oversight_triggered"] = False action.metadata["oversight_reason"] = None if max_cpu > self.thermal_limit_c: # Emergency cooling action.crac_setpoint_c = 12.0 action.fan_speeds_pct = [100.0] * len(action.fan_speeds_pct) action.num_active_chillers = 4 action.metadata["oversight_triggered"] = True action.metadata["oversight_reason"] = f"CPU Temp CRITICAL ({max_cpu:.1f}°C). Emergency cooling forced." print(f"[OVERSIGHT] OVERRIDE: {action.metadata['oversight_reason']}") return action def simulate_negotiation(obs: ThermalGridRlAgentObservation, step: int) -> str: """ Simulate a structured one-round debate between the three specialized agents. Each agent gives an initial position, then replies to the others. Returns a formatted transcript string for the coordinator to resolve. """ cooling_pos = CoolingAgent.get_recommendation(obs) energy_pos = EnergyAgent.get_recommendation(obs) workload_pos = WorkloadAgent.get_recommendation(obs) max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0 price = obs.energy_price_per_kwh pue = obs.pue pending = obs.pending_batch_jobs dr = obs.demand_response_signal off_peak = obs.off_peak_window # --- Cooling Agent's reply to Energy Agent --- if dr == 1 or price > 0.15: cooling_reply = ( "[COOLING→ENERGY] I understand the DR signal, but thermal safety " "cannot be compromised. If we raise the setpoint above 22°C now, " "CPU temps will spike within 5 steps. Propose: CRAC at 18°C max." ) elif max_cpu > 65.0: cooling_reply = ( "[COOLING→ENERGY] Current CPU temps are dangerously high. " "Any further energy saving must wait. Safety override required." ) else: cooling_reply = ( "[COOLING→ENERGY] Thermal state is manageable. I can accept " "a moderate setpoint increase if PUE improvement is significant." ) # --- Energy Agent's reply to Cooling Agent --- if pue > 1.30: energy_reply = ( "[ENERGY→COOLING] PUE is above 1.30 — we're burning money. " "Raising CRAC by 2°C and reducing fans by 15% will save ~8% energy " "with minimal thermal impact at current load levels." ) elif dr == 1: energy_reply = ( "[ENERGY→COOLING] Grid is requesting load shed. We MUST reduce " "facility power by at least 10%. I support minimal cooling for now " "if you can keep CPUs below 70°C." ) else: energy_reply = ( "[ENERGY→COOLING] Energy costs are within acceptable bounds. " "No immediate conflict — support your cooling recommendation." ) # --- Workload Agent's reply --- if pending > 50 and off_peak == 1 and dr == 0: workload_reply = ( "[WORKLOAD→BOTH] Off-peak window is active and backlog is growing. " "If either of you can spare 10% headroom, I recommend running " "batch jobs now to clear the queue before peak pricing resumes." ) elif pending > 100: workload_reply = ( "[WORKLOAD→BOTH] Batch backlog is critical. Throughput must improve " "or SLAs will be breached. Request at least 2 chillers remain active." ) else: workload_reply = ( "[WORKLOAD→BOTH] Workload is stable. No conflict from my side — " "please optimize for energy efficiency and safety as you see fit." ) transcript = f""" --- ROUND 1: AGENT POSITIONS --- [COOLING AGENT] : {cooling_pos} [ENERGY AGENT] : {energy_pos} [WORKLOAD AGENT]: {workload_pos} --- ROUND 2: AGENT REPLIES --- {cooling_reply} {energy_reply} {workload_reply} """.strip() return transcript SYSTEM_PROMPT = """You are the Facility Coordinator for a datacenter. Analyze the state and agent recommendations, then output a JSON object with your final control actions. Required JSON format: ```json { "reasoning": "Step-by-step logic resolving conflicts", "crac_setpoint_c": 16.0, "fan_speeds_pct": [75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0], "num_active_chillers": 3 } ``` Constraints: - crac_setpoint_c: Float between 12.0 and 27.0 - fan_speeds_pct: List of exactly 10 floats between 20.0 and 100.0 - num_active_chillers: Integer between 1 and 4 Priority: 1. Safety, 2. Grid, 3. Throughput, 4. PUE. Output ONLY the JSON object. Do not add conversational text.""" def build_user_message(obs: ThermalGridRlAgentObservation, step: int, task: str) -> str: max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0 avg_cpu = sum(obs.mean_cpu_temps_c) / len(obs.mean_cpu_temps_c) if obs.mean_cpu_temps_c else 0.0 # Run multi-turn agent negotiation negotiation_transcript = simulate_negotiation(obs, step) return f"""STEP {step}/{MAX_STEPS} | Task: {task} {negotiation_transcript} --- CURRENT ENVIRONMENT STATE --- - Thermal : max_cpu={max_cpu:.1f}°C avg_cpu={avg_cpu:.1f}°C - Cooling : PUE={obs.pue:.3f} setpoint={obs.crac_supply_temp_c:.1f}°C fans={obs.avg_fan_speed_pct:.0f}% chillers={obs.num_active_chillers} - Grid : price=${obs.energy_price_per_kwh:.3f}/kWh DR={obs.demand_response_signal} ambient={obs.ambient_temp_c:.1f}°C - Batch : {obs.pending_batch_jobs} jobs pending off_peak={obs.off_peak_window} You have read both rounds of debate above. Resolve the conflict and respond with JSON only.""" def _extract_json(raw: str) -> dict: """Robustly parse JSON from LLM, handling markdown and truncation.""" if not raw: return {} # 1. Strip markdown code blocks cleaned = re.sub(r"```(?:json)?\s*(.*?)\s*```", r"\1", raw, flags=re.DOTALL) cleaned = cleaned.strip() # 2. Try direct parse try: return json.loads(cleaned) except json.JSONDecodeError: pass # 3. Attempt to fix truncated JSON (adding missing closing braces) # This handles the case where the LLM repeats itself and gets cut off for _ in range(5): cleaned += "\n}" try: return json.loads(cleaned) except: continue # 4. Fallback: search for first { and last } m = re.search(r'\{.*\}', raw, re.DOTALL) if m: try: return json.loads(m.group()) except: pass return {} def get_llm_action( client: OpenAI, obs: ThermalGridRlAgentObservation, step: int, task_id: ThermalGridTaskID, model: str ) -> tuple: user_prompt = build_user_message(obs, step, task_id.value) raw_attempts = [] parsed_data = None final_raw = "" for attempt in range(3): try: response = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], max_tokens=1024, temperature=0.3 if attempt == 0 else 0.7, ) raw = response.choices[0].message.content or "{}" raw_attempts.append(raw) data = _extract_json(raw) if data and "crac_setpoint_c" in data: parsed_data = data final_raw = raw crac = max(12.0, min(27.0, float(data.get("crac_setpoint_c", 18.0)))) fans = [max(20.0, min(100.0, float(f))) for f in data.get("fan_speeds_pct", [70.0] * 10)] chillers = max(1, min(4, int(data.get("num_active_chillers", 2)))) if len(fans) != 10: fans = [fans[0] if fans else 70.0] * 10 action = ThermalGridRlAgentAction( crac_setpoint_c=crac, fan_speeds_pct=fans, num_active_chillers=chillers, ) oversight = OversightAgent() final_action = oversight.enforce(obs, action) return final_action, user_prompt, final_raw, raw_attempts, parsed_data except Exception as e: print(f"[ERROR] LLM Attempt {attempt+1} failed: {e}") logger.warning(f"Step {step} Attempt {attempt+1} failed: {e}") raise ValueError(f"Unparseable LLM response after 3 attempts: {raw_attempts}") async def run_inference(task_id: ThermalGridTaskID, model: str, train_mode: bool = False) -> None: client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if train_mode: print(f"[MODE] TRAIN — collecting trajectories") else: print(f"[MODE] INFERENCE — early stopping enabled") log_start(task=task_id.value, env=BENCHMARK, model=model) env = ThermalGridRlAgentEnv(base_url=ENV_URL) grader = ThermalGridGrader(task_id=task_id.value) rewards = [] steps_taken = 0 success = False episode_buffer = [] try: # =============================== # RESET ENV # =============================== try: reset_result = await env.reset(task_id=task_id.value) except Exception as e: raise RuntimeError(f"env.reset() failed: {e}") observation = reset_result.observation # =============================== # MAIN LOOP # =============================== for step in range(1, MAX_STEPS + 1): try: # ✅ GET ACTION FROM LLM action, prompt, raw, raw_attempts, parsed_data = await asyncio.to_thread( get_llm_action, client, observation, step, task_id, model ) # Final action dict action_dict = { "crac_setpoint_c": action.crac_setpoint_c, "fan_speeds_pct": action.fan_speeds_pct, "num_active_chillers": action.num_active_chillers, } action_str = json.dumps(action_dict, separators=(",", ":")) # =============================== # ENV STEP # =============================== step_result = await env.step(action) reward = float(step_result.reward or 0.0) done = step_result.done rewards.append(reward) steps_taken = step # =============================== # SAVE DATASET ENTRY # =============================== episode_buffer.append({ "step": step, # Input "prompt": prompt, # Final action (your required format) "response": action_dict, # Reward "reward": reward, # Oversight info "oversight_triggered": action.metadata.get("oversight_triggered", False), # Extra (important for debugging + replay) "raw_response": raw, "raw_attempts": raw_attempts, "parsed_action": parsed_data, # Structured state (VERY IMPORTANT) "state": { "max_cpu": max(observation.max_cpu_temps_c) if observation.max_cpu_temps_c else 0.0, "avg_cpu": sum(observation.mean_cpu_temps_c)/len(observation.mean_cpu_temps_c) if observation.mean_cpu_temps_c else 0.0, "pue": observation.pue, "energy_price": observation.energy_price_per_kwh, "ambient": observation.ambient_temp_c, "pending_jobs": observation.pending_batch_jobs, "off_peak": observation.off_peak_window } }) # Logging log_step(step=step, action=action_str, reward=reward, done=done, error=None) # Move to next state observation = step_result.observation # =============================== # EARLY STOP # =============================== if not train_mode: recent = rewards[-EARLY_STOP_CONSEC:] if len(rewards) >= EARLY_STOP_CONSEC else [] if recent and all(r >= EARLY_STOP_REWARD for r in recent): print("[EARLY STOP] Stable high reward achieved") break if done: break except Exception as e: import traceback print("\n[ERROR] Exception in loop:") traceback.print_exc() log_step(step=step, action="{}", reward=0.0, done=False, error=str(e)) break # =============================== # FINAL SCORE # =============================== score = min(max(grader.get_thermal_grid_score(), 0.0), 1.0) success = score >= SUCCESS_SCORE_THRESHOLD # =============================== # SAVE DATASET # =============================== if train_mode: # 1. Save all prompts for RL exploration with open("all_prompts.jsonl", "a") as f: for step_data in episode_buffer: f.write(json.dumps({"prompt": step_data["prompt"]}) + "\n") # 2. Save full trajectories with open("expert_trajectories.jsonl", "a") as f: for step_data in episode_buffer: f.write(json.dumps({ "step": step_data["step"], "prompt": step_data["prompt"], "response": step_data["response"], "reward": step_data["reward"], "oversight_triggered": step_data["oversight_triggered"] }) + "\n") finally: try: await env.close() except: pass log_end(success=success, steps=steps_taken, score=score, rewards=rewards) async def main() -> None: import argparse parser = argparse.ArgumentParser(description="Thermal Grid RL Agent Inference") parser.add_argument("--model", type=str, default=DEFAULT_MODEL) parser.add_argument("--train", action="store_true", default=False) args = parser.parse_args() current_model = args.model train_mode = args.train # Clear old trajectory data only when starting a new training collection run if train_mode: for f_path in ["all_prompts.jsonl", "expert_trajectories.jsonl"]: if os.path.exists(f_path): os.remove(f_path) for task_id in TASKS: await run_inference(task_id, current_model, train_mode=train_mode) if __name__ == "__main__": asyncio.run(main())