smart_grid_env / inference.py
Puneet Gopinath
fix: small changes for temperature
a34efea unverified
Raw
History Blame Contribute Delete
5.67 kB
import asyncio
import os
import textwrap
from typing import List, Optional
from openai import OpenAI
from client import SmartGridEnv
from models import SmartGridAction
IMAGE_NAME = os.getenv("IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
TASK_NAME = os.getenv("TASK", "balanced_grid_easy")
BENCHMARK = os.getenv("BENCHMARK", "smart_grid")
MAX_STEPS = 24
TEMPERATURE = 0.1
MAX_TOKENS = 50
SUCCESS_SCORE_THRESHOLD = 0.6 # normalized score in [0, 1]
_MAX_REWARD_PER_STEP = 0.1
MAX_TOTAL_REWARD = MAX_STEPS * _MAX_REWARD_PER_STEP
SYSTEM_PROMPT = textwrap.dedent(
"""
You are controlling a smart grid environment.
You MUST output EXACTLY FOUR numbers separated by commas.
supply_r1, supply_r2, supply_r3, charge_battery
Even if charge_battery is zero, you must include it in the output. Do not omit any value.
STRICT RULES:
- Output ONLY numbers
- NO text, NO explanation
- EXACTLY 4 values
- Example: 10.5,20.0,15.0,-5.0
At each step:
1. You receive demand for three regions
2. You receive how much solar and wind power has been generated
3. You have a battery that can be charged or discharged (up to its capacity)
Your goal:
- Minimize unmet demand
- Avoid wasting energy
- Use battery with brains
Respond only with 4 numbers separated by commas, with no extra text:
supply_r1, supply_r2, supply_r3, charge_battery
If you do not follow format, the system will FAIL.
Rules:
- supply values are supposed to be non-negative and represents how much energy you allocate to a region
- charge_battery can be positive (to charge) or negative (to discharge) the battery within limits
- Do not output anything else
Example:
Demand: 20,30,25
Output: 20,30,25,0
"""
).strip()
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
def build_user_prompt(obs) -> str:
return textwrap.dedent(
f"""
Hour: {obs.hour}
Demand:
- R1={obs.demand_r1}
- R2={obs.demand_r2}
- R3={obs.demand_r3}
Generation:
- Solar={obs.solar_generation}
- Wind={obs.wind_generation}
Battery:
- Level={obs.battery_level}
- Capacity={obs.battery_capacity}
What action do you take? How much do you supply and what should be the battery action?
"""
).strip()
def get_action(client: OpenAI, obs) -> SmartGridAction:
user_prompt = build_user_prompt(obs)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
text = (completion.choices[0].message.content or "").strip()
parts = text.replace("\n", "").split(",")
# I see the model doesn't write the 4th value when it's zero, so default to 0 if not mentioned.
if len(parts) < 3:
raise ValueError(f"Expected 4 comma-separated values, got {len(parts)}\nResponse: {text}")
if len(parts) == 3:
parts.append("0.0")
r1, r2, r3, bt = map(float, parts)
return SmartGridAction(
supply_r1=r1,
supply_r2=r2,
supply_r3=r3,
charge_battery=bt
)
except Exception as exc:
print(f"[DEBUG] Model request failed: {exc}", flush=True)
return SmartGridAction(supply_r1=0.0, supply_r2=0.0, supply_r3=0.0, charge_battery=0.0)
async def main() -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
env = await SmartGridEnv.from_docker_image(IMAGE_NAME)
rewards = []
steps_taken = 0
score = 0.0
success = False
log_start(TASK_NAME, BENCHMARK, MODEL_NAME)
try:
result = await env.reset() # OpenENV.reset()
obs = result.observation
for step in range(1, MAX_STEPS + 1):
if result.done:
break
action = get_action(client, obs)
result = await env.step(action)
obs = result.observation
reward = result.reward or 0.0
done = result.done
rewards.append(reward)
steps_taken = step
log_step(step, str(action), reward, done, None)
if done:
break
score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
score = min(max(score, 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True)
log_end(success, steps_taken, score, rewards)
if __name__ == "__main__":
asyncio.run(main())