meet4150's picture
download
raw
3.01 kB
import json
import re
from pathlib import Path
from typing import Any
from rag.pipeline import query_rag
_TOKEN_RE = re.compile(r"[A-Za-z0-9]+")
def _resolve_dataset_path(dataset_path: str, base_dir: Path) -> Path:
candidate = Path(dataset_path)
if candidate.is_absolute() and candidate.exists():
return candidate
options = [
Path.cwd() / dataset_path,
base_dir / dataset_path,
base_dir.parent / dataset_path,
]
for option in options:
if option.exists():
return option.resolve()
raise FileNotFoundError(
f"Dataset file not found: {dataset_path}. "
f"Tried cwd/base_dir/base_dir.parent resolution."
)
def _token_set(text: str) -> set[str]:
return {t.lower() for t in _TOKEN_RE.findall(text) if len(t) > 2}
def _overlap(expected: str, predicted: str) -> float:
expected_tokens = _token_set(expected)
if not expected_tokens:
return 0.0
predicted_tokens = _token_set(predicted)
return len(expected_tokens & predicted_tokens) / len(expected_tokens)
def evaluate_rag_json(
dataset_path: str,
base_dir: Path,
max_questions: int = 10,
use_reranker: bool = True,
) -> dict[str, Any]:
path = _resolve_dataset_path(dataset_path, base_dir=base_dir)
raw = json.loads(path.read_text(encoding="utf-8"))
if not isinstance(raw, list):
raise ValueError("Dataset JSON must be a list of question objects.")
rows: list[dict[str, Any]] = []
for item in raw[:max_questions]:
if not isinstance(item, dict):
continue
question = str(item.get("question", "")).strip()
if not question:
continue
expected = ""
answer_obj = item.get("answer")
if isinstance(answer_obj, dict):
expected = str(answer_obj.get("answer", "")).strip()
elif isinstance(answer_obj, str):
expected = answer_obj.strip()
result = query_rag(question, use_reranker=use_reranker)
predicted = str(result.get("answer", "")).strip()
overlap = _overlap(expected, predicted) if expected else 0.0
rows.append(
{
"id": item.get("id"),
"question": question,
"expected_answer": expected,
"predicted_answer": predicted,
"lexical_overlap": round(overlap, 4),
"likely_relevant": overlap >= 0.25,
"sources": result.get("sources", []),
}
)
total = len(rows)
relevant = sum(1 for row in rows if row["likely_relevant"])
avg_overlap = round(
sum(row["lexical_overlap"] for row in rows) / total, 4
) if total else 0.0
return {
"dataset_path": str(path),
"questions_evaluated": total,
"likely_relevant_count": relevant,
"likely_relevant_rate": round((relevant / total), 4) if total else 0.0,
"avg_lexical_overlap": avg_overlap,
"results": rows,
}

Xet Storage Details

Size:
3.01 kB
·
Xet hash:
c1395c34216180cf739e5d820b933f2516fb1e0e8512d5830a79a3d860d0891c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.