agentic-traffic / scripts /diagnose_llm_runtime.py
Aditya2162's picture
Upload folder using huggingface_hub
3d2dbcf verified
from __future__ import annotations
import argparse
from collections import Counter
from dataclasses import asdict
from datetime import datetime, timezone
import difflib
import json
from pathlib import Path
from statistics import mean, median
import sys
from typing import Any
import numpy as np
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from district_llm.eval import load_rows
from district_llm.inference import DistrictLLMInference
from district_llm.prompting import (
build_chat_messages,
build_system_prompt,
build_user_prompt,
format_district_prompt,
format_district_prompt_from_user_content,
render_chat_prompt,
)
from district_llm.repair import (
RepairConfig,
parse_candidate_intersections_from_text,
)
from district_llm.rl_guidance_wrapper import FixedRLPolicyAdapter
from district_llm.schema import (
DISTRICT_STRATEGIES,
PHASE_BIASES,
PRIORITY_CORRIDORS,
CandidateIntersection,
CongestedIntersection,
DistrictAction,
DistrictStateSummary,
)
from district_llm.summary_builder import DistrictStateSummaryBuilder
from env.observation_builder import ObservationConfig
from env.reward import RewardConfig
from env.traffic_env import EnvConfig
from training.cityflow_dataset import CityFlowDataset, ScenarioSpec
from training.train_local_policy import build_env
REQUIRED_ACTION_KEYS = {
"strategy",
"priority_corridor",
"target_intersections",
"phase_bias",
"duration_steps",
}
SUMMARY_SCALAR_ORDER = [
"city_id",
"district_id",
"district_type",
"scenario",
"scenario_type",
"decision_step",
"sim_time",
"intersection_count",
"avg_queue",
"max_queue",
"total_queue",
"avg_wait",
"max_wait",
"total_wait",
"avg_outgoing_load",
"max_outgoing_load",
"total_outgoing_load",
"recent_throughput",
"queue_change",
"wait_change",
"throughput_change",
"ns_queue",
"ew_queue",
"ns_wait",
"ew_wait",
"dominant_flow",
"boundary_queue_total",
"boundary_wait_total",
"spillback_risk",
"incident_flag",
"construction_flag",
"overload_flag",
"event_flag",
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Diagnose why district-LLM runtime guidance fails even when offline validation looks strong."
)
)
parser.add_argument("--model-path", required=True)
parser.add_argument("--rl-checkpoint", required=True)
parser.add_argument("--val-jsonl", default="data/district_llm_dataset_v3/val.jsonl")
parser.add_argument("--generated-root", default="data/generated")
parser.add_argument("--splits-root", default="data/splits")
parser.add_argument("--split", default="val", choices=("train", "val", "test"))
parser.add_argument("--cities", nargs="+", default=None)
parser.add_argument("--scenarios", nargs="+", default=None)
parser.add_argument("--max-diagnostic-calls", type=int, default=20)
parser.add_argument("--max-offline-examples", type=int, default=20)
parser.add_argument("--max-episode-seconds", type=int, default=300)
parser.add_argument("--max-new-tokens", type=int, default=128)
parser.add_argument("--device", default=None)
parser.add_argument("--output-dir", default="artifacts/llm_runtime_diagnosis")
parser.add_argument(
"--allow-only-visible-candidates",
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument("--max-target-intersections", type=int, default=3)
parser.add_argument(
"--fallback-on-empty-targets",
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument(
"--fallback-mode",
choices=("heuristic", "hold", "none"),
default="heuristic",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repair_config = RepairConfig(
allow_only_visible_candidates=args.allow_only_visible_candidates,
max_target_intersections=args.max_target_intersections,
fallback_on_empty_targets=args.fallback_on_empty_targets,
fallback_mode=args.fallback_mode,
)
inference = DistrictLLMInference(
model_name_or_path=args.model_path,
device=args.device,
repair_config=repair_config,
)
runtime_rows = collect_runtime_rows(args=args, inference=inference)
offline_rows = collect_offline_rows(args=args, inference=inference)
all_rows = runtime_rows + offline_rows
failure_examples = [
flatten_failure_example(row, prompt_style_key)
for row in all_rows
for prompt_style_key in ("runtime_flat", "training_chat")
if row[prompt_style_key]["wrapper_would_fallback"]
]
summary_report = build_summary_report(
args=args,
inference=inference,
runtime_rows=runtime_rows,
offline_rows=offline_rows,
)
prompt_comparison = render_prompt_comparison(
runtime_rows=runtime_rows,
offline_rows=offline_rows,
summary_report=summary_report,
)
write_jsonl(output_dir / "diagnostic_rows.jsonl", all_rows)
write_json(output_dir / "summary_report.json", summary_report)
write_text(output_dir / "prompt_comparison.md", prompt_comparison)
write_jsonl(output_dir / "validator_failure_examples.jsonl", failure_examples)
print(json.dumps(summary_report, indent=2, sort_keys=True))
def collect_runtime_rows(
args: argparse.Namespace,
inference: DistrictLLMInference,
) -> list[dict[str, Any]]:
dataset = CityFlowDataset(
generated_root=args.generated_root,
splits_root=args.splits_root,
)
dataset.generate_default_splits()
scenario_specs = resolve_scenario_specs(dataset=dataset, args=args)
policy = FixedRLPolicyAdapter(
checkpoint_path=args.rl_checkpoint,
device=args.device,
)
env_config = policy.env_config or default_env_config()
env_config = EnvConfig(
simulator_interval=env_config.simulator_interval,
decision_interval=env_config.decision_interval,
min_green_time=env_config.min_green_time,
thread_num=env_config.thread_num,
max_episode_seconds=int(args.max_episode_seconds),
observation=env_config.observation,
reward=env_config.reward,
)
summary_builder = DistrictStateSummaryBuilder(
top_k=3,
candidate_limit=max(6, args.max_target_intersections),
)
rows: list[dict[str, Any]] = []
for scenario_spec in scenario_specs:
if len(rows) >= args.max_diagnostic_calls:
break
env = build_env(env_config, scenario_spec)
observation_batch = env.reset()
summary_builder.reset()
done = False
while not done and len(rows) < args.max_diagnostic_calls:
summaries = summary_builder.build_all(env, observation_batch)
for district_id in sorted(summaries):
if len(rows) >= args.max_diagnostic_calls:
break
summary = summaries[district_id]
rows.append(
diagnose_summary_call(
inference=inference,
summary=summary,
source="runtime_live",
city_id=scenario_spec.city_id,
scenario=scenario_spec.scenario_name,
district_id=district_id,
decision_step=int(summary.decision_step),
wrapper_mode="diagnose_llm_runtime",
max_new_tokens=args.max_new_tokens,
)
)
if len(rows) >= args.max_diagnostic_calls:
break
actions = policy.decide(observation_batch).actions
observation_batch, _, done, _ = env.step(actions)
return rows
def collect_offline_rows(
args: argparse.Namespace,
inference: DistrictLLMInference,
) -> list[dict[str, Any]]:
raw_rows = load_rows(args.val_jsonl, max_examples=args.max_offline_examples)
rows: list[dict[str, Any]] = []
for index, row in enumerate(raw_rows):
training_messages = row["messages"][:2]
summary_text = row["messages"][1]["content"]
summary, summary_parse = parse_summary_text(summary_text)
rows.append(
diagnose_summary_call(
inference=inference,
summary=summary,
source="offline_validation_runtime_codepath",
city_id=str(row.get("city_id", summary.city_id)),
scenario=str(row.get("scenario", summary.scenario_name)),
district_id=str(row.get("district_id", summary.district_id)),
decision_step=int(summary.decision_step),
wrapper_mode="diagnose_llm_runtime",
max_new_tokens=args.max_new_tokens,
training_messages=training_messages,
ground_truth_payload=json.loads(row["messages"][2]["content"]),
original_user_prompt=summary_text,
summary_parse=summary_parse,
example_index=index,
)
)
return rows
def diagnose_summary_call(
inference: DistrictLLMInference,
summary: DistrictStateSummary,
source: str,
city_id: str,
scenario: str,
district_id: str,
decision_step: int,
wrapper_mode: str,
max_new_tokens: int,
training_messages: list[dict[str, str]] | None = None,
ground_truth_payload: dict[str, Any] | None = None,
original_user_prompt: str | None = None,
summary_parse: dict[str, Any] | None = None,
example_index: int | None = None,
) -> dict[str, Any]:
training_messages = training_messages or build_chat_messages(
summary,
max_target_intersections=inference.repair_config.max_target_intersections,
allow_only_visible_candidates=inference.repair_config.allow_only_visible_candidates,
)
runtime_user_prompt = original_user_prompt or build_user_prompt(summary)
runtime_flat_prompt = format_district_prompt(
summary,
max_target_intersections=inference.repair_config.max_target_intersections,
allow_only_visible_candidates=inference.repair_config.allow_only_visible_candidates,
)
training_chat_prompt = render_chat_prompt(
training_messages,
tokenizer=inference.tokenizer,
add_generation_prompt=True,
)
runtime_flat_from_user_prompt = format_district_prompt_from_user_content(
runtime_user_prompt,
max_target_intersections=inference.repair_config.max_target_intersections,
allow_only_visible_candidates=inference.repair_config.allow_only_visible_candidates,
)
runtime_flat = run_prompt_diagnostic(
inference=inference,
prompt_text=runtime_flat_prompt,
summary=summary,
max_new_tokens=max_new_tokens,
prompt_style="runtime_flat",
)
training_chat = run_prompt_diagnostic(
inference=inference,
prompt_text=training_chat_prompt,
summary=summary,
max_new_tokens=max_new_tokens,
prompt_style="training_chat",
)
training_system_prompt = training_messages[0]["content"] if training_messages else build_system_prompt(
max_target_intersections=inference.repair_config.max_target_intersections,
allow_only_visible_candidates=inference.repair_config.allow_only_visible_candidates,
)
training_user_prompt = training_messages[1]["content"] if len(training_messages) > 1 else runtime_user_prompt
prompt_compare = compare_prompt_shapes(
training_system_prompt=training_system_prompt,
training_user_prompt=training_user_prompt,
runtime_flat_prompt=runtime_flat_prompt,
runtime_flat_from_user_prompt=runtime_flat_from_user_prompt,
training_chat_prompt=training_chat_prompt,
)
row = {
"source": source,
"example_index": example_index,
"city_id": city_id,
"scenario": scenario,
"district_id": district_id,
"decision_step": int(decision_step),
"wrapper_mode": wrapper_mode,
"training_system_prompt": training_system_prompt,
"training_user_prompt": training_user_prompt,
"runtime_flat_prompt": runtime_flat_prompt,
"runtime_flat_prompt_from_user_prompt": runtime_flat_from_user_prompt,
"training_chat_prompt": training_chat_prompt,
"prompt_comparison": prompt_compare,
"summary_features": summary_features(runtime_user_prompt, summary_parse=summary_parse),
"summary_state": summary.to_dict(),
"runtime_flat": runtime_flat,
"training_chat": training_chat,
"ground_truth_payload": ground_truth_payload,
}
return row
def run_prompt_diagnostic(
inference: DistrictLLMInference,
prompt_text: str,
summary: DistrictStateSummary,
max_new_tokens: int,
prompt_style: str,
) -> dict[str, Any]:
raw_text = inference.generate_raw(prompt=prompt_text, max_new_tokens=max_new_tokens)
action, repair_report, parsed_payload, json_valid, schema_valid_before_repair = inference.parse_action(
raw_text,
summary=summary,
)
prompt_token_length = token_length(inference, prompt_text)
output_token_length = token_length(inference, raw_text)
diagnostics = analyze_generation_result(
raw_text=raw_text,
parsed_payload=parsed_payload,
summary=summary,
repair_report=repair_report,
max_new_tokens=max_new_tokens,
output_token_length=output_token_length,
json_valid=json_valid,
schema_valid_before_repair=schema_valid_before_repair,
)
return {
"prompt_style": prompt_style,
"prompt_token_length": prompt_token_length,
"output_token_length": output_token_length,
"prompt_near_model_limit": prompt_near_model_limit(inference, prompt_token_length),
"raw_text": raw_text,
"extracted_json_text": diagnostics["extracted_json_text"],
"parsed_payload_before_repair": parsed_payload,
"action_after_repair": action.to_dict(),
"repair_report": repair_report.to_dict(),
"json_valid": bool(json_valid),
"schema_valid_before_repair": bool(schema_valid_before_repair),
"wrapper_would_fallback": bool(diagnostics["wrapper_would_fallback"]),
"failure_reasons": diagnostics["failure_reasons"],
"candidate_diagnostics": diagnostics["candidate_diagnostics"],
"possible_truncation": bool(diagnostics["possible_truncation"]),
}
def analyze_generation_result(
raw_text: str,
parsed_payload: dict[str, Any] | None,
summary: DistrictStateSummary,
repair_report,
max_new_tokens: int,
output_token_length: int | None,
json_valid: bool,
schema_valid_before_repair: bool,
) -> dict[str, Any]:
failure_reasons: list[str] = []
extracted_json_text, prefix_text, suffix_text = extract_json_details(raw_text)
if "```" in raw_text:
failure_reasons.append("markdown_code_fence_present")
if prefix_text.strip():
failure_reasons.append("extra_prefix_text")
if suffix_text.strip():
failure_reasons.append("extra_suffix_text")
if extracted_json_text is None:
failure_reasons.append("no_json_object_found")
if not json_valid:
failure_reasons.append("json_parse_error")
raw_targets = []
if parsed_payload is None:
parsed_payload = None
else:
missing_keys = sorted(REQUIRED_ACTION_KEYS - set(parsed_payload))
extra_keys = sorted(set(parsed_payload) - REQUIRED_ACTION_KEYS)
if missing_keys:
failure_reasons.append("missing_required_field")
if extra_keys:
failure_reasons.append("extra_field_present")
strategy = parsed_payload.get("strategy")
if strategy not in DISTRICT_STRATEGIES:
failure_reasons.append("unknown_strategy")
priority_corridor = parsed_payload.get("priority_corridor")
if priority_corridor is not None and priority_corridor not in PRIORITY_CORRIDORS:
failure_reasons.append("unknown_priority_corridor")
phase_bias = parsed_payload.get("phase_bias")
if phase_bias not in PHASE_BIASES:
failure_reasons.append("unknown_phase_bias")
duration_steps = parsed_payload.get("duration_steps")
if not isinstance(duration_steps, int):
failure_reasons.append("invalid_duration_type")
elif not 1 <= duration_steps <= 20:
failure_reasons.append("invalid_duration_range")
raw_target_payload = parsed_payload.get("target_intersections", [])
if isinstance(raw_target_payload, list):
raw_targets = [str(item) for item in raw_target_payload]
elif isinstance(raw_target_payload, str):
failure_reasons.append("target_intersections_not_json_array")
raw_targets = [raw_target_payload]
else:
failure_reasons.append("target_intersections_wrong_type")
if not schema_valid_before_repair:
failure_reasons.append("schema_validation_failed")
if raw_text.strip() and not raw_text.rstrip().endswith("}"):
failure_reasons.append("output_does_not_end_with_json")
if output_token_length is not None and output_token_length >= max_new_tokens:
failure_reasons.append("possible_generation_truncation")
candidate_ids = set(summary.candidate_ids())
candidate_diagnostics = []
for target in raw_targets:
visible = target in candidate_ids
candidate_diagnostics.append(
{
"target_intersection": target,
"visible_candidate": visible,
"valid_id_format": target.startswith("i_"),
}
)
if not target.startswith("i_"):
failure_reasons.append("invalid_target_id_format")
if candidate_ids and not visible:
failure_reasons.append("candidate_intersections_constraint_violation")
if raw_targets == []:
failure_reasons.append("empty_target_intersections")
if repair_report.invalid_ids_removed:
failure_reasons.append("repair_removed_invalid_ids")
if repair_report.non_visible_ids_removed:
failure_reasons.append("repair_removed_non_visible_ids")
if repair_report.empty_after_filtering:
failure_reasons.append("repair_emptied_targets")
if repair_report.fallback_used:
failure_reasons.append(f"repair_used_fallback:{repair_report.fallback_mode}")
wrapper_would_fallback = (
not json_valid
or not schema_valid_before_repair
or bool(repair_report.fallback_used)
or bool(repair_report.empty_after_filtering)
)
return {
"extracted_json_text": extracted_json_text,
"failure_reasons": sorted(set(failure_reasons)),
"candidate_diagnostics": candidate_diagnostics,
"possible_truncation": bool(output_token_length is not None and output_token_length >= max_new_tokens),
"wrapper_would_fallback": wrapper_would_fallback,
}
def extract_json_details(raw_text: str) -> tuple[str | None, str, str]:
start = raw_text.find("{")
end = raw_text.rfind("}")
if start == -1 or end == -1 or end <= start:
return None, raw_text, ""
return raw_text[start : end + 1], raw_text[:start], raw_text[end + 1 :]
def parse_summary_text(summary_text: str) -> tuple[DistrictStateSummary, dict[str, Any]]:
text = summary_text.strip()
if text.startswith("### DISTRICT STATE"):
text = text.split("\n", 1)[1] if "\n" in text else ""
lines = [line.rstrip() for line in text.splitlines() if line.strip()]
payload: dict[str, Any] = {}
top_lines: list[str] = []
candidate_lines: list[str] = []
observed_order: list[str] = []
section = "scalars"
for line in lines:
if line == "top_congested_intersections:":
section = "top"
continue
if line == "candidate_intersections:":
section = "candidate"
continue
if line.startswith("- "):
if section == "top":
top_lines.append(line)
elif section == "candidate":
candidate_lines.append(line)
continue
if section != "scalars" or ": " not in line:
continue
key, value = line.split(": ", 1)
observed_order.append(key)
payload[key] = parse_summary_scalar(key, value)
payload["scenario_name"] = payload.pop("scenario", payload.get("scenario_name", ""))
payload["top_congested_intersections"] = [parse_top_congested_line(line) for line in top_lines if line != "- none"]
candidate_text = "candidate_intersections:\n" + "\n".join(candidate_lines or ["- none"])
payload["candidate_intersections"] = parse_candidate_intersections_from_text(candidate_text)
summary = DistrictStateSummary.from_dict(payload)
return summary, {
"observed_field_order": observed_order,
"missing_scalar_fields": [key for key in SUMMARY_SCALAR_ORDER if key not in observed_order],
"extra_scalar_fields": [key for key in observed_order if key not in SUMMARY_SCALAR_ORDER],
"top_congested_count": len(payload["top_congested_intersections"]),
"candidate_intersections_count": len(payload["candidate_intersections"]),
"summary_text_length": len(summary_text),
"line_count": len(lines),
}
def parse_summary_scalar(key: str, value: str) -> Any:
if key in {
"decision_step",
"sim_time",
"intersection_count",
}:
return int(value)
if key in {
"spillback_risk",
"incident_flag",
"construction_flag",
"overload_flag",
"event_flag",
}:
return bool(int(value))
if key in {
"city_id",
"district_id",
"district_type",
"scenario",
"scenario_type",
"dominant_flow",
}:
return value
return float(value)
def parse_top_congested_line(line: str) -> dict[str, Any]:
tokens = line[2:].split()
payload: dict[str, Any] = {
"intersection_id": tokens[0],
"queue_total": 0.0,
"wait_total": 0.0,
"outgoing_load": 0.0,
"current_phase": 0,
"is_boundary": False,
}
for token in tokens[1:]:
if "=" not in token:
continue
key, value = token.split("=", 1)
if key == "q":
payload["queue_total"] = float(value)
elif key == "w":
payload["wait_total"] = float(value)
elif key == "out":
payload["outgoing_load"] = float(value)
elif key == "phase":
payload["current_phase"] = int(value)
elif key == "boundary":
payload["is_boundary"] = value == "1"
CongestedIntersection(**payload)
return payload
def summary_features(summary_text: str, summary_parse: dict[str, Any] | None = None) -> dict[str, Any]:
summary_parse = summary_parse or {}
lines = [line.rstrip() for line in summary_text.splitlines() if line.strip()]
field_order = []
for line in lines:
if ": " in line and not line.startswith("- "):
field_order.append(line.split(": ", 1)[0])
elif line.endswith(":") and not line.startswith("- "):
field_order.append(line[:-1])
return {
"summary_length_chars": len(summary_text),
"summary_line_count": len(lines),
"field_order": field_order,
"field_count": len(field_order),
"has_candidate_intersections": "candidate_intersections:" in summary_text,
"has_top_congested_intersections": "top_congested_intersections:" in summary_text,
**summary_parse,
}
def compare_prompt_shapes(
training_system_prompt: str,
training_user_prompt: str,
runtime_flat_prompt: str,
runtime_flat_from_user_prompt: str,
training_chat_prompt: str,
) -> dict[str, Any]:
differences = []
if training_system_prompt and "You are a district traffic coordinator" in training_system_prompt:
differences.append("training uses an explicit system prompt with JSON rules")
if runtime_flat_prompt.startswith("### DISTRICT ACTION SCHEMA"):
differences.append("runtime flat prompt injects schema text into the single prompt body")
if training_user_prompt.startswith("### DISTRICT STATE"):
differences.append("training user message contains only district state")
if "### DECISION" in runtime_flat_prompt:
differences.append("runtime flat prompt appends an explicit decision header")
if "### DISTRICT ACTION SCHEMA" not in training_user_prompt:
differences.append("training user message omits the schema header entirely")
if runtime_flat_prompt != runtime_flat_from_user_prompt:
differences.append("runtime flat prompt reconstructed from summary object differs from user prompt reconstruction")
return {
"differences": differences,
"runtime_has_system_role": False,
"training_has_system_role": True,
"runtime_has_schema_header": "### DISTRICT ACTION SCHEMA" in runtime_flat_prompt,
"training_user_has_schema_header": "### DISTRICT ACTION SCHEMA" in training_user_prompt,
"runtime_has_decision_header": "### DECISION" in runtime_flat_prompt,
"training_chat_prompt_startswith_system": training_chat_prompt.startswith("system:") or training_chat_prompt.startswith("<"),
}
def aggregate_prompt_results(rows: list[dict[str, Any]], prompt_style_key: str) -> dict[str, Any]:
if not rows:
return {}
prompt_rows = [row[prompt_style_key] for row in rows]
failure_counter = Counter(
reason
for row in prompt_rows
for reason in row["failure_reasons"]
)
candidate_rows = [row["candidate_diagnostics"] for row in prompt_rows]
target_records = [record for records in candidate_rows for record in records]
repair_reports = [row["repair_report"] for row in prompt_rows]
return {
"num_examples": len(prompt_rows),
"json_valid_rate": safe_ratio(sum(int(row["json_valid"]) for row in prompt_rows), len(prompt_rows)),
"schema_valid_before_repair_rate": safe_ratio(
sum(int(row["schema_valid_before_repair"]) for row in prompt_rows),
len(prompt_rows),
),
"wrapper_would_fallback_rate": safe_ratio(
sum(int(row["wrapper_would_fallback"]) for row in prompt_rows),
len(prompt_rows),
),
"repair_fallback_rate": safe_ratio(
sum(int(report["fallback_used"]) for report in repair_reports),
len(repair_reports),
),
"repair_changed_target_list_rate": safe_ratio(
sum(
int(report["raw_targets"] != report["repaired_targets"])
for report in repair_reports
),
len(repair_reports),
),
"repair_emptied_targets_rate": safe_ratio(
sum(int(report["empty_after_filtering"]) for report in repair_reports),
len(repair_reports),
),
"mean_prompt_token_length": average([row["prompt_token_length"] for row in prompt_rows]),
"mean_output_token_length": average([row["output_token_length"] for row in prompt_rows]),
"possible_truncation_rate": safe_ratio(
sum(int(row["possible_truncation"]) for row in prompt_rows),
len(prompt_rows),
),
"top_failure_reasons": dict(failure_counter.most_common(20)),
"targets_outside_visible_candidate_rate": safe_ratio(
sum(int(not record["visible_candidate"]) for record in target_records),
len(target_records),
),
"invalid_target_id_format_rate": safe_ratio(
sum(int(not record["valid_id_format"]) for record in target_records),
len(target_records),
),
}
def aggregate_summary_features(rows: list[dict[str, Any]]) -> dict[str, Any]:
if not rows:
return {}
features = [row["summary_features"] for row in rows]
field_order_signatures = Counter(tuple(item["field_order"]) for item in features)
return {
"num_examples": len(features),
"mean_summary_length_chars": average([item["summary_length_chars"] for item in features]),
"median_summary_length_chars": median_or_zero([item["summary_length_chars"] for item in features]),
"mean_candidate_intersections_count": average(
[item.get("candidate_intersections_count", 0) for item in features]
),
"mean_top_congested_count": average(
[item.get("top_congested_count", 0) for item in features]
),
"field_order_signatures": {
"most_common": [
{
"count": count,
"field_order": list(signature),
}
for signature, count in field_order_signatures.most_common(5)
]
},
}
def build_summary_report(
args: argparse.Namespace,
inference: DistrictLLMInference,
runtime_rows: list[dict[str, Any]],
offline_rows: list[dict[str, Any]],
) -> dict[str, Any]:
runtime_flat_runtime = aggregate_prompt_results(runtime_rows, "runtime_flat")
training_chat_runtime = aggregate_prompt_results(runtime_rows, "training_chat")
runtime_flat_offline = aggregate_prompt_results(offline_rows, "runtime_flat")
training_chat_offline = aggregate_prompt_results(offline_rows, "training_chat")
root_causes = rank_root_causes(
runtime_flat_runtime=runtime_flat_runtime,
training_chat_runtime=training_chat_runtime,
runtime_flat_offline=runtime_flat_offline,
training_chat_offline=training_chat_offline,
)
return {
"generated_at": datetime.now(timezone.utc).isoformat(),
"model_path": args.model_path,
"rl_checkpoint": args.rl_checkpoint,
"repair_config": asdict(inference.repair_config),
"generation_settings": {
"runtime_inference": {
"max_new_tokens": int(args.max_new_tokens),
"do_sample": False,
"prompt_style": "flat single prompt from format_district_prompt",
},
"offline_eval_style": {
"max_new_tokens": int(args.max_new_tokens),
"do_sample": False,
"prompt_style": "chat messages rendered via build_generation_prompt",
},
},
"runtime_live": {
"summary_distribution": aggregate_summary_features(runtime_rows),
"runtime_flat": runtime_flat_runtime,
"training_chat": training_chat_runtime,
},
"offline_validation_runtime_codepath": {
"summary_distribution": aggregate_summary_features(offline_rows),
"runtime_flat": runtime_flat_offline,
"training_chat": training_chat_offline,
},
"key_answers": build_key_answers(
runtime_flat_runtime=runtime_flat_runtime,
training_chat_runtime=training_chat_runtime,
runtime_flat_offline=runtime_flat_offline,
training_chat_offline=training_chat_offline,
),
"likely_root_causes_ranked": root_causes,
}
def build_key_answers(
runtime_flat_runtime: dict[str, Any],
training_chat_runtime: dict[str, Any],
runtime_flat_offline: dict[str, Any],
training_chat_offline: dict[str, Any],
) -> dict[str, Any]:
return {
"runtime_prompt_vs_training_prompt": (
"different"
if runtime_flat_offline.get("wrapper_would_fallback_rate", 0.0)
!= training_chat_offline.get("wrapper_would_fallback_rate", 0.0)
else "similar"
),
"runtime_summary_structure_vs_training_distribution": (
"requires inspection of summary_distribution stats and prompt comparison"
),
"raw_outputs_malformed_or_rejected": {
"runtime_flat_runtime_fallback_rate": runtime_flat_runtime.get("wrapper_would_fallback_rate"),
"runtime_flat_runtime_json_valid_rate": runtime_flat_runtime.get("json_valid_rate"),
"runtime_flat_runtime_schema_valid_rate": runtime_flat_runtime.get("schema_valid_before_repair_rate"),
},
"candidate_constraints_main_problem": bool(
runtime_flat_runtime.get("targets_outside_visible_candidate_rate", 0.0) > 0.2
or runtime_flat_offline.get("targets_outside_visible_candidate_rate", 0.0) > 0.2
),
"truncation_happening": bool(
runtime_flat_runtime.get("possible_truncation_rate", 0.0) > 0.05
or runtime_flat_offline.get("possible_truncation_rate", 0.0) > 0.05
),
"runtime_codepath_succeeds_on_heldout_validation": bool(
runtime_flat_offline.get("wrapper_would_fallback_rate", 1.0) < 0.2
),
}
def rank_root_causes(
runtime_flat_runtime: dict[str, Any],
training_chat_runtime: dict[str, Any],
runtime_flat_offline: dict[str, Any],
training_chat_offline: dict[str, Any],
) -> list[dict[str, Any]]:
causes: list[dict[str, Any]] = []
prompt_gap = (
runtime_flat_offline.get("wrapper_would_fallback_rate", 0.0)
- training_chat_offline.get("wrapper_would_fallback_rate", 0.0)
)
causes.append(
{
"cause": "prompt_mismatch_between_runtime_and_training_offline_chat_path",
"score": float(prompt_gap),
"evidence": {
"offline_runtime_flat_fallback_rate": runtime_flat_offline.get("wrapper_would_fallback_rate"),
"offline_training_chat_fallback_rate": training_chat_offline.get("wrapper_would_fallback_rate"),
},
}
)
runtime_summary_gap = (
training_chat_runtime.get("wrapper_would_fallback_rate", 0.0)
- training_chat_offline.get("wrapper_would_fallback_rate", 0.0)
)
causes.append(
{
"cause": "runtime_summary_distribution_shift",
"score": float(runtime_summary_gap),
"evidence": {
"runtime_training_chat_fallback_rate": training_chat_runtime.get("wrapper_would_fallback_rate"),
"offline_training_chat_fallback_rate": training_chat_offline.get("wrapper_would_fallback_rate"),
},
}
)
candidate_score = max(
runtime_flat_runtime.get("targets_outside_visible_candidate_rate", 0.0),
runtime_flat_offline.get("targets_outside_visible_candidate_rate", 0.0),
runtime_flat_runtime.get("repair_emptied_targets_rate", 0.0),
runtime_flat_offline.get("repair_emptied_targets_rate", 0.0),
)
causes.append(
{
"cause": "candidate_intersections_or_visible_target_constraint_mismatch",
"score": float(candidate_score),
"evidence": {
"runtime_targets_outside_visible_rate": runtime_flat_runtime.get("targets_outside_visible_candidate_rate"),
"offline_targets_outside_visible_rate": runtime_flat_offline.get("targets_outside_visible_candidate_rate"),
"runtime_repair_emptied_targets_rate": runtime_flat_runtime.get("repair_emptied_targets_rate"),
},
}
)
validator_score = max(
runtime_flat_runtime.get("wrapper_would_fallback_rate", 0.0)
- runtime_flat_runtime.get("json_valid_rate", 0.0),
runtime_flat_offline.get("wrapper_would_fallback_rate", 0.0)
- runtime_flat_offline.get("json_valid_rate", 0.0),
)
causes.append(
{
"cause": "validator_or_repair_stricter_than_raw_generation_quality",
"score": float(validator_score),
"evidence": {
"runtime_json_valid_rate": runtime_flat_runtime.get("json_valid_rate"),
"runtime_wrapper_fallback_rate": runtime_flat_runtime.get("wrapper_would_fallback_rate"),
"offline_json_valid_rate": runtime_flat_offline.get("json_valid_rate"),
"offline_wrapper_fallback_rate": runtime_flat_offline.get("wrapper_would_fallback_rate"),
},
}
)
truncation_score = max(
runtime_flat_runtime.get("possible_truncation_rate", 0.0),
runtime_flat_offline.get("possible_truncation_rate", 0.0),
)
causes.append(
{
"cause": "generation_truncation",
"score": float(truncation_score),
"evidence": {
"runtime_possible_truncation_rate": runtime_flat_runtime.get("possible_truncation_rate"),
"offline_possible_truncation_rate": runtime_flat_offline.get("possible_truncation_rate"),
},
}
)
causes.sort(key=lambda item: item["score"], reverse=True)
return causes
def render_prompt_comparison(
runtime_rows: list[dict[str, Any]],
offline_rows: list[dict[str, Any]],
summary_report: dict[str, Any],
) -> str:
runtime_example = runtime_rows[0] if runtime_rows else None
offline_example = offline_rows[0] if offline_rows else None
lines = [
"# Runtime Prompt Diagnosis",
"",
"## Key Finding",
"",
"Training/offline evaluation uses a chat-style prompt with separate `system` and `user` messages.",
"This report compares the chat-style prompt path against the older flattened prompt path.",
"",
"## Aggregate Answers",
"",
"```json",
json.dumps(summary_report.get("key_answers", {}), indent=2, sort_keys=True),
"```",
"",
]
if offline_example is not None:
lines.extend(
[
"## Representative Offline Validation Example",
"",
"### Training System Prompt",
"",
"```text",
offline_example["training_system_prompt"],
"```",
"",
"### Training User Prompt",
"",
"```text",
offline_example["training_user_prompt"],
"```",
"",
"### Runtime Flat Prompt",
"",
"```text",
offline_example["runtime_flat_prompt"],
"```",
"",
"### Training Chat Rendered Prompt",
"",
"```text",
offline_example["training_chat_prompt"],
"```",
"",
"### Prompt Diff",
"",
"```diff",
*list(
difflib.unified_diff(
offline_example["training_chat_prompt"].splitlines(),
offline_example["runtime_flat_prompt"].splitlines(),
fromfile="training_chat_prompt",
tofile="runtime_flat_prompt",
lineterm="",
)
),
"```",
"",
]
)
if runtime_example is not None:
lines.extend(
[
"## Representative Runtime Summary Example",
"",
"### Runtime Flat Output",
"",
"```json",
json.dumps(runtime_example["runtime_flat"], indent=2, sort_keys=True),
"```",
"",
"### Training Chat Output On Same Summary",
"",
"```json",
json.dumps(runtime_example["training_chat"], indent=2, sort_keys=True),
"```",
"",
]
)
return "\n".join(lines) + "\n"
def flatten_failure_example(row: dict[str, Any], prompt_style_key: str) -> dict[str, Any]:
payload = row[prompt_style_key]
return {
"source": row["source"],
"prompt_style": prompt_style_key,
"city_id": row["city_id"],
"scenario": row["scenario"],
"district_id": row["district_id"],
"decision_step": row["decision_step"],
"failure_reasons": payload["failure_reasons"],
"json_valid": payload["json_valid"],
"schema_valid_before_repair": payload["schema_valid_before_repair"],
"wrapper_would_fallback": payload["wrapper_would_fallback"],
"repair_report": payload["repair_report"],
"raw_text": payload["raw_text"],
"parsed_payload_before_repair": payload["parsed_payload_before_repair"],
"action_after_repair": payload["action_after_repair"],
"candidate_diagnostics": payload["candidate_diagnostics"],
"prompt_text": row["runtime_flat_prompt"] if prompt_style_key == "runtime_flat" else row["training_chat_prompt"],
}
def resolve_scenario_specs(dataset: CityFlowDataset, args: argparse.Namespace) -> list[ScenarioSpec]:
city_ids = list(args.cities) if args.cities else dataset.load_split(args.split)
scenario_specs: list[ScenarioSpec] = []
for city_id in city_ids:
available_scenarios = dataset.scenarios_for_city(city_id)
requested = list(args.scenarios) if args.scenarios else available_scenarios
for scenario_name in requested:
scenario_specs.append(dataset.build_scenario_spec(city_id, scenario_name))
return scenario_specs
def default_env_config() -> EnvConfig:
return EnvConfig(
simulator_interval=1,
decision_interval=5,
min_green_time=10,
thread_num=1,
max_episode_seconds=300,
observation=ObservationConfig(),
reward=RewardConfig(variant="wait_queue_throughput"),
)
def token_length(inference: DistrictLLMInference, text: str) -> int | None:
tokenizer = inference.tokenizer
if tokenizer is None:
return None
try:
encoded = tokenizer(text, add_special_tokens=False)
except TypeError:
encoded = tokenizer(text)
return int(len(encoded["input_ids"]))
def prompt_near_model_limit(inference: DistrictLLMInference, prompt_token_length: int | None) -> bool | None:
if prompt_token_length is None or inference.tokenizer is None:
return None
model_max_length = getattr(inference.tokenizer, "model_max_length", None)
if model_max_length is None or model_max_length <= 0 or model_max_length > 1_000_000:
return None
return bool(prompt_token_length >= int(0.9 * model_max_length))
def average(values: list[int | float | None]) -> float:
filtered = [float(value) for value in values if value is not None]
return float(mean(filtered)) if filtered else 0.0
def median_or_zero(values: list[int | float | None]) -> float:
filtered = [float(value) for value in values if value is not None]
return float(median(filtered)) if filtered else 0.0
def safe_ratio(numerator: int | float, denominator: int | float) -> float:
if float(denominator) == 0.0:
return 0.0
return float(numerator) / float(denominator)
def write_json(path: Path, payload: Any) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row, sort_keys=True))
handle.write("\n")
def write_text(path: Path, payload: str) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(payload, encoding="utf-8")
if __name__ == "__main__":
main()