ShreeshantXD commited on
Commit
2ede269
·
1 Parent(s): f316664

fix llm based reward output

Browse files
Files changed (1) hide show
  1. inference.py +41 -32
inference.py CHANGED
@@ -46,27 +46,16 @@ except ImportError:
46
  # ── Constants ──────────────────────────────────────────────────────────────
47
 
48
  ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
49
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
50
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
51
-
52
- # ── Environment Variable Handling ─────────────────────────────────────────
53
- # The LLM API credential is read from HF_TOKEN or OPENAI_API_KEY environment variables
54
- # and passed directly to the OpenAI client for initialization.
55
- # Primary: HF_TOKEN
56
- # Fallback: OPENAI_API_KEY (for local testing/development)
57
  HF_TOKEN = os.getenv("HF_TOKEN")
58
  OPENAI_API_KEY = HF_TOKEN
 
 
59
  DEFAULT_EPISODES = 1
60
  DEFAULT_SEED_BASE = 1000
61
  MAX_RETRIES = 3
62
- # 96 steps × 15 min = 24 h (must match env.EpisodeSteps)
63
  EPISODE_STEPS = 96
64
  LAST_STEP_INDEX = EPISODE_STEPS - 1
65
- SCORE_EPSILON = 1e-6
66
-
67
- REW_MIN = -8.0
68
- REW_MAX = 6.0
69
- REW_RANGE = REW_MAX - REW_MIN
70
 
