File size: 21,090 Bytes
5a22808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
"""
Inference Script for Thermal Grid RL Agent Environment
========================================================
MANDATORY
- API_BASE_URL, MODEL_NAME, HF_TOKEN must be set in environment / .env
- Use OpenAI client for all LLM calls
- Emit [START], [STEP], [END] to stdout exactly as specified

Environment Variables:
    HF_TOKEN       - Hugging Face / API key (checked first)
    API_KEY        - Alternative API key (fallback)
    API_BASE_URL   - The API endpoint for the LLM
    MODEL_NAME     - The model identifier to use for inference
    ENV_URL        - URL of the thermal grid environment server

STDOUT FORMAT
    [START] task=<task_name> env=<benchmark> model=<model_name>
    [STEP]  step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
    [END]   success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
"""

import asyncio
import os
import re
import json
import logging
from typing import List, Optional

from openai import OpenAI
from dotenv import load_dotenv

load_dotenv()

from client import ThermalGridRlAgentEnv
from models import ThermalGridRlAgentAction, ThermalGridRlAgentObservation
from server.thermal_grid_rl_agent_environment import ThermalGridTaskID
from server.grader import ThermalGridGrader

logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)


API_KEY      = os.getenv("HF_TOKEN") 
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
DEFAULT_MODEL = os.getenv("MODEL_NAME") or "meta-llama/Llama-3.2-1B-Instruct"
ENV_URL      = os.getenv("ENV_URL", "http://localhost:8000")

print("\n--- LOADED ENVIRONMENT VARIABLES ---")
print(f"API_BASE_URL : {API_BASE_URL}")
print(f"MODEL_NAME   : {DEFAULT_MODEL}")
print(f"API_KEY      : {API_KEY[:4]}...{API_KEY[-4:] if len(API_KEY)>8 else ''}")
print(f"ENV_URL      : {ENV_URL}")
print("------------------------------------\n")

BENCHMARK               = "thermal_grid_rl_multi_agent"
MAX_STEPS               = 30
SUCCESS_SCORE_THRESHOLD = 0.1
EARLY_STOP_REWARD = 0.9   # not 1.0 (unless rewards are definitely capped at 1.0)
EARLY_STOP_CONSEC = 5      # 5 steps for more stability

TASKS = [
    ThermalGridTaskID.BASELINE,
    ThermalGridTaskID.LOAD_SHIFT,
    ThermalGridTaskID.GRID_STRESS,
]


def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
    print(
        f"[STEP] step={step} action={action} reward={reward:.2f} "
        f"done={str(done).lower()} error={error if error else 'null'}",
        flush=True,
    )


def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(
        f"[END] success={str(success).lower()} steps={steps} "
        f"score={score:.3f} rewards={rewards_str}",
        flush=True,
    )


class CoolingAgent:
    """Focuses on thermal safety and equipment longevity."""
    @staticmethod
    def get_recommendation(obs: ThermalGridRlAgentObservation) -> str:
        max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0
        ambient = obs.ambient_temp_c
        
        if max_cpu > 75.0 or ambient > 38.0:
            return (
                "CRITICAL: Thermal emergency. Suggest CRAC at 12°C, 100% fans, "
                "and all 4 chillers. Priorities: Safety over cost."
            )
        elif max_cpu > 65.0:
            return (
                "WARNING: High temperatures. Recommend CRAC at 15°C and 85% fans. "
                "Ensure at least 3 chillers are active."
            )
        elif obs.thermal_mass_lag_c_per_min > 0.3:
            return (
                "PREDICTIVE: Room heating up. Proactively lower CRAC setpoint "
                "and increase fans by 10%."
            )
        return "STATUS: Thermal state stable. Maintain current cooling setpoints."

