rishabh16196 commited on
Commit
3e960de
·
verified ·
1 Parent(s): bbdf6a5

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +227 -67
inference.py CHANGED
@@ -38,7 +38,13 @@ from typing import Any, Dict, List, Optional
38
  from openai import OpenAI
39
 
40
  from traffic_light_env import TrafficLightAction, TrafficLightEnv
41
- from traffic_light_env.models import DIRECTION_NAMES, NUM_PHASES, TASK_NAMES
 
 
 
 
 
 
42
 
43
  IMAGE_NAME = os.getenv("IMAGE_NAME")
44
  API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
@@ -47,51 +53,45 @@ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
47
  MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
48
  BENCHMARK = "traffic_light_env"
49
  MAX_STEPS = 200
50
- TEMPERATURE = 0.3
51
- MAX_TOKENS = 64
 
 
 
 
 
 
52
 
53
  # Tasks to run. Override with TRAFFIC_LIGHT_TASKS env var (comma-separated).
54
  TASKS = os.getenv("TRAFFIC_LIGHT_TASKS", ",".join(TASK_NAMES)).split(",")
55
 
56
  SYSTEM_PROMPT = textwrap.dedent(
57
  """
58
- You are controlling a traffic light at a 4-way intersection with 4 directions
59
- (NS=north-to-south, SN=south-to-north, EW=east-to-west, WE=west-to-east),
60
- each with 2 lanes (8 lanes total).
61
-
62
- Your goal: minimize total vehicle waiting time by choosing the optimal phase.
63
-
64
- Available phases (pick one number 0-5):
65
- 0 = NS+SN corridor (both north-south directions green)
66
- 1 = EW+WE corridor (both east-west directions green)
67
- 2 = NS only green
68
- 3 = SN only green
69
- 4 = EW only green
70
- 5 = WE only green
71
-
72
- Phase switching incurs a 2-step yellow transition (no departures) and a -2.0
73
- reward penalty. Avoid unnecessary switching.
74
-
75
- SAFETY: Each lane has a mix of vehicle types (car, suv, bus, truck, motorcycle)
76
- with different stopping distances based on real physics. When you switch phases,
77
- vehicles in the 100m zone that can't stop in time are in the "dilemma zone":
78
- - Trucks: 37m stopping distance (37% of 100m zone at risk)
79
- - SUVs: 33m (33% at risk)
80
- - Buses: 30m (30% at risk)
81
- - Cars: 28m (28% at risk)
82
- - Motorcycles: 28m (28% at risk)
83
- Each dilemma-zone vehicle incurs a -1.5 reward penalty. Avoid switching when
84
- many heavy vehicles (trucks, buses) are in the green lanes' 100m zones.
85
-
86
- Strategy tips:
87
- - Corridor phases (0, 1) green 4 lanes at once — high throughput.
88
- - Single-direction phases (2-5) useful when one direction is much busier.
89
- - Consider 500m vehicles: they migrate to 100m soon.
90
- - For emergency vehicles, prioritize the direction containing the emergency.
91
- - Avoid switching when trucks/buses are in the 100m zone (high dilemma risk).
92
- - Minimize total switches — each costs yellow time + dilemma risk + penalty.
93
-
94
- Respond with ONLY a single digit: 0, 1, 2, 3, 4, or 5
95
  """
96
  ).strip()
97
 
@@ -121,6 +121,127 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
121
  )
122
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  # ---------------------------------------------------------------------------
125
  # Observation → LLM prompt
126
  # ---------------------------------------------------------------------------
@@ -142,15 +263,10 @@ def obs_to_summary(obs: Any) -> str:
142
  f"Total waiting: {obs.total_waiting}",
143
  f"Throughput so far: {obs.total_throughput}",
144
  ]
145
- # Show heavy vehicle counts in 100m zone (dilemma risk factors)
146
- v100 = obs.vehicles_100m
147
- heavy = {d: 0 for d in range(4)}
148
- dir_labels = ["NS", "SN", "EW", "WE"]
149
- for vt in ("truck", "bus", "suv"):
150
- for d in range(4):
151
- heavy[d] += v100.get(vt, [0, 0, 0, 0])[d]
152
- heavy_str = " ".join(f"{dir_labels[d]}:{heavy[d]}" for d in range(4))
153
- lines.append(f"Heavy vehicles (truck+bus+suv) at 100m — {heavy_str}")
154
  lines.append(f"Cumulative dilemma-zone vehicles: {obs.total_dilemma_vehicles:.1f}")