71
  SYSPROMPT = """You are GridMind, an expert industrial energy management controller.
72
  You control a building's HVAC, thermal storage, batch job scheduling, and load shedding.
@@ -109,7 +98,7 @@ def extract_json_object(text: str) -> dict[str, Any] | None:
109
 
110
 
111
  def clamp_open_score(score: float) -> float:
112
- """Clamp score into strict open interval (0, 1)."""
113
  if score <= 0.0:
114
  return SCORE_EPSILON
115
  if score >= 1.0:
@@ -117,10 +106,16 @@ def clamp_open_score(score: float) -> float:
117
  return score
118
 
119
 
120
- def normalize_reward(raw_reward: float) -> float:
121
- """Normalize raw reward to (0, 1) using fixed theoretical range."""
122
- normalized = (raw_reward - REW_MIN) / REW_RANGE
123
- return clamp_open_score(normalized)
 
 
 
 
 
 
124
 
125
 
126
  # ── Environment client ───────────────────────────────────────────────────────
@@ -189,8 +184,6 @@ class LLMAgent:
189
  """OpenAI-compatible LLM agent that chooses actions given observations."""
190
 
191
  def __init__(self):
192
- # Initialize OpenAI client with credentials from HF_TOKEN (per hackathon spec)
193
- # The OPENAI_API_KEY variable contains the HF_TOKEN value passed by evaluators
194
  self.client = OpenAI(base_url=API_BASE_URL, api_key=OPENAI_API_KEY)
195
  self.model = MODEL_NAME
196
  self.fallback_mode = False
@@ -208,17 +201,23 @@ Current observation:
208
  - Thermal storage level: {obs.get('thermal_storage_level', 0.5):.2f} (0=empty, 1=full)
209
  - Process demand: {obs.get('process_demand', 15):.1f} kW
210
  - Current electricity price: ${obs.get('current_price', 0.10):.4f}/kWh
211
- - Grid stress signal: {obs.get('grid_stress_signal', 0):.3f} (>0.7 = critical, shed load!)
212
  - Carbon intensity: {obs.get('carbon_intensity', 300):.0f} gCO2/kWh
213
  - Hour of day: {obs.get('hour_of_day', 12)} (0=midnight, peak prices 8-12 and 17-21)
214
  - Pending batch job deadlines: {obs.get('batch_queue', [])}
215
  - Cumulative cost so far: ${obs.get('cumulative_cost', 0):.4f}
216
  - Episode step: {obs.get('step', 0)}/{LAST_STEP_INDEX}
217
 
 
 
 
 
 
218
  Strategy hints:
219
- - Charge thermal storage when price < $0.08/kWh, discharge when price > $0.15/kWh
 
 
220
  - Set HVAC low during peak prices (0.3-0.4) and use storage for temperature control
221
- - Shed 30-50% load if grid_stress_signal > 0.7
222
  - Schedule batch jobs early if deadline is close (slot 0 or 1)
223
 
224
  Respond with ONLY a JSON action:
@@ -347,6 +346,8 @@ def run_episode(
347
  cached_action = agent._default_action()
348
 
349
  step_rewards: list[float] = []
 
 
350
  success = False
351
  obs: dict[str, Any] = {}
352
 
@@ -383,18 +384,29 @@ def run_episode(
383
 
384
  obs = step_resp["observation"]
385
  raw_reward = float(step_resp["reward"])
386
- reward = normalize_reward(raw_reward)
387
  total_reward += raw_reward
388
  step_rewards.append(raw_reward)
 
 
 
 
 
 
389
  total_steps += 1
390
  done = bool(step_resp.get("done", False))
391
 
 
 
 
 
 
 
392
  action_json = json.dumps(action, separators=(',', ':'))
393
  last_action_error = step_resp.get("last_action_error")
394
  error_field = "null" if last_action_error is None else str(last_action_error)
395
  print(
396
  f"[STEP] step={total_steps} action={action_json} "
397
- f"reward={reward:.2f} done={'true' if done else 'false'} error={error_field}",
398
  flush=True
399
  )
400
 
@@ -429,10 +441,7 @@ def run_episode(
429
  elapsed = time.time() - start_time
430
  grade = env_client.grade()
431
 
432
- if step_rewards:
433
- normalized_rewards = [normalize_reward(r) for r in step_rewards]
434
- else:
435
- normalized_rewards = []
436
 
437
  episode_score = sum(normalized_rewards) / len(normalized_rewards) if normalized_rewards else SCORE_EPSILON
438
  episode_score = clamp_open_score(episode_score)
@@ -563,9 +572,9 @@ def main() -> None:
563
  parser.add_argument(
564
  "--llm-every",
565
  type=int,
566
- default=4,
567
  metavar="N",
568
- help="Reuse the same LLM action for N consecutive steps (default: 4).",
569
  )
570
  parser.add_argument(
571
  "--max-steps",
 
46
  # ── Constants ──────────────────────────────────────────────────────────────
47
 
48
  ENV_URL = os.getenv("ENV_URL", "http://localhost:7860")
 
 
 
 
 
 
 
 
49
  HF_TOKEN = os.getenv("HF_TOKEN")
50
  OPENAI_API_KEY = HF_TOKEN
51
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-7B-Instruct")
52
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
53
  DEFAULT_EPISODES = 1
54
  DEFAULT_SEED_BASE = 1000
55
  MAX_RETRIES = 3
 
56
  EPISODE_STEPS = 96
57
  LAST_STEP_INDEX = EPISODE_STEPS - 1
58
+ SCORE_EPSILON = 0.01
 
 
 
 
59
 
60
  SYSPROMPT = """You are GridMind, an expert industrial energy management controller.
61
  You control a building's HVAC, thermal storage, batch job scheduling, and load shedding.
 
98
 
99
 
100
  def clamp_open_score(score: float) -> float:
101
+ """Clamp score to strictly between 0 and 1 (never 0.0 or 1.0)."""
102
  if score <= 0.0:
103
  return SCORE_EPSILON
104
  if score >= 1.0:
 
106
  return score
107
 
108
 
109
+ def normalize_rewards(raw_rewards: list[float]) -> list[float]:
110
+ """Normalize raw rewards to (0, 1) using min-max scaling."""
111
+ if not raw_rewards:
112
+ return []
113
+ raw_min = min(raw_rewards)
114
+ raw_max = max(raw_rewards)
115
+ raw_range = raw_max - raw_min
116
+ if raw_range > 0:
117
+ return [clamp_open_score((r - raw_min) / raw_range) for r in raw_rewards]
118
+ return [0.5] * len(raw_rewards)
119
 
120
 
121
  # ── Environment client ───────────────────────────────────────────────────────
 
184
  """OpenAI-compatible LLM agent that chooses actions given observations."""
