Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import statistics | |
| import sys | |
| from collections import Counter, defaultdict | |
| from copy import deepcopy | |
| from datetime import datetime | |
| from difflib import SequenceMatcher | |
| from itertools import combinations | |
| from pathlib import Path | |
| from time import perf_counter | |
| from typing import Any | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from nlu_engine import NLUEngine | |
| from state_manager import GameState | |
| from story_engine import StoryEngine | |
| DATASET_DIR = PROJECT_ROOT / "evaluation" / "datasets" | |
| RESULTS_DIR = PROJECT_ROOT / "evaluation" / "results" | |
| def _json_safe(value: Any) -> Any: | |
| if value is None or isinstance(value, (str, int, float, bool)): | |
| return value | |
| if isinstance(value, dict): | |
| return {str(key): _json_safe(val) for key, val in value.items()} | |
| if isinstance(value, (list, tuple, set)): | |
| return [_json_safe(item) for item in value] | |
| if hasattr(value, "model_dump"): | |
| return _json_safe(value.model_dump()) | |
| return str(value) | |
| def _normalize_text(value: Any) -> str: | |
| return str(value or "").strip().lower() | |
| def _load_dataset(name: str) -> Any: | |
| with (DATASET_DIR / f"{name}.json").open("r", encoding="utf-8") as fh: | |
| return json.load(fh) | |
| def _apply_setup(game_state: GameState, setup: dict[str, Any] | None) -> GameState: | |
| if not setup: | |
| game_state.player.location = game_state.world.current_scene | |
| return game_state | |
| player_setup = setup.get("player", {}) | |
| world_setup = setup.get("world", {}) | |
| for key, value in player_setup.items(): | |
| if key == "inventory": | |
| game_state.player.inventory = list(value) | |
| elif key == "skills": | |
| game_state.player.skills = list(value) | |
| elif key == "equipment": | |
| updated = dict(game_state.player.equipment) | |
| updated.update(dict(value)) | |
| game_state.player.equipment = updated | |
| else: | |
| setattr(game_state.player, key, deepcopy(value)) | |
| for key, value in world_setup.items(): | |
| if key == "discovered_locations": | |
| game_state.world.discovered_locations = list(value) | |
| elif key == "global_flags": | |
| game_state.world.global_flags.update(dict(value)) | |
| else: | |
| setattr(game_state.world, key, deepcopy(value)) | |
| for npc_name, overrides in setup.get("npc_overrides", {}).items(): | |
| npc = game_state.world.npcs.get(npc_name) | |
| if npc is None: | |
| continue | |
| for key, value in overrides.items(): | |
| setattr(npc, key, deepcopy(value)) | |
| if "turn" in setup: | |
| game_state.turn = int(setup["turn"]) | |
| if "location" not in player_setup and "current_scene" in world_setup: | |
| game_state.player.location = game_state.world.current_scene | |
| elif "location" in player_setup and "current_scene" not in world_setup: | |
| game_state.world.current_scene = game_state.player.location | |
| elif not player_setup and not world_setup: | |
| game_state.player.location = game_state.world.current_scene | |
| return game_state | |
| def _build_game_state(setup: dict[str, Any] | None = None) -> GameState: | |
| game_state = GameState(player_name="Evaluator") | |
| return _apply_setup(game_state, setup) | |
| def _state_snapshot(game_state: GameState) -> dict[str, Any]: | |
| return { | |
| "turn": game_state.turn, | |
| "game_mode": game_state.game_mode, | |
| "location": game_state.player.location, | |
| "scene": game_state.world.current_scene, | |
| "day": game_state.world.day_count, | |
| "time_of_day": game_state.world.time_of_day, | |
| "weather": game_state.world.weather, | |
| "hp": game_state.player.hp, | |
| "mp": game_state.player.mp, | |
| "gold": game_state.player.gold, | |
| "morale": game_state.player.morale, | |
| "sanity": game_state.player.sanity, | |
| "hunger": game_state.player.hunger, | |
| "inventory": list(game_state.player.inventory), | |
| "equipment": dict(game_state.player.equipment), | |
| "skills": list(game_state.player.skills), | |
| "active_quests": { | |
| quest_id: { | |
| "status": quest.status, | |
| "objectives": dict(quest.objectives), | |
| } | |
| for quest_id, quest in game_state.world.quests.items() | |
| if quest.status == "active" | |
| }, | |
| } | |
| def _flatten(value: Any, prefix: str = "") -> set[str]: | |
| flattened: set[str] = set() | |
| if isinstance(value, dict): | |
| for key, child in value.items(): | |
| child_prefix = f"{prefix}.{key}" if prefix else str(key) | |
| flattened.update(_flatten(child, child_prefix)) | |
| elif isinstance(value, list): | |
| list_prefix = prefix or "list" | |
| for index, child in enumerate(value): | |
| flattened.update(_flatten(child, f"{list_prefix}[{index}]")) | |
| if not value: | |
| flattened.add(f"{list_prefix}=[]") | |
| else: | |
| flattened.add(f"{prefix}={value}") | |
| return flattened | |
| def _jaccard_distance(left: set[str], right: set[str]) -> float: | |
| union = left | right | |
| if not union: | |
| return 0.0 | |
| intersection = left & right | |
| return 1.0 - (len(intersection) / len(union)) | |
| def _option_texts(options: list[dict[str, Any]]) -> set[str]: | |
| texts = set() | |
| for option in options or []: | |
| if isinstance(option, dict): | |
| texts.add(str(option.get("text", ""))) | |
| else: | |
| texts.add(str(option)) | |
| return texts | |
| def _consume_story_stream(story_engine: StoryEngine, intent: dict[str, Any]) -> tuple[dict[str, Any], float]: | |
| story_chunks: list[str] = [] | |
| final_result: dict[str, Any] | None = None | |
| started = perf_counter() | |
| for update in story_engine.generate_story_stream(intent): | |
| if update["type"] == "story_chunk": | |
| story_chunks.append(update["text"]) | |
| elif update["type"] == "final": | |
| final_result = update | |
| latency_ms = (perf_counter() - started) * 1000 | |
| if final_result is None: | |
| final_result = { | |
| "story_text": story_chunks[-1] if story_chunks else "", | |
| "options": [], | |
| "state_changes": {}, | |
| "change_log": [], | |
| "consistency_issues": [], | |
| "telemetry": { | |
| "engine_mode": "evaluation_fallback", | |
| "used_fallback": True, | |
| "fallback_reason": "missing_final_event", | |
| }, | |
| } | |
| return final_result, latency_ms | |
| def _run_text_turn(user_input: str, setup: dict[str, Any] | None = None) -> dict[str, Any]: | |
| game_state = _build_game_state(setup) | |
| nlu = NLUEngine(game_state) | |
| story = StoryEngine(game_state) | |
| nlu_started = perf_counter() | |
| intent = nlu.parse_intent(user_input) | |
| nlu_latency_ms = (perf_counter() - nlu_started) * 1000 | |
| final_result, story_latency_ms = _consume_story_stream(story, intent) | |
| return { | |
| "user_input": user_input, | |
| "intent": intent, | |
| "nlu_latency_ms": nlu_latency_ms, | |
| "story_latency_ms": story_latency_ms, | |
| "total_latency_ms": nlu_latency_ms + story_latency_ms, | |
| "final_result": final_result, | |
| "state_snapshot": _state_snapshot(game_state), | |
| } | |
| def _percentile(values: list[float], percentile: float) -> float: | |
| if not values: | |
| return 0.0 | |
| ordered = sorted(values) | |
| index = max(0, min(len(ordered) - 1, round((percentile / 100) * (len(ordered) - 1)))) | |
| return ordered[index] | |
| def _summarize_fallback_records(records: list[dict[str, Any]]) -> dict[str, Any]: | |
| fallback_count = 0 | |
| reason_counter = Counter() | |
| engine_counter = Counter() | |
| for record in records: | |
| if record.get("used_fallback"): | |
| fallback_count += 1 | |
| reason_counter[str(record.get("fallback_reason") or "unknown")] += 1 | |
| engine_counter[str(record.get("engine_mode") or "unknown")] += 1 | |
| total = len(records) | |
| return { | |
| "fallback_count": fallback_count, | |
| "fallback_rate": round(fallback_count / total, 4) if total else 0.0, | |
| "fallback_reason_breakdown": dict(reason_counter), | |
| "engine_mode_breakdown": dict(engine_counter), | |
| } | |
| def _limit_cases(cases: list[dict[str, Any]], limit: int = 5) -> list[dict[str, Any]]: | |
| return cases[:limit] | |
| def evaluate_intent_accuracy() -> dict[str, Any]: | |
| dataset = _load_dataset("intent_accuracy") | |
| details = [] | |
| parser_sources = Counter() | |
| confusion = defaultdict(Counter) | |
| intent_correct = 0 | |
| target_correct = 0 | |
| target_total = 0 | |
| latencies = [] | |
| for example in dataset: | |
| game_state = _build_game_state(example.get("setup")) | |
| nlu = NLUEngine(game_state) | |
| started = perf_counter() | |
| result = nlu.parse_intent(example["input"]) | |
| latency_ms = (perf_counter() - started) * 1000 | |
| expected_intent = example["intent"] | |
| predicted_intent = result.get("intent") | |
| is_intent_correct = predicted_intent == expected_intent | |
| intent_correct += int(is_intent_correct) | |
| latencies.append(latency_ms) | |
| parser_sources[result.get("parser_source", "unknown")] += 1 | |
| confusion[expected_intent][str(predicted_intent)] += 1 | |
| expected_target = example.get("target") | |
| predicted_target = result.get("target") | |
| is_target_correct = None | |
| if expected_target is not None: | |
| target_total += 1 | |
| is_target_correct = _normalize_text(predicted_target) == _normalize_text(expected_target) | |
| target_correct += int(bool(is_target_correct)) | |
| details.append( | |
| { | |
| "id": example["id"], | |
| "input": example["input"], | |
| "expected_intent": expected_intent, | |
| "predicted_intent": predicted_intent, | |
| "intent_correct": is_intent_correct, | |
| "expected_target": expected_target, | |
| "predicted_target": predicted_target, | |
| "target_correct": is_target_correct, | |
| "parser_source": result.get("parser_source"), | |
| "latency_ms": round(latency_ms, 2), | |
| } | |
| ) | |
| return { | |
| "task": "intent_accuracy", | |
| "dataset_size": len(dataset), | |
| "intent_accuracy": round(intent_correct / len(dataset), 4) if dataset else 0.0, | |
| "target_accuracy": round(target_correct / target_total, 4) if target_total else None, | |
| "avg_latency_ms": round(statistics.mean(latencies), 2) if latencies else 0.0, | |
| "parser_source_breakdown": dict(parser_sources), | |
| "confusion": {expected: dict(counts) for expected, counts in confusion.items()}, | |
| "details": details, | |
| } | |
| def evaluate_consistency() -> dict[str, Any]: | |
| dataset = _load_dataset("consistency") | |
| guard_cases = dataset["action_guard_cases"] | |
| state_cases = dataset["state_check_cases"] | |
| guard_details = [] | |
| guard_correct = 0 | |
| for case in guard_cases: | |
| game_state = _build_game_state(case.get("setup")) | |
| is_valid, rejection_reason = game_state.pre_validate_action(case["intent"]) | |
| is_correct = is_valid == case["expected_valid"] | |
| guard_correct += int(is_correct) | |
| guard_details.append( | |
| { | |
| "id": case["id"], | |
| "expected_valid": case["expected_valid"], | |
| "predicted_valid": is_valid, | |
| "correct": is_correct, | |
| "rejection_reason": rejection_reason, | |
| "intent": case["intent"], | |
| } | |
| ) | |
| state_details = [] | |
| state_correct = 0 | |
| for case in state_cases: | |
| game_state = _build_game_state(case.get("setup")) | |
| contradictions = game_state.check_consistency(case["proposed_changes"]) | |
| predicted_contradiction = bool(contradictions) | |
| is_correct = predicted_contradiction == case["expected_contradiction"] | |
| expected_contains = case.get("expected_contains", []) | |
| if expected_contains: | |
| is_correct = is_correct and all( | |
| any(fragment in issue for issue in contradictions) | |
| for fragment in expected_contains | |
| ) | |
| state_correct += int(is_correct) | |
| state_details.append( | |
| { | |
| "id": case["id"], | |
| "expected_contradiction": case["expected_contradiction"], | |
| "predicted_contradiction": predicted_contradiction, | |
| "correct": is_correct, | |
| "contradictions": contradictions, | |
| "proposed_changes": case["proposed_changes"], | |
| } | |
| ) | |
| total_cases = len(guard_cases) + len(state_cases) | |
| total_correct = guard_correct + state_correct | |
| return { | |
| "task": "consistency", | |
| "guard_accuracy": round(guard_correct / len(guard_cases), 4) if guard_cases else 0.0, | |
| "state_check_accuracy": round(state_correct / len(state_cases), 4) if state_cases else 0.0, | |
| "overall_accuracy": round(total_correct / total_cases, 4) if total_cases else 0.0, | |
| "action_guard_details": guard_details, | |
| "state_check_details": state_details, | |
| } | |
| def evaluate_latency(repeats: int) -> dict[str, Any]: | |
| dataset = _load_dataset("latency") | |
| scenario_summaries = [] | |
| all_nlu = [] | |
| all_story = [] | |
| all_total = [] | |
| fallback_total = 0 | |
| total_runs = 0 | |
| fallback_records = [] | |
| failure_cases = [] | |
| for scenario in dataset: | |
| runs = [] | |
| for _ in range(repeats): | |
| run_result = _run_text_turn(scenario["input"], scenario.get("setup")) | |
| final_result = run_result["final_result"] | |
| telemetry = final_result.get("telemetry", {}) | |
| used_fallback = bool(telemetry.get("used_fallback", False)) | |
| total_runs += 1 | |
| fallback_total += int(used_fallback) | |
| all_nlu.append(run_result["nlu_latency_ms"]) | |
| all_story.append(run_result["story_latency_ms"]) | |
| all_total.append(run_result["total_latency_ms"]) | |
| runs.append( | |
| { | |
| "nlu_latency_ms": round(run_result["nlu_latency_ms"], 2), | |
| "story_latency_ms": round(run_result["story_latency_ms"], 2), | |
| "total_latency_ms": round(run_result["total_latency_ms"], 2), | |
| "used_fallback": used_fallback, | |
| "fallback_reason": telemetry.get("fallback_reason"), | |
| "engine_mode": telemetry.get("engine_mode"), | |
| } | |
| ) | |
| fallback_records.append(runs[-1]) | |
| total_values = [item["total_latency_ms"] for item in runs] | |
| scenario_fallback_rate = sum(1 for item in runs if item["used_fallback"]) / len(runs) | |
| if scenario_fallback_rate > 0: | |
| failure_cases.append( | |
| { | |
| "scenario_id": scenario["id"], | |
| "input": scenario["input"], | |
| "fallback_rate": round(scenario_fallback_rate, 4), | |
| "fallback_reasons": dict( | |
| Counter( | |
| str(item.get("fallback_reason") or "unknown") | |
| for item in runs | |
| if item["used_fallback"] | |
| ) | |
| ), | |
| } | |
| ) | |
| scenario_summaries.append( | |
| { | |
| "id": scenario["id"], | |
| "input": scenario["input"], | |
| "repeats": repeats, | |
| "avg_total_latency_ms": round(statistics.mean(total_values), 2), | |
| "p95_total_latency_ms": round(_percentile(total_values, 95), 2), | |
| "fallback_rate": round(scenario_fallback_rate, 4), | |
| "fallback_reason_breakdown": dict( | |
| Counter( | |
| str(item.get("fallback_reason") or "unknown") | |
| for item in runs | |
| if item["used_fallback"] | |
| ) | |
| ), | |
| "runs": runs, | |
| } | |
| ) | |
| fallback_summary = _summarize_fallback_records(fallback_records) | |
| return { | |
| "task": "latency", | |
| "scenario_count": len(dataset), | |
| "repeats": repeats, | |
| "avg_nlu_latency_ms": round(statistics.mean(all_nlu), 2) if all_nlu else 0.0, | |
| "avg_story_latency_ms": round(statistics.mean(all_story), 2) if all_story else 0.0, | |
| "avg_total_latency_ms": round(statistics.mean(all_total), 2) if all_total else 0.0, | |
| "p95_total_latency_ms": round(_percentile(all_total, 95), 2) if all_total else 0.0, | |
| "fallback_rate": round(fallback_total / total_runs, 4) if total_runs else 0.0, | |
| "fallback_count": fallback_summary["fallback_count"], | |
| "fallback_reason_breakdown": fallback_summary["fallback_reason_breakdown"], | |
| "engine_mode_breakdown": fallback_summary["engine_mode_breakdown"], | |
| "failure_cases": _limit_cases(failure_cases), | |
| "scenarios": scenario_summaries, | |
| } | |
| def evaluate_branch_divergence() -> dict[str, Any]: | |
| dataset = _load_dataset("branch_divergence") | |
| group_summaries = [] | |
| pair_scores = [] | |
| fallback_records = [] | |
| low_divergence_groups = [] | |
| for group in dataset: | |
| branch_results = [] | |
| for branch in group["branches"]: | |
| run_result = _run_text_turn(branch["input"], group.get("setup")) | |
| branch_results.append( | |
| { | |
| "label": branch["label"], | |
| "input": branch["input"], | |
| "story_text": run_result["final_result"].get("story_text", ""), | |
| "options": run_result["final_result"].get("options", []), | |
| "state_snapshot": run_result["state_snapshot"], | |
| "state_changes": run_result["final_result"].get("state_changes", {}), | |
| "telemetry": run_result["final_result"].get("telemetry", {}), | |
| } | |
| ) | |
| fallback_records.append( | |
| { | |
| "used_fallback": bool( | |
| run_result["final_result"].get("telemetry", {}).get("used_fallback", False) | |
| ), | |
| "fallback_reason": run_result["final_result"].get("telemetry", {}).get("fallback_reason"), | |
| "engine_mode": run_result["final_result"].get("telemetry", {}).get("engine_mode"), | |
| } | |
| ) | |
| group_pairs = [] | |
| for left, right in combinations(branch_results, 2): | |
| text_divergence = 1.0 - SequenceMatcher( | |
| None, | |
| left["story_text"], | |
| right["story_text"], | |
| ).ratio() | |
| state_divergence = _jaccard_distance( | |
| _flatten(left["state_snapshot"]), | |
| _flatten(right["state_snapshot"]), | |
| ) | |
| option_divergence = _jaccard_distance( | |
| _option_texts(left["options"]), | |
| _option_texts(right["options"]), | |
| ) | |
| pair_score = round((text_divergence + state_divergence + option_divergence) / 3, 4) | |
| pair_detail = { | |
| "left": left["label"], | |
| "right": right["label"], | |
| "text_divergence": round(text_divergence, 4), | |
| "state_divergence": round(state_divergence, 4), | |
| "option_divergence": round(option_divergence, 4), | |
| "pair_divergence_score": pair_score, | |
| "meaningfully_divergent": pair_score >= 0.2, | |
| } | |
| pair_scores.append(pair_score) | |
| group_pairs.append(pair_detail) | |
| avg_pair_divergence = round( | |
| statistics.mean([pair["pair_divergence_score"] for pair in group_pairs]), | |
| 4, | |
| ) if group_pairs else 0.0 | |
| if avg_pair_divergence < 0.2: | |
| low_divergence_groups.append( | |
| { | |
| "group_id": group["id"], | |
| "avg_pair_divergence": avg_pair_divergence, | |
| "branch_labels": [branch["label"] for branch in branch_results], | |
| } | |
| ) | |
| group_summaries.append( | |
| { | |
| "id": group["id"], | |
| "avg_pair_divergence": avg_pair_divergence, | |
| "branches": [ | |
| { | |
| "label": branch["label"], | |
| "input": branch["input"], | |
| "telemetry": _json_safe(branch["telemetry"]), | |
| "state_changes": _json_safe(branch["state_changes"]), | |
| } | |
| for branch in branch_results | |
| ], | |
| "pair_details": group_pairs, | |
| } | |
| ) | |
| meaningful_pairs = sum(1 for score in pair_scores if score >= 0.2) | |
| fallback_summary = _summarize_fallback_records(fallback_records) | |
| return { | |
| "task": "branch_divergence", | |
| "group_count": len(dataset), | |
| "avg_pair_divergence": round(statistics.mean(pair_scores), 4) if pair_scores else 0.0, | |
| "meaningfully_divergent_pair_rate": round( | |
| meaningful_pairs / len(pair_scores), | |
| 4, | |
| ) if pair_scores else 0.0, | |
| "fallback_count": fallback_summary["fallback_count"], | |
| "fallback_rate": fallback_summary["fallback_rate"], | |
| "fallback_reason_breakdown": fallback_summary["fallback_reason_breakdown"], | |
| "engine_mode_breakdown": fallback_summary["engine_mode_breakdown"], | |
| "failure_cases": _limit_cases(low_divergence_groups), | |
| "groups": group_summaries, | |
| } | |
| TASK_RUNNERS = { | |
| "intent": lambda repeats: evaluate_intent_accuracy(), | |
| "consistency": lambda repeats: evaluate_consistency(), | |
| "latency": lambda repeats: evaluate_latency(repeats), | |
| "branch": lambda repeats: evaluate_branch_divergence(), | |
| } | |
| def _build_failure_summary(results: dict[str, Any]) -> dict[str, Any]: | |
| failure_summary: dict[str, Any] = {} | |
| if "intent" in results: | |
| intent_failures = [ | |
| { | |
| "id": detail["id"], | |
| "input": detail["input"], | |
| "expected_intent": detail["expected_intent"], | |
| "predicted_intent": detail["predicted_intent"], | |
| "parser_source": detail["parser_source"], | |
| } | |
| for detail in results["intent"]["details"] | |
| if not detail["intent_correct"] | |
| ] | |
| failure_summary["intent_failures"] = { | |
| "count": len(intent_failures), | |
| "cases": _limit_cases(intent_failures), | |
| } | |
| if "consistency" in results: | |
| consistency_failures = [ | |
| { | |
| "id": detail["id"], | |
| "type": "action_guard", | |
| "expected_valid": detail["expected_valid"], | |
| "predicted_valid": detail["predicted_valid"], | |
| "rejection_reason": detail["rejection_reason"], | |
| } | |
| for detail in results["consistency"]["action_guard_details"] | |
| if not detail["correct"] | |
| ] | |
| consistency_failures.extend( | |
| { | |
| "id": detail["id"], | |
| "type": "state_check", | |
| "expected_contradiction": detail["expected_contradiction"], | |
| "predicted_contradiction": detail["predicted_contradiction"], | |
| "contradictions": detail["contradictions"], | |
| } | |
| for detail in results["consistency"]["state_check_details"] | |
| if not detail["correct"] | |
| ) | |
| failure_summary["consistency_failures"] = { | |
| "count": len(consistency_failures), | |
| "cases": _limit_cases(consistency_failures), | |
| } | |
| if "latency" in results: | |
| failure_summary["latency_failures"] = { | |
| "count": len(results["latency"].get("failure_cases", [])), | |
| "cases": _limit_cases(results["latency"].get("failure_cases", [])), | |
| } | |
| if "branch" in results: | |
| failure_summary["branch_failures"] = { | |
| "count": len(results["branch"].get("failure_cases", [])), | |
| "cases": _limit_cases(results["branch"].get("failure_cases", [])), | |
| } | |
| return failure_summary | |
| def _build_summary(results: dict[str, Any]) -> dict[str, Any]: | |
| summary = {} | |
| if "intent" in results: | |
| summary["intent_accuracy"] = results["intent"]["intent_accuracy"] | |
| if "consistency" in results: | |
| summary["consistency_overall_accuracy"] = results["consistency"]["overall_accuracy"] | |
| if "latency" in results: | |
| summary["avg_total_latency_ms"] = results["latency"]["avg_total_latency_ms"] | |
| summary["latency_fallback_rate"] = results["latency"]["fallback_rate"] | |
| summary["latency_fallback_count"] = results["latency"]["fallback_count"] | |
| if "branch" in results: | |
| summary["avg_pair_divergence"] = results["branch"]["avg_pair_divergence"] | |
| summary["branch_fallback_rate"] = results["branch"]["fallback_rate"] | |
| return summary | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="Run reproducible StoryWeaver evaluation tasks.") | |
| parser.add_argument( | |
| "--task", | |
| choices=["all", *TASK_RUNNERS.keys()], | |
| default="all", | |
| help="Evaluation task to run.", | |
| ) | |
| parser.add_argument( | |
| "--repeats", | |
| type=int, | |
| default=3, | |
| help="Repeat count for latency measurements.", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| default="", | |
| help="Optional path for the output JSON file.", | |
| ) | |
| args = parser.parse_args() | |
| selected_tasks = list(TASK_RUNNERS.keys()) if args.task == "all" else [args.task] | |
| task_results = {task: TASK_RUNNERS[task](args.repeats) for task in selected_tasks} | |
| payload = { | |
| "generated_at": datetime.now().isoformat(timespec="seconds"), | |
| "task": args.task, | |
| "summary": _build_summary(task_results), | |
| "failure_summary": _build_failure_summary(task_results), | |
| "results": task_results, | |
| } | |
| RESULTS_DIR.mkdir(parents=True, exist_ok=True) | |
| if args.output: | |
| output_path = Path(args.output) | |
| if not output_path.is_absolute(): | |
| output_path = PROJECT_ROOT / output_path | |
| else: | |
| timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| suffix = args.task | |
| output_path = RESULTS_DIR / f"{timestamp}-{suffix}.json" | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| with output_path.open("w", encoding="utf-8") as fh: | |
| json.dump(payload, fh, ensure_ascii=False, indent=2) | |
| print(json.dumps(payload["summary"], ensure_ascii=False, indent=2)) | |
| print(f"Saved evaluation results to: {output_path}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |