Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from collections import Counter | |
| from pathlib import Path | |
| from typing import Any | |
| from district_llm.metrics import aggregate_target_metrics, compute_target_metrics, safe_ratio, target_failure_buckets | |
| from district_llm.repair import RepairConfig, extract_visible_candidate_ids, sanitize_action_payload | |
| from district_llm.schema import DistrictAction | |
| from env.utils import build_topology | |
| try: | |
| from tqdm.auto import tqdm | |
| except ImportError: # pragma: no cover | |
| tqdm = None | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser( | |
| description="Offline evaluation for district-LLM outputs." | |
| ) | |
| parser.add_argument("--model-path", required=True) | |
| parser.add_argument("--val-jsonl", required=True) | |
| parser.add_argument("--max-examples", type=int, default=200) | |
| parser.add_argument("--debug-examples", type=int, default=10) | |
| parser.add_argument("--max-new-tokens", type=int, default=128) | |
| parser.add_argument("--device", default=None) | |
| parser.add_argument("--generated-root", default="data/generated") | |
| parser.add_argument("--restrict-targets-to-visible-summary", action="store_true") | |
| parser.add_argument( | |
| "--allow-only-visible-candidates", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| ) | |
| parser.add_argument("--max-target-intersections", type=int, default=3) | |
| parser.add_argument( | |
| "--fallback-on-empty-targets", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| ) | |
| parser.add_argument( | |
| "--fallback-mode", | |
| choices=("heuristic", "hold", "none"), | |
| default="heuristic", | |
| ) | |
| parser.add_argument( | |
| "--report-before-after-repair", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| ) | |
| return parser.parse_args() | |
| def load_rows(path: str | Path, max_examples: int | None = None) -> list[dict[str, Any]]: | |
| rows = [] | |
| with Path(path).open("r", encoding="utf-8") as handle: | |
| for line in handle: | |
| if not line.strip(): | |
| continue | |
| rows.append(json.loads(line)) | |
| if max_examples is not None and len(rows) >= max_examples: | |
| break | |
| return rows | |
| def extract_json_object(payload: str) -> str: | |
| start = payload.find("{") | |
| end = payload.rfind("}") | |
| if start == -1 or end == -1 or end <= start: | |
| raise ValueError("No JSON object found.") | |
| return payload[start : end + 1] | |
| def load_model_and_tokenizer(model_path: str, device: str | None = None): | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model_dir = Path(model_path) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if (model_dir / "adapter_config.json").exists(): | |
| try: | |
| from peft import AutoPeftModelForCausalLM | |
| except ImportError as exc: | |
| raise ImportError( | |
| "Evaluating a LoRA adapter requires the 'peft' package." | |
| ) from exc | |
| model = AutoPeftModelForCausalLM.from_pretrained(model_path) | |
| else: | |
| target_device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
| model = AutoModelForCausalLM.from_pretrained(model_path).to(target_device) | |
| model.eval() | |
| return model, tokenizer | |
| def build_generation_prompt(tokenizer, messages: list[dict[str, str]]) -> str: | |
| if getattr(tokenizer, "chat_template", None): | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| return "\n".join(f"{message['role']}: {message['content']}" for message in messages) + "\nassistant:" | |
| def generate_response(model, tokenizer, messages: list[dict[str, str]], max_new_tokens: int) -> str: | |
| import torch | |
| prompt = build_generation_prompt(tokenizer, messages) | |
| device = getattr(model, "device", None) | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if device is not None: | |
| inputs = {key: value.to(device) for key, value in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, | |
| ) | |
| generated = outputs[0][inputs["input_ids"].shape[1] :] | |
| return tokenizer.decode(generated, skip_special_tokens=True) | |
| def parse_prediction(payload: str) -> tuple[bool, bool, dict[str, Any] | None]: | |
| try: | |
| json_payload = json.loads(extract_json_object(payload)) | |
| except Exception: | |
| return False, False, None | |
| try: | |
| action = DistrictAction.from_dict(json_payload) | |
| except Exception: | |
| return True, False, json_payload | |
| return True, True, action.to_dict() | |
| class DistrictTopologyIndex: | |
| def __init__(self, generated_root: str | Path): | |
| self.generated_root = Path(generated_root) | |
| self._cache: dict[str, dict[str, set[str]]] = {} | |
| def district_intersections(self, city_id: str, district_id: str) -> set[str]: | |
| if city_id not in self._cache: | |
| roadnet_path = self.generated_root / city_id / "roadnet.json" | |
| district_map_path = self.generated_root / city_id / "district_map.json" | |
| metadata_path = self.generated_root / city_id / "metadata.json" | |
| _, districts = build_topology( | |
| roadnet_path=roadnet_path, | |
| district_map_path=district_map_path, | |
| metadata_path=metadata_path, | |
| ) | |
| self._cache[city_id] = { | |
| key: set(value.intersection_ids) | |
| for key, value in districts.items() | |
| } | |
| return self._cache[city_id].get(district_id, set()) | |
| def field_accuracy(pred: dict[str, Any] | None, gt: dict[str, Any], field: str) -> float: | |
| if pred is None: | |
| return 0.0 | |
| return float(pred.get(field) == gt.get(field)) | |
| def invalid_target_fraction(pred_targets: list[str], district_candidates: set[str]) -> float: | |
| if not pred_targets: | |
| return 0.0 | |
| invalid_count = sum(1 for item in pred_targets if item not in district_candidates) | |
| return safe_ratio(invalid_count, len(pred_targets)) | |
| def evaluate_rows( | |
| rows: list[dict[str, Any]], | |
| model, | |
| tokenizer, | |
| max_new_tokens: int, | |
| topology_index: DistrictTopologyIndex, | |
| restrict_targets_to_visible_summary: bool, | |
| debug_examples: int, | |
| repair_config: RepairConfig, | |
| report_before_after_repair: bool, | |
| ) -> dict[str, Any]: | |
| json_valid_count = 0 | |
| schema_valid_count = 0 | |
| field_totals_before = Counter() | |
| field_totals_after = Counter() | |
| full_object_correct_before = 0 | |
| full_object_correct_after = 0 | |
| target_rows_before: list[dict[str, float]] = [] | |
| target_rows_after: list[dict[str, float]] = [] | |
| restricted_target_rows_before: list[dict[str, float]] = [] | |
| restricted_target_rows_after: list[dict[str, float]] = [] | |
| invalid_rates_before: list[float] = [] | |
| invalid_rates_after: list[float] = [] | |
| fallback_used_count = 0 | |
| failure_buckets = Counter() | |
| debug_rows = [] | |
| progress = ( | |
| tqdm(total=len(rows), desc="eval", dynamic_ncols=True) | |
| if tqdm is not None | |
| else None | |
| ) | |
| try: | |
| for row in rows: | |
| messages = row["messages"] | |
| ground_truth = json.loads(messages[2]["content"]) | |
| raw_prediction = generate_response( | |
| model=model, | |
| tokenizer=tokenizer, | |
| messages=messages[:2], | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| json_valid, schema_valid, prediction_before = parse_prediction(raw_prediction) | |
| repaired_action, repair_report = sanitize_action_payload( | |
| payload=prediction_before if json_valid else None, | |
| summary=row, | |
| prompt_text=messages[1]["content"], | |
| config=repair_config, | |
| ) | |
| prediction_after = repaired_action.to_dict() | |
| json_valid_count += int(json_valid) | |
| schema_valid_count += int(schema_valid) | |
| fallback_used_count += int(repair_report.fallback_used) | |
| field_totals_before["strategy"] += field_accuracy(prediction_before, ground_truth, "strategy") | |
| field_totals_before["priority_corridor"] += field_accuracy(prediction_before, ground_truth, "priority_corridor") | |
| field_totals_before["phase_bias"] += field_accuracy(prediction_before, ground_truth, "phase_bias") | |
| field_totals_before["duration_steps"] += field_accuracy(prediction_before, ground_truth, "duration_steps") | |
| field_totals_after["strategy"] += field_accuracy(prediction_after, ground_truth, "strategy") | |
| field_totals_after["priority_corridor"] += field_accuracy(prediction_after, ground_truth, "priority_corridor") | |
| field_totals_after["phase_bias"] += field_accuracy(prediction_after, ground_truth, "phase_bias") | |
| field_totals_after["duration_steps"] += field_accuracy(prediction_after, ground_truth, "duration_steps") | |
| if prediction_before == ground_truth: | |
| full_object_correct_before += 1 | |
| if prediction_after == ground_truth: | |
| full_object_correct_after += 1 | |
| pred_targets_before = [] if prediction_before is None else list(prediction_before.get("target_intersections", [])) | |
| pred_targets_after = list(prediction_after.get("target_intersections", [])) | |
| gt_targets = list(ground_truth.get("target_intersections", [])) | |
| visible_candidates = set( | |
| extract_visible_candidate_ids(summary=row, prompt_text=messages[1]["content"]) | |
| ) | |
| district_candidates = topology_index.district_intersections( | |
| city_id=row["city_id"], | |
| district_id=row["district_id"], | |
| ) | |
| invalid_before = [item for item in pred_targets_before if item not in district_candidates] | |
| invalid_after = [item for item in pred_targets_after if item not in district_candidates] | |
| non_visible_before = [ | |
| item for item in pred_targets_before | |
| if visible_candidates and item not in visible_candidates | |
| ] | |
| metrics_before = compute_target_metrics(pred_targets_before, gt_targets) | |
| metrics_after = compute_target_metrics(pred_targets_after, gt_targets) | |
| target_rows_before.append(metrics_before) | |
| target_rows_after.append(metrics_after) | |
| invalid_rates_before.append(invalid_target_fraction(pred_targets_before, district_candidates)) | |
| invalid_rates_after.append(invalid_target_fraction(pred_targets_after, district_candidates)) | |
| if restrict_targets_to_visible_summary: | |
| filtered_pred_before = [item for item in pred_targets_before if item in visible_candidates] | |
| filtered_pred_after = [item for item in pred_targets_after if item in visible_candidates] | |
| filtered_gt = [item for item in gt_targets if item in visible_candidates] | |
| restricted_target_rows_before.append( | |
| compute_target_metrics(filtered_pred_before, filtered_gt) | |
| ) | |
| restricted_target_rows_after.append( | |
| compute_target_metrics(filtered_pred_after, filtered_gt) | |
| ) | |
| for failure_bucket in set( | |
| target_failure_buckets( | |
| pred_list=pred_targets_before, | |
| gt_list=gt_targets, | |
| visible_candidates=visible_candidates, | |
| invalid_ids=invalid_before, | |
| non_visible_ids=non_visible_before, | |
| repaired_targets=pred_targets_after, | |
| fallback_used=repair_report.fallback_used, | |
| ) | |
| ): | |
| failure_buckets[failure_bucket] += 1 | |
| if len(debug_rows) < debug_examples: | |
| debug_rows.append( | |
| { | |
| "district_summary": messages[1]["content"], | |
| "predicted_json_raw": raw_prediction, | |
| "predicted_json_parsed_before_repair": prediction_before, | |
| "predicted_json_parsed_after_repair": prediction_after, | |
| "ground_truth_json": ground_truth, | |
| "target_intersections_metrics_before_repair": metrics_before, | |
| "target_intersections_metrics_after_repair": metrics_after, | |
| "repair_report": repair_report.to_dict(), | |
| "visible_candidate_ids": sorted(visible_candidates), | |
| "failure_buckets": sorted( | |
| set( | |
| target_failure_buckets( | |
| pred_list=pred_targets_before, | |
| gt_list=gt_targets, | |
| visible_candidates=visible_candidates, | |
| invalid_ids=invalid_before, | |
| non_visible_ids=non_visible_before, | |
| repaired_targets=pred_targets_after, | |
| fallback_used=repair_report.fallback_used, | |
| ) | |
| ) | |
| ), | |
| } | |
| ) | |
| if progress is not None: | |
| progress.update(1) | |
| finally: | |
| if progress is not None: | |
| progress.close() | |
| total_rows = max(1, len(rows)) | |
| results = { | |
| "num_examples": len(rows), | |
| "json_validity_rate": float(json_valid_count) / total_rows, | |
| "schema_validity_rate": float(schema_valid_count) / total_rows, | |
| "field_accuracy": { | |
| "strategy": float(field_totals_before["strategy"]) / total_rows, | |
| "priority_corridor": float(field_totals_before["priority_corridor"]) / total_rows, | |
| "phase_bias": float(field_totals_before["phase_bias"]) / total_rows, | |
| "duration_steps": float(field_totals_before["duration_steps"]) / total_rows, | |
| }, | |
| "field_accuracy_after_repair": { | |
| "strategy": float(field_totals_after["strategy"]) / total_rows, | |
| "priority_corridor": float(field_totals_after["priority_corridor"]) / total_rows, | |
| "phase_bias": float(field_totals_after["phase_bias"]) / total_rows, | |
| "duration_steps": float(field_totals_after["duration_steps"]) / total_rows, | |
| }, | |
| "target_intersections_before_repair": aggregate_target_metrics(target_rows_before), | |
| "target_intersections_after_repair": aggregate_target_metrics(target_rows_after), | |
| "target_intersections": aggregate_target_metrics(target_rows_after), | |
| "target_intersections_failure_buckets": dict(sorted(failure_buckets.items())), | |
| "exact_full_object_accuracy": float(full_object_correct_before) / total_rows, | |
| "exact_full_object_accuracy_after_repair": float(full_object_correct_after) / total_rows, | |
| "debug_examples": debug_rows, | |
| } | |
| if restrict_targets_to_visible_summary: | |
| results["target_intersections_restricted_to_visible_summary_before_repair"] = aggregate_target_metrics( | |
| restricted_target_rows_before | |
| ) | |
| results["target_intersections_restricted_to_visible_summary_after_repair"] = aggregate_target_metrics( | |
| restricted_target_rows_after | |
| ) | |
| results["target_intersections_restricted_to_visible_summary"] = aggregate_target_metrics( | |
| restricted_target_rows_after | |
| ) | |
| if report_before_after_repair: | |
| results["target_intersections_before_after_repair"] = { | |
| "invalid_id_rate_before_repair": float(sum(invalid_rates_before) / total_rows), | |
| "invalid_id_rate_after_repair": float(sum(invalid_rates_after) / total_rows), | |
| "exact_set_match_before_repair": aggregate_target_metrics(target_rows_before).get("exact_set_match", 0.0), | |
| "exact_set_match_after_repair": aggregate_target_metrics(target_rows_after).get("exact_set_match", 0.0), | |
| "jaccard_before_repair": aggregate_target_metrics(target_rows_before).get("jaccard", 0.0), | |
| "jaccard_after_repair": aggregate_target_metrics(target_rows_after).get("jaccard", 0.0), | |
| "fallback_used_rate": float(fallback_used_count) / total_rows, | |
| } | |
| return results | |
| def print_debug_examples(debug_rows: list[dict[str, Any]]) -> None: | |
| for index, item in enumerate(debug_rows, start=1): | |
| print(f"[debug {index}] district_summary:") | |
| print(item["district_summary"]) | |
| print(f"[debug {index}] predicted_json_raw={item['predicted_json_raw']}") | |
| print( | |
| f"[debug {index}] predicted_json_parsed_before_repair=" | |
| f"{json.dumps(item['predicted_json_parsed_before_repair'], sort_keys=True)}" | |
| ) | |
| print( | |
| f"[debug {index}] predicted_json_parsed_after_repair=" | |
| f"{json.dumps(item['predicted_json_parsed_after_repair'], sort_keys=True)}" | |
| ) | |
| print( | |
| f"[debug {index}] ground_truth_json=" | |
| f"{json.dumps(item['ground_truth_json'], sort_keys=True)}" | |
| ) | |
| print( | |
| f"[debug {index}] target_intersections_metrics_before_repair=" | |
| f"{json.dumps(item['target_intersections_metrics_before_repair'], sort_keys=True)}" | |
| ) | |
| print( | |
| f"[debug {index}] target_intersections_metrics_after_repair=" | |
| f"{json.dumps(item['target_intersections_metrics_after_repair'], sort_keys=True)}" | |
| ) | |
| print( | |
| f"[debug {index}] repair_report=" | |
| f"{json.dumps(item['repair_report'], sort_keys=True)}" | |
| ) | |
| print( | |
| f"[debug {index}] visible_candidate_ids=" | |
| f"{json.dumps(item['visible_candidate_ids'], sort_keys=True)}" | |
| ) | |
| print(f"[debug {index}] failure_buckets={json.dumps(item['failure_buckets'])}") | |
| def main() -> None: | |
| args = parse_args() | |
| rows = load_rows(args.val_jsonl, max_examples=args.max_examples) | |
| model, tokenizer = load_model_and_tokenizer(args.model_path, device=args.device) | |
| topology_index = DistrictTopologyIndex(args.generated_root) | |
| results = evaluate_rows( | |
| rows=rows, | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=args.max_new_tokens, | |
| topology_index=topology_index, | |
| restrict_targets_to_visible_summary=args.restrict_targets_to_visible_summary, | |
| debug_examples=args.debug_examples, | |
| repair_config=RepairConfig( | |
| allow_only_visible_candidates=args.allow_only_visible_candidates, | |
| max_target_intersections=args.max_target_intersections, | |
| fallback_on_empty_targets=args.fallback_on_empty_targets, | |
| fallback_mode=args.fallback_mode, | |
| ), | |
| report_before_after_repair=args.report_before_after_repair, | |
| ) | |
| print(json.dumps({k: v for k, v in results.items() if k != "debug_examples"}, indent=2, sort_keys=True)) | |
| print_debug_examples(results["debug_examples"]) | |
| if __name__ == "__main__": | |
| main() | |