185
 
186
  def __init__(self):
 
 
187
  self.client = OpenAI(base_url=API_BASE_URL, api_key=OPENAI_API_KEY)
188
  self.model = MODEL_NAME
189
  self.fallback_mode = False
 
201
  - Thermal storage level: {obs.get('thermal_storage_level', 0.5):.2f} (0=empty, 1=full)
202
  - Process demand: {obs.get('process_demand', 15):.1f} kW
203
  - Current electricity price: ${obs.get('current_price', 0.10):.4f}/kWh
204
+ - Grid stress signal: {obs.get('grid_stress_signal', 0):.3f} (>0.7 = critical, MUST shed 0.2-0.5 load!)
205
  - Carbon intensity: {obs.get('carbon_intensity', 300):.0f} gCO2/kWh
206
  - Hour of day: {obs.get('hour_of_day', 12)} (0=midnight, peak prices 8-12 and 17-21)
207
  - Pending batch job deadlines: {obs.get('batch_queue', [])}
208
  - Cumulative cost so far: ${obs.get('cumulative_cost', 0):.4f}
209
  - Episode step: {obs.get('step', 0)}/{LAST_STEP_INDEX}
210
 
211
+ IMPORTANT RULES:
212
+ - thermal_charge_rate: use NEGATIVE (-0.5) to DISCHARGE storage, POSITIVE (+0.5) to CHARGE
213
+ - load_shed_fraction: MUST be 0.2-0.5 when grid_stress_signal > 0.7, otherwise 0.0
214
+ - shed load during grid stress to earn rewards, else keep at 0.0
215
+
216
  Strategy hints:
217
+ - Charge thermal storage (positive) when price < $0.08/kWh
218
+ - Discharge thermal storage (negative) when price > $0.15/kWh
219
+ - MUST shed load (0.2-0.5) when grid_stress_signal > 0.7
220
  - Set HVAC low during peak prices (0.3-0.4) and use storage for temperature control
 
221
  - Schedule batch jobs early if deadline is close (slot 0 or 1)
222
 
223
  Respond with ONLY a JSON action:
 
346
  cached_action = agent._default_action()
347
 
348
  step_rewards: list[float] = []
349
+ reward_min = float('inf')
350
+ reward_max = float('-inf')
351
  success = False
352
  obs: dict[str, Any] = {}
353
 
 
384
 
385
  obs = step_resp["observation"]
386
  raw_reward = float(step_resp["reward"])
 
387
  total_reward += raw_reward
388
  step_rewards.append(raw_reward)
389
+
390
+ if raw_reward < reward_min:
391
+ reward_min = raw_reward
392
+ if raw_reward > reward_max:
393
+ reward_max = raw_reward
394
+
395
  total_steps += 1
396
  done = bool(step_resp.get("done", False))
397
 
398
+ reward_range = reward_max - reward_min
399
+ if reward_range > 0:
400
+ normalized_reward = clamp_open_score((raw_reward - reward_min) / reward_range)
401
+ else:
402
+ normalized_reward = 0.5
403
+
404
  action_json = json.dumps(action, separators=(',', ':'))
405
  last_action_error = step_resp.get("last_action_error")
406
  error_field = "null" if last_action_error is None else str(last_action_error)
407
  print(
408
  f"[STEP] step={total_steps} action={action_json} "
409
+ f"reward={normalized_reward:.2f} done={'true' if done else 'false'} error={error_field}",
410
  flush=True
411
  )
412
 
 
441
  elapsed = time.time() - start_time
442
  grade = env_client.grade()
443
 
444
+ normalized_rewards = normalize_rewards(step_rewards)
 
 
 
445
 
446
  episode_score = sum(normalized_rewards) / len(normalized_rewards) if normalized_rewards else SCORE_EPSILON
447
  episode_score = clamp_open_score(episode_score)
 
572
  parser.add_argument(
573
  "--llm-every",
574
  type=int,
575
+ default=8,
576
  metavar="N",
577
+ help="Reuse the same LLM action for N consecutive steps (default: 8).",
578
  )
579
  parser.add_argument(
580
  "--max-steps",