class EnergyAgent:
    """Focuses on PUE, electricity cost, and grid signals."""
    @staticmethod
    def get_recommendation(obs: ThermalGridRlAgentObservation) -> str:
        price = obs.energy_price_per_kwh
        dr_signal = obs.demand_response_signal
        pue = obs.pue
        
        if dr_signal == 1 or price > 0.15:
            return (
                "CRITICAL: DR signal active or high pricing. Suggest raising CRAC "
                "to 25°C and reducing fans to 40% to shed load."
            )
        elif pue > 1.30:
            return (
                "INEFFICIENCY: High PUE. Recommend optimizing chiller count "
                "and raising CRAC setpoint to improve efficiency."
            )
        return "STATUS: Energy usage within bounds. Optimize for efficiency if safety allows."

class WorkloadAgent:
    """Focuses on throughput and job scheduling."""
    @staticmethod
    def get_recommendation(obs: ThermalGridRlAgentObservation) -> str:
        pending = obs.pending_batch_jobs
        off_peak = obs.off_peak_window
        
        if pending > 100 and off_peak == 1:
            return (
                "OPPORTUNITY: Off-peak window and high backlog. Suggest running "
                "all pending batch jobs now."
            )
        elif pending > 50:
            return "URGENT: Batch backlog growing. Suggest increasing throughput."
        return "STATUS: Workload manageable. Schedule batch jobs normally."

class OversightAgent:
    """Deterministic safety monitor that can override the Coordinator."""
    def __init__(self, thermal_limit_c: float = 80.0):
        self.thermal_limit_c = thermal_limit_c

    def enforce(self, obs: ThermalGridRlAgentObservation, action: ThermalGridRlAgentAction) -> ThermalGridRlAgentAction:
        """Deterministic safety overrides."""
        max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0
        
        # Priority 4: Metadata fix
        action.metadata["oversight_triggered"] = False
        action.metadata["oversight_reason"] = None

        if max_cpu > self.thermal_limit_c:
            # Emergency cooling
            action.crac_setpoint_c = 12.0 
            action.fan_speeds_pct = [100.0] * len(action.fan_speeds_pct)
            action.num_active_chillers = 4
            
            action.metadata["oversight_triggered"] = True
            action.metadata["oversight_reason"] = f"CPU Temp CRITICAL ({max_cpu:.1f}°C). Emergency cooling forced."
            print(f"[OVERSIGHT] OVERRIDE: {action.metadata['oversight_reason']}")
            
        return action

