tokev's picture
Add files using upload-large-folder tool
5893134 verified
from __future__ import annotations
import argparse
import json
from collections import Counter
from pathlib import Path
from typing import Any
from district_llm.metrics import aggregate_target_metrics, compute_target_metrics, safe_ratio, target_failure_buckets
from district_llm.repair import RepairConfig, extract_visible_candidate_ids, sanitize_action_payload
from district_llm.schema import DistrictAction
from env.utils import build_topology
try:
from tqdm.auto import tqdm
except ImportError: # pragma: no cover
tqdm = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Offline evaluation for district-LLM outputs."
)
parser.add_argument("--model-path", required=True)
parser.add_argument("--val-jsonl", required=True)
parser.add_argument("--max-examples", type=int, default=200)
parser.add_argument("--debug-examples", type=int, default=10)
parser.add_argument("--max-new-tokens", type=int, default=128)
parser.add_argument("--device", default=None)
parser.add_argument("--generated-root", default="data/generated")
parser.add_argument("--restrict-targets-to-visible-summary", action="store_true")
parser.add_argument(
"--allow-only-visible-candidates",
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument("--max-target-intersections", type=int, default=3)
parser.add_argument(
"--fallback-on-empty-targets",
action=argparse.BooleanOptionalAction,
default=True,
)
parser.add_argument(
"--fallback-mode",
choices=("heuristic", "hold", "none"),
default="heuristic",
)
parser.add_argument(
"--report-before-after-repair",
action=argparse.BooleanOptionalAction,
default=True,
)
return parser.parse_args()
def load_rows(path: str | Path, max_examples: int | None = None) -> list[dict[str, Any]]:
rows = []
with Path(path).open("r", encoding="utf-8") as handle:
for line in handle:
if not line.strip():
continue
rows.append(json.loads(line))
if max_examples is not None and len(rows) >= max_examples:
break
return rows
def extract_json_object(payload: str) -> str:
start = payload.find("{")
end = payload.rfind("}")
if start == -1 or end == -1 or end <= start:
raise ValueError("No JSON object found.")
return payload[start : end + 1]
def load_model_and_tokenizer(model_path: str, device: str | None = None):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_dir = Path(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
if (model_dir / "adapter_config.json").exists():
try:
from peft import AutoPeftModelForCausalLM
except ImportError as exc:
raise ImportError(
"Evaluating a LoRA adapter requires the 'peft' package."
) from exc
model = AutoPeftModelForCausalLM.from_pretrained(model_path)
else:
target_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_path).to(target_device)
model.eval()
return model, tokenizer
def build_generation_prompt(tokenizer, messages: list[dict[str, str]]) -> str:
if getattr(tokenizer, "chat_template", None):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
return "\n".join(f"{message['role']}: {message['content']}" for message in messages) + "\nassistant:"
def generate_response(model, tokenizer, messages: list[dict[str, str]], max_new_tokens: int) -> str:
import torch
prompt = build_generation_prompt(tokenizer, messages)
device = getattr(model, "device", None)
inputs = tokenizer(prompt, return_tensors="pt")
if device is not None:
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
)
generated = outputs[0][inputs["input_ids"].shape[1] :]
return tokenizer.decode(generated, skip_special_tokens=True)
def parse_prediction(payload: str) -> tuple[bool, bool, dict[str, Any] | None]:
try:
json_payload = json.loads(extract_json_object(payload))
except Exception:
return False, False, None
try:
action = DistrictAction.from_dict(json_payload)
except Exception:
return True, False, json_payload
return True, True, action.to_dict()
class DistrictTopologyIndex:
def __init__(self, generated_root: str | Path):
self.generated_root = Path(generated_root)
self._cache: dict[str, dict[str, set[str]]] = {}
def district_intersections(self, city_id: str, district_id: str) -> set[str]:
if city_id not in self._cache:
roadnet_path = self.generated_root / city_id / "roadnet.json"
district_map_path = self.generated_root / city_id / "district_map.json"
metadata_path = self.generated_root / city_id / "metadata.json"
_, districts = build_topology(
roadnet_path=roadnet_path,
district_map_path=district_map_path,
metadata_path=metadata_path,
)
self._cache[city_id] = {
key: set(value.intersection_ids)
for key, value in districts.items()
}
return self._cache[city_id].get(district_id, set())
def field_accuracy(pred: dict[str, Any] | None, gt: dict[str, Any], field: str) -> float:
if pred is None:
return 0.0
return float(pred.get(field) == gt.get(field))
def invalid_target_fraction(pred_targets: list[str], district_candidates: set[str]) -> float:
if not pred_targets:
return 0.0
invalid_count = sum(1 for item in pred_targets if item not in district_candidates)
return safe_ratio(invalid_count, len(pred_targets))
def evaluate_rows(
rows: list[dict[str, Any]],
model,
tokenizer,
max_new_tokens: int,
topology_index: DistrictTopologyIndex,
restrict_targets_to_visible_summary: bool,
debug_examples: int,
repair_config: RepairConfig,
report_before_after_repair: bool,
) -> dict[str, Any]:
json_valid_count = 0
schema_valid_count = 0
field_totals_before = Counter()
field_totals_after = Counter()
full_object_correct_before = 0
full_object_correct_after = 0
target_rows_before: list[dict[str, float]] = []
target_rows_after: list[dict[str, float]] = []
restricted_target_rows_before: list[dict[str, float]] = []
restricted_target_rows_after: list[dict[str, float]] = []
invalid_rates_before: list[float] = []
invalid_rates_after: list[float] = []
fallback_used_count = 0
failure_buckets = Counter()
debug_rows = []
progress = (
tqdm(total=len(rows), desc="eval", dynamic_ncols=True)
if tqdm is not None
else None
)
try:
for row in rows:
messages = row["messages"]
ground_truth = json.loads(messages[2]["content"])
raw_prediction = generate_response(
model=model,
tokenizer=tokenizer,
messages=messages[:2],
max_new_tokens=max_new_tokens,
)
json_valid, schema_valid, prediction_before = parse_prediction(raw_prediction)
repaired_action, repair_report = sanitize_action_payload(
payload=prediction_before if json_valid else None,
summary=row,
prompt_text=messages[1]["content"],
config=repair_config,
)
prediction_after = repaired_action.to_dict()
json_valid_count += int(json_valid)
schema_valid_count += int(schema_valid)
fallback_used_count += int(repair_report.fallback_used)
field_totals_before["strategy"] += field_accuracy(prediction_before, ground_truth, "strategy")
field_totals_before["priority_corridor"] += field_accuracy(prediction_before, ground_truth, "priority_corridor")
field_totals_before["phase_bias"] += field_accuracy(prediction_before, ground_truth, "phase_bias")
field_totals_before["duration_steps"] += field_accuracy(prediction_before, ground_truth, "duration_steps")
field_totals_after["strategy"] += field_accuracy(prediction_after, ground_truth, "strategy")
field_totals_after["priority_corridor"] += field_accuracy(prediction_after, ground_truth, "priority_corridor")
field_totals_after["phase_bias"] += field_accuracy(prediction_after, ground_truth, "phase_bias")
field_totals_after["duration_steps"] += field_accuracy(prediction_after, ground_truth, "duration_steps")
if prediction_before == ground_truth:
full_object_correct_before += 1
if prediction_after == ground_truth:
full_object_correct_after += 1
pred_targets_before = [] if prediction_before is None else list(prediction_before.get("target_intersections", []))
pred_targets_after = list(prediction_after.get("target_intersections", []))
gt_targets = list(ground_truth.get("target_intersections", []))
visible_candidates = set(
extract_visible_candidate_ids(summary=row, prompt_text=messages[1]["content"])
)
district_candidates = topology_index.district_intersections(
city_id=row["city_id"],
district_id=row["district_id"],
)
invalid_before = [item for item in pred_targets_before if item not in district_candidates]
invalid_after = [item for item in pred_targets_after if item not in district_candidates]
non_visible_before = [
item for item in pred_targets_before
if visible_candidates and item not in visible_candidates
]
metrics_before = compute_target_metrics(pred_targets_before, gt_targets)
metrics_after = compute_target_metrics(pred_targets_after, gt_targets)
target_rows_before.append(metrics_before)
target_rows_after.append(metrics_after)
invalid_rates_before.append(invalid_target_fraction(pred_targets_before, district_candidates))
invalid_rates_after.append(invalid_target_fraction(pred_targets_after, district_candidates))
if restrict_targets_to_visible_summary:
filtered_pred_before = [item for item in pred_targets_before if item in visible_candidates]
filtered_pred_after = [item for item in pred_targets_after if item in visible_candidates]
filtered_gt = [item for item in gt_targets if item in visible_candidates]
restricted_target_rows_before.append(
compute_target_metrics(filtered_pred_before, filtered_gt)
)
restricted_target_rows_after.append(
compute_target_metrics(filtered_pred_after, filtered_gt)
)
for failure_bucket in set(
target_failure_buckets(
pred_list=pred_targets_before,
gt_list=gt_targets,
visible_candidates=visible_candidates,
invalid_ids=invalid_before,
non_visible_ids=non_visible_before,
repaired_targets=pred_targets_after,
fallback_used=repair_report.fallback_used,
)
):
failure_buckets[failure_bucket] += 1
if len(debug_rows) < debug_examples:
debug_rows.append(
{
"district_summary": messages[1]["content"],
"predicted_json_raw": raw_prediction,
"predicted_json_parsed_before_repair": prediction_before,
"predicted_json_parsed_after_repair": prediction_after,
"ground_truth_json": ground_truth,
"target_intersections_metrics_before_repair": metrics_before,
"target_intersections_metrics_after_repair": metrics_after,
"repair_report": repair_report.to_dict(),
"visible_candidate_ids": sorted(visible_candidates),
"failure_buckets": sorted(
set(
target_failure_buckets(
pred_list=pred_targets_before,
gt_list=gt_targets,
visible_candidates=visible_candidates,
invalid_ids=invalid_before,
non_visible_ids=non_visible_before,
repaired_targets=pred_targets_after,
fallback_used=repair_report.fallback_used,
)
)
),
}
)
if progress is not None:
progress.update(1)
finally:
if progress is not None:
progress.close()
total_rows = max(1, len(rows))
results = {
"num_examples": len(rows),
"json_validity_rate": float(json_valid_count) / total_rows,
"schema_validity_rate": float(schema_valid_count) / total_rows,
"field_accuracy": {
"strategy": float(field_totals_before["strategy"]) / total_rows,
"priority_corridor": float(field_totals_before["priority_corridor"]) / total_rows,
"phase_bias": float(field_totals_before["phase_bias"]) / total_rows,
"duration_steps": float(field_totals_before["duration_steps"]) / total_rows,
},
"field_accuracy_after_repair": {
"strategy": float(field_totals_after["strategy"]) / total_rows,
"priority_corridor": float(field_totals_after["priority_corridor"]) / total_rows,
"phase_bias": float(field_totals_after["phase_bias"]) / total_rows,
"duration_steps": float(field_totals_after["duration_steps"]) / total_rows,
},
"target_intersections_before_repair": aggregate_target_metrics(target_rows_before),
"target_intersections_after_repair": aggregate_target_metrics(target_rows_after),
"target_intersections": aggregate_target_metrics(target_rows_after),
"target_intersections_failure_buckets": dict(sorted(failure_buckets.items())),
"exact_full_object_accuracy": float(full_object_correct_before) / total_rows,
"exact_full_object_accuracy_after_repair": float(full_object_correct_after) / total_rows,
"debug_examples": debug_rows,
}
if restrict_targets_to_visible_summary:
results["target_intersections_restricted_to_visible_summary_before_repair"] = aggregate_target_metrics(
restricted_target_rows_before
)
results["target_intersections_restricted_to_visible_summary_after_repair"] = aggregate_target_metrics(
restricted_target_rows_after
)
results["target_intersections_restricted_to_visible_summary"] = aggregate_target_metrics(
restricted_target_rows_after
)
if report_before_after_repair:
results["target_intersections_before_after_repair"] = {
"invalid_id_rate_before_repair": float(sum(invalid_rates_before) / total_rows),
"invalid_id_rate_after_repair": float(sum(invalid_rates_after) / total_rows),
"exact_set_match_before_repair": aggregate_target_metrics(target_rows_before).get("exact_set_match", 0.0),
"exact_set_match_after_repair": aggregate_target_metrics(target_rows_after).get("exact_set_match", 0.0),
"jaccard_before_repair": aggregate_target_metrics(target_rows_before).get("jaccard", 0.0),
"jaccard_after_repair": aggregate_target_metrics(target_rows_after).get("jaccard", 0.0),
"fallback_used_rate": float(fallback_used_count) / total_rows,
}
return results
def print_debug_examples(debug_rows: list[dict[str, Any]]) -> None:
for index, item in enumerate(debug_rows, start=1):
print(f"[debug {index}] district_summary:")
print(item["district_summary"])
print(f"[debug {index}] predicted_json_raw={item['predicted_json_raw']}")
print(
f"[debug {index}] predicted_json_parsed_before_repair="
f"{json.dumps(item['predicted_json_parsed_before_repair'], sort_keys=True)}"
)
print(
f"[debug {index}] predicted_json_parsed_after_repair="
f"{json.dumps(item['predicted_json_parsed_after_repair'], sort_keys=True)}"
)
print(
f"[debug {index}] ground_truth_json="
f"{json.dumps(item['ground_truth_json'], sort_keys=True)}"
)
print(
f"[debug {index}] target_intersections_metrics_before_repair="
f"{json.dumps(item['target_intersections_metrics_before_repair'], sort_keys=True)}"
)
print(
f"[debug {index}] target_intersections_metrics_after_repair="
f"{json.dumps(item['target_intersections_metrics_after_repair'], sort_keys=True)}"
)
print(
f"[debug {index}] repair_report="
f"{json.dumps(item['repair_report'], sort_keys=True)}"
)
print(
f"[debug {index}] visible_candidate_ids="
f"{json.dumps(item['visible_candidate_ids'], sort_keys=True)}"
)
print(f"[debug {index}] failure_buckets={json.dumps(item['failure_buckets'])}")
def main() -> None:
args = parse_args()
rows = load_rows(args.val_jsonl, max_examples=args.max_examples)
model, tokenizer = load_model_and_tokenizer(args.model_path, device=args.device)
topology_index = DistrictTopologyIndex(args.generated_root)
results = evaluate_rows(
rows=rows,
model=model,
tokenizer=tokenizer,
max_new_tokens=args.max_new_tokens,
topology_index=topology_index,
restrict_targets_to_visible_summary=args.restrict_targets_to_visible_summary,
debug_examples=args.debug_examples,
repair_config=RepairConfig(
allow_only_visible_candidates=args.allow_only_visible_candidates,
max_target_intersections=args.max_target_intersections,
fallback_on_empty_targets=args.fallback_on_empty_targets,
fallback_mode=args.fallback_mode,
),
report_before_after_repair=args.report_before_after_repair,
)
print(json.dumps({k: v for k, v in results.items() if k != "debug_examples"}, indent=2, sort_keys=True))
print_debug_examples(results["debug_examples"])
if __name__ == "__main__":
main()