varshu23's picture
Clean commit without images
5a22808
"""
Inference Script for Thermal Grid RL Agent Environment
========================================================
MANDATORY
- API_BASE_URL, MODEL_NAME, HF_TOKEN must be set in environment / .env
- Use OpenAI client for all LLM calls
- Emit [START], [STEP], [END] to stdout exactly as specified
Environment Variables:
HF_TOKEN - Hugging Face / API key (checked first)
API_KEY - Alternative API key (fallback)
API_BASE_URL - The API endpoint for the LLM
MODEL_NAME - The model identifier to use for inference
ENV_URL - URL of the thermal grid environment server
STDOUT FORMAT
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
"""
import asyncio
import os
import re
import json
import logging
from typing import List, Optional
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
from client import ThermalGridRlAgentEnv
from models import ThermalGridRlAgentAction, ThermalGridRlAgentObservation
from server.thermal_grid_rl_agent_environment import ThermalGridTaskID
from server.grader import ThermalGridGrader
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)
API_KEY = os.getenv("HF_TOKEN")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
DEFAULT_MODEL = os.getenv("MODEL_NAME") or "meta-llama/Llama-3.2-1B-Instruct"
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
print("\n--- LOADED ENVIRONMENT VARIABLES ---")
print(f"API_BASE_URL : {API_BASE_URL}")
print(f"MODEL_NAME : {DEFAULT_MODEL}")
print(f"API_KEY : {API_KEY[:4]}...{API_KEY[-4:] if len(API_KEY)>8 else ''}")
print(f"ENV_URL : {ENV_URL}")
print("------------------------------------\n")
BENCHMARK = "thermal_grid_rl_multi_agent"
MAX_STEPS = 30
SUCCESS_SCORE_THRESHOLD = 0.1
EARLY_STOP_REWARD = 0.9 # not 1.0 (unless rewards are definitely capped at 1.0)
EARLY_STOP_CONSEC = 5 # 5 steps for more stability
TASKS = [
ThermalGridTaskID.BASELINE,
ThermalGridTaskID.LOAD_SHIFT,
ThermalGridTaskID.GRID_STRESS,
]
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
print(
f"[STEP] step={step} action={action} reward={reward:.2f} "
f"done={str(done).lower()} error={error if error else 'null'}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={rewards_str}",
flush=True,
)
class CoolingAgent:
"""Focuses on thermal safety and equipment longevity."""
@staticmethod
def get_recommendation(obs: ThermalGridRlAgentObservation) -> str:
max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0
ambient = obs.ambient_temp_c
if max_cpu > 75.0 or ambient > 38.0:
return (
"CRITICAL: Thermal emergency. Suggest CRAC at 12°C, 100% fans, "
"and all 4 chillers. Priorities: Safety over cost."
)
elif max_cpu > 65.0:
return (
"WARNING: High temperatures. Recommend CRAC at 15°C and 85% fans. "
"Ensure at least 3 chillers are active."
)
elif obs.thermal_mass_lag_c_per_min > 0.3:
return (
"PREDICTIVE: Room heating up. Proactively lower CRAC setpoint "
"and increase fans by 10%."
)
return "STATUS: Thermal state stable. Maintain current cooling setpoints."
class EnergyAgent:
"""Focuses on PUE, electricity cost, and grid signals."""
@staticmethod
def get_recommendation(obs: ThermalGridRlAgentObservation) -> str:
price = obs.energy_price_per_kwh
dr_signal = obs.demand_response_signal
pue = obs.pue
if dr_signal == 1 or price > 0.15:
return (
"CRITICAL: DR signal active or high pricing. Suggest raising CRAC "
"to 25°C and reducing fans to 40% to shed load."
)
elif pue > 1.30:
return (
"INEFFICIENCY: High PUE. Recommend optimizing chiller count "
"and raising CRAC setpoint to improve efficiency."
)
return "STATUS: Energy usage within bounds. Optimize for efficiency if safety allows."
class WorkloadAgent:
"""Focuses on throughput and job scheduling."""
@staticmethod
def get_recommendation(obs: ThermalGridRlAgentObservation) -> str:
pending = obs.pending_batch_jobs
off_peak = obs.off_peak_window
if pending > 100 and off_peak == 1:
return (
"OPPORTUNITY: Off-peak window and high backlog. Suggest running "
"all pending batch jobs now."
)
elif pending > 50:
return "URGENT: Batch backlog growing. Suggest increasing throughput."
return "STATUS: Workload manageable. Schedule batch jobs normally."
class OversightAgent:
"""Deterministic safety monitor that can override the Coordinator."""
def __init__(self, thermal_limit_c: float = 80.0):
self.thermal_limit_c = thermal_limit_c
def enforce(self, obs: ThermalGridRlAgentObservation, action: ThermalGridRlAgentAction) -> ThermalGridRlAgentAction:
"""Deterministic safety overrides."""
max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0
# Priority 4: Metadata fix
action.metadata["oversight_triggered"] = False
action.metadata["oversight_reason"] = None
if max_cpu > self.thermal_limit_c:
# Emergency cooling
action.crac_setpoint_c = 12.0
action.fan_speeds_pct = [100.0] * len(action.fan_speeds_pct)
action.num_active_chillers = 4
action.metadata["oversight_triggered"] = True
action.metadata["oversight_reason"] = f"CPU Temp CRITICAL ({max_cpu:.1f}°C). Emergency cooling forced."
print(f"[OVERSIGHT] OVERRIDE: {action.metadata['oversight_reason']}")
return action
def simulate_negotiation(obs: ThermalGridRlAgentObservation, step: int) -> str:
"""
Simulate a structured one-round debate between the three specialized agents.
Each agent gives an initial position, then replies to the others.
Returns a formatted transcript string for the coordinator to resolve.
"""
cooling_pos = CoolingAgent.get_recommendation(obs)
energy_pos = EnergyAgent.get_recommendation(obs)
workload_pos = WorkloadAgent.get_recommendation(obs)
max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0
price = obs.energy_price_per_kwh
pue = obs.pue
pending = obs.pending_batch_jobs
dr = obs.demand_response_signal
off_peak = obs.off_peak_window
# --- Cooling Agent's reply to Energy Agent ---
if dr == 1 or price > 0.15:
cooling_reply = (
"[COOLING→ENERGY] I understand the DR signal, but thermal safety "
"cannot be compromised. If we raise the setpoint above 22°C now, "
"CPU temps will spike within 5 steps. Propose: CRAC at 18°C max."
)
elif max_cpu > 65.0:
cooling_reply = (
"[COOLING→ENERGY] Current CPU temps are dangerously high. "
"Any further energy saving must wait. Safety override required."
)
else:
cooling_reply = (
"[COOLING→ENERGY] Thermal state is manageable. I can accept "
"a moderate setpoint increase if PUE improvement is significant."
)
# --- Energy Agent's reply to Cooling Agent ---
if pue > 1.30:
energy_reply = (
"[ENERGY→COOLING] PUE is above 1.30 — we're burning money. "
"Raising CRAC by 2°C and reducing fans by 15% will save ~8% energy "
"with minimal thermal impact at current load levels."
)
elif dr == 1:
energy_reply = (
"[ENERGY→COOLING] Grid is requesting load shed. We MUST reduce "
"facility power by at least 10%. I support minimal cooling for now "
"if you can keep CPUs below 70°C."
)
else:
energy_reply = (
"[ENERGY→COOLING] Energy costs are within acceptable bounds. "
"No immediate conflict — support your cooling recommendation."
)
# --- Workload Agent's reply ---
if pending > 50 and off_peak == 1 and dr == 0:
workload_reply = (
"[WORKLOAD→BOTH] Off-peak window is active and backlog is growing. "
"If either of you can spare 10% headroom, I recommend running "
"batch jobs now to clear the queue before peak pricing resumes."
)
elif pending > 100:
workload_reply = (
"[WORKLOAD→BOTH] Batch backlog is critical. Throughput must improve "
"or SLAs will be breached. Request at least 2 chillers remain active."
)
else:
workload_reply = (
"[WORKLOAD→BOTH] Workload is stable. No conflict from my side — "
"please optimize for energy efficiency and safety as you see fit."
)
transcript = f"""
--- ROUND 1: AGENT POSITIONS ---
[COOLING AGENT] : {cooling_pos}
[ENERGY AGENT] : {energy_pos}
[WORKLOAD AGENT]: {workload_pos}
--- ROUND 2: AGENT REPLIES ---
{cooling_reply}
{energy_reply}
{workload_reply}
""".strip()
return transcript
SYSTEM_PROMPT = """You are the Facility Coordinator for a datacenter.
Analyze the state and agent recommendations, then output a JSON object with your final control actions.
Required JSON format:
```json
{
"reasoning": "Step-by-step logic resolving conflicts",
"crac_setpoint_c": 16.0,
"fan_speeds_pct": [75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0],
"num_active_chillers": 3
}
```
Constraints:
- crac_setpoint_c: Float between 12.0 and 27.0
- fan_speeds_pct: List of exactly 10 floats between 20.0 and 100.0
- num_active_chillers: Integer between 1 and 4
Priority: 1. Safety, 2. Grid, 3. Throughput, 4. PUE.
Output ONLY the JSON object. Do not add conversational text."""
def build_user_message(obs: ThermalGridRlAgentObservation, step: int, task: str) -> str:
max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0
avg_cpu = sum(obs.mean_cpu_temps_c) / len(obs.mean_cpu_temps_c) if obs.mean_cpu_temps_c else 0.0
# Run multi-turn agent negotiation
negotiation_transcript = simulate_negotiation(obs, step)
return f"""STEP {step}/{MAX_STEPS} | Task: {task}
{negotiation_transcript}
--- CURRENT ENVIRONMENT STATE ---
- Thermal : max_cpu={max_cpu:.1f}°C avg_cpu={avg_cpu:.1f}°C
- Cooling : PUE={obs.pue:.3f} setpoint={obs.crac_supply_temp_c:.1f}°C fans={obs.avg_fan_speed_pct:.0f}% chillers={obs.num_active_chillers}
- Grid : price=${obs.energy_price_per_kwh:.3f}/kWh DR={obs.demand_response_signal} ambient={obs.ambient_temp_c:.1f}°C
- Batch : {obs.pending_batch_jobs} jobs pending off_peak={obs.off_peak_window}
You have read both rounds of debate above. Resolve the conflict and respond with JSON only."""
def _extract_json(raw: str) -> dict:
"""Robustly parse JSON from LLM, handling markdown and truncation."""
if not raw: return {}
# 1. Strip markdown code blocks
cleaned = re.sub(r"```(?:json)?\s*(.*?)\s*```", r"\1", raw, flags=re.DOTALL)
cleaned = cleaned.strip()
# 2. Try direct parse
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
# 3. Attempt to fix truncated JSON (adding missing closing braces)
# This handles the case where the LLM repeats itself and gets cut off
for _ in range(5):
cleaned += "\n}"
try:
return json.loads(cleaned)
except:
continue
# 4. Fallback: search for first { and last }
m = re.search(r'\{.*\}', raw, re.DOTALL)
if m:
try:
return json.loads(m.group())
except:
pass
return {}
def get_llm_action(
client: OpenAI,
obs: ThermalGridRlAgentObservation,
step: int,
task_id: ThermalGridTaskID,
model: str
) -> tuple:
user_prompt = build_user_message(obs, step, task_id.value)
raw_attempts = []
parsed_data = None
final_raw = ""
for attempt in range(3):
try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
max_tokens=1024,
temperature=0.3 if attempt == 0 else 0.7,
)
raw = response.choices[0].message.content or "{}"
raw_attempts.append(raw)
data = _extract_json(raw)
if data and "crac_setpoint_c" in data:
parsed_data = data
final_raw = raw
crac = max(12.0, min(27.0, float(data.get("crac_setpoint_c", 18.0))))
fans = [max(20.0, min(100.0, float(f))) for f in data.get("fan_speeds_pct", [70.0] * 10)]
chillers = max(1, min(4, int(data.get("num_active_chillers", 2))))
if len(fans) != 10:
fans = [fans[0] if fans else 70.0] * 10
action = ThermalGridRlAgentAction(
crac_setpoint_c=crac,
fan_speeds_pct=fans,
num_active_chillers=chillers,
)
oversight = OversightAgent()
final_action = oversight.enforce(obs, action)
return final_action, user_prompt, final_raw, raw_attempts, parsed_data
except Exception as e:
print(f"[ERROR] LLM Attempt {attempt+1} failed: {e}")
logger.warning(f"Step {step} Attempt {attempt+1} failed: {e}")
raise ValueError(f"Unparseable LLM response after 3 attempts: {raw_attempts}")
async def run_inference(task_id: ThermalGridTaskID, model: str, train_mode: bool = False) -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
if train_mode:
print(f"[MODE] TRAIN — collecting trajectories")
else:
print(f"[MODE] INFERENCE — early stopping enabled")
log_start(task=task_id.value, env=BENCHMARK, model=model)
env = ThermalGridRlAgentEnv(base_url=ENV_URL)
grader = ThermalGridGrader(task_id=task_id.value)
rewards = []
steps_taken = 0
success = False
episode_buffer = []
try:
# ===============================
# RESET ENV
# ===============================
try:
reset_result = await env.reset(task_id=task_id.value)
except Exception as e:
raise RuntimeError(f"env.reset() failed: {e}")
observation = reset_result.observation
# ===============================
# MAIN LOOP
# ===============================
for step in range(1, MAX_STEPS + 1):
try:
# ✅ GET ACTION FROM LLM
action, prompt, raw, raw_attempts, parsed_data = await asyncio.to_thread(
get_llm_action, client, observation, step, task_id, model
)
# Final action dict
action_dict = {
"crac_setpoint_c": action.crac_setpoint_c,
"fan_speeds_pct": action.fan_speeds_pct,
"num_active_chillers": action.num_active_chillers,
}
action_str = json.dumps(action_dict, separators=(",", ":"))
# ===============================
# ENV STEP
# ===============================
step_result = await env.step(action)
reward = float(step_result.reward or 0.0)
done = step_result.done
rewards.append(reward)
steps_taken = step
# ===============================
# SAVE DATASET ENTRY
# ===============================
episode_buffer.append({
"step": step,
# Input
"prompt": prompt,
# Final action (your required format)
"response": action_dict,
# Reward
"reward": reward,
# Oversight info
"oversight_triggered": action.metadata.get("oversight_triggered", False),
# Extra (important for debugging + replay)
"raw_response": raw,
"raw_attempts": raw_attempts,
"parsed_action": parsed_data,
# Structured state (VERY IMPORTANT)
"state": {
"max_cpu": max(observation.max_cpu_temps_c) if observation.max_cpu_temps_c else 0.0,
"avg_cpu": sum(observation.mean_cpu_temps_c)/len(observation.mean_cpu_temps_c) if observation.mean_cpu_temps_c else 0.0,
"pue": observation.pue,
"energy_price": observation.energy_price_per_kwh,
"ambient": observation.ambient_temp_c,
"pending_jobs": observation.pending_batch_jobs,
"off_peak": observation.off_peak_window
}
})
# Logging
log_step(step=step, action=action_str, reward=reward, done=done, error=None)
# Move to next state
observation = step_result.observation
# ===============================
# EARLY STOP
# ===============================
if not train_mode:
recent = rewards[-EARLY_STOP_CONSEC:] if len(rewards) >= EARLY_STOP_CONSEC else []
if recent and all(r >= EARLY_STOP_REWARD for r in recent):
print("[EARLY STOP] Stable high reward achieved")
break
if done:
break
except Exception as e:
import traceback
print("\n[ERROR] Exception in loop:")
traceback.print_exc()
log_step(step=step, action="{}", reward=0.0, done=False, error=str(e))
break
# ===============================
# FINAL SCORE
# ===============================
score = min(max(grader.get_thermal_grid_score(), 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
# ===============================
# SAVE DATASET
# ===============================
if train_mode:
# 1. Save all prompts for RL exploration
with open("all_prompts.jsonl", "a") as f:
for step_data in episode_buffer:
f.write(json.dumps({"prompt": step_data["prompt"]}) + "\n")
# 2. Save full trajectories
with open("expert_trajectories.jsonl", "a") as f:
for step_data in episode_buffer:
f.write(json.dumps({
"step": step_data["step"],
"prompt": step_data["prompt"],
"response": step_data["response"],
"reward": step_data["reward"],
"oversight_triggered": step_data["oversight_triggered"]
}) + "\n")
finally:
try:
await env.close()
except:
pass
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
async def main() -> None:
import argparse
parser = argparse.ArgumentParser(description="Thermal Grid RL Agent Inference")
parser.add_argument("--model", type=str, default=DEFAULT_MODEL)
parser.add_argument("--train", action="store_true", default=False)
args = parser.parse_args()
current_model = args.model
train_mode = args.train
# Clear old trajectory data only when starting a new training collection run
if train_mode:
for f_path in ["all_prompts.jsonl", "expert_trajectories.jsonl"]:
if os.path.exists(f_path):
os.remove(f_path)
for task_id in TASKS:
await run_inference(task_id, current_model, train_mode=train_mode)
if __name__ == "__main__":
asyncio.run(main())