def simulate_negotiation(obs: ThermalGridRlAgentObservation, step: int) -> str:
    """
    Simulate a structured one-round debate between the three specialized agents.
    Each agent gives an initial position, then replies to the others.
    Returns a formatted transcript string for the coordinator to resolve.
    """
    cooling_pos  = CoolingAgent.get_recommendation(obs)
    energy_pos   = EnergyAgent.get_recommendation(obs)
    workload_pos = WorkloadAgent.get_recommendation(obs)

    max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0
    price   = obs.energy_price_per_kwh
    pue     = obs.pue
    pending = obs.pending_batch_jobs
    dr      = obs.demand_response_signal
    off_peak = obs.off_peak_window

    # --- Cooling Agent's reply to Energy Agent ---
    if dr == 1 or price > 0.15:
        cooling_reply = (
            "[COOLING→ENERGY] I understand the DR signal, but thermal safety "
            "cannot be compromised. If we raise the setpoint above 22°C now, "
            "CPU temps will spike within 5 steps. Propose: CRAC at 18°C max."
        )
    elif max_cpu > 65.0:
        cooling_reply = (
            "[COOLING→ENERGY] Current CPU temps are dangerously high. "
            "Any further energy saving must wait. Safety override required."
        )
    else:
        cooling_reply = (
            "[COOLING→ENERGY] Thermal state is manageable. I can accept "
            "a moderate setpoint increase if PUE improvement is significant."
        )

    # --- Energy Agent's reply to Cooling Agent ---
    if pue > 1.30:
        energy_reply = (
            "[ENERGY→COOLING] PUE is above 1.30 — we're burning money. "
            "Raising CRAC by 2°C and reducing fans by 15% will save ~8% energy "
            "with minimal thermal impact at current load levels."
        )
    elif dr == 1:
        energy_reply = (
            "[ENERGY→COOLING] Grid is requesting load shed. We MUST reduce "
            "facility power by at least 10%. I support minimal cooling for now "
            "if you can keep CPUs below 70°C."
        )
    else:
        energy_reply = (
            "[ENERGY→COOLING] Energy costs are within acceptable bounds. "
            "No immediate conflict — support your cooling recommendation."
        )

    # --- Workload Agent's reply ---
    if pending > 50 and off_peak == 1 and dr == 0:
        workload_reply = (
            "[WORKLOAD→BOTH] Off-peak window is active and backlog is growing. "
            "If either of you can spare 10% headroom, I recommend running "
            "batch jobs now to clear the queue before peak pricing resumes."
        )
    elif pending > 100:
        workload_reply = (
            "[WORKLOAD→BOTH] Batch backlog is critical. Throughput must improve "
            "or SLAs will be breached. Request at least 2 chillers remain active."
        )
    else:
        workload_reply = (
            "[WORKLOAD→BOTH] Workload is stable. No conflict from my side — "
            "please optimize for energy efficiency and safety as you see fit."
        )

    transcript = f"""
--- ROUND 1: AGENT POSITIONS ---
[COOLING AGENT] : {cooling_pos}
[ENERGY AGENT]  : {energy_pos}
[WORKLOAD AGENT]: {workload_pos}

--- ROUND 2: AGENT REPLIES ---
{cooling_reply}
{energy_reply}
{workload_reply}
""".strip()

    return transcript


SYSTEM_PROMPT = """You are the Facility Coordinator for a datacenter.
Analyze the state and agent recommendations, then output a JSON object with your final control actions.

Required JSON format:
```json
{
  "reasoning": "Step-by-step logic resolving conflicts",
  "crac_setpoint_c": 16.0, 
  "fan_speeds_pct": [75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0, 75.0], 
  "num_active_chillers": 3
}
```

Constraints:
- crac_setpoint_c: Float between 12.0 and 27.0
- fan_speeds_pct: List of exactly 10 floats between 20.0 and 100.0
- num_active_chillers: Integer between 1 and 4

Priority: 1. Safety, 2. Grid, 3. Throughput, 4. PUE.
Output ONLY the JSON object. Do not add conversational text."""


def build_user_message(obs: ThermalGridRlAgentObservation, step: int, task: str) -> str:
    max_cpu = max(obs.max_cpu_temps_c) if obs.max_cpu_temps_c else 0.0
    avg_cpu = sum(obs.mean_cpu_temps_c) / len(obs.mean_cpu_temps_c) if obs.mean_cpu_temps_c else 0.0

    # Run multi-turn agent negotiation
    negotiation_transcript = simulate_negotiation(obs, step)

    return f"""STEP {step}/{MAX_STEPS} | Task: {task}

{negotiation_transcript}

--- CURRENT ENVIRONMENT STATE ---
- Thermal : max_cpu={max_cpu:.1f}°C  avg_cpu={avg_cpu:.1f}°C
- Cooling : PUE={obs.pue:.3f}  setpoint={obs.crac_supply_temp_c:.1f}°C  fans={obs.avg_fan_speed_pct:.0f}%  chillers={obs.num_active_chillers}
- Grid    : price=${obs.energy_price_per_kwh:.3f}/kWh  DR={obs.demand_response_signal}  ambient={obs.ambient_temp_c:.1f}°C
- Batch   : {obs.pending_batch_jobs} jobs pending  off_peak={obs.off_peak_window}

You have read both rounds of debate above. Resolve the conflict and respond with JSON only."""


