File size: 4,694 Bytes
28f1212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
# [Shared: Track Utilities]
"""
Cross-Track Comparison β loads results from all tracks and produces
comparative tables and charts.
Usage:
python -m tracks.shared.compare --tracks A B C D --dataset medqa
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
from typing import Dict, List, Optional
# Ensure imports work
BACKEND_DIR = Path(__file__).resolve().parent.parent.parent
if str(BACKEND_DIR) not in sys.path:
sys.path.insert(0, str(BACKEND_DIR))
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Result loading
# ββββββββββββββββββββββββββββββββββββββββββββββ
TRACK_DIRS = {
"A": BACKEND_DIR / "validation" / "results",
"B": BACKEND_DIR / "tracks" / "rag_variants" / "results",
"C": BACKEND_DIR / "tracks" / "iterative" / "results",
"D": BACKEND_DIR / "tracks" / "arbitrated" / "results",
}
def load_latest_result(track_id: str, dataset: str = "medqa") -> Optional[dict]:
"""Load the most recent result file for a track + dataset."""
result_dir = TRACK_DIRS.get(track_id)
if not result_dir or not result_dir.exists():
return None
# Find matching files, sorted by name (timestamp suffix) descending
pattern = f"*{dataset}*.json"
if track_id != "A":
pattern = f"track{track_id}_{dataset}*.json"
files = sorted(result_dir.glob(pattern), reverse=True)
if not files:
return None
with open(files[0], "r", encoding="utf-8") as f:
return json.load(f)
def load_all_results(dataset: str = "medqa") -> Dict[str, Optional[dict]]:
"""Load latest results for all tracks."""
return {tid: load_latest_result(tid, dataset) for tid in TRACK_DIRS}
# ββββββββββββββββββββββββββββββββββββββββββββββ
# Comparison table
# ββββββββββββββββββββββββββββββββββββββββββββββ
def compare_tracks(dataset: str = "medqa") -> str:
"""
Generate a comparison table across all tracks.
Returns a formatted text table suitable for console or markdown.
"""
results = load_all_results(dataset)
header = f"{'Track':<22} {'Top-1':>7} {'Top-3':>7} {'Mentioned':>10} {'Pipeline':>9} {'Cost':>10}"
sep = "-" * len(header)
lines = [f"\nCross-Track Comparison: {dataset.upper()}", sep, header, sep]
for tid, data in results.items():
name = {
"A": "A: Baseline",
"B": "B: RAG Variants",
"C": "C: Iterative",
"D": "D: Arbitrated",
}.get(tid, tid)
if data is None:
lines.append(f"{name:<22} {'--':>7} {'--':>7} {'--':>10} {'--':>9} {'--':>10}")
continue
metrics = data.get("metrics", data.get("summary", {}).get("metrics", {}))
top1 = metrics.get("top1_accuracy", -1)
top3 = metrics.get("top3_accuracy", -1)
mentioned = metrics.get("mentioned_accuracy", -1)
pipeline = metrics.get("parse_success", metrics.get("pipeline_success", -1))
cost = data.get("total_cost_usd", data.get("cost", {}).get("total_cost_usd", -1))
def fmt(v: float) -> str:
return f"{v:.1%}" if v >= 0 else "--"
cost_str = f"${cost:.4f}" if cost >= 0 else "--"
lines.append(
f"{name:<22} {fmt(top1):>7} {fmt(top3):>7} {fmt(mentioned):>10} "
f"{fmt(pipeline):>9} {cost_str:>10}"
)
lines.append(sep)
return "\n".join(lines)
# ββββββββββββββββββββββββββββββββββββββββββββββ
# CLI
# ββββββββββββββββββββββββββββββββββββββββββββββ
def main():
import argparse
parser = argparse.ArgumentParser(description="Compare results across experimental tracks")
parser.add_argument("--dataset", default="medqa", help="Dataset to compare (default: medqa)")
parser.add_argument("--json", action="store_true", help="Output as JSON instead of table")
args = parser.parse_args()
if args.json:
results = load_all_results(args.dataset)
# Filter out None values for clean JSON
clean = {k: v for k, v in results.items() if v is not None}
print(json.dumps(clean, indent=2))
else:
print(compare_tracks(args.dataset))
if __name__ == "__main__":
main()
|