Financial_bot / src /evaluator.py
Pushkya's picture
Upload 30 files
8299003 verified
Raw
History Blame Contribute Delete
33.9 kB
"""
evaluator.py
============
Phase 8 – RAG Pipeline Evaluation
Measures pipeline quality with four RAGAS-inspired metrics, all computed
locally without any external API or LLM judge.
Metrics
-------
Faithfulness Are answer sentences entailed by the retrieved context?
(NLI via AnswerVerifier β€” reuses Phase 7)
Answer Relevancy Is the answer on-topic for the question?
(cosine sim between question embedding and answer embedding)
Context Precision What fraction of retrieved chunks are actually useful?
(NLI: does the chunk entail any reference-answer sentence?)
Context Recall What fraction of reference-answer claims are covered by
the retrieved context?
(NLI: is each reference sentence supported by any chunk?)
Reference-based metrics (require ground-truth answers):
Answer F1 Token-overlap F1 between generated and reference answer
(SQuAD-style, normalised)
Exact Match 1 if normalised strings match, else 0
Retrieval Hit Rate Fraction of known-relevant chunk IDs actually retrieved
Grade thresholds
----------------
PASS : metric β‰₯ upper threshold
REVIEW : metric β‰₯ lower threshold
FAIL : metric < lower threshold
Usage
-----
from src.evaluator import RAGEvaluator, AppleFinanceTestSet
from src.retriever import FinancialRetriever
from src.rag_chain import build_rag_chain, get_llm
retriever = FinancialRetriever(vectorstore_dir=..., rerank=True)
chain = build_rag_chain(llm=get_llm(), rerank=True)
evaluator = RAGEvaluator(retriever=retriever, chain=chain)
samples = AppleFinanceTestSet.get_samples()
results = evaluator.evaluate_dataset(samples)
print(evaluator.report(results))
"""
import re
import time
import logging
from dataclasses import dataclass, field
from typing import Optional
from src.retriever import _table_to_labelled_text
log = logging.getLogger(__name__)
# ── Grade thresholds (upper=PASS, lower=REVIEW, below=FAIL) ────────────────
GRADE_THRESHOLDS: dict[str, tuple[float, float]] = {
"faithfulness" : (0.80, 0.60),
"answer_relevancy" : (0.75, 0.50),
"context_precision" : (0.70, 0.50),
"context_recall" : (0.70, 0.50),
"answer_f1" : (0.50, 0.30),
"retrieval_hit_rate": (0.80, 0.50),
}
# NLI threshold for deciding a chunk/sentence is "supported"
_PRECISION_RECALL_ENTAIL_THRESHOLD = 0.40 # slightly lower than runtime gate
def _chunk_text_for_nli(chunk: dict) -> str:
"""
Prepare a chunk's text for NLI premise/recall scoring.
Two transformations are applied:
1. Table markdown β†’ labelled text (e.g. "Total net sales: 2024=$391,035")
so the cross-encoder can match numbers without being confused by pipe
characters (same logic used in retriever.build_context).
2. Company name prefix for table chunks (e.g. "Apple Inc. β€” ")
SEC HTML tables have no company name in their cell text; the NLI model
requires "Apple" in the premise to entail hypotheses about "Apple's net
sales". Without this prefix, entailment scores are ~0 even when the table
contains the exact figure.
"""
meta = chunk.get("metadata", {})
chunk_type = meta.get("chunk_type", "text")
if chunk_type != "table":
return chunk["text"]
company = meta.get("company", "")
labelled = _table_to_labelled_text(chunk["text"])
return f"{company} β€” {labelled}" if company else labelled
# ══════════════════════════════════════════════════════════════════════════════
# DATA CLASSES
# ══════════════════════════════════════════════════════════════════════════════
@dataclass
class EvalSample:
"""One evaluation example with question, reference answer, and optional metadata."""
question : str
reference_answer : str
filters : dict = field(default_factory=dict)
relevant_chunk_ids : list[str] = field(default_factory=list)
category : str = "general" # e.g. revenue, risk, segment
@dataclass
class EvalMetrics:
"""All evaluation metrics for a single sample."""
faithfulness : float = 0.0 # NLI grounding: generated vs context
answer_relevancy : float = 0.0 # cosine sim: question ↔ answer
context_precision : float = 0.0 # fraction of chunks useful for answering
context_recall : float = 0.0 # fraction of reference claims in context
answer_f1 : float = 0.0 # token F1 vs reference answer
exact_match : float = 0.0 # 1.0 if normalised strings match
retrieval_hit_rate : float = 0.0 # fraction of relevant IDs retrieved
def aggregate_score(self) -> float:
"""
Weighted average of core metrics.
Weights: faithfulness 30 %, answer_relevancy 25 %,
context_precision 20 %, context_recall 15 %, answer_f1 10 %.
"""
return (
0.30 * self.faithfulness +
0.25 * self.answer_relevancy +
0.20 * self.context_precision +
0.15 * self.context_recall +
0.10 * self.answer_f1
)
@dataclass
class EvalResult:
"""Evaluation outcome for one sample."""
sample : EvalSample
generated_answer : str
retrieved_chunks : list[dict]
metrics : EvalMetrics
latency_ms : float = 0.0
error : str = ""
# ══════════════════════════════════════════════════════════════════════════════
# EVALUATOR
# ══════════════════════════════════════════════════════════════════════════════
class RAGEvaluator:
"""
Evaluates the full RAG pipeline with RAGAS-inspired metrics, all local.
Parameters
----------
retriever : FinancialRetriever instance
chain : LangChain LCEL chain built by build_rag_chain()
nli_model : NLI CrossEncoder model name
embed_model : HuggingFaceEmbeddings instance (for answer relevancy)
Pass None to skip answer-relevancy computation.
entail_threshold : Threshold for faithfulness (SUPPORTED verdict)
"""
def __init__(
self,
retriever,
chain,
nli_model : str = "cross-encoder/nli-deberta-v3-small",
embed_model = None,
entail_threshold : float = 0.50,
):
from sentence_transformers import CrossEncoder
from src.verifier import AnswerVerifier
self.retriever = retriever
self.chain = chain
self._embed = embed_model
log.info(f"Loading NLI model: {nli_model}")
self._nli = CrossEncoder(nli_model, max_length=512)
self._verifier = AnswerVerifier(
nli_model = nli_model,
entail_threshold = entail_threshold,
)
# Confirm label indices from model config
id2label = self._nli.model.config.id2label
self._entail_idx = [k for k, v in id2label.items() if "entail" in v.lower()][0]
self._contra_idx = [k for k, v in id2label.items() if "contra" in v.lower()][0]
log.info("RAGEvaluator ready.")
# ── Faithfulness ─────────────────────────────────────────────────────────
def compute_faithfulness(self, answer: str, chunks: list[dict]) -> float:
"""
Fraction of answer sentences entailed by the retrieved context.
Delegates to AnswerVerifier (Phase 7 reuse).
Returns 0.0 if answer or chunks are empty.
"""
if not answer.strip() or not chunks:
return 0.0
result = self._verifier.verify(answer, chunks)
return result["grounding_score"]
# ── Answer Relevancy ─────────────────────────────────────────────────────
def compute_answer_relevancy(self, question: str, answer: str) -> float:
"""
Cosine similarity between question and answer embeddings.
Proxy for 'does the answer actually address the question?'
Returns 0.0 if embed_model not provided or answer is empty.
"""
if not self._embed or not answer.strip():
return 0.0
import numpy as np
q_vec = np.array(self._embed.embed_query(question))
a_vec = np.array(self._embed.embed_query(answer))
denom = np.linalg.norm(q_vec) * np.linalg.norm(a_vec)
if denom == 0:
return 0.0
return float(np.clip(np.dot(q_vec, a_vec) / denom, 0.0, 1.0))
# ── Context Precision ────────────────────────────────────────────────────
def compute_context_precision(
self,
chunks : list[dict],
reference_answer : str,
) -> float:
"""
Fraction of retrieved chunks that are useful for answering the question.
A chunk is 'useful' if it entails at least one sentence from the
reference answer (NLI entailment β‰₯ threshold).
"""
if not chunks or not reference_answer.strip():
return 0.0
ref_sentences = self._verifier.split_sentences(reference_answer)
if not ref_sentences:
return 0.0
useful = 0
for chunk in chunks:
text = _chunk_text_for_nli(chunk)
pairs = [(text, s) for s in ref_sentences]
scores = self._nli.predict(pairs, apply_softmax=True)
best_e = max(float(s[self._entail_idx]) for s in scores)
if best_e >= _PRECISION_RECALL_ENTAIL_THRESHOLD:
useful += 1
return useful / len(chunks)
# ── Context Recall ───────────────────────────────────────────────────────
def compute_context_recall(
self,
chunks : list[dict],
reference_answer : str,
) -> float:
"""
Fraction of reference-answer claims covered by the retrieved context.
A claim is 'covered' if at least one retrieved chunk entails it
(NLI entailment β‰₯ threshold).
"""
if not chunks or not reference_answer.strip():
return 0.0
ref_sentences = self._verifier.split_sentences(reference_answer)
if not ref_sentences:
return 0.0
context_texts = [_chunk_text_for_nli(c) for c in chunks]
covered = 0
for sent in ref_sentences:
pairs = [(ctx, sent) for ctx in context_texts]
scores = self._nli.predict(pairs, apply_softmax=True)
best_e = max(float(s[self._entail_idx]) for s in scores)
if best_e >= _PRECISION_RECALL_ENTAIL_THRESHOLD:
covered += 1
return covered / len(ref_sentences)
# ── Answer F1 ────────────────────────────────────────────────────────────
@staticmethod
def _normalise(text: str) -> str:
"""Lowercase, strip punctuation, collapse whitespace."""
text = text.lower()
text = re.sub(r"[^a-z0-9\s]", " ", text)
return " ".join(text.split())
def compute_answer_f1(self, generated: str, reference: str) -> float:
"""
Token-level F1 between generated and reference answer (SQuAD-style).
F1 = 2 * precision * recall / (precision + recall)
where precision = |common| / |generated tokens|
recall = |common| / |reference tokens|
"""
gen_tokens = set(self._normalise(generated).split())
ref_tokens = set(self._normalise(reference).split())
if not gen_tokens or not ref_tokens:
return 0.0
common = gen_tokens & ref_tokens
if not common:
return 0.0
precision = len(common) / len(gen_tokens)
recall = len(common) / len(ref_tokens)
return 2 * precision * recall / (precision + recall)
def compute_exact_match(self, generated: str, reference: str) -> float:
"""1.0 if normalised strings are identical, else 0.0."""
return float(self._normalise(generated) == self._normalise(reference))
# ── Retrieval Hit Rate ───────────────────────────────────────────────────
def compute_retrieval_hit_rate(
self,
retrieved_chunks : list[dict],
relevant_chunk_ids : list[str],
) -> float:
"""
Fraction of known-relevant chunk IDs that appear in retrieved chunks.
Returns 1.0 if relevant_chunk_ids is empty (undefined β†’ treat as pass).
"""
if not relevant_chunk_ids:
return 1.0
retrieved_ids = {c["id"] for c in retrieved_chunks}
hits = sum(1 for rid in relevant_chunk_ids if rid in retrieved_ids)
return hits / len(relevant_chunk_ids)
# ── Single-sample evaluation ─────────────────────────────────────────────
def evaluate_sample(
self,
sample : "EvalSample",
n_results : int = 8,
) -> "EvalResult":
"""
Run the full pipeline on one sample and compute all metrics.
Steps:
1. Retrieve chunks
2. Generate answer via chain
3. Compute all six metrics
"""
t0 = time.time()
try:
# 1. Retrieve
chunks = self.retriever.retrieve(
sample.question,
n_results = n_results,
filters = sample.filters or None,
)
# 2. Generate
answer = self.chain.invoke({
"query" : sample.question,
"filters": sample.filters or None,
})
# 3. Metrics
metrics = EvalMetrics(
faithfulness = self.compute_faithfulness(answer, chunks),
answer_relevancy = self.compute_answer_relevancy(
sample.question, answer),
context_precision = self.compute_context_precision(
chunks, sample.reference_answer),
context_recall = self.compute_context_recall(
chunks, sample.reference_answer),
answer_f1 = self.compute_answer_f1(
answer, sample.reference_answer),
exact_match = self.compute_exact_match(
answer, sample.reference_answer),
retrieval_hit_rate = self.compute_retrieval_hit_rate(
chunks, sample.relevant_chunk_ids),
)
return EvalResult(
sample = sample,
generated_answer = answer,
retrieved_chunks = chunks,
metrics = metrics,
latency_ms = (time.time() - t0) * 1000,
)
except Exception as exc:
log.error(f"Error evaluating '{sample.question[:60]}': {exc}")
return EvalResult(
sample = sample,
generated_answer = "",
retrieved_chunks = [],
metrics = EvalMetrics(),
latency_ms = (time.time() - t0) * 1000,
error = str(exc),
)
# ── Dataset evaluation ────────────────────────────────────────────────────
def evaluate_dataset(
self,
samples : list["EvalSample"],
n_results : int = 8,
verbose : bool = True,
) -> list["EvalResult"]:
"""Evaluate all samples. Logs progress and returns list of EvalResult."""
results = []
for i, sample in enumerate(samples, 1):
if verbose:
log.info(f"[{i:>2}/{len(samples)}] {sample.question[:70]}")
result = self.evaluate_sample(sample, n_results=n_results)
results.append(result)
if verbose and result.error:
log.warning(f" β†’ Error: {result.error}")
return results
# ── Aggregation ──────────────────────────────────────────────────────────
def scorecard(self, results: list["EvalResult"]) -> dict:
"""
Compute mean scores across all valid samples.
Returns a dict with metric averages, aggregate_score, avg_latency_ms,
n_samples, and n_errors.
"""
if not results:
return {}
valid = [r for r in results if not r.error]
if not valid:
return {"error": "All samples failed", "n_errors": len(results)}
def avg(metric: str) -> float:
return sum(getattr(r.metrics, metric) for r in valid) / len(valid)
metric_keys = [
"faithfulness", "answer_relevancy", "context_precision",
"context_recall", "answer_f1", "exact_match", "retrieval_hit_rate",
]
scores = {m: round(avg(m), 3) for m in metric_keys}
scores["aggregate_score"] = round(
sum(r.metrics.aggregate_score() for r in valid) / len(valid), 3
)
scores["avg_latency_ms"] = round(
sum(r.latency_ms for r in valid) / len(valid), 1
)
scores["n_samples"] = len(valid)
scores["n_errors"] = len(results) - len(valid)
return scores
# ── Category breakdown ────────────────────────────────────────────────────
def category_breakdown(self, results: list["EvalResult"]) -> dict[str, dict]:
"""
Scorecard split by sample category (revenue, risk, segment, etc.).
Returns {category: scorecard_dict}.
"""
from collections import defaultdict
by_cat: dict[str, list["EvalResult"]] = defaultdict(list)
for r in results:
by_cat[r.sample.category].append(r)
return {cat: self.scorecard(rs) for cat, rs in sorted(by_cat.items())}
# ── Report ───────────────────────────────────────────────────────────────
def report(self, results: list["EvalResult"]) -> str:
"""Formatted evaluation report with per-metric grades and per-sample table."""
scores = self.scorecard(results)
lines = [
"=" * 75,
" RAG Pipeline Evaluation Report",
"=" * 75,
f" Samples : {scores.get('n_samples', 0)} "
f"(errors: {scores.get('n_errors', 0)})",
f" Avg latency : {scores.get('avg_latency_ms', 0):.0f} ms/query",
"-" * 75,
f" {'Metric':<28} {'Score':>7} Grade",
"-" * 75,
]
metric_labels = {
"faithfulness" : "Faithfulness (NLI)",
"answer_relevancy" : "Answer Relevancy (cosine)",
"context_precision" : "Context Precision",
"context_recall" : "Context Recall",
"answer_f1" : "Answer F1 (vs reference)",
"retrieval_hit_rate": "Retrieval Hit Rate",
}
for key, label in metric_labels.items():
score = scores.get(key, 0.0)
hi, lo = GRADE_THRESHOLDS.get(key, (0.75, 0.50))
if score >= hi:
grade, icon = "PASS", "βœ“"
elif score >= lo:
grade, icon = "REVIEW", "⚠"
else:
grade, icon = "FAIL", "βœ—"
lines.append(f" {icon} {label:<28} {score:>6.1%} {grade}")
agg = scores.get("aggregate_score", 0.0)
lines += [
"-" * 75,
f" {'Aggregate Score':<28} {agg:>6.1%}",
"=" * 75,
"",
" Per-sample breakdown:",
"-" * 75,
f" {'#':>3} {'Faith':>5} {'Relev':>5} {'CPrec':>5} "
f"{'CRec':>5} {'F1':>5} {'Agg':>5} Question",
"-" * 75,
]
for i, r in enumerate(results, 1):
m = r.metrics
agg_s = m.aggregate_score()
err = " [ERR]" if r.error else ""
lines.append(
f" {i:>3} {m.faithfulness:>5.2f} {m.answer_relevancy:>5.2f} "
f"{m.context_precision:>5.2f} {m.context_recall:>5.2f} "
f"{m.answer_f1:>5.2f} {agg_s:>5.2f} "
f"{r.sample.question[:40]}{err}"
)
lines.append("=" * 75)
return "\n".join(lines)
# ── Configuration comparison ─────────────────────────────────────────────
def compare_configs(
self,
configs : list[dict],
samples : list["EvalSample"],
n_results: int = 8,
) -> str:
"""
Compare multiple pipeline configurations on the same test set.
Each config dict must have:
label : str β€” display name
chain : LCEL chain β€” built by build_rag_chain()
retriever : retriever β€” FinancialRetriever instance
Returns a formatted comparison table.
"""
config_scores: dict[str, dict] = {}
for cfg in configs:
label = cfg["label"]
self.chain = cfg["chain"]
self.retriever = cfg["retriever"]
log.info(f"Evaluating config: {label}")
results = self.evaluate_dataset(samples, n_results=n_results, verbose=False)
config_scores[label] = self.scorecard(results)
metrics = [
"faithfulness", "answer_relevancy", "context_precision",
"context_recall", "answer_f1", "aggregate_score",
]
col_w = max(len(k) for k in config_scores) + 4
width = 28 + col_w * len(config_scores)
lines = [
"=" * width,
" Configuration Comparison",
"=" * width,
f" {'Metric':<26}" + "".join(f" {k:>{col_w-2}}" for k in config_scores),
"-" * width,
]
for m in metrics:
label = m.replace("_", " ").title()
row = f" {label:<26}"
for k in config_scores:
v = config_scores[k].get(m, 0.0)
row += f" {v:>{col_w-2}.3f}"
lines.append(row)
lines.append("=" * width)
return "\n".join(lines)
# ══════════════════════════════════════════════════════════════════════════════
# BUILT-IN TEST SET β€” Apple Finance
# ══════════════════════════════════════════════════════════════════════════════
class AppleFinanceTestSet:
"""
15 evaluation questions covering Apple SEC filings (FY2024 10-K) and
Morningstar research reports.
Reference answers are based on published FY2024 10-K figures (fiscal year
ending September 2024) and Apple's Q1 FY2025 10-Q.
All dollar amounts are in millions unless otherwise noted.
"""
SAMPLES: list[EvalSample] = [
# ── Revenue ──────────────────────────────────────────────────────────
EvalSample(
question="What were Apple's total net sales for fiscal year 2024?",
reference_answer=(
"Apple's total net sales for fiscal year 2024 were $391,035 million."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="revenue",
),
# ── Profitability ─────────────────────────────────────────────────────
EvalSample(
question="What was Apple's total gross margin in fiscal year 2024?",
reference_answer=(
"Apple's total gross margin for fiscal year 2024 was $180,683 million, "
"compared to $169,148 million in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="profitability",
),
EvalSample(
question="What was Apple's net income for fiscal year 2024?",
reference_answer=(
"Apple's net income for fiscal year 2024 was $93,736 million."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="profitability",
),
EvalSample(
question="What was Apple's diluted earnings per share for FY2024?",
reference_answer=(
"Apple's diluted earnings per share for fiscal year 2024 was $6.11, "
"compared to $6.13 in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="profitability",
),
EvalSample(
question="What was Apple's operating income for fiscal year 2024?",
reference_answer=(
"Apple's operating income was $123,216 million for fiscal year 2024."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="profitability",
),
# ── Product segments ──────────────────────────────────────────────────
EvalSample(
question="What was iPhone revenue for fiscal year 2024?",
reference_answer=(
"iPhone net sales were $201,183 million in fiscal year 2024, "
"a slight increase from $200,583 million in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="segment",
),
EvalSample(
question="How much did Apple's Services segment generate in FY2024?",
reference_answer=(
"Apple's Services segment generated net sales of $96,169 million "
"in fiscal year 2024, compared to $85,200 million in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="segment",
),
EvalSample(
question="What was Apple's Mac revenue for fiscal year 2024?",
reference_answer=(
"Mac net sales were $29,984 million in fiscal year 2024, "
"up from $29,357 million in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="segment",
),
EvalSample(
question="What were Apple's Wearables, Home and Accessories net sales in FY2024?",
reference_answer=(
"Wearables, Home and Accessories net sales were $37,005 million "
"in fiscal year 2024, a decrease from $39,845 million in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="segment",
),
EvalSample(
question="What was Apple's iPad revenue for fiscal year 2024?",
reference_answer=(
"iPad net sales were $26,694 million in fiscal year 2024, "
"compared to $28,300 million in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="segment",
),
# ── Geographic ───────────────────────────────────────────────────────
EvalSample(
question="What was Apple's revenue in the Americas geographic segment in FY2024?",
reference_answer=(
"Apple's Americas net sales were $167,045 million in fiscal year 2024, "
"up from $162,560 million in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="geographic",
),
EvalSample(
question="What was Apple's Greater China revenue in FY2024?",
reference_answer=(
"Apple's Greater China net sales were $66,952 million in fiscal year 2024, "
"compared to $72,559 million in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="geographic",
),
# ── Expenses ─────────────────────────────────────────────────────────
EvalSample(
question="What were Apple's research and development expenses in FY2024?",
reference_answer=(
"Apple's research and development expenses were $31,370 million "
"in fiscal year 2024, up from $29,915 million in fiscal year 2023."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="expense",
),
# ── Balance sheet ─────────────────────────────────────────────────────
EvalSample(
question="How much cash and marketable securities did Apple hold at the end of FY2024?",
reference_answer=(
"At the end of fiscal year 2024, Apple held $29,943 million in cash and "
"cash equivalents."
),
filters={"$and": [{"doc_type": "10-K"}, {"fiscal_year": "2024"}]},
category="balance_sheet",
),
# ── Risk ─────────────────────────────────────────────────────────────
EvalSample(
question="What supply chain risk factors does Apple cite in its 10-K?",
reference_answer=(
"Apple cites several supply chain risks in its 10-K: dependence on "
"single-source or limited-source suppliers for certain components, "
"concentration of manufacturing in Asia particularly China, exposure "
"to geopolitical tensions and trade disputes, potential disruptions "
"from natural disasters or public health events, and risks related "
"to component shortages and price volatility."
),
filters={"doc_type": "10-K"},
category="risk",
),
]
@classmethod
def get_samples(cls, category: str = None) -> list[EvalSample]:
"""
Return all samples, or filter by category.
Categories: revenue, profitability, segment, geographic, expense,
balance_sheet, risk.
"""
if category:
return [s for s in cls.SAMPLES if s.category == category]
return list(cls.SAMPLES)
@classmethod
def categories(cls) -> list[str]:
"""List all available categories."""
return sorted(set(s.category for s in cls.SAMPLES))
@classmethod
def summary(cls) -> str:
"""Human-readable summary of the test set."""
from collections import Counter
counts = Counter(s.category for s in cls.SAMPLES)
lines = [
f"Apple Finance Test Set β€” {len(cls.SAMPLES)} samples",
"-" * 40,
]
for cat, n in sorted(counts.items()):
lines.append(f" {cat:<15} {n} sample{'s' if n > 1 else ''}")
return "\n".join(lines)