def _extract_json(raw: str) -> dict:
    """Robustly parse JSON from LLM, handling markdown and truncation."""
    if not raw: return {}
    
    # 1. Strip markdown code blocks
    cleaned = re.sub(r"```(?:json)?\s*(.*?)\s*```", r"\1", raw, flags=re.DOTALL)
    cleaned = cleaned.strip()
    
    # 2. Try direct parse
    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        pass
    
    # 3. Attempt to fix truncated JSON (adding missing closing braces)
    # This handles the case where the LLM repeats itself and gets cut off
    for _ in range(5): 
        cleaned += "\n}"
        try:
            return json.loads(cleaned)
        except:
            continue
            
    # 4. Fallback: search for first { and last }
    m = re.search(r'\{.*\}', raw, re.DOTALL)
    if m:
        try:
            return json.loads(m.group())
        except:
            pass
            
    return {}


def get_llm_action(
    client: OpenAI,
    obs: ThermalGridRlAgentObservation,
    step: int,
    task_id: ThermalGridTaskID,
    model: str
) -> tuple:

    user_prompt = build_user_message(obs, step, task_id.value)

    raw_attempts = []
    parsed_data = None
    final_raw = ""

    for attempt in range(3):
        try:
            response = client.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user",   "content": user_prompt},
                ],
                max_tokens=1024,
                temperature=0.3 if attempt == 0 else 0.7,
            )

            raw = response.choices[0].message.content or "{}"
            raw_attempts.append(raw)

            data = _extract_json(raw)

            if data and "crac_setpoint_c" in data:
                parsed_data = data
                final_raw = raw

                crac     = max(12.0, min(27.0, float(data.get("crac_setpoint_c", 18.0))))
                fans     = [max(20.0, min(100.0, float(f))) for f in data.get("fan_speeds_pct", [70.0] * 10)]
                chillers = max(1, min(4, int(data.get("num_active_chillers", 2))))

                if len(fans) != 10:
                    fans = [fans[0] if fans else 70.0] * 10

                action = ThermalGridRlAgentAction(
                    crac_setpoint_c=crac,
                    fan_speeds_pct=fans,
                    num_active_chillers=chillers,
                )

                oversight = OversightAgent()
                final_action = oversight.enforce(obs, action)

                return final_action, user_prompt, final_raw, raw_attempts, parsed_data

        except Exception as e:
            print(f"[ERROR] LLM Attempt {attempt+1} failed: {e}")
            logger.warning(f"Step {step} Attempt {attempt+1} failed: {e}")

    raise ValueError(f"Unparseable LLM response after 3 attempts: {raw_attempts}")

