import asyncio import os import textwrap from typing import List, Optional from openai import OpenAI from client import SmartGridEnv from models import SmartGridAction IMAGE_NAME = os.getenv("IMAGE_NAME") API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct" TASK_NAME = os.getenv("TASK", "balanced_grid_easy") BENCHMARK = os.getenv("BENCHMARK", "smart_grid") MAX_STEPS = 24 TEMPERATURE = 0.1 MAX_TOKENS = 50 SUCCESS_SCORE_THRESHOLD = 0.6 # normalized score in [0, 1] _MAX_REWARD_PER_STEP = 0.1 MAX_TOTAL_REWARD = MAX_STEPS * _MAX_REWARD_PER_STEP SYSTEM_PROMPT = textwrap.dedent( """ You are controlling a smart grid environment. You MUST output EXACTLY FOUR numbers separated by commas. supply_r1, supply_r2, supply_r3, charge_battery Even if charge_battery is zero, you must include it in the output. Do not omit any value. STRICT RULES: - Output ONLY numbers - NO text, NO explanation - EXACTLY 4 values - Example: 10.5,20.0,15.0,-5.0 At each step: 1. You receive demand for three regions 2. You receive how much solar and wind power has been generated 3. You have a battery that can be charged or discharged (up to its capacity) Your goal: - Minimize unmet demand - Avoid wasting energy - Use battery with brains Respond only with 4 numbers separated by commas, with no extra text: supply_r1, supply_r2, supply_r3, charge_battery If you do not follow format, the system will FAIL. Rules: - supply values are supposed to be non-negative and represents how much energy you allocate to a region - charge_battery can be positive (to charge) or negative (to discharge) the battery within limits - Do not output anything else Example: Demand: 20,30,25 Output: 20,30,25,0 """ ).strip() def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: error_val = error if error else "null" done_val = str(done).lower() print( f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) def build_user_prompt(obs) -> str: return textwrap.dedent( f""" Hour: {obs.hour} Demand: - R1={obs.demand_r1} - R2={obs.demand_r2} - R3={obs.demand_r3} Generation: - Solar={obs.solar_generation} - Wind={obs.wind_generation} Battery: - Level={obs.battery_level} - Capacity={obs.battery_capacity} What action do you take? How much do you supply and what should be the battery action? """ ).strip() def get_action(client: OpenAI, obs) -> SmartGridAction: user_prompt = build_user_prompt(obs) try: completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=TEMPERATURE, max_tokens=MAX_TOKENS, stream=False, ) text = (completion.choices[0].message.content or "").strip() parts = text.replace("\n", "").split(",") # I see the model doesn't write the 4th value when it's zero, so default to 0 if not mentioned. if len(parts) < 3: raise ValueError(f"Expected 4 comma-separated values, got {len(parts)}\nResponse: {text}") if len(parts) == 3: parts.append("0.0") r1, r2, r3, bt = map(float, parts) return SmartGridAction( supply_r1=r1, supply_r2=r2, supply_r3=r3, charge_battery=bt ) except Exception as exc: print(f"[DEBUG] Model request failed: {exc}", flush=True) return SmartGridAction(supply_r1=0.0, supply_r2=0.0, supply_r3=0.0, charge_battery=0.0) async def main() -> None: client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) env = await SmartGridEnv.from_docker_image(IMAGE_NAME) rewards = [] steps_taken = 0 score = 0.0 success = False log_start(TASK_NAME, BENCHMARK, MODEL_NAME) try: result = await env.reset() # OpenENV.reset() obs = result.observation for step in range(1, MAX_STEPS + 1): if result.done: break action = get_action(client, obs) result = await env.step(action) obs = result.observation reward = result.reward or 0.0 done = result.done rewards.append(reward) steps_taken = step log_step(step, str(action), reward, done, None) if done: break score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0 score = min(max(score, 0.0), 1.0) success = score >= SUCCESS_SCORE_THRESHOLD finally: try: await env.close() except Exception as e: print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True) log_end(success, steps_taken, score, rewards) if __name__ == "__main__": asyncio.run(main())