155
 
156
  if obs.emergency_direction >= 0:
@@ -163,6 +279,11 @@ def obs_to_summary(obs: Any) -> str:
163
  f"EMERGENCY vehicle in {dir_name} direction (use {phases_help}), "
164
  f"waiting {obs.emergency_wait} steps"
165
  )
 
 
 
 
 
166
  return "\n".join(lines)
167
 
168
 
@@ -178,7 +299,7 @@ def get_phase_from_llm(
178
  """Ask the LLM which phase to choose. Falls back to heuristic on failure."""
179
  user_prompt = obs_to_summary(obs)
180
  if history:
181
- user_prompt += "\n\nRecent history:\n" + "\n".join(history[-5:])
182
  user_prompt += "\n\nChoose phase (0-5):"
183
 
184
  try:
@@ -199,24 +320,41 @@ def get_phase_from_llm(
199
  except Exception as exc:
200
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
201
 
202
- return heuristic_phase(obs)
203
 
204
 
205
- def heuristic_phase(obs: Any) -> int:
206
- """Heuristic baseline: corridor for the busier axis, or target emergency."""
207
- # Emergency: green the direction containing the emergency vehicle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  if obs.emergency_direction >= 0:
209
- d = obs.emergency_direction
210
- # Use corridor if possible (0 for NS/SN, 1 for EW/WE)
211
- if d <= 1:
212
- return 0 # NS+SN corridor
213
- else:
214
- return 1 # EW+WE corridor
215
 
216
- # Compare NS+SN axis vs EW+WE axis (100m weighted 1.0, 500m weighted 0.3)
217
- ns_sn = (obs.ns_100m + obs.sn_100m) + 0.3 * (obs.ns_500m + obs.sn_500m)
218
- ew_we = (obs.ew_100m + obs.we_100m) + 0.3 * (obs.ew_500m + obs.we_500m)
219
- return 0 if ns_sn >= ew_we else 1
220
 
221
 
222
  # ---------------------------------------------------------------------------
@@ -236,14 +374,26 @@ async def run_task(client: OpenAI, env: TrafficLightEnv, task: str) -> Dict[str,
236
  try:
237
  result = await env.reset(task=task)
238
  obs = result.observation
 
 
239
 
240
  for step in range(1, MAX_STEPS + 1):
241
  if result.done:
242
  break
243
 
244
- phase = get_phase_from_llm(client, obs, history)
245
- action = TrafficLightAction(phase=phase)
 
 
 
 
 
 
 
 
 
246
 
 
247
  result = await env.step(action)
248
  obs = result.observation
249
 
@@ -263,7 +413,8 @@ async def run_task(client: OpenAI, env: TrafficLightEnv, task: str) -> Dict[str,
263
  )
264
 
265
  history.append(
266
- f"Step {step}: phase={phase}, waiting={obs.total_waiting}, reward={reward:+.2f}"
 
267
  )
268
 
269
  if done:
@@ -315,6 +466,15 @@ async def main() -> None:
315
  f" [{status}] {r['task']:22s} score={r['score']:.4f} steps={r['steps']}",
316
  flush=True,
317
  )
 
 
 
 
 
 
 
 
 
318
  avg_score = (
319
  sum(r["score"] for r in all_results) / len(all_results)
320
  if all_results else 0.0
 
38
  from openai import OpenAI
39
 
40
  from traffic_light_env import TrafficLightAction, TrafficLightEnv
41
+ from traffic_light_env.models import (
42
+ DILEMMA_FRACTIONS,
43
+ DIRECTION_NAMES,
44
+ NUM_PHASES,
45
+ TASK_NAMES,
46
+ VEHICLE_TYPE_NAMES,
47
+ )
48
 
49
  IMAGE_NAME = os.getenv("IMAGE_NAME")
50
  API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("HF_TOKEN") or os.getenv("API_KEY")
 
53
  MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
54
  BENCHMARK = "traffic_light_env"
55
  MAX_STEPS = 200
56
+ TEMPERATURE = 0.2
57
+ MAX_TOKENS = 128
58
+
59
+ # Strategy parameters
60
+ MIN_HOLD_TIME = 8 # Minimum steps to hold a phase before considering switch
61
+ SWITCH_THRESHOLD = 1.8 # Opposing axis must be this many times busier to switch
62
+ LLM_CONSULT_INTERVAL = 10 # Ask LLM every N steps for strategic guidance
63
+ EMERGENCY_OVERRIDE = True # Immediately switch for emergency vehicles
64
 
65
  # Tasks to run. Override with TRAFFIC_LIGHT_TASKS env var (comma-separated).
66
  TASKS = os.getenv("TRAFFIC_LIGHT_TASKS", ",".join(TASK_NAMES)).split(",")
67
 
68
  SYSTEM_PROMPT = textwrap.dedent(
69
  """
70
+ You are a traffic light controller at a 4-way intersection. 4 directions
71
+ (NS, SN, EW, WE) with 2 lanes each (8 total). You pick one of 6 phases:
72
+
73
+ 0 = NS+SN corridor (4 lanes green — best throughput for N-S axis)
74
+ 1 = EW+WE corridor (4 lanes green best throughput for E-W axis)
75
+ 2 = NS only 3 = SN only 4 = EW only 5 = WE only
76
+
77
+ CRITICAL RULES switching phases costs 2 dead steps (yellow) + dilemma-zone
78
+ risk (vehicles that can't stop safely). Every unnecessary switch HURTS your score.
79
+
80
+ DECISION FRAMEWORK:
81
+ 1. If currently in yellow transition → keep the pending phase (no choice).
82
+ 2. If emergency vehicle present → switch to its corridor ONCE, then hold.
83
+ 3. If held current phase < 8 steps → KEEP current phase (too early to switch).
84
+ 4. Only switch if opposing axis queue is >1.8× current axis queue.
85
+ 5. Prefer corridor phases (0 or 1) for maximum throughput.
86
+ 6. Use single-direction phases (2-5) ONLY if one direction has >3× its opposite.
87
+
88
+ Scoring: 40% waiting (lower=better), 40% throughput (higher=better), 20% safety
89
+ (fewer dilemma vehicles=better). The fixed-timer baseline scores 0.81 by switching
90
+ every 10 steps. You should switch LESS often than that on balanced traffic.
91
+
92
+ Respond: one line with the phase digit (0-5), then a brief reason.
93
+ Format: <digit> <reason>
94
+ Example: 0 NS+SN corridor has more vehicles, hold current phase
 
 
 
 
 
 
 
 
 
 
 
 
95
  """
96
  ).strip()
97
 
 
121
  )
122
 
123
 
124
+ # ---------------------------------------------------------------------------
125
+ # Dilemma risk estimation
126
+ # ---------------------------------------------------------------------------
127
+
128
+ def estimate_dilemma_risk(obs: Any, green_dirs: List[int]) -> float:
129
+ """Estimate how many vehicles would be in the dilemma zone if we switch now."""
130
+ v100 = obs.vehicles_100m
131
+ dir_labels = ["NS", "SN", "EW", "WE"]
132
+ risk = 0.0
133
+ for d in green_dirs:
134
+ for vt in VEHICLE_TYPE_NAMES:
135
+ count = v100.get(vt, [0, 0, 0, 0])[d]
136
+ if count > 0:
137
+ risk += count * DILEMMA_FRACTIONS[vt]
138
+ return risk
139
+
140
+
141
+ def get_green_dirs(phase: int) -> List[int]:
142
+ """Return which directions are green for a given phase."""
143
+ mapping = {0: [0, 1], 1: [2, 3], 2: [0], 3: [1], 4: [2], 5: [3]}
144
+ return mapping.get(phase, [])
145
+
146
+
147
+ # ---------------------------------------------------------------------------
148
+ # Smart heuristic (primary decision maker)
149
+ # ---------------------------------------------------------------------------
150
+
151
+ def smart_heuristic(obs: Any, current_phase: int, time_in_phase: int) -> int:
152
+ """
153
+ Heuristic that minimizes switching while maintaining good throughput.
154
+ Key insight: the fixed-timer baseline (switch every 10 steps) scores 0.81.
155
+ We can beat it by being smarter about WHEN to switch.
156
+ """
157
+ # During yellow, we can't do anything — return current pending or active
158
+ if obs.yellow_remaining > 0:
159
+ return obs.active_phase if obs.active_phase >= 0 else current_phase
160
+
161
+ # Emergency override: immediately switch to emergency corridor
162
+ if obs.emergency_direction >= 0:
163
+ d = obs.emergency_direction
164
+ target = 0 if d <= 1 else 1
165
+ if current_phase != target:
166
+ return target
167
+ return current_phase
168
+
169
+ # Compute axis loads (100m weighted heavily, 500m as future pressure)
170
+ ns_sn_100 = obs.ns_100m + obs.sn_100m
171
+ ew_we_100 = obs.ew_100m + obs.we_100m
172
+ ns_sn_500 = obs.ns_500m + obs.sn_500m
173
+ ew_we_500 = obs.ew_500m + obs.we_500m
174
+
175
+ ns_sn_load = ns_sn_100 + 0.3 * ns_sn_500
176
+ ew_we_load = ew_we_100 + 0.3 * ew_we_500
177
+
178
+ # Determine which corridor the current phase serves
179
+ current_green_dirs = get_green_dirs(current_phase)
180
+ serves_ns = any(d in [0, 1] for d in current_green_dirs)
181
+ serves_ew = any(d in [2, 3] for d in current_green_dirs)
182
+
183
+ current_load = 0.0
184
+ opposing_load = 0.0
185
+ if serves_ns and not serves_ew:
186
+ current_load = ns_sn_load
187
+ opposing_load = ew_we_load
188
+ elif serves_ew and not serves_ns:
189
+ current_load = ew_we_load
190
+ opposing_load = ns_sn_load
191
+ else:
192
+ # Phase serves both or neither — use corridor phases
193
+ current_load = ns_sn_load
194
+ opposing_load = ew_we_load
195
+
196
+ # Don't switch if we haven't held long enough
197
+ if time_in_phase < MIN_HOLD_TIME:
198
+ return current_phase
199
+
200
+ # Check if opposing axis is significantly busier
201
+ if opposing_load > 0 and current_load > 0:
202
+ ratio = opposing_load / max(current_load, 1.0)
203
+ elif opposing_load > 0:
204
+ ratio = 10.0 # Current axis is empty
205
+ else:
206
+ ratio = 0.0 # Opposing axis is empty
207
+
208
+ # Also factor in dilemma risk — if many heavy vehicles in green lanes, don't switch
209
+ dilemma_risk = estimate_dilemma_risk(obs, current_green_dirs)
210
+
211
+ # Adaptive threshold: require higher ratio if dilemma risk is high
212
+ effective_threshold = SWITCH_THRESHOLD + (dilemma_risk * 0.1)
213
+
214
+ if ratio >= effective_threshold:
215
+ # Switch to the opposing corridor
216
+ if serves_ns or (not serves_ew and ns_sn_load < ew_we_load):
217
+ # Check if one EW direction dominates — use single phase
218
+ if obs.ew_100m > 3 * obs.we_100m and obs.ew_100m > 10:
219
+ return 4 # EW only
220
+ elif obs.we_100m > 3 * obs.ew_100m and obs.we_100m > 10:
221
+ return 5 # WE only
222
+ return 1 # EW+WE corridor
223
+ else:
224
+ if obs.ns_100m > 3 * obs.sn_100m and obs.ns_100m > 10:
225
+ return 2 # NS only
226
+ elif obs.sn_100m > 3 * obs.ns_100m and obs.sn_100m > 10:
227
+ return 3 # SN only
228
+ return 0 # NS+SN corridor
229
+
230
+ # Check for very unbalanced single-direction loads within current axis
231
+ if serves_ns and time_in_phase >= MIN_HOLD_TIME + 4:
232
+ if obs.ns_100m > 3 * obs.sn_100m and obs.ns_100m > 15 and current_phase == 0:
233
+ return 2 # Focus on NS only
234
+ elif obs.sn_100m > 3 * obs.ns_100m and obs.sn_100m > 15 and current_phase == 0:
235
+ return 3 # Focus on SN only
236
+ elif serves_ew and time_in_phase >= MIN_HOLD_TIME + 4:
237
+ if obs.ew_100m > 3 * obs.we_100m and obs.ew_100m > 15 and current_phase == 1:
238
+ return 4
239
+ elif obs.we_100m > 3 * obs.ew_100m and obs.we_100m > 15 and current_phase == 1:
240
+ return 5
241
+
242
+ return current_phase
243
+
244
+
245
  # ---------------------------------------------------------------------------
246
  # Observation → LLM prompt
247
  # ---------------------------------------------------------------------------
 
263
  f"Total waiting: {obs.total_waiting}",
264
  f"Throughput so far: {obs.total_throughput}",
265
  ]
266
+ # Dilemma risk info
267
+ green_dirs = get_green_dirs(obs.active_phase)
268
+ dilemma = estimate_dilemma_risk(obs, green_dirs)
269
+ lines.append(f"Dilemma risk if switching now: {dilemma:.1f} vehicles")
 
 
 
 
 
270
  lines.append(f"Cumulative dilemma-zone vehicles: {obs.total_dilemma_vehicles:.1f}")
271
 
272
  if obs.emergency_direction >= 0:
 
279
  f"EMERGENCY vehicle in {dir_name} direction (use {phases_help}), "
280
  f"waiting {obs.emergency_wait} steps"
281
  )