async def run_inference(task_id: ThermalGridTaskID, model: str, train_mode: bool = False) -> None:
    client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)

    if train_mode:
        print(f"[MODE] TRAIN — collecting trajectories")
    else:
        print(f"[MODE] INFERENCE — early stopping enabled")

    log_start(task=task_id.value, env=BENCHMARK, model=model)

    env    = ThermalGridRlAgentEnv(base_url=ENV_URL)
    grader = ThermalGridGrader(task_id=task_id.value)

    rewards = []
    steps_taken = 0
    success = False

    episode_buffer = []

    try:
        # ===============================
        # RESET ENV
        # ===============================
        try:
            reset_result = await env.reset(task_id=task_id.value)
        except Exception as e:
            raise RuntimeError(f"env.reset() failed: {e}")

        observation = reset_result.observation

        # ===============================
        # MAIN LOOP
        # ===============================
        for step in range(1, MAX_STEPS + 1):
            try:
                # ✅ GET ACTION FROM LLM
                action, prompt, raw, raw_attempts, parsed_data = await asyncio.to_thread(
                    get_llm_action, client, observation, step, task_id, model
                )

                # Final action dict
                action_dict = {
                    "crac_setpoint_c": action.crac_setpoint_c,
                    "fan_speeds_pct": action.fan_speeds_pct,
                    "num_active_chillers": action.num_active_chillers,
                }

                action_str = json.dumps(action_dict, separators=(",", ":"))


                # ===============================
                # ENV STEP
                # ===============================
                step_result = await env.step(action)

                reward = float(step_result.reward or 0.0)
                done   = step_result.done

                rewards.append(reward)
                steps_taken = step

                # ===============================
                # SAVE DATASET ENTRY
                # ===============================
                episode_buffer.append({
                    "step": step,

                    # Input
                    "prompt": prompt,

                    # Final action (your required format)
                    "response": action_dict,

                    # Reward
                    "reward": reward,

                    # Oversight info
                    "oversight_triggered": action.metadata.get("oversight_triggered", False),

                    # Extra (important for debugging + replay)
                    "raw_response": raw,
                    "raw_attempts": raw_attempts,
                    "parsed_action": parsed_data,

                    # Structured state (VERY IMPORTANT)
                    "state": {
                        "max_cpu": max(observation.max_cpu_temps_c) if observation.max_cpu_temps_c else 0.0,
                        "avg_cpu": sum(observation.mean_cpu_temps_c)/len(observation.mean_cpu_temps_c) if observation.mean_cpu_temps_c else 0.0,
                        "pue": observation.pue,
                        "energy_price": observation.energy_price_per_kwh,
                        "ambient": observation.ambient_temp_c,
                        "pending_jobs": observation.pending_batch_jobs,
                        "off_peak": observation.off_peak_window
                    }
                })

                # Logging
                log_step(step=step, action=action_str, reward=reward, done=done, error=None)

                # Move to next state
                observation = step_result.observation

                # ===============================
                # EARLY STOP
                # ===============================
                if not train_mode:
                    recent = rewards[-EARLY_STOP_CONSEC:] if len(rewards) >= EARLY_STOP_CONSEC else []
                    if recent and all(r >= EARLY_STOP_REWARD for r in recent):
                        print("[EARLY STOP] Stable high reward achieved")
                        break

                if done:
                    break

            except Exception as e:
                import traceback
                print("\n[ERROR] Exception in loop:")
                traceback.print_exc()
                log_step(step=step, action="{}", reward=0.0, done=False, error=str(e))
                break

        # ===============================
        # FINAL SCORE
        # ===============================
        score = min(max(grader.get_thermal_grid_score(), 0.0), 1.0)
        success = score >= SUCCESS_SCORE_THRESHOLD

        # ===============================
        # SAVE DATASET
        # ===============================
        if train_mode:
            # 1. Save all prompts for RL exploration
            with open("all_prompts.jsonl", "a") as f:
                for step_data in episode_buffer:
                    f.write(json.dumps({"prompt": step_data["prompt"]}) + "\n")

            # 2. Save full trajectories
            with open("expert_trajectories.jsonl", "a") as f:
                for step_data in episode_buffer:
                    f.write(json.dumps({
                        "step": step_data["step"],
                        "prompt": step_data["prompt"],
                        "response": step_data["response"],
                        "reward": step_data["reward"],
                        "oversight_triggered": step_data["oversight_triggered"]
                    }) + "\n")

    finally:
        try:
            await env.close()
        except:
            pass

        log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
async def main() -> None:
    import argparse

    parser = argparse.ArgumentParser(description="Thermal Grid RL Agent Inference")

    parser.add_argument("--model", type=str, default=DEFAULT_MODEL)
    parser.add_argument("--train", action="store_true", default=False)

    args = parser.parse_args()

    current_model = args.model
    train_mode = args.train

    # Clear old trajectory data only when starting a new training collection run
    if train_mode:
        for f_path in ["all_prompts.jsonl", "expert_trajectories.jsonl"]:
            if os.path.exists(f_path):
                os.remove(f_path)

    for task_id in TASKS:
        await run_inference(task_id, current_model, train_mode=train_mode)


if __name__ == "__main__":
    asyncio.run(main())