| from __future__ import annotations |
|
|
| import argparse |
| from collections import Counter |
| from dataclasses import asdict |
| from datetime import datetime, timezone |
| import difflib |
| import json |
| from pathlib import Path |
| from statistics import mean, median |
| import sys |
| from typing import Any |
|
|
| import numpy as np |
|
|
| REPO_ROOT = Path(__file__).resolve().parents[1] |
| if str(REPO_ROOT) not in sys.path: |
| sys.path.insert(0, str(REPO_ROOT)) |
|
|
| from district_llm.eval import load_rows |
| from district_llm.inference import DistrictLLMInference |
| from district_llm.prompting import ( |
| build_chat_messages, |
| build_system_prompt, |
| build_user_prompt, |
| format_district_prompt, |
| format_district_prompt_from_user_content, |
| render_chat_prompt, |
| ) |
| from district_llm.repair import ( |
| RepairConfig, |
| parse_candidate_intersections_from_text, |
| ) |
| from district_llm.rl_guidance_wrapper import FixedRLPolicyAdapter |
| from district_llm.schema import ( |
| DISTRICT_STRATEGIES, |
| PHASE_BIASES, |
| PRIORITY_CORRIDORS, |
| CandidateIntersection, |
| CongestedIntersection, |
| DistrictAction, |
| DistrictStateSummary, |
| ) |
| from district_llm.summary_builder import DistrictStateSummaryBuilder |
| from env.observation_builder import ObservationConfig |
| from env.reward import RewardConfig |
| from env.traffic_env import EnvConfig |
| from training.cityflow_dataset import CityFlowDataset, ScenarioSpec |
| from training.train_local_policy import build_env |
|
|
|
|
| REQUIRED_ACTION_KEYS = { |
| "strategy", |
| "priority_corridor", |
| "target_intersections", |
| "phase_bias", |
| "duration_steps", |
| } |
| SUMMARY_SCALAR_ORDER = [ |
| "city_id", |
| "district_id", |
| "district_type", |
| "scenario", |
| "scenario_type", |
| "decision_step", |
| "sim_time", |
| "intersection_count", |
| "avg_queue", |
| "max_queue", |
| "total_queue", |
| "avg_wait", |
| "max_wait", |
| "total_wait", |
| "avg_outgoing_load", |
| "max_outgoing_load", |
| "total_outgoing_load", |
| "recent_throughput", |
| "queue_change", |
| "wait_change", |
| "throughput_change", |
| "ns_queue", |
| "ew_queue", |
| "ns_wait", |
| "ew_wait", |
| "dominant_flow", |
| "boundary_queue_total", |
| "boundary_wait_total", |
| "spillback_risk", |
| "incident_flag", |
| "construction_flag", |
| "overload_flag", |
| "event_flag", |
| ] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Diagnose why district-LLM runtime guidance fails even when offline validation looks strong." |
| ) |
| ) |
| parser.add_argument("--model-path", required=True) |
| parser.add_argument("--rl-checkpoint", required=True) |
| parser.add_argument("--val-jsonl", default="data/district_llm_dataset_v3/val.jsonl") |
| parser.add_argument("--generated-root", default="data/generated") |
| parser.add_argument("--splits-root", default="data/splits") |
| parser.add_argument("--split", default="val", choices=("train", "val", "test")) |
| parser.add_argument("--cities", nargs="+", default=None) |
| parser.add_argument("--scenarios", nargs="+", default=None) |
| parser.add_argument("--max-diagnostic-calls", type=int, default=20) |
| parser.add_argument("--max-offline-examples", type=int, default=20) |
| parser.add_argument("--max-episode-seconds", type=int, default=300) |
| parser.add_argument("--max-new-tokens", type=int, default=128) |
| parser.add_argument("--device", default=None) |
| parser.add_argument("--output-dir", default="artifacts/llm_runtime_diagnosis") |
| parser.add_argument( |
| "--allow-only-visible-candidates", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| parser.add_argument("--max-target-intersections", type=int, default=3) |
| parser.add_argument( |
| "--fallback-on-empty-targets", |
| action=argparse.BooleanOptionalAction, |
| default=True, |
| ) |
| parser.add_argument( |
| "--fallback-mode", |
| choices=("heuristic", "hold", "none"), |
| default="heuristic", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| repair_config = RepairConfig( |
| allow_only_visible_candidates=args.allow_only_visible_candidates, |
| max_target_intersections=args.max_target_intersections, |
| fallback_on_empty_targets=args.fallback_on_empty_targets, |
| fallback_mode=args.fallback_mode, |
| ) |
| inference = DistrictLLMInference( |
| model_name_or_path=args.model_path, |
| device=args.device, |
| repair_config=repair_config, |
| ) |
|
|
| runtime_rows = collect_runtime_rows(args=args, inference=inference) |
| offline_rows = collect_offline_rows(args=args, inference=inference) |
| all_rows = runtime_rows + offline_rows |
|
|
| failure_examples = [ |
| flatten_failure_example(row, prompt_style_key) |
| for row in all_rows |
| for prompt_style_key in ("runtime_flat", "training_chat") |
| if row[prompt_style_key]["wrapper_would_fallback"] |
| ] |
| summary_report = build_summary_report( |
| args=args, |
| inference=inference, |
| runtime_rows=runtime_rows, |
| offline_rows=offline_rows, |
| ) |
| prompt_comparison = render_prompt_comparison( |
| runtime_rows=runtime_rows, |
| offline_rows=offline_rows, |
| summary_report=summary_report, |
| ) |
|
|
| write_jsonl(output_dir / "diagnostic_rows.jsonl", all_rows) |
| write_json(output_dir / "summary_report.json", summary_report) |
| write_text(output_dir / "prompt_comparison.md", prompt_comparison) |
| write_jsonl(output_dir / "validator_failure_examples.jsonl", failure_examples) |
|
|
| print(json.dumps(summary_report, indent=2, sort_keys=True)) |
|
|
|
|
| def collect_runtime_rows( |
| args: argparse.Namespace, |
| inference: DistrictLLMInference, |
| ) -> list[dict[str, Any]]: |
| dataset = CityFlowDataset( |
| generated_root=args.generated_root, |
| splits_root=args.splits_root, |
| ) |
| dataset.generate_default_splits() |
| scenario_specs = resolve_scenario_specs(dataset=dataset, args=args) |
|
|
| policy = FixedRLPolicyAdapter( |
| checkpoint_path=args.rl_checkpoint, |
| device=args.device, |
| ) |
| env_config = policy.env_config or default_env_config() |
| env_config = EnvConfig( |
| simulator_interval=env_config.simulator_interval, |
| decision_interval=env_config.decision_interval, |
| min_green_time=env_config.min_green_time, |
| thread_num=env_config.thread_num, |
| max_episode_seconds=int(args.max_episode_seconds), |
| observation=env_config.observation, |
| reward=env_config.reward, |
| ) |
|
|
| summary_builder = DistrictStateSummaryBuilder( |
| top_k=3, |
| candidate_limit=max(6, args.max_target_intersections), |
| ) |
| rows: list[dict[str, Any]] = [] |
|
|
| for scenario_spec in scenario_specs: |
| if len(rows) >= args.max_diagnostic_calls: |
| break |
| env = build_env(env_config, scenario_spec) |
| observation_batch = env.reset() |
| summary_builder.reset() |
| done = False |
| while not done and len(rows) < args.max_diagnostic_calls: |
| summaries = summary_builder.build_all(env, observation_batch) |
| for district_id in sorted(summaries): |
| if len(rows) >= args.max_diagnostic_calls: |
| break |
| summary = summaries[district_id] |
| rows.append( |
| diagnose_summary_call( |
| inference=inference, |
| summary=summary, |
| source="runtime_live", |
| city_id=scenario_spec.city_id, |
| scenario=scenario_spec.scenario_name, |
| district_id=district_id, |
| decision_step=int(summary.decision_step), |
| wrapper_mode="diagnose_llm_runtime", |
| max_new_tokens=args.max_new_tokens, |
| ) |
| ) |
| if len(rows) >= args.max_diagnostic_calls: |
| break |
| actions = policy.decide(observation_batch).actions |
| observation_batch, _, done, _ = env.step(actions) |
| return rows |
|
|
|
|
| def collect_offline_rows( |
| args: argparse.Namespace, |
| inference: DistrictLLMInference, |
| ) -> list[dict[str, Any]]: |
| raw_rows = load_rows(args.val_jsonl, max_examples=args.max_offline_examples) |
| rows: list[dict[str, Any]] = [] |
| for index, row in enumerate(raw_rows): |
| training_messages = row["messages"][:2] |
| summary_text = row["messages"][1]["content"] |
| summary, summary_parse = parse_summary_text(summary_text) |
| rows.append( |
| diagnose_summary_call( |
| inference=inference, |
| summary=summary, |
| source="offline_validation_runtime_codepath", |
| city_id=str(row.get("city_id", summary.city_id)), |
| scenario=str(row.get("scenario", summary.scenario_name)), |
| district_id=str(row.get("district_id", summary.district_id)), |
| decision_step=int(summary.decision_step), |
| wrapper_mode="diagnose_llm_runtime", |
| max_new_tokens=args.max_new_tokens, |
| training_messages=training_messages, |
| ground_truth_payload=json.loads(row["messages"][2]["content"]), |
| original_user_prompt=summary_text, |
| summary_parse=summary_parse, |
| example_index=index, |
| ) |
| ) |
| return rows |
|
|
|
|
| def diagnose_summary_call( |
| inference: DistrictLLMInference, |
| summary: DistrictStateSummary, |
| source: str, |
| city_id: str, |
| scenario: str, |
| district_id: str, |
| decision_step: int, |
| wrapper_mode: str, |
| max_new_tokens: int, |
| training_messages: list[dict[str, str]] | None = None, |
| ground_truth_payload: dict[str, Any] | None = None, |
| original_user_prompt: str | None = None, |
| summary_parse: dict[str, Any] | None = None, |
| example_index: int | None = None, |
| ) -> dict[str, Any]: |
| training_messages = training_messages or build_chat_messages( |
| summary, |
| max_target_intersections=inference.repair_config.max_target_intersections, |
| allow_only_visible_candidates=inference.repair_config.allow_only_visible_candidates, |
| ) |
| runtime_user_prompt = original_user_prompt or build_user_prompt(summary) |
| runtime_flat_prompt = format_district_prompt( |
| summary, |
| max_target_intersections=inference.repair_config.max_target_intersections, |
| allow_only_visible_candidates=inference.repair_config.allow_only_visible_candidates, |
| ) |
| training_chat_prompt = render_chat_prompt( |
| training_messages, |
| tokenizer=inference.tokenizer, |
| add_generation_prompt=True, |
| ) |
| runtime_flat_from_user_prompt = format_district_prompt_from_user_content( |
| runtime_user_prompt, |
| max_target_intersections=inference.repair_config.max_target_intersections, |
| allow_only_visible_candidates=inference.repair_config.allow_only_visible_candidates, |
| ) |
|
|
| runtime_flat = run_prompt_diagnostic( |
| inference=inference, |
| prompt_text=runtime_flat_prompt, |
| summary=summary, |
| max_new_tokens=max_new_tokens, |
| prompt_style="runtime_flat", |
| ) |
| training_chat = run_prompt_diagnostic( |
| inference=inference, |
| prompt_text=training_chat_prompt, |
| summary=summary, |
| max_new_tokens=max_new_tokens, |
| prompt_style="training_chat", |
| ) |
|
|
| training_system_prompt = training_messages[0]["content"] if training_messages else build_system_prompt( |
| max_target_intersections=inference.repair_config.max_target_intersections, |
| allow_only_visible_candidates=inference.repair_config.allow_only_visible_candidates, |
| ) |
| training_user_prompt = training_messages[1]["content"] if len(training_messages) > 1 else runtime_user_prompt |
|
|
| prompt_compare = compare_prompt_shapes( |
| training_system_prompt=training_system_prompt, |
| training_user_prompt=training_user_prompt, |
| runtime_flat_prompt=runtime_flat_prompt, |
| runtime_flat_from_user_prompt=runtime_flat_from_user_prompt, |
| training_chat_prompt=training_chat_prompt, |
| ) |
|
|
| row = { |
| "source": source, |
| "example_index": example_index, |
| "city_id": city_id, |
| "scenario": scenario, |
| "district_id": district_id, |
| "decision_step": int(decision_step), |
| "wrapper_mode": wrapper_mode, |
| "training_system_prompt": training_system_prompt, |
| "training_user_prompt": training_user_prompt, |
| "runtime_flat_prompt": runtime_flat_prompt, |
| "runtime_flat_prompt_from_user_prompt": runtime_flat_from_user_prompt, |
| "training_chat_prompt": training_chat_prompt, |
| "prompt_comparison": prompt_compare, |
| "summary_features": summary_features(runtime_user_prompt, summary_parse=summary_parse), |
| "summary_state": summary.to_dict(), |
| "runtime_flat": runtime_flat, |
| "training_chat": training_chat, |
| "ground_truth_payload": ground_truth_payload, |
| } |
| return row |
|
|
|
|
| def run_prompt_diagnostic( |
| inference: DistrictLLMInference, |
| prompt_text: str, |
| summary: DistrictStateSummary, |
| max_new_tokens: int, |
| prompt_style: str, |
| ) -> dict[str, Any]: |
| raw_text = inference.generate_raw(prompt=prompt_text, max_new_tokens=max_new_tokens) |
| action, repair_report, parsed_payload, json_valid, schema_valid_before_repair = inference.parse_action( |
| raw_text, |
| summary=summary, |
| ) |
| prompt_token_length = token_length(inference, prompt_text) |
| output_token_length = token_length(inference, raw_text) |
| diagnostics = analyze_generation_result( |
| raw_text=raw_text, |
| parsed_payload=parsed_payload, |
| summary=summary, |
| repair_report=repair_report, |
| max_new_tokens=max_new_tokens, |
| output_token_length=output_token_length, |
| json_valid=json_valid, |
| schema_valid_before_repair=schema_valid_before_repair, |
| ) |
| return { |
| "prompt_style": prompt_style, |
| "prompt_token_length": prompt_token_length, |
| "output_token_length": output_token_length, |
| "prompt_near_model_limit": prompt_near_model_limit(inference, prompt_token_length), |
| "raw_text": raw_text, |
| "extracted_json_text": diagnostics["extracted_json_text"], |
| "parsed_payload_before_repair": parsed_payload, |
| "action_after_repair": action.to_dict(), |
| "repair_report": repair_report.to_dict(), |
| "json_valid": bool(json_valid), |
| "schema_valid_before_repair": bool(schema_valid_before_repair), |
| "wrapper_would_fallback": bool(diagnostics["wrapper_would_fallback"]), |
| "failure_reasons": diagnostics["failure_reasons"], |
| "candidate_diagnostics": diagnostics["candidate_diagnostics"], |
| "possible_truncation": bool(diagnostics["possible_truncation"]), |
| } |
|
|
|
|
| def analyze_generation_result( |
| raw_text: str, |
| parsed_payload: dict[str, Any] | None, |
| summary: DistrictStateSummary, |
| repair_report, |
| max_new_tokens: int, |
| output_token_length: int | None, |
| json_valid: bool, |
| schema_valid_before_repair: bool, |
| ) -> dict[str, Any]: |
| failure_reasons: list[str] = [] |
| extracted_json_text, prefix_text, suffix_text = extract_json_details(raw_text) |
| if "```" in raw_text: |
| failure_reasons.append("markdown_code_fence_present") |
| if prefix_text.strip(): |
| failure_reasons.append("extra_prefix_text") |
| if suffix_text.strip(): |
| failure_reasons.append("extra_suffix_text") |
| if extracted_json_text is None: |
| failure_reasons.append("no_json_object_found") |
| if not json_valid: |
| failure_reasons.append("json_parse_error") |
|
|
| raw_targets = [] |
| if parsed_payload is None: |
| parsed_payload = None |
| else: |
| missing_keys = sorted(REQUIRED_ACTION_KEYS - set(parsed_payload)) |
| extra_keys = sorted(set(parsed_payload) - REQUIRED_ACTION_KEYS) |
| if missing_keys: |
| failure_reasons.append("missing_required_field") |
| if extra_keys: |
| failure_reasons.append("extra_field_present") |
| strategy = parsed_payload.get("strategy") |
| if strategy not in DISTRICT_STRATEGIES: |
| failure_reasons.append("unknown_strategy") |
| priority_corridor = parsed_payload.get("priority_corridor") |
| if priority_corridor is not None and priority_corridor not in PRIORITY_CORRIDORS: |
| failure_reasons.append("unknown_priority_corridor") |
| phase_bias = parsed_payload.get("phase_bias") |
| if phase_bias not in PHASE_BIASES: |
| failure_reasons.append("unknown_phase_bias") |
| duration_steps = parsed_payload.get("duration_steps") |
| if not isinstance(duration_steps, int): |
| failure_reasons.append("invalid_duration_type") |
| elif not 1 <= duration_steps <= 20: |
| failure_reasons.append("invalid_duration_range") |
|
|
| raw_target_payload = parsed_payload.get("target_intersections", []) |
| if isinstance(raw_target_payload, list): |
| raw_targets = [str(item) for item in raw_target_payload] |
| elif isinstance(raw_target_payload, str): |
| failure_reasons.append("target_intersections_not_json_array") |
| raw_targets = [raw_target_payload] |
| else: |
| failure_reasons.append("target_intersections_wrong_type") |
|
|
| if not schema_valid_before_repair: |
| failure_reasons.append("schema_validation_failed") |
| if raw_text.strip() and not raw_text.rstrip().endswith("}"): |
| failure_reasons.append("output_does_not_end_with_json") |
| if output_token_length is not None and output_token_length >= max_new_tokens: |
| failure_reasons.append("possible_generation_truncation") |
|
|
| candidate_ids = set(summary.candidate_ids()) |
| candidate_diagnostics = [] |
| for target in raw_targets: |
| visible = target in candidate_ids |
| candidate_diagnostics.append( |
| { |
| "target_intersection": target, |
| "visible_candidate": visible, |
| "valid_id_format": target.startswith("i_"), |
| } |
| ) |
| if not target.startswith("i_"): |
| failure_reasons.append("invalid_target_id_format") |
| if candidate_ids and not visible: |
| failure_reasons.append("candidate_intersections_constraint_violation") |
|
|
| if raw_targets == []: |
| failure_reasons.append("empty_target_intersections") |
| if repair_report.invalid_ids_removed: |
| failure_reasons.append("repair_removed_invalid_ids") |
| if repair_report.non_visible_ids_removed: |
| failure_reasons.append("repair_removed_non_visible_ids") |
| if repair_report.empty_after_filtering: |
| failure_reasons.append("repair_emptied_targets") |
| if repair_report.fallback_used: |
| failure_reasons.append(f"repair_used_fallback:{repair_report.fallback_mode}") |
|
|
| wrapper_would_fallback = ( |
| not json_valid |
| or not schema_valid_before_repair |
| or bool(repair_report.fallback_used) |
| or bool(repair_report.empty_after_filtering) |
| ) |
| return { |
| "extracted_json_text": extracted_json_text, |
| "failure_reasons": sorted(set(failure_reasons)), |
| "candidate_diagnostics": candidate_diagnostics, |
| "possible_truncation": bool(output_token_length is not None and output_token_length >= max_new_tokens), |
| "wrapper_would_fallback": wrapper_would_fallback, |
| } |
|
|
|
|
| def extract_json_details(raw_text: str) -> tuple[str | None, str, str]: |
| start = raw_text.find("{") |
| end = raw_text.rfind("}") |
| if start == -1 or end == -1 or end <= start: |
| return None, raw_text, "" |
| return raw_text[start : end + 1], raw_text[:start], raw_text[end + 1 :] |
|
|
|
|
| def parse_summary_text(summary_text: str) -> tuple[DistrictStateSummary, dict[str, Any]]: |
| text = summary_text.strip() |
| if text.startswith("### DISTRICT STATE"): |
| text = text.split("\n", 1)[1] if "\n" in text else "" |
| lines = [line.rstrip() for line in text.splitlines() if line.strip()] |
|
|
| payload: dict[str, Any] = {} |
| top_lines: list[str] = [] |
| candidate_lines: list[str] = [] |
| observed_order: list[str] = [] |
| section = "scalars" |
| for line in lines: |
| if line == "top_congested_intersections:": |
| section = "top" |
| continue |
| if line == "candidate_intersections:": |
| section = "candidate" |
| continue |
| if line.startswith("- "): |
| if section == "top": |
| top_lines.append(line) |
| elif section == "candidate": |
| candidate_lines.append(line) |
| continue |
| if section != "scalars" or ": " not in line: |
| continue |
| key, value = line.split(": ", 1) |
| observed_order.append(key) |
| payload[key] = parse_summary_scalar(key, value) |
|
|
| payload["scenario_name"] = payload.pop("scenario", payload.get("scenario_name", "")) |
| payload["top_congested_intersections"] = [parse_top_congested_line(line) for line in top_lines if line != "- none"] |
| candidate_text = "candidate_intersections:\n" + "\n".join(candidate_lines or ["- none"]) |
| payload["candidate_intersections"] = parse_candidate_intersections_from_text(candidate_text) |
| summary = DistrictStateSummary.from_dict(payload) |
| return summary, { |
| "observed_field_order": observed_order, |
| "missing_scalar_fields": [key for key in SUMMARY_SCALAR_ORDER if key not in observed_order], |
| "extra_scalar_fields": [key for key in observed_order if key not in SUMMARY_SCALAR_ORDER], |
| "top_congested_count": len(payload["top_congested_intersections"]), |
| "candidate_intersections_count": len(payload["candidate_intersections"]), |
| "summary_text_length": len(summary_text), |
| "line_count": len(lines), |
| } |
|
|
|
|
| def parse_summary_scalar(key: str, value: str) -> Any: |
| if key in { |
| "decision_step", |
| "sim_time", |
| "intersection_count", |
| }: |
| return int(value) |
| if key in { |
| "spillback_risk", |
| "incident_flag", |
| "construction_flag", |
| "overload_flag", |
| "event_flag", |
| }: |
| return bool(int(value)) |
| if key in { |
| "city_id", |
| "district_id", |
| "district_type", |
| "scenario", |
| "scenario_type", |
| "dominant_flow", |
| }: |
| return value |
| return float(value) |
|
|
|
|
| def parse_top_congested_line(line: str) -> dict[str, Any]: |
| tokens = line[2:].split() |
| payload: dict[str, Any] = { |
| "intersection_id": tokens[0], |
| "queue_total": 0.0, |
| "wait_total": 0.0, |
| "outgoing_load": 0.0, |
| "current_phase": 0, |
| "is_boundary": False, |
| } |
| for token in tokens[1:]: |
| if "=" not in token: |
| continue |
| key, value = token.split("=", 1) |
| if key == "q": |
| payload["queue_total"] = float(value) |
| elif key == "w": |
| payload["wait_total"] = float(value) |
| elif key == "out": |
| payload["outgoing_load"] = float(value) |
| elif key == "phase": |
| payload["current_phase"] = int(value) |
| elif key == "boundary": |
| payload["is_boundary"] = value == "1" |
| CongestedIntersection(**payload) |
| return payload |
|
|
|
|
| def summary_features(summary_text: str, summary_parse: dict[str, Any] | None = None) -> dict[str, Any]: |
| summary_parse = summary_parse or {} |
| lines = [line.rstrip() for line in summary_text.splitlines() if line.strip()] |
| field_order = [] |
| for line in lines: |
| if ": " in line and not line.startswith("- "): |
| field_order.append(line.split(": ", 1)[0]) |
| elif line.endswith(":") and not line.startswith("- "): |
| field_order.append(line[:-1]) |
| return { |
| "summary_length_chars": len(summary_text), |
| "summary_line_count": len(lines), |
| "field_order": field_order, |
| "field_count": len(field_order), |
| "has_candidate_intersections": "candidate_intersections:" in summary_text, |
| "has_top_congested_intersections": "top_congested_intersections:" in summary_text, |
| **summary_parse, |
| } |
|
|
|
|
| def compare_prompt_shapes( |
| training_system_prompt: str, |
| training_user_prompt: str, |
| runtime_flat_prompt: str, |
| runtime_flat_from_user_prompt: str, |
| training_chat_prompt: str, |
| ) -> dict[str, Any]: |
| differences = [] |
| if training_system_prompt and "You are a district traffic coordinator" in training_system_prompt: |
| differences.append("training uses an explicit system prompt with JSON rules") |
| if runtime_flat_prompt.startswith("### DISTRICT ACTION SCHEMA"): |
| differences.append("runtime flat prompt injects schema text into the single prompt body") |
| if training_user_prompt.startswith("### DISTRICT STATE"): |
| differences.append("training user message contains only district state") |
| if "### DECISION" in runtime_flat_prompt: |
| differences.append("runtime flat prompt appends an explicit decision header") |
| if "### DISTRICT ACTION SCHEMA" not in training_user_prompt: |
| differences.append("training user message omits the schema header entirely") |
| if runtime_flat_prompt != runtime_flat_from_user_prompt: |
| differences.append("runtime flat prompt reconstructed from summary object differs from user prompt reconstruction") |
| return { |
| "differences": differences, |
| "runtime_has_system_role": False, |
| "training_has_system_role": True, |
| "runtime_has_schema_header": "### DISTRICT ACTION SCHEMA" in runtime_flat_prompt, |
| "training_user_has_schema_header": "### DISTRICT ACTION SCHEMA" in training_user_prompt, |
| "runtime_has_decision_header": "### DECISION" in runtime_flat_prompt, |
| "training_chat_prompt_startswith_system": training_chat_prompt.startswith("system:") or training_chat_prompt.startswith("<"), |
| } |
|
|
|
|
| def aggregate_prompt_results(rows: list[dict[str, Any]], prompt_style_key: str) -> dict[str, Any]: |
| if not rows: |
| return {} |
| prompt_rows = [row[prompt_style_key] for row in rows] |
| failure_counter = Counter( |
| reason |
| for row in prompt_rows |
| for reason in row["failure_reasons"] |
| ) |
| candidate_rows = [row["candidate_diagnostics"] for row in prompt_rows] |
| target_records = [record for records in candidate_rows for record in records] |
| repair_reports = [row["repair_report"] for row in prompt_rows] |
| return { |
| "num_examples": len(prompt_rows), |
| "json_valid_rate": safe_ratio(sum(int(row["json_valid"]) for row in prompt_rows), len(prompt_rows)), |
| "schema_valid_before_repair_rate": safe_ratio( |
| sum(int(row["schema_valid_before_repair"]) for row in prompt_rows), |
| len(prompt_rows), |
| ), |
| "wrapper_would_fallback_rate": safe_ratio( |
| sum(int(row["wrapper_would_fallback"]) for row in prompt_rows), |
| len(prompt_rows), |
| ), |
| "repair_fallback_rate": safe_ratio( |
| sum(int(report["fallback_used"]) for report in repair_reports), |
| len(repair_reports), |
| ), |
| "repair_changed_target_list_rate": safe_ratio( |
| sum( |
| int(report["raw_targets"] != report["repaired_targets"]) |
| for report in repair_reports |
| ), |
| len(repair_reports), |
| ), |
| "repair_emptied_targets_rate": safe_ratio( |
| sum(int(report["empty_after_filtering"]) for report in repair_reports), |
| len(repair_reports), |
| ), |
| "mean_prompt_token_length": average([row["prompt_token_length"] for row in prompt_rows]), |
| "mean_output_token_length": average([row["output_token_length"] for row in prompt_rows]), |
| "possible_truncation_rate": safe_ratio( |
| sum(int(row["possible_truncation"]) for row in prompt_rows), |
| len(prompt_rows), |
| ), |
| "top_failure_reasons": dict(failure_counter.most_common(20)), |
| "targets_outside_visible_candidate_rate": safe_ratio( |
| sum(int(not record["visible_candidate"]) for record in target_records), |
| len(target_records), |
| ), |
| "invalid_target_id_format_rate": safe_ratio( |
| sum(int(not record["valid_id_format"]) for record in target_records), |
| len(target_records), |
| ), |
| } |
|
|
|
|
| def aggregate_summary_features(rows: list[dict[str, Any]]) -> dict[str, Any]: |
| if not rows: |
| return {} |
| features = [row["summary_features"] for row in rows] |
| field_order_signatures = Counter(tuple(item["field_order"]) for item in features) |
| return { |
| "num_examples": len(features), |
| "mean_summary_length_chars": average([item["summary_length_chars"] for item in features]), |
| "median_summary_length_chars": median_or_zero([item["summary_length_chars"] for item in features]), |
| "mean_candidate_intersections_count": average( |
| [item.get("candidate_intersections_count", 0) for item in features] |
| ), |
| "mean_top_congested_count": average( |
| [item.get("top_congested_count", 0) for item in features] |
| ), |
| "field_order_signatures": { |
| "most_common": [ |
| { |
| "count": count, |
| "field_order": list(signature), |
| } |
| for signature, count in field_order_signatures.most_common(5) |
| ] |
| }, |
| } |
|
|
|
|
| def build_summary_report( |
| args: argparse.Namespace, |
| inference: DistrictLLMInference, |
| runtime_rows: list[dict[str, Any]], |
| offline_rows: list[dict[str, Any]], |
| ) -> dict[str, Any]: |
| runtime_flat_runtime = aggregate_prompt_results(runtime_rows, "runtime_flat") |
| training_chat_runtime = aggregate_prompt_results(runtime_rows, "training_chat") |
| runtime_flat_offline = aggregate_prompt_results(offline_rows, "runtime_flat") |
| training_chat_offline = aggregate_prompt_results(offline_rows, "training_chat") |
|
|
| root_causes = rank_root_causes( |
| runtime_flat_runtime=runtime_flat_runtime, |
| training_chat_runtime=training_chat_runtime, |
| runtime_flat_offline=runtime_flat_offline, |
| training_chat_offline=training_chat_offline, |
| ) |
| return { |
| "generated_at": datetime.now(timezone.utc).isoformat(), |
| "model_path": args.model_path, |
| "rl_checkpoint": args.rl_checkpoint, |
| "repair_config": asdict(inference.repair_config), |
| "generation_settings": { |
| "runtime_inference": { |
| "max_new_tokens": int(args.max_new_tokens), |
| "do_sample": False, |
| "prompt_style": "flat single prompt from format_district_prompt", |
| }, |
| "offline_eval_style": { |
| "max_new_tokens": int(args.max_new_tokens), |
| "do_sample": False, |
| "prompt_style": "chat messages rendered via build_generation_prompt", |
| }, |
| }, |
| "runtime_live": { |
| "summary_distribution": aggregate_summary_features(runtime_rows), |
| "runtime_flat": runtime_flat_runtime, |
| "training_chat": training_chat_runtime, |
| }, |
| "offline_validation_runtime_codepath": { |
| "summary_distribution": aggregate_summary_features(offline_rows), |
| "runtime_flat": runtime_flat_offline, |
| "training_chat": training_chat_offline, |
| }, |
| "key_answers": build_key_answers( |
| runtime_flat_runtime=runtime_flat_runtime, |
| training_chat_runtime=training_chat_runtime, |
| runtime_flat_offline=runtime_flat_offline, |
| training_chat_offline=training_chat_offline, |
| ), |
| "likely_root_causes_ranked": root_causes, |
| } |
|
|
|
|
| def build_key_answers( |
| runtime_flat_runtime: dict[str, Any], |
| training_chat_runtime: dict[str, Any], |
| runtime_flat_offline: dict[str, Any], |
| training_chat_offline: dict[str, Any], |
| ) -> dict[str, Any]: |
| return { |
| "runtime_prompt_vs_training_prompt": ( |
| "different" |
| if runtime_flat_offline.get("wrapper_would_fallback_rate", 0.0) |
| != training_chat_offline.get("wrapper_would_fallback_rate", 0.0) |
| else "similar" |
| ), |
| "runtime_summary_structure_vs_training_distribution": ( |
| "requires inspection of summary_distribution stats and prompt comparison" |
| ), |
| "raw_outputs_malformed_or_rejected": { |
| "runtime_flat_runtime_fallback_rate": runtime_flat_runtime.get("wrapper_would_fallback_rate"), |
| "runtime_flat_runtime_json_valid_rate": runtime_flat_runtime.get("json_valid_rate"), |
| "runtime_flat_runtime_schema_valid_rate": runtime_flat_runtime.get("schema_valid_before_repair_rate"), |
| }, |
| "candidate_constraints_main_problem": bool( |
| runtime_flat_runtime.get("targets_outside_visible_candidate_rate", 0.0) > 0.2 |
| or runtime_flat_offline.get("targets_outside_visible_candidate_rate", 0.0) > 0.2 |
| ), |
| "truncation_happening": bool( |
| runtime_flat_runtime.get("possible_truncation_rate", 0.0) > 0.05 |
| or runtime_flat_offline.get("possible_truncation_rate", 0.0) > 0.05 |
| ), |
| "runtime_codepath_succeeds_on_heldout_validation": bool( |
| runtime_flat_offline.get("wrapper_would_fallback_rate", 1.0) < 0.2 |
| ), |
| } |
|
|
|
|
| def rank_root_causes( |
| runtime_flat_runtime: dict[str, Any], |
| training_chat_runtime: dict[str, Any], |
| runtime_flat_offline: dict[str, Any], |
| training_chat_offline: dict[str, Any], |
| ) -> list[dict[str, Any]]: |
| causes: list[dict[str, Any]] = [] |
|
|
| prompt_gap = ( |
| runtime_flat_offline.get("wrapper_would_fallback_rate", 0.0) |
| - training_chat_offline.get("wrapper_would_fallback_rate", 0.0) |
| ) |
| causes.append( |
| { |
| "cause": "prompt_mismatch_between_runtime_and_training_offline_chat_path", |
| "score": float(prompt_gap), |
| "evidence": { |
| "offline_runtime_flat_fallback_rate": runtime_flat_offline.get("wrapper_would_fallback_rate"), |
| "offline_training_chat_fallback_rate": training_chat_offline.get("wrapper_would_fallback_rate"), |
| }, |
| } |
| ) |
|
|
| runtime_summary_gap = ( |
| training_chat_runtime.get("wrapper_would_fallback_rate", 0.0) |
| - training_chat_offline.get("wrapper_would_fallback_rate", 0.0) |
| ) |
| causes.append( |
| { |
| "cause": "runtime_summary_distribution_shift", |
| "score": float(runtime_summary_gap), |
| "evidence": { |
| "runtime_training_chat_fallback_rate": training_chat_runtime.get("wrapper_would_fallback_rate"), |
| "offline_training_chat_fallback_rate": training_chat_offline.get("wrapper_would_fallback_rate"), |
| }, |
| } |
| ) |
|
|
| candidate_score = max( |
| runtime_flat_runtime.get("targets_outside_visible_candidate_rate", 0.0), |
| runtime_flat_offline.get("targets_outside_visible_candidate_rate", 0.0), |
| runtime_flat_runtime.get("repair_emptied_targets_rate", 0.0), |
| runtime_flat_offline.get("repair_emptied_targets_rate", 0.0), |
| ) |
| causes.append( |
| { |
| "cause": "candidate_intersections_or_visible_target_constraint_mismatch", |
| "score": float(candidate_score), |
| "evidence": { |
| "runtime_targets_outside_visible_rate": runtime_flat_runtime.get("targets_outside_visible_candidate_rate"), |
| "offline_targets_outside_visible_rate": runtime_flat_offline.get("targets_outside_visible_candidate_rate"), |
| "runtime_repair_emptied_targets_rate": runtime_flat_runtime.get("repair_emptied_targets_rate"), |
| }, |
| } |
| ) |
|
|
| validator_score = max( |
| runtime_flat_runtime.get("wrapper_would_fallback_rate", 0.0) |
| - runtime_flat_runtime.get("json_valid_rate", 0.0), |
| runtime_flat_offline.get("wrapper_would_fallback_rate", 0.0) |
| - runtime_flat_offline.get("json_valid_rate", 0.0), |
| ) |
| causes.append( |
| { |
| "cause": "validator_or_repair_stricter_than_raw_generation_quality", |
| "score": float(validator_score), |
| "evidence": { |
| "runtime_json_valid_rate": runtime_flat_runtime.get("json_valid_rate"), |
| "runtime_wrapper_fallback_rate": runtime_flat_runtime.get("wrapper_would_fallback_rate"), |
| "offline_json_valid_rate": runtime_flat_offline.get("json_valid_rate"), |
| "offline_wrapper_fallback_rate": runtime_flat_offline.get("wrapper_would_fallback_rate"), |
| }, |
| } |
| ) |
|
|
| truncation_score = max( |
| runtime_flat_runtime.get("possible_truncation_rate", 0.0), |
| runtime_flat_offline.get("possible_truncation_rate", 0.0), |
| ) |
| causes.append( |
| { |
| "cause": "generation_truncation", |
| "score": float(truncation_score), |
| "evidence": { |
| "runtime_possible_truncation_rate": runtime_flat_runtime.get("possible_truncation_rate"), |
| "offline_possible_truncation_rate": runtime_flat_offline.get("possible_truncation_rate"), |
| }, |
| } |
| ) |
|
|
| causes.sort(key=lambda item: item["score"], reverse=True) |
| return causes |
|
|
|
|
| def render_prompt_comparison( |
| runtime_rows: list[dict[str, Any]], |
| offline_rows: list[dict[str, Any]], |
| summary_report: dict[str, Any], |
| ) -> str: |
| runtime_example = runtime_rows[0] if runtime_rows else None |
| offline_example = offline_rows[0] if offline_rows else None |
| lines = [ |
| "# Runtime Prompt Diagnosis", |
| "", |
| "## Key Finding", |
| "", |
| "Training/offline evaluation uses a chat-style prompt with separate `system` and `user` messages.", |
| "This report compares the chat-style prompt path against the older flattened prompt path.", |
| "", |
| "## Aggregate Answers", |
| "", |
| "```json", |
| json.dumps(summary_report.get("key_answers", {}), indent=2, sort_keys=True), |
| "```", |
| "", |
| ] |
| if offline_example is not None: |
| lines.extend( |
| [ |
| "## Representative Offline Validation Example", |
| "", |
| "### Training System Prompt", |
| "", |
| "```text", |
| offline_example["training_system_prompt"], |
| "```", |
| "", |
| "### Training User Prompt", |
| "", |
| "```text", |
| offline_example["training_user_prompt"], |
| "```", |
| "", |
| "### Runtime Flat Prompt", |
| "", |
| "```text", |
| offline_example["runtime_flat_prompt"], |
| "```", |
| "", |
| "### Training Chat Rendered Prompt", |
| "", |
| "```text", |
| offline_example["training_chat_prompt"], |
| "```", |
| "", |
| "### Prompt Diff", |
| "", |
| "```diff", |
| *list( |
| difflib.unified_diff( |
| offline_example["training_chat_prompt"].splitlines(), |
| offline_example["runtime_flat_prompt"].splitlines(), |
| fromfile="training_chat_prompt", |
| tofile="runtime_flat_prompt", |
| lineterm="", |
| ) |
| ), |
| "```", |
| "", |
| ] |
| ) |
| if runtime_example is not None: |
| lines.extend( |
| [ |
| "## Representative Runtime Summary Example", |
| "", |
| "### Runtime Flat Output", |
| "", |
| "```json", |
| json.dumps(runtime_example["runtime_flat"], indent=2, sort_keys=True), |
| "```", |
| "", |
| "### Training Chat Output On Same Summary", |
| "", |
| "```json", |
| json.dumps(runtime_example["training_chat"], indent=2, sort_keys=True), |
| "```", |
| "", |
| ] |
| ) |
| return "\n".join(lines) + "\n" |
|
|
|
|
| def flatten_failure_example(row: dict[str, Any], prompt_style_key: str) -> dict[str, Any]: |
| payload = row[prompt_style_key] |
| return { |
| "source": row["source"], |
| "prompt_style": prompt_style_key, |
| "city_id": row["city_id"], |
| "scenario": row["scenario"], |
| "district_id": row["district_id"], |
| "decision_step": row["decision_step"], |
| "failure_reasons": payload["failure_reasons"], |
| "json_valid": payload["json_valid"], |
| "schema_valid_before_repair": payload["schema_valid_before_repair"], |
| "wrapper_would_fallback": payload["wrapper_would_fallback"], |
| "repair_report": payload["repair_report"], |
| "raw_text": payload["raw_text"], |
| "parsed_payload_before_repair": payload["parsed_payload_before_repair"], |
| "action_after_repair": payload["action_after_repair"], |
| "candidate_diagnostics": payload["candidate_diagnostics"], |
| "prompt_text": row["runtime_flat_prompt"] if prompt_style_key == "runtime_flat" else row["training_chat_prompt"], |
| } |
|
|
|
|
| def resolve_scenario_specs(dataset: CityFlowDataset, args: argparse.Namespace) -> list[ScenarioSpec]: |
| city_ids = list(args.cities) if args.cities else dataset.load_split(args.split) |
| scenario_specs: list[ScenarioSpec] = [] |
| for city_id in city_ids: |
| available_scenarios = dataset.scenarios_for_city(city_id) |
| requested = list(args.scenarios) if args.scenarios else available_scenarios |
| for scenario_name in requested: |
| scenario_specs.append(dataset.build_scenario_spec(city_id, scenario_name)) |
| return scenario_specs |
|
|
|
|
| def default_env_config() -> EnvConfig: |
| return EnvConfig( |
| simulator_interval=1, |
| decision_interval=5, |
| min_green_time=10, |
| thread_num=1, |
| max_episode_seconds=300, |
| observation=ObservationConfig(), |
| reward=RewardConfig(variant="wait_queue_throughput"), |
| ) |
|
|
|
|
| def token_length(inference: DistrictLLMInference, text: str) -> int | None: |
| tokenizer = inference.tokenizer |
| if tokenizer is None: |
| return None |
| try: |
| encoded = tokenizer(text, add_special_tokens=False) |
| except TypeError: |
| encoded = tokenizer(text) |
| return int(len(encoded["input_ids"])) |
|
|
|
|
| def prompt_near_model_limit(inference: DistrictLLMInference, prompt_token_length: int | None) -> bool | None: |
| if prompt_token_length is None or inference.tokenizer is None: |
| return None |
| model_max_length = getattr(inference.tokenizer, "model_max_length", None) |
| if model_max_length is None or model_max_length <= 0 or model_max_length > 1_000_000: |
| return None |
| return bool(prompt_token_length >= int(0.9 * model_max_length)) |
|
|
|
|
| def average(values: list[int | float | None]) -> float: |
| filtered = [float(value) for value in values if value is not None] |
| return float(mean(filtered)) if filtered else 0.0 |
|
|
|
|
| def median_or_zero(values: list[int | float | None]) -> float: |
| filtered = [float(value) for value in values if value is not None] |
| return float(median(filtered)) if filtered else 0.0 |
|
|
|
|
| def safe_ratio(numerator: int | float, denominator: int | float) -> float: |
| if float(denominator) == 0.0: |
| return 0.0 |
| return float(numerator) / float(denominator) |
|
|
|
|
| def write_json(path: Path, payload: Any) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") |
|
|
|
|
| def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("w", encoding="utf-8") as handle: |
| for row in rows: |
| handle.write(json.dumps(row, sort_keys=True)) |
| handle.write("\n") |
|
|
|
|
| def write_text(path: Path, payload: str) -> None: |
| path.parent.mkdir(parents=True, exist_ok=True) |
| path.write_text(payload, encoding="utf-8") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|