282
+
283
+ # Heuristic recommendation
284
+ heuristic_rec = smart_heuristic(obs, obs.active_phase, obs.time_in_phase)
285
+ lines.append(f"\nHeuristic recommends: phase {heuristic_rec} ({phase_desc.get(heuristic_rec, '?')})")
286
+
287
  return "\n".join(lines)
288
 
289
 
 
299
  """Ask the LLM which phase to choose. Falls back to heuristic on failure."""
300
  user_prompt = obs_to_summary(obs)
301
  if history:
302
+ user_prompt += "\n\nRecent actions:\n" + "\n".join(history[-5:])
303
  user_prompt += "\n\nChoose phase (0-5):"
304
 
305
  try:
 
320
  except Exception as exc:
321
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
322
 
323
+ return smart_heuristic(obs, obs.active_phase, obs.time_in_phase)
324
 
325
 
326
+ # ---------------------------------------------------------------------------
327
+ # Hybrid decision: heuristic + periodic LLM consultation
328
+ # ---------------------------------------------------------------------------
329
+
330
+ def decide_phase(
331
+ client: OpenAI,
332
+ obs: Any,
333
+ history: List[str],
334
+ step: int,
335
+ current_phase: int,
336
+ time_in_phase: int,
337
+ ) -> int:
338
+ """
339
+ Hybrid approach:
340
+ - Use heuristic for most steps (fast, no API cost, avoids over-switching)
341
+ - Consult LLM every LLM_CONSULT_INTERVAL steps for strategic decisions
342
+ - Always use heuristic for emergency overrides
343
+ """
344
+ # During yellow, just hold
345
+ if obs.yellow_remaining > 0:
346
+ return current_phase
347
+
348
+ # Emergency: always use heuristic (fast, deterministic)
349
  if obs.emergency_direction >= 0:
350
+ return smart_heuristic(obs, current_phase, time_in_phase)
351
+
352
+ # Consult LLM at strategic intervals when we might need to switch
353
+ if (step % LLM_CONSULT_INTERVAL == 0) and time_in_phase >= MIN_HOLD_TIME:
354
+ return get_phase_from_llm(client, obs, history)
 
355
 
356
+ # Default: use heuristic
357
+ return smart_heuristic(obs, current_phase, time_in_phase)
 
 
358
 
359
 
360
  # ---------------------------------------------------------------------------
 
374
  try:
375
  result = await env.reset(task=task)
376
  obs = result.observation
377
+ current_phase = 0 # Start at NS+SN corridor
378
+ time_in_phase = 0
379
 
380
  for step in range(1, MAX_STEPS + 1):
381
  if result.done:
382
  break
383
 
384
+ phase = decide_phase(
385
+ client, obs, history, step,
386
+ current_phase, time_in_phase,
387
+ )
388
+
389
+ # Track phase timing locally
390
+ if phase != current_phase:
391
+ time_in_phase = 0
392
+ current_phase = phase
393
+ else:
394
+ time_in_phase += 1
395
 
396
+ action = TrafficLightAction(phase=phase)
397
  result = await env.step(action)
398
  obs = result.observation
399
 
 
413
  )
414
 
415
  history.append(
416
+ f"Step {step}: phase={phase}, waiting={obs.total_waiting}, "
417
+ f"throughput={obs.total_throughput}, reward={reward:+.2f}"
418
  )
419
 
420
  if done:
 
466
  f" [{status}] {r['task']:22s} score={r['score']:.4f} steps={r['steps']}",
467
  flush=True,
468
  )
469
+ if r.get("grade_details"):
470
+ d = r["grade_details"]
471
+ print(
472
+ f" waiting={d.get('waiting_score', 0):.3f} "
473
+ f"throughput={d.get('throughput_score', 0):.3f} "
474
+ f"safety={d.get('safety_score', 0):.3f} "
475
+ f"dilemma={d.get('total_dilemma_vehicles', 0):.1f}",
476
+ flush=True,
477
+ )
478
  avg_score = (
479
  sum(r["score"] for r in all_results) / len(all_results)
480
  if all_results else 0.0