from __future__ import annotations import argparse import json from collections import Counter from pathlib import Path from typing import Any from district_llm.metrics import aggregate_target_metrics, compute_target_metrics, safe_ratio, target_failure_buckets from district_llm.repair import RepairConfig, extract_visible_candidate_ids, sanitize_action_payload from district_llm.schema import DistrictAction from env.utils import build_topology try: from tqdm.auto import tqdm except ImportError: # pragma: no cover tqdm = None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Offline evaluation for district-LLM outputs." ) parser.add_argument("--model-path", required=True) parser.add_argument("--val-jsonl", required=True) parser.add_argument("--max-examples", type=int, default=200) parser.add_argument("--debug-examples", type=int, default=10) parser.add_argument("--max-new-tokens", type=int, default=128) parser.add_argument("--device", default=None) parser.add_argument("--generated-root", default="data/generated") parser.add_argument("--restrict-targets-to-visible-summary", action="store_true") parser.add_argument( "--allow-only-visible-candidates", action=argparse.BooleanOptionalAction, default=True, ) parser.add_argument("--max-target-intersections", type=int, default=3) parser.add_argument( "--fallback-on-empty-targets", action=argparse.BooleanOptionalAction, default=True, ) parser.add_argument( "--fallback-mode", choices=("heuristic", "hold", "none"), default="heuristic", ) parser.add_argument( "--report-before-after-repair", action=argparse.BooleanOptionalAction, default=True, ) return parser.parse_args() def load_rows(path: str | Path, max_examples: int | None = None) -> list[dict[str, Any]]: rows = [] with Path(path).open("r", encoding="utf-8") as handle: for line in handle: if not line.strip(): continue rows.append(json.loads(line)) if max_examples is not None and len(rows) >= max_examples: break return rows def extract_json_object(payload: str) -> str: start = payload.find("{") end = payload.rfind("}") if start == -1 or end == -1 or end <= start: raise ValueError("No JSON object found.") return payload[start : end + 1] def load_model_and_tokenizer(model_path: str, device: str | None = None): import torch from transformers import AutoModelForCausalLM, AutoTokenizer model_dir = Path(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token = tokenizer.eos_token if (model_dir / "adapter_config.json").exists(): try: from peft import AutoPeftModelForCausalLM except ImportError as exc: raise ImportError( "Evaluating a LoRA adapter requires the 'peft' package." ) from exc model = AutoPeftModelForCausalLM.from_pretrained(model_path) else: target_device = device or ("cuda" if torch.cuda.is_available() else "cpu") model = AutoModelForCausalLM.from_pretrained(model_path).to(target_device) model.eval() return model, tokenizer def build_generation_prompt(tokenizer, messages: list[dict[str, str]]) -> str: if getattr(tokenizer, "chat_template", None): return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) return "\n".join(f"{message['role']}: {message['content']}" for message in messages) + "\nassistant:" def generate_response(model, tokenizer, messages: list[dict[str, str]], max_new_tokens: int) -> str: import torch prompt = build_generation_prompt(tokenizer, messages) device = getattr(model, "device", None) inputs = tokenizer(prompt, return_tensors="pt") if device is not None: inputs = {key: value.to(device) for key, value in inputs.items()} with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, ) generated = outputs[0][inputs["input_ids"].shape[1] :] return tokenizer.decode(generated, skip_special_tokens=True) def parse_prediction(payload: str) -> tuple[bool, bool, dict[str, Any] | None]: try: json_payload = json.loads(extract_json_object(payload)) except Exception: return False, False, None try: action = DistrictAction.from_dict(json_payload) except Exception: return True, False, json_payload return True, True, action.to_dict() class DistrictTopologyIndex: def __init__(self, generated_root: str | Path): self.generated_root = Path(generated_root) self._cache: dict[str, dict[str, set[str]]] = {} def district_intersections(self, city_id: str, district_id: str) -> set[str]: if city_id not in self._cache: roadnet_path = self.generated_root / city_id / "roadnet.json" district_map_path = self.generated_root / city_id / "district_map.json" metadata_path = self.generated_root / city_id / "metadata.json" _, districts = build_topology( roadnet_path=roadnet_path, district_map_path=district_map_path, metadata_path=metadata_path, ) self._cache[city_id] = { key: set(value.intersection_ids) for key, value in districts.items() } return self._cache[city_id].get(district_id, set()) def field_accuracy(pred: dict[str, Any] | None, gt: dict[str, Any], field: str) -> float: if pred is None: return 0.0 return float(pred.get(field) == gt.get(field)) def invalid_target_fraction(pred_targets: list[str], district_candidates: set[str]) -> float: if not pred_targets: return 0.0 invalid_count = sum(1 for item in pred_targets if item not in district_candidates) return safe_ratio(invalid_count, len(pred_targets)) def evaluate_rows( rows: list[dict[str, Any]], model, tokenizer, max_new_tokens: int, topology_index: DistrictTopologyIndex, restrict_targets_to_visible_summary: bool, debug_examples: int, repair_config: RepairConfig, report_before_after_repair: bool, ) -> dict[str, Any]: json_valid_count = 0 schema_valid_count = 0 field_totals_before = Counter() field_totals_after = Counter() full_object_correct_before = 0 full_object_correct_after = 0 target_rows_before: list[dict[str, float]] = [] target_rows_after: list[dict[str, float]] = [] restricted_target_rows_before: list[dict[str, float]] = [] restricted_target_rows_after: list[dict[str, float]] = [] invalid_rates_before: list[float] = [] invalid_rates_after: list[float] = [] fallback_used_count = 0 failure_buckets = Counter() debug_rows = [] progress = ( tqdm(total=len(rows), desc="eval", dynamic_ncols=True) if tqdm is not None else None ) try: for row in rows: messages = row["messages"] ground_truth = json.loads(messages[2]["content"]) raw_prediction = generate_response( model=model, tokenizer=tokenizer, messages=messages[:2], max_new_tokens=max_new_tokens, ) json_valid, schema_valid, prediction_before = parse_prediction(raw_prediction) repaired_action, repair_report = sanitize_action_payload( payload=prediction_before if json_valid else None, summary=row, prompt_text=messages[1]["content"], config=repair_config, ) prediction_after = repaired_action.to_dict() json_valid_count += int(json_valid) schema_valid_count += int(schema_valid) fallback_used_count += int(repair_report.fallback_used) field_totals_before["strategy"] += field_accuracy(prediction_before, ground_truth, "strategy") field_totals_before["priority_corridor"] += field_accuracy(prediction_before, ground_truth, "priority_corridor") field_totals_before["phase_bias"] += field_accuracy(prediction_before, ground_truth, "phase_bias") field_totals_before["duration_steps"] += field_accuracy(prediction_before, ground_truth, "duration_steps") field_totals_after["strategy"] += field_accuracy(prediction_after, ground_truth, "strategy") field_totals_after["priority_corridor"] += field_accuracy(prediction_after, ground_truth, "priority_corridor") field_totals_after["phase_bias"] += field_accuracy(prediction_after, ground_truth, "phase_bias") field_totals_after["duration_steps"] += field_accuracy(prediction_after, ground_truth, "duration_steps") if prediction_before == ground_truth: full_object_correct_before += 1 if prediction_after == ground_truth: full_object_correct_after += 1 pred_targets_before = [] if prediction_before is None else list(prediction_before.get("target_intersections", [])) pred_targets_after = list(prediction_after.get("target_intersections", [])) gt_targets = list(ground_truth.get("target_intersections", [])) visible_candidates = set( extract_visible_candidate_ids(summary=row, prompt_text=messages[1]["content"]) ) district_candidates = topology_index.district_intersections( city_id=row["city_id"], district_id=row["district_id"], ) invalid_before = [item for item in pred_targets_before if item not in district_candidates] invalid_after = [item for item in pred_targets_after if item not in district_candidates] non_visible_before = [ item for item in pred_targets_before if visible_candidates and item not in visible_candidates ] metrics_before = compute_target_metrics(pred_targets_before, gt_targets) metrics_after = compute_target_metrics(pred_targets_after, gt_targets) target_rows_before.append(metrics_before) target_rows_after.append(metrics_after) invalid_rates_before.append(invalid_target_fraction(pred_targets_before, district_candidates)) invalid_rates_after.append(invalid_target_fraction(pred_targets_after, district_candidates)) if restrict_targets_to_visible_summary: filtered_pred_before = [item for item in pred_targets_before if item in visible_candidates] filtered_pred_after = [item for item in pred_targets_after if item in visible_candidates] filtered_gt = [item for item in gt_targets if item in visible_candidates] restricted_target_rows_before.append( compute_target_metrics(filtered_pred_before, filtered_gt) ) restricted_target_rows_after.append( compute_target_metrics(filtered_pred_after, filtered_gt) ) for failure_bucket in set( target_failure_buckets( pred_list=pred_targets_before, gt_list=gt_targets, visible_candidates=visible_candidates, invalid_ids=invalid_before, non_visible_ids=non_visible_before, repaired_targets=pred_targets_after, fallback_used=repair_report.fallback_used, ) ): failure_buckets[failure_bucket] += 1 if len(debug_rows) < debug_examples: debug_rows.append( { "district_summary": messages[1]["content"], "predicted_json_raw": raw_prediction, "predicted_json_parsed_before_repair": prediction_before, "predicted_json_parsed_after_repair": prediction_after, "ground_truth_json": ground_truth, "target_intersections_metrics_before_repair": metrics_before, "target_intersections_metrics_after_repair": metrics_after, "repair_report": repair_report.to_dict(), "visible_candidate_ids": sorted(visible_candidates), "failure_buckets": sorted( set( target_failure_buckets( pred_list=pred_targets_before, gt_list=gt_targets, visible_candidates=visible_candidates, invalid_ids=invalid_before, non_visible_ids=non_visible_before, repaired_targets=pred_targets_after, fallback_used=repair_report.fallback_used, ) ) ), } ) if progress is not None: progress.update(1) finally: if progress is not None: progress.close() total_rows = max(1, len(rows)) results = { "num_examples": len(rows), "json_validity_rate": float(json_valid_count) / total_rows, "schema_validity_rate": float(schema_valid_count) / total_rows, "field_accuracy": { "strategy": float(field_totals_before["strategy"]) / total_rows, "priority_corridor": float(field_totals_before["priority_corridor"]) / total_rows, "phase_bias": float(field_totals_before["phase_bias"]) / total_rows, "duration_steps": float(field_totals_before["duration_steps"]) / total_rows, }, "field_accuracy_after_repair": { "strategy": float(field_totals_after["strategy"]) / total_rows, "priority_corridor": float(field_totals_after["priority_corridor"]) / total_rows, "phase_bias": float(field_totals_after["phase_bias"]) / total_rows, "duration_steps": float(field_totals_after["duration_steps"]) / total_rows, }, "target_intersections_before_repair": aggregate_target_metrics(target_rows_before), "target_intersections_after_repair": aggregate_target_metrics(target_rows_after), "target_intersections": aggregate_target_metrics(target_rows_after), "target_intersections_failure_buckets": dict(sorted(failure_buckets.items())), "exact_full_object_accuracy": float(full_object_correct_before) / total_rows, "exact_full_object_accuracy_after_repair": float(full_object_correct_after) / total_rows, "debug_examples": debug_rows, } if restrict_targets_to_visible_summary: results["target_intersections_restricted_to_visible_summary_before_repair"] = aggregate_target_metrics( restricted_target_rows_before ) results["target_intersections_restricted_to_visible_summary_after_repair"] = aggregate_target_metrics( restricted_target_rows_after ) results["target_intersections_restricted_to_visible_summary"] = aggregate_target_metrics( restricted_target_rows_after ) if report_before_after_repair: results["target_intersections_before_after_repair"] = { "invalid_id_rate_before_repair": float(sum(invalid_rates_before) / total_rows), "invalid_id_rate_after_repair": float(sum(invalid_rates_after) / total_rows), "exact_set_match_before_repair": aggregate_target_metrics(target_rows_before).get("exact_set_match", 0.0), "exact_set_match_after_repair": aggregate_target_metrics(target_rows_after).get("exact_set_match", 0.0), "jaccard_before_repair": aggregate_target_metrics(target_rows_before).get("jaccard", 0.0), "jaccard_after_repair": aggregate_target_metrics(target_rows_after).get("jaccard", 0.0), "fallback_used_rate": float(fallback_used_count) / total_rows, } return results def print_debug_examples(debug_rows: list[dict[str, Any]]) -> None: for index, item in enumerate(debug_rows, start=1): print(f"[debug {index}] district_summary:") print(item["district_summary"]) print(f"[debug {index}] predicted_json_raw={item['predicted_json_raw']}") print( f"[debug {index}] predicted_json_parsed_before_repair=" f"{json.dumps(item['predicted_json_parsed_before_repair'], sort_keys=True)}" ) print( f"[debug {index}] predicted_json_parsed_after_repair=" f"{json.dumps(item['predicted_json_parsed_after_repair'], sort_keys=True)}" ) print( f"[debug {index}] ground_truth_json=" f"{json.dumps(item['ground_truth_json'], sort_keys=True)}" ) print( f"[debug {index}] target_intersections_metrics_before_repair=" f"{json.dumps(item['target_intersections_metrics_before_repair'], sort_keys=True)}" ) print( f"[debug {index}] target_intersections_metrics_after_repair=" f"{json.dumps(item['target_intersections_metrics_after_repair'], sort_keys=True)}" ) print( f"[debug {index}] repair_report=" f"{json.dumps(item['repair_report'], sort_keys=True)}" ) print( f"[debug {index}] visible_candidate_ids=" f"{json.dumps(item['visible_candidate_ids'], sort_keys=True)}" ) print(f"[debug {index}] failure_buckets={json.dumps(item['failure_buckets'])}") def main() -> None: args = parse_args() rows = load_rows(args.val_jsonl, max_examples=args.max_examples) model, tokenizer = load_model_and_tokenizer(args.model_path, device=args.device) topology_index = DistrictTopologyIndex(args.generated_root) results = evaluate_rows( rows=rows, model=model, tokenizer=tokenizer, max_new_tokens=args.max_new_tokens, topology_index=topology_index, restrict_targets_to_visible_summary=args.restrict_targets_to_visible_summary, debug_examples=args.debug_examples, repair_config=RepairConfig( allow_only_visible_candidates=args.allow_only_visible_candidates, max_target_intersections=args.max_target_intersections, fallback_on_empty_targets=args.fallback_on_empty_targets, fallback_mode=args.fallback_mode, ), report_before_after_repair=args.report_before_after_repair, ) print(json.dumps({k: v for k, v in results.items() if k != "debug_examples"}, indent=2, sort_keys=True)) print_debug_examples(results["debug_examples"]) if __name__ == "__main__": main()