Rom89823974978 commited on
Commit
12409b1
Β·
1 Parent(s): 549e0c8

Updated codebase

Browse files
README.md CHANGED
@@ -22,19 +22,16 @@ Hugginface spaces setup
22
  ## 1 Quick start
23
 
24
  ```bash
25
- # ❢ Clone and set up the dev env
26
- git clone https://github.com/<your-org>/rag-eval-framework.git
27
- cd rag-eval-framework
28
  python -m venv .venv && source .venv/bin/activate
29
  pip install -r requirements.txt
30
  pre-commit install
31
 
32
- # ❷ Fetch a toy corpus (β‰ˆ200 docs)
33
  bash scripts/download_data.sh
34
 
35
- # ❸ First single-config run (indexes auto-build)
36
- python scripts/run_experiments.py \
37
- --config configs/pipeline_hybrid_ce.yaml \
38
  --queries data/sample_queries.jsonl
39
  ````
40
 
@@ -54,10 +51,10 @@ evaluation/ ← Core library
54
  β”œβ”€ metrics/ β€’ Retrieval, generation, composite RAG score
55
  └─ stats/ β€’ Correlation, significance, robustness utilities
56
  scripts/ ← CLI tools
57
- β”œβ”€ run_experiments.py β€’ Single-config runner (logs, metrics, plots)
58
- β”œβ”€ run_grid_experiments.py β€’ **Grid runner** – all configs Γ— datasets, RQ1-RQ4 analysis
59
  β”œβ”€ dashboard.py β€’ **Streamlit dashboard** for interactive exploration
60
- tests/ ← PyTest smoke tests
61
  configs/ ← YAML templates for pipelines & stats
62
  .github/workflows/ ← Lint + tests CI
63
  Dockerfile ← Slim reproducible image
@@ -69,10 +66,10 @@ Dockerfile ← Slim reproducible image
69
 
70
  | Research-proposal element | Code artefact | Purpose |
71
  | ------------------------------------------------- | ---------------------------------------------------------------- | --------------------------------------------------------------------------------- |
72
- | **RQ1** Classical retrieval ↔ factual correctness | `evaluation/retrievers/`, `run_grid_experiments.py` | Computes Spearman / Kendall ρ with CIs for MRR, MAP, P\@k vs *human\_correct*. |
73
- | **RQ2** Faithfulness metrics vs expert judgements | `evaluation/metrics/`, `evaluation/stats/`, grid script | Correlates QAGS, FactScore, RAGAS-F etc. with *human\_faithful*; Wilcoxon + Holm. |
74
- | **RQ3** Error propagation β†’ hallucination | `evaluation/stats.robustness`, grid script | χ² test, conditional failure rates across corpora / document styles. |
75
- | **RQ4** Robustness to adversarial evidence | Perturbed datasets (`*_pert.jsonl`) + grid script | Ξ”-metrics & Cohen’s *d* between clean and perturbed runs. |
76
  | Interactive analysis / decision-making | `scripts/dashboard.py` | Select dataset + configs, explore tables & plots instantly. |
77
  | EU AI-Act traceability (Art. 14-15) | Rotating file logging (`evaluation/utils/logger.py`), Docker, CI | Full run provenance (config + log + results + stats) stored under `outputs/`. |
78
 
@@ -82,7 +79,7 @@ Dockerfile ← Slim reproducible image
82
 
83
  ```bash
84
  # Evaluate three configs on two datasets, save everything under outputs/grid
85
- python scripts/run_grid_experiments.py \
86
  --configs configs/*.yaml \
87
  --datasets data/legal.jsonl data/finance.jsonl \
88
  --plots
@@ -104,7 +101,7 @@ outputs/grid/<dataset>/wilcoxon_rag_holm.yaml ← pairwise p-values
104
  Run a *single* new config and automatically compare it to all previous ones:
105
 
106
  ```bash
107
- python scripts/run_grid_experiments.py \
108
  --configs configs/my_new.yaml \
109
  --datasets data/legal.jsonl \
110
  --outdir outputs/grid \
 
22
  ## 1 Quick start
23
 
24
  ```bash
25
+ git clone https://github.com/Romainkul/rag_evaluation.git
26
+ cd rag_evaluation
 
27
  python -m venv .venv && source .venv/bin/activate
28
  pip install -r requirements.txt
29
  pre-commit install
30
 
 
31
  bash scripts/download_data.sh
32
 
33
+ python scripts/analysis.py \
34
+ --config configs/kilt_hybrid_ce.yaml \
 
35
  --queries data/sample_queries.jsonl
36
  ````
37
 
 
51
  β”œβ”€ metrics/ β€’ Retrieval, generation, composite RAG score
52
  └─ stats/ β€’ Correlation, significance, robustness utilities
53
  scripts/ ← CLI tools
54
+ β”œβ”€ prep_annotations.py β€’ Runs RAG, and logs all outpus for expert annotations
55
+ β”œβ”€ analysis.py β€’ **Grid runner** – all configs Γ— datasets, RQ1-RQ4 analysis
56
  β”œβ”€ dashboard.py β€’ **Streamlit dashboard** for interactive exploration
57
+ tests/ ← PyTest tests
58
  configs/ ← YAML templates for pipelines & stats
59
  .github/workflows/ ← Lint + tests CI
60
  Dockerfile ← Slim reproducible image
 
66
 
67
  | Research-proposal element | Code artefact | Purpose |
68
  | ------------------------------------------------- | ---------------------------------------------------------------- | --------------------------------------------------------------------------------- |
69
+ | **RQ1** Classical retrieval ↔ factual correctness | `evaluation/retrievers/`, `analysis.py` | Computes Spearman / Kendall ρ with CIs for MRR, MAP, P\@k vs *human\_correct*. |
70
+ | **RQ2** Faithfulness metrics vs expert judgements | `evaluation/metrics/`, `evaluation/stats/`, `analysis.py` | Correlates QAGS, FactScore, RAGAS-F etc. with *human\_faithful*; Wilcoxon + Holm. |
71
+ | **RQ3** Error propagation β†’ hallucination | `evaluation/stats.robustness`, `analysis.py` | χ² test, conditional failure rates across corpora / document styles. |
72
+ | **RQ4** Robustness to adversarial evidence | Perturbed datasets (`*_pert.jsonl`) + `analysis.py` | Ξ”-metrics & Cohen’s *d* between clean and perturbed runs. |
73
  | Interactive analysis / decision-making | `scripts/dashboard.py` | Select dataset + configs, explore tables & plots instantly. |
74
  | EU AI-Act traceability (Art. 14-15) | Rotating file logging (`evaluation/utils/logger.py`), Docker, CI | Full run provenance (config + log + results + stats) stored under `outputs/`. |
75
 
 
79
 
80
  ```bash
81
  # Evaluate three configs on two datasets, save everything under outputs/grid
82
+ python scripts/analysis.py \
83
  --configs configs/*.yaml \
84
  --datasets data/legal.jsonl data/finance.jsonl \
85
  --plots
 
101
  Run a *single* new config and automatically compare it to all previous ones:
102
 
103
  ```bash
104
+ python scripts/analysis.py \
105
  --configs configs/my_new.yaml \
106
  --datasets data/legal.jsonl \
107
  --outdir outputs/grid \
configs/kilt_hybrid_ce.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This configuration file sets up a hybrid pipeline using a retriever, generator, and reranker.
2
+ # It is designed to work with the KILT dataset and uses FAISS for retrieval.
3
+
4
+ logging:
5
+ log_dir: logs
6
+ level: INFO
7
+ max_mb: 5
8
+ backups: 5
9
+
10
+ retriever:
11
+ # using Faiss (dense) retrieval over KILT’s Wikipedia passages
12
+ name: dense
13
+ faiss_index: /path/to/kilt_wiki_faiss.index
14
+ top_k: 5
15
+ model_name: sentence-transformers/all-MiniLM-L6-v2
16
+ device: cpu
17
+
18
+ generator:
19
+ model_name: facebook/bart-large
20
+ device: cpu
21
+ max_new_tokens: 256
22
+ temperature: 0.0
23
+
24
+ reranker:
25
+ enable: true
26
+ model_name: cross-encoder/ms-marco-MiniLM-L-6-v2
27
+ device: cpu
28
+ max_length: 512
29
+ first_stage_k: 5
30
+ final_k: 5
31
+
32
+ stats:
33
+ correlation_method: spearman
34
+ n_boot: 1000
35
+ ci: 0.95
36
+ wilcoxon_alternative: two-sided
37
+ multiple_correction: holm-bonferroni
38
+ alpha: 0.05
39
+ compute_effect_size: true
40
+ n_permutations: 1000
41
+ failure_threshold: 0.0
42
+
data/load_datasets.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+
3
+ # Load datasets for evaluation
4
+ # This script loads various datasets for evaluation purposes, including finance, legal, KILT, and Natural Questions (NQ).
5
+
6
+ # Finance dataset
7
+ ds_finance = load_dataset("PatronusAI/financebench")
8
+
9
+ # Legal dataset
10
+ ds_legal = load_dataset("nguha/legalbench","canada_tax_court_outcomes")
11
+ # Possible datasets in LegalBench:
12
+ # ['abercrombie', 'canada_tax_court_outcomes', 'citation_prediction_classification', 'citation_prediction_open', 'consumer_contracts_qa', 'contract_nli_confidentiality_of_agreement', 'contract_nli_explicit_identification', 'contract_nli_inclusion_of_verbally_conveyed_information', 'contract_nli_limited_use', 'contract_nli_no_licensing', 'contract_nli_notice_on_compelled_disclosure', 'contract_nli_permissible_acquirement_of_similar_information', 'contract_nli_permissible_copy', 'contract_nli_permissible_development_of_similar_information', 'contract_nli_permissible_post-agreement_possession', 'contract_nli_return_of_confidential_information', 'contract_nli_sharing_with_employees', 'contract_nli_sharing_with_third-parties', 'contract_nli_survival_of_obligations', 'contract_qa', 'corporate_lobbying', 'cuad_affiliate_license-licensee', 'cuad_affiliate_license-licensor', 'cuad_anti-assignment', 'cuad_audit_rights', 'cuad_cap_on_liability', 'cuad_change_of_control', 'cuad_competitive_restriction_exception', 'cuad_covenant_not_to_sue', 'cuad_effective_date', 'cuad_exclusivity', 'cuad_expiration_date', 'cuad_governing_law', 'cuad_insurance', 'cuad_ip_ownership_assignment', 'cuad_irrevocable_or_perpetual_license', 'cuad_joint_ip_ownership', 'cuad_license_grant', 'cuad_liquidated_damages', 'cuad_minimum_commitment', 'cuad_most_favored_nation', 'cuad_no-solicit_of_customers', 'cuad_no-solicit_of_employees', 'cuad_non-compete', 'cuad_non-disparagement', 'cuad_non-transferable_license', 'cuad_notice_period_to_terminate_renewal', 'cuad_post-termination_services', 'cuad_price_restrictions', 'cuad_renewal_term', 'cuad_revenue-profit_sharing', 'cuad_rofr-rofo-rofn', 'cuad_source_code_escrow', 'cuad_termination_for_convenience', 'cuad_third_party_beneficiary', 'cuad_uncapped_liability', 'cuad_unlimited-all-you-can-eat-license', 'cuad_volume_restriction', 'cuad_warranty_duration', 'definition_classification', 'definition_extraction', 'diversity_1', 'diversity_2', 'diversity_3', 'diversity_4', 'diversity_5', 'diversity_6', 'function_of_decision_section', 'hearsay', 'insurance_policy_interpretation', 'international_citizenship_questions', 'jcrew_blocker', 'learned_hands_benefits', 'learned_hands_business', 'learned_hands_consumer', 'learned_hands_courts', 'learned_hands_crime', 'learned_hands_divorce', 'learned_hands_domestic_violence', 'learned_hands_education', 'learned_hands_employment', 'learned_hands_estates', 'learned_hands_family', 'learned_hands_health', 'learned_hands_housing', 'learned_hands_immigration', 'learned_hands_torts', 'learned_hands_traffic', 'legal_reasoning_causality', 'maud_ability_to_consummate_concept_is_subject_to_mae_carveouts', 'maud_accuracy_of_fundamental_target_rws_bringdown_standard', 'maud_accuracy_of_target_capitalization_rw_(outstanding_shares)_bringdown_standard_answer', 'maud_accuracy_of_target_general_rw_bringdown_timing_answer', 'maud_additional_matching_rights_period_for_modifications_(cor)', 'maud_application_of_buyer_consent_requirement_(negative_interim_covenant)', 'maud_buyer_consent_requirement_(ordinary_course)', 'maud_change_in_law__subject_to_disproportionate_impact_modifier', 'maud_changes_in_gaap_or_other_accounting_principles__subject_to_disproportionate_impact_modifier', 'maud_cor_permitted_in_response_to_intervening_event', 'maud_cor_permitted_with_board_fiduciary_determination_only', 'maud_cor_standard_(intervening_event)', 'maud_cor_standard_(superior_offer)', 'maud_definition_contains_knowledge_requirement_-_answer', 'maud_definition_includes_asset_deals', 'maud_definition_includes_stock_deals', 'maud_fiduciary_exception__board_determination_standard', 'maud_fiduciary_exception_board_determination_trigger_(no_shop)', 'maud_financial_point_of_view_is_the_sole_consideration', 'maud_fls_(mae)_standard', 'maud_general_economic_and_financial_conditions_subject_to_disproportionate_impact_modifier', 'maud_includes_consistent_with_past_practice', 'maud_initial_matching_rights_period_(cor)', 'maud_initial_matching_rights_period_(ftr)', 'maud_intervening_event_-_required_to_occur_after_signing_-_answer', 'maud_knowledge_definition', 'maud_liability_standard_for_no-shop_breach_by_target_non-do_representatives', 'maud_ordinary_course_efforts_standard', 'maud_pandemic_or_other_public_health_event__subject_to_disproportionate_impact_modifier', 'maud_pandemic_or_other_public_health_event_specific_reference_to_pandemic-related_governmental_responses_or_measures', 'maud_relational_language_(mae)_applies_to', 'maud_specific_performance', 'maud_tail_period_length', 'maud_type_of_consideration', 'nys_judicial_ethics', 'opp115_data_retention', 'opp115_data_security', 'opp115_do_not_track', 'opp115_first_party_collection_use', 'opp115_international_and_specific_audiences', 'opp115_policy_change', 'opp115_third_party_sharing_collection', 'opp115_user_access,_edit_and_deletion', 'opp115_user_choice_control', 'oral_argument_question_purpose', 'overruling', 'personal_jurisdiction', 'privacy_policy_entailment', 'privacy_policy_qa', 'proa', 'rule_qa', 'sara_entailment', 'sara_numeric', 'scalr', 'ssla_company_defendants', 'ssla_individual_defendants', 'ssla_plaintiff', 'successor_liability', 'supply_chain_disclosure_best_practice_accountability', 'supply_chain_disclosure_best_practice_audits', 'supply_chain_disclosure_best_practice_certification', 'supply_chain_disclosure_best_practice_training', 'supply_chain_disclosure_best_practice_verification', 'supply_chain_disclosure_disclosed_accountability', 'supply_chain_disclosure_disclosed_audits', 'supply_chain_disclosure_disclosed_certification', 'supply_chain_disclosure_disclosed_training', 'supply_chain_disclosure_disclosed_verification', 'telemarketing_sales_rule', 'textualism_tool_dictionaries', 'textualism_tool_plain', 'ucc_v_common_law', 'unfair_tos']
13
+
14
+ # KILT dataset
15
+ ds_kilt = load_dataset("facebook/kilt_tasks", "nq")
16
+
17
+ # Natural Questions dataset
18
+ ds_nq = load_dataset("sentence-transformers/natural-questions")
19
+
20
+
21
+ def load_datasets():
22
+ """Load and return the datasets."""
23
+ return {
24
+ "finance": ds_finance,
25
+ "legal": ds_legal,
26
+ "kilt": ds_kilt,
27
+ "nq": ds_nq
28
+ }
29
+
evaluation/pipeline.py CHANGED
@@ -35,14 +35,49 @@ class RAGPipeline:
35
  # Public API
36
  # ---------------------------------------------------------------------
37
  def run(self, question: str) -> Dict[str, Any]:
38
- """Retrieve context and generate answer."""
39
  logger.info("Question: %s", question)
40
- contexts = self._retrieve(question)
41
- answer = self._generate(question, contexts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  return {
43
  "question": question,
 
 
 
44
  "answer": answer,
45
- "contexts": [c.text for c in contexts],
46
  }
47
 
48
  __call__ = run # alias
 
35
  # Public API
36
  # ---------------------------------------------------------------------
37
  def run(self, question: str) -> Dict[str, Any]:
 
38
  logger.info("Question: %s", question)
39
+
40
+ # 1. raw retrieval
41
+ k_first = self.cfg.reranker.first_stage_k if self.reranker else self.cfg.retriever.top_k
42
+ initial: List[Context] = self.retriever.retrieve(question, top_k=k_first)
43
+
44
+ raw_hits = [
45
+ {"text": c.text, "id": c.id, "score": getattr(c, "retrieval_score", None)}
46
+ for c in initial
47
+ ]
48
+
49
+ # 2. reranking (if enabled)
50
+ if self.reranker:
51
+ final_k = self.cfg.reranker.final_k or self.cfg.retriever.top_k
52
+ reranked: List[Context] = self.reranker.rerank(question, initial, k=final_k)
53
+
54
+ reranked_hits = [
55
+ {
56
+ "text": c.text,
57
+ "id": c.id,
58
+ "score": getattr(c, "cross_encoder_score", None),
59
+ }
60
+ for c in reranked
61
+ ]
62
+ contexts_for_gen = reranked
63
+ else:
64
+ reranked_hits = []
65
+ contexts_for_gen = initial
66
+
67
+ # 3. generation
68
+ answer = self.generator.generate(
69
+ question,
70
+ [c.text for c in contexts_for_gen],
71
+ max_new_tokens=self.cfg.generator.max_new_tokens,
72
+ temperature=self.cfg.generator.temperature,
73
+ )
74
+
75
  return {
76
  "question": question,
77
+ "raw_retrieval": raw_hits,
78
+ "reranked": reranked_hits,
79
+ "contexts": [c.text for c in contexts_for_gen],
80
  "answer": answer,
 
81
  }
82
 
83
  __call__ = run # alias
pyserini/search.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  class SimpleSearcher:
4
  def __init__(self, index_path):
5
- # no-op
6
  pass
7
  def set_bm25(self):
8
  pass
 
2
 
3
  class SimpleSearcher:
4
  def __init__(self, index_path):
 
5
  pass
6
  def set_bm25(self):
7
  pass
requirements.txt CHANGED
@@ -9,6 +9,7 @@ langchain>=0.1.0
9
  ragas>=0.1.0
10
  trulens-eval>=0.21.0
11
  evaluate
 
12
 
13
  # Data & science
14
  pandas>=2.2
 
9
  ragas>=0.1.0
10
  trulens-eval>=0.21.0
11
  evaluate
12
+ datasets
13
 
14
  # Data & science
15
  pandas>=2.2
scripts/analysis.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Runs evaluation (RQ1–RQ4, statistical tests, plots) on previously annotated
3
+ pipeline outputs that include `human_correct` and `human_faithful`.
4
+
5
+ Assumes outputs were generated using `separate_for_annotation.py` and
6
+ subsequently annotated.
7
+ """
8
+
9
+ import argparse
10
+ import json
11
+ import logging
12
+ import itertools
13
+ from pathlib import Path
14
+
15
+ import numpy as np
16
+ import yaml
17
+ import matplotlib.pyplot as plt
18
+
19
+ from evaluation.stats import (
20
+ corr_ci,
21
+ wilcoxon_signed_rank,
22
+ holm_bonferroni,
23
+ conditional_failure_rate,
24
+ chi2_error_propagation,
25
+ delta_metric,
26
+ )
27
+ from evaluation.utils.logger import init_logging
28
+
29
+
30
+ def read_jsonl(path: Path):
31
+ with path.open() as f:
32
+ return [json.loads(line) for line in f]
33
+
34
+
35
+ def save_yaml(path: Path, obj: dict):
36
+ path.parent.mkdir(parents=True, exist_ok=True)
37
+ path.write_text(yaml.safe_dump(obj, sort_keys=False))
38
+
39
+
40
+ def agg_mean(rows: list[dict]) -> dict:
41
+ keys = rows[0]["metrics"].keys()
42
+ return {k: float(np.mean([r["metrics"][k] for r in rows])) for k in keys}
43
+
44
+
45
+ def rq1_correlation(rows):
46
+ if "human_correct" not in rows[0] or rows[0]["human_correct"] is None:
47
+ return {}
48
+ retrieval_keys = [k for k in rows[0]["metrics"] if k in {"mrr", "map", "precision@10"}]
49
+ gold = [1.0 if r["human_correct"] else 0.0 for r in rows]
50
+ out = {}
51
+ for k in retrieval_keys:
52
+ vec = [r["metrics"][k] for r in rows]
53
+ r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95)
54
+ out[k] = dict(r=r, ci=[lo, hi], p=p)
55
+ return out
56
+
57
+
58
+ def rq2_faithfulness(rows):
59
+ if "human_faithful" not in rows[0] or rows[0]["human_faithful"] is None:
60
+ return {}
61
+ faith_keys = [k for k in rows[0]["metrics"] if k.lower().startswith(("faith", "qags", "fact", "ragas"))]
62
+ gold = [r["human_faithful"] for r in rows]
63
+ out = {}
64
+ for k in faith_keys:
65
+ vec = [r["metrics"][k] for r in rows]
66
+ r, (lo, hi), p = corr_ci(vec, gold, method="pearson", n_boot=1000, ci=0.95)
67
+ out[k] = dict(r=r, ci=[lo, hi], p=p)
68
+ return out
69
+
70
+
71
+ def rq3_error_propagation(rows):
72
+ if "retrieval_error" not in rows[0] or "hallucination" not in rows[0]:
73
+ return {}
74
+ ret_err = [r["retrieval_error"] for r in rows]
75
+ halluc = [r["hallucination"] for r in rows]
76
+ return {
77
+ "conditional": conditional_failure_rate(ret_err, halluc),
78
+ "chi2": chi2_error_propagation(ret_err, halluc),
79
+ }
80
+
81
+
82
+ def rq4_robustness(orig_rows, pert_rows):
83
+ if pert_rows is None:
84
+ return {}
85
+ metrics = orig_rows[0]["metrics"].keys()
86
+ out = {}
87
+ for m in metrics:
88
+ d, eff = delta_metric(
89
+ [r["metrics"][m] for r in orig_rows],
90
+ [r["metrics"][m] for r in pert_rows],
91
+ )
92
+ out[m] = dict(delta=d, cohen_d=eff)
93
+ return out
94
+
95
+
96
+ def scatter_mrr_vs_correct(rows, path: Path):
97
+ x = [r["metrics"].get("mrr", np.nan) for r in rows]
98
+ y = [1 if r.get("human_correct") else 0 for r in rows]
99
+ plt.figure()
100
+ plt.scatter(x, y, alpha=0.5)
101
+ plt.xlabel("MRR"); plt.ylabel("Correct (1)")
102
+ plt.title("MRR vs. Human Correctness")
103
+ plt.tight_layout(); plt.savefig(path); plt.close()
104
+
105
+
106
+ def main(argv=None):
107
+ ap = argparse.ArgumentParser()
108
+ ap.add_argument("--results", nargs="+", type=Path, required=True,
109
+ help="One or more annotated results.jsonl files.")
110
+ ap.add_argument("--outdir", type=Path, default=Path("outputs/grid"))
111
+ ap.add_argument("--perturbed-suffix", default="_pert.jsonl",
112
+ help="Looks for this perturbed variant for RQ4.")
113
+ ap.add_argument("--plots", action="store_true")
114
+ args = ap.parse_args(argv)
115
+
116
+ init_logging(log_dir=args.outdir / "logs", level="INFO")
117
+ log = logging.getLogger("resume")
118
+
119
+ historical = {}
120
+
121
+ for res_path in args.results:
122
+ cfg_name = res_path.parent.name
123
+ dataset_name = res_path.parent.parent.name
124
+ log.info("Processing %s on %s", cfg_name, dataset_name)
125
+
126
+ rows = read_jsonl(res_path)
127
+ pert_path = res_path.with_name(res_path.stem.replace("unlabeled", "pert") + args.perturbed_suffix)
128
+ pert_rows = read_jsonl(pert_path) if pert_path.exists() else None
129
+
130
+ run_dir = args.outdir / dataset_name / cfg_name
131
+ run_dir.mkdir(parents=True, exist_ok=True)
132
+
133
+ save_yaml(run_dir / "aggregates.yaml", agg_mean(rows))
134
+ save_yaml(run_dir / "rq1.yaml", rq1_correlation(rows))
135
+ save_yaml(run_dir / "rq2.yaml", rq2_faithfulness(rows))
136
+ save_yaml(run_dir / "rq3.yaml", rq3_error_propagation(rows))
137
+ if pert_rows:
138
+ save_yaml(run_dir / "rq4.yaml", rq4_robustness(rows, pert_rows))
139
+ if args.plots:
140
+ scatter_mrr_vs_correct(rows, run_dir / "mrr_vs_correct.png")
141
+
142
+ historical[cfg_name] = rows
143
+
144
+ # Pairwise Wilcoxon + Holm correction
145
+ if len(historical) > 1:
146
+ names = list(historical)
147
+ pairs = {}
148
+ for a, b in itertools.combinations(names, 2):
149
+ x = [r["metrics"]["rag_score"] for r in historical[a]]
150
+ y = [r["metrics"]["rag_score"] for r in historical[b]]
151
+ _, p = wilcoxon_signed_rank(x, y)
152
+ pairs[f"{a}~{b}"] = p
153
+ dataset_name = args.results[0].parent.parent.name
154
+ save_yaml(args.outdir / dataset_name / "wilcoxon_rag_raw.yaml", pairs)
155
+ save_yaml(args.outdir / dataset_name / "wilcoxon_rag_holm.yaml", holm_bonferroni(pairs))
156
+ log.info("Pairwise significance testing complete (rag_score).")
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()
scripts/dashboard.py CHANGED
@@ -1,12 +1,8 @@
1
- #!/usr/bin/env python
2
  """
3
- dashboard.py
4
- ============
5
-
6
  Launch with:
7
  streamlit run scripts/dashboard.py
8
 
9
- Relies on the directory structure produced by run_grid_experiments.py:
10
  outputs/grid/<dataset>/<config>/{aggregates.yaml, rq1.yaml, ...}
11
  """
12
  from __future__ import annotations
@@ -19,8 +15,8 @@ import pandas as pd
19
  import streamlit as st
20
  import matplotlib.pyplot as plt
21
 
22
- BASE_DIR = Path("outputs/grid") # change if you store runs elsewhere
23
- METRIC_KEY = "rag_score" # bar/box plots focus on this
24
 
25
  # --------------------------------------------------------------------- Sidebar
26
  st.sidebar.title("RAG-Eval Dashboard")
 
 
1
  """
 
 
 
2
  Launch with:
3
  streamlit run scripts/dashboard.py
4
 
5
+ Relies on the directory structure produced by analysis.py:
6
  outputs/grid/<dataset>/<config>/{aggregates.yaml, rq1.yaml, ...}
7
  """
8
  from __future__ import annotations
 
15
  import streamlit as st
16
  import matplotlib.pyplot as plt
17
 
18
+ BASE_DIR = Path("outputs/grid")
19
+ METRIC_KEY = "rag_score"
20
 
21
  # --------------------------------------------------------------------- Sidebar
22
  st.sidebar.title("RAG-Eval Dashboard")
scripts/prep_annotations.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Runs RAG pipeline over dataset(s) and saves partial results
3
+ for manual annotation.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ from pathlib import Path
9
+ from typing import Any, Dict
10
+
11
+ from evaluation import PipelineConfig, RetrieverConfig, GeneratorConfig, CrossEncoderConfig, StatsConfig, LoggingConfig, RAGPipeline
12
+ from evaluation.utils.logger import init_logging
13
+
14
+ import yaml
15
+
16
+
17
+ def merge_dataclass(dc_cls, override: Dict[str, Any]):
18
+ from dataclasses import asdict
19
+ base = asdict(dc_cls())
20
+ base.update({k: v for k, v in override.items() if v is not None})
21
+ return dc_cls(**base)
22
+
23
+
24
+ def load_pipeline_config(yaml_path: Path) -> PipelineConfig:
25
+ data = yaml.safe_load(yaml_path.read_text())
26
+ return PipelineConfig(
27
+ retriever=merge_dataclass(RetrieverConfig, data.get("retriever", {})),
28
+ generator=merge_dataclass(GeneratorConfig, data.get("generator", {})),
29
+ reranker=merge_dataclass(CrossEncoderConfig, data.get("reranker", {})),
30
+ stats=merge_dataclass(StatsConfig, data.get("stats", {})),
31
+ logging=merge_dataclass(LoggingConfig, data.get("logging", {})),
32
+ )
33
+
34
+
35
+ def read_jsonl(path: Path) -> list[dict]:
36
+ with path.open() as f:
37
+ return [json.loads(line) for line in f]
38
+
39
+
40
+ def write_jsonl(path: Path, rows: list[dict]) -> None:
41
+ path.parent.mkdir(parents=True, exist_ok=True)
42
+ with path.open("w") as f:
43
+ for row in rows:
44
+ f.write(json.dumps(row) + "\n")
45
+
46
+
47
+ def main(argv=None):
48
+ ap = argparse.ArgumentParser()
49
+ ap.add_argument("--config", type=Path, required=True)
50
+ ap.add_argument("--datasets", nargs="+", type=Path, required=True)
51
+ ap.add_argument("--outdir", type=Path, default=Path("outputs/for_annotation"))
52
+ args = ap.parse_args(argv)
53
+
54
+ init_logging(log_dir=args.outdir / "logs")
55
+ cfg = load_pipeline_config(args.config)
56
+ pipe = RAGPipeline(cfg)
57
+
58
+ for dataset in args.datasets:
59
+ queries = read_jsonl(dataset)
60
+ output_dir = args.outdir / dataset.stem / args.config.stem
61
+ output_path = output_dir / "unlabeled_results.jsonl"
62
+
63
+ if output_path.exists():
64
+ print(f"Skipping {dataset.name} – already exists.")
65
+ continue
66
+
67
+ rows = []
68
+ for q in queries:
69
+ result = pipe.run(q["question"])
70
+ entry = {
71
+ "question": q["question"],
72
+ "retrieved_docs": result.get("retrieved_docs", []),
73
+ "generated_answer": result.get("generated_answer", ""),
74
+ "metrics": result.get("metrics", {}),
75
+ # Human annotators will add these
76
+ "human_correct": None,
77
+ "human_faithful": None
78
+ }
79
+ rows.append(entry)
80
+
81
+ write_jsonl(output_path, rows)
82
+ print(f"Wrote {len(rows)} results to {output_path}")
83
+
84
+
85
+ if __name__ == "__main__":
86
+ main()
scripts/run_experiments.py DELETED
@@ -1,251 +0,0 @@
1
- #!/usr/bin/env python
2
- """
3
- run_experiments.py
4
- ==================
5
-
6
- High-level driver that wires together:
7
-
8
- 1. YAML / CLI β†’ `PipelineConfig` + `LoggingConfig`
9
- 2. Initialises dual-sink logging (console + rotating file)
10
- 3. Builds a `RAGPipeline`
11
- 4. Streams a list of questions through the pipeline
12
- 5. Logs progress, writes per-query JSONL results, and
13
- (optionally) prints aggregate statistics.
14
-
15
- You can keep it minimal – or expand the marked TODO sections to:
16
- * compute metrics immediately
17
- * push results to a tracker (W&B, MLflow, etc.)
18
- * spawn multiple configs in parallel.
19
- """
20
- from __future__ import annotations
21
-
22
- import argparse
23
- import json
24
- import sys
25
- from pathlib import Path
26
- from typing import Any, Dict, Iterable, List, Mapping
27
-
28
- import yaml
29
-
30
- from evaluation import (
31
- PipelineConfig,
32
- RetrieverConfig,
33
- GeneratorConfig,
34
- CrossEncoderConfig,
35
- StatsConfig,
36
- LoggingConfig,
37
- RAGPipeline,
38
- )
39
- from evaluation.utils.logger import init_logging
40
-
41
- from evaluation.stats import (
42
- corr_ci,
43
- wilcoxon_signed_rank,
44
- holm_bonferroni,
45
- )
46
-
47
- import matplotlib.pyplot as plt
48
-
49
- # ──────────────────────────────────────────────────────────────────────────────
50
- # Helpers
51
- # ──────────────────────────────────────────────────────────────────────────────
52
-
53
-
54
- def _merge_dataclass(dc_cls, default, override: Mapping[str, Any]):
55
- """Return a new *dc_cls* where fields from *override* overwrite *default*."""
56
- from dataclasses import asdict
57
-
58
- merged = asdict(default)
59
- merged.update({k: v for k, v in override.items() if v is not None})
60
- return dc_cls(**merged)
61
-
62
-
63
- def _load_pipeline_config(yaml_path: Path | None) -> PipelineConfig:
64
- """Parse YAML into nested dataclasses; fall back to defaults."""
65
- if yaml_path is None:
66
- return PipelineConfig() # all defaults
67
-
68
- data = yaml.safe_load(yaml_path.read_text())
69
-
70
- retr_cfg = _merge_dataclass(
71
- RetrieverConfig(), RetrieverConfig(), data.get("retriever", {})
72
- )
73
- gen_cfg = _merge_dataclass(
74
- GeneratorConfig(), GeneratorConfig(), data.get("generator", {})
75
- )
76
- rr_cfg = _merge_dataclass(
77
- CrossEncoderConfig(), CrossEncoderConfig(), data.get("reranker", {})
78
- )
79
- stats_cfg = _merge_dataclass(StatsConfig(), StatsConfig(), data.get("stats", {}))
80
- log_cfg = _merge_dataclass(LoggingConfig(), LoggingConfig(), data.get("logging", {}))
81
-
82
- return PipelineConfig(
83
- retriever=retr_cfg,
84
- generator=gen_cfg,
85
- reranker=rr_cfg,
86
- stats=stats_cfg,
87
- logging=log_cfg,
88
- )
89
-
90
-
91
- def _read_jsonl(path: Path) -> List[Dict[str, Any]]:
92
- with path.open() as f:
93
- return [json.loads(line) for line in f]
94
-
95
-
96
- def _write_jsonl(path: Path, rows: Iterable[Mapping[str, Any]]):
97
- path.parent.mkdir(parents=True, exist_ok=True)
98
- with path.open("w") as f:
99
- for row in rows:
100
- f.write(json.dumps(row) + "\n")
101
-
102
- # Stats Helper
103
- def aggregate_metrics(rows: list[dict[str, Any]]) -> dict[str, float]:
104
- """Return mean of every numeric metric found under row['metrics']."""
105
- import numpy as np
106
- keys = rows[0]["metrics"].keys()
107
- return {k: float(np.mean([r["metrics"][k] for r in rows])) for k in keys}
108
-
109
-
110
- def correlation_with_gold(rows: list[dict[str, Any]], cfg: StatsConfig):
111
- """Spearman/Kendall correlation between retrieval scores and correctness flag."""
112
- if "human_correct" not in rows[0]:
113
- return None # nothing to correlate
114
- mrr = [r["metrics"].get("mrr", float("nan")) for r in rows]
115
- gold = [1.0 if r["human_correct"] else 0.0 for r in rows]
116
- r, (lo, hi), p = corr_ci(
117
- mrr, gold, method=cfg.correlation_method, n_boot=cfg.n_boot, ci=cfg.ci
118
- )
119
- return dict(r=r, ci_low=lo, ci_high=hi, p=p)
120
-
121
-
122
- def wilcoxon_against_baseline(
123
- cur: list[dict[str, Any]],
124
- base: list[dict[str, Any]],
125
- cfg: StatsConfig,
126
- ):
127
- """Paired Wilcoxon + Holm-Bonferroni across all metric keys."""
128
- from evaluation.stats import wilcoxon_signed_rank, holm_bonferroni
129
-
130
- assert len(cur) == len(base), "Runs must have same #queries"
131
- metrics = cur[0]["metrics"].keys()
132
- p_raw = {}
133
- for m in metrics:
134
- cur_m = [r["metrics"][m] for r in cur]
135
- base_m = [r["metrics"][m] for r in base]
136
- _, p = wilcoxon_signed_rank(cur_m, base_m, alternative=cfg.wilcoxon_alternative)
137
- p_raw[m] = p
138
- return holm_bonferroni(p_raw)
139
-
140
- # Plot helper
141
- def save_scatter(rows, out_dir: Path):
142
- out_dir.mkdir(parents=True, exist_ok=True)
143
- x = [r["metrics"]["mrr"] for r in rows if "mrr" in r["metrics"]]
144
- y = [1.0 if r.get("human_correct") else 0.0 for r in rows]
145
- plt.figure()
146
- plt.scatter(x, y, alpha=0.6)
147
- plt.xlabel("MRR")
148
- plt.ylabel("Correct (1=yes)")
149
- plt.title("MRR vs. Human Correctness")
150
- path = out_dir / "mrr_vs_correct.png"
151
- plt.savefig(path, bbox_inches="tight")
152
- plt.close()
153
- return path
154
-
155
- # ──────────────────────────────────────────────────────────────────────────────
156
- # Main
157
- # ──────────────────────────────────────────────────────────────────────────────
158
- def main(argv: list[str] | None = None) -> None:
159
- ap = argparse.ArgumentParser(description="Run RAG evaluation experiments.")
160
- ap.add_argument("--config", type=Path, help="YAML config with pipeline settings")
161
- ap.add_argument(
162
- "--queries",
163
- type=Path,
164
- required=True,
165
- help="JSONL file – each line must contain at least {'question': ...}",
166
- )
167
- ap.add_argument(
168
- "--output",
169
- type=Path,
170
- default=Path("outputs/results.jsonl"),
171
- help="Where to write JSONL results",
172
- )
173
- ap.add_argument("--dry-run", action="store_true", help="Do not execute pipeline")
174
- ap.add_argument(
175
- "--baseline",
176
- type=Path,
177
- help="Optional: JSONL with baseline run for significance tests",
178
- )
179
- ap.add_argument(
180
- "--plots",
181
- action="store_true",
182
- help="Save diagnostic plots (PNG) alongside results",
183
- )
184
- args = ap.parse_args(argv)
185
-
186
- # 1. Parse configuration
187
- cfg = _load_pipeline_config(args.config)
188
-
189
- # 2. Initialise logging (file + stderr)
190
- init_logging(
191
- log_dir=cfg.logging.log_dir,
192
- level=cfg.logging.level,
193
- max_mb=cfg.logging.max_mb,
194
- backups=cfg.logging.backups,
195
- )
196
-
197
- import logging
198
-
199
- logger = logging.getLogger(__name__)
200
- logger.info("Loaded PipelineConfig:\n%s", cfg)
201
-
202
- # 3. Build pipeline (retrieval β†’ (rerank) β†’ generation)
203
- pipeline = RAGPipeline(cfg)
204
-
205
- # 4. Load queries
206
- rows = _read_jsonl(args.queries)
207
- logger.info("Loaded %d queries from %s", len(rows), args.queries)
208
-
209
- if args.dry_run:
210
- logger.warning("Dry-run flag active – exiting before execution.")
211
- sys.exit(0)
212
-
213
- # 5. Execute pipeline
214
- results: List[Dict[str, Any]] = []
215
- for i, row in enumerate(rows, 1):
216
- q = row["question"]
217
- logger.info("[%d/%d] Q: %s", i, len(rows), q)
218
- out = pipeline.run(q)
219
- merged = {**row, **out} # keep any gold labels or metadata
220
- results.append(merged)
221
-
222
- # 6. Persist results
223
- _write_jsonl(args.output, results)
224
- logger.info("Wrote %d results to %s", len(results), args.output)
225
-
226
- # 7. Aggregate statistics, significance tests, plots
227
- agg = aggregate_metrics(results)
228
- logger.info("Mean metrics: %s", json.dumps(agg, indent=2))
229
-
230
- corr = correlation_with_gold(results, cfg.stats)
231
- if corr:
232
- logger.info(
233
- "Correlation MRR↔gold %s=%.3f 95%%CI=[%.3f, %.3f] p=%.3g",
234
- cfg.stats.correlation_method,
235
- corr["r"],
236
- corr["ci_low"],
237
- corr["ci_high"],
238
- corr["p"],
239
- )
240
-
241
- if args.baseline:
242
- baseline_rows = _read_jsonl(args.baseline)
243
- p_adj = wilcoxon_against_baseline(results, baseline_rows, cfg.stats)
244
- logger.info("Wilcoxon vs baseline (Holm-Bonferroni Ξ±=%s): %s", cfg.stats.alpha, p_adj)
245
-
246
- if args.plots:
247
- plot_path = save_scatter(results, args.output.parent)
248
- logger.info("Saved plot β†’ %s", plot_path)
249
-
250
- if __name__ == "__main__":
251
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/run_grid_experiments.py DELETED
@@ -1,239 +0,0 @@
1
- #!/usr/bin/env python
2
- """
3
- run_grid_experiments.py
4
- =======================
5
- Batch driver for *config Γ— dataset* evaluation, including:
6
-
7
- * RQ1 – Correlation of classical retrieval metrics with factual-correctness
8
- * RQ2 – Correlation of faithfulness metrics with expert judgements
9
- * RQ3 – Retrieval-error ➜ hallucination propagation (χ² + conditional rates)
10
- * RQ4 – Robustness under adversarial perturbations (Ξ”-metrics, Cohen d)
11
-
12
- Features
13
- --------
14
- * Incremental mode – pass **one** new --config, it is compared to all
15
- previous runs already found under --outdir/<dataset>/.
16
- * Saves:
17
- - `results.jsonl`
18
- - `aggregates.yaml`
19
- - `rq1.yaml`, `rq2.yaml`, `rq3.yaml`, `rq4.yaml`
20
- - pairwise Wilcoxon/ Holm tables
21
- - bar-, box-, scatter-plots (if --plots flag)
22
- """
23
-
24
- from __future__ import annotations
25
-
26
- import argparse
27
- import itertools
28
- import json
29
- import logging
30
- import os
31
- from pathlib import Path
32
- from typing import Any, Dict, Iterable, List, Mapping
33
-
34
- import matplotlib.pyplot as plt
35
- import numpy as np
36
- import yaml
37
-
38
- from evaluation import (
39
- PipelineConfig,
40
- RetrieverConfig,
41
- GeneratorConfig,
42
- CrossEncoderConfig,
43
- StatsConfig,
44
- LoggingConfig,
45
- RAGPipeline,
46
- )
47
- from evaluation.stats import (
48
- corr_ci,
49
- wilcoxon_signed_rank,
50
- holm_bonferroni,
51
- conditional_failure_rate,
52
- chi2_error_propagation,
53
- delta_metric,
54
- )
55
- from evaluation.utils.logger import init_logging
56
-
57
- # ─────────────────────────────── I/O helpers ────────────────────────────────
58
-
59
-
60
- def read_jsonl(path: Path) -> List[Dict[str, Any]]:
61
- with path.open() as f:
62
- return [json.loads(line) for line in f]
63
-
64
-
65
- def write_jsonl(path: Path, rows: Iterable[Mapping[str, Any]]) -> None:
66
- path.parent.mkdir(parents=True, exist_ok=True)
67
- with path.open("w") as f:
68
- for row in rows:
69
- f.write(json.dumps(row) + "\n")
70
-
71
-
72
- def save_yaml(path: Path, obj: Mapping[str, Any]) -> None:
73
- path.parent.mkdir(parents=True, exist_ok=True)
74
- path.write_text(yaml.safe_dump(obj, sort_keys=False))
75
-
76
-
77
- # ─────────────────────── config merge (same as earlier) ─────────────────────
78
-
79
-
80
- def merge_dataclass(dc_cls, override: Mapping[str, Any]):
81
- from dataclasses import asdict
82
-
83
- base = asdict(dc_cls())
84
- base.update({k: v for k, v in override.items() if v is not None})
85
- return dc_cls(**base)
86
-
87
-
88
- def load_pipeline_config(yaml_path: Path) -> PipelineConfig:
89
- data = yaml.safe_load(yaml_path.read_text())
90
- return PipelineConfig(
91
- retriever=merge_dataclass(RetrieverConfig, data.get("retriever", {})),
92
- generator=merge_dataclass(GeneratorConfig, data.get("generator", {})),
93
- reranker=merge_dataclass(CrossEncoderConfig, data.get("reranker", {})),
94
- stats=merge_dataclass(StatsConfig, data.get("stats", {})),
95
- logging=merge_dataclass(LoggingConfig, data.get("logging", {})),
96
- )
97
-
98
-
99
- # ───────────────────────────── stats helpers ────────────────────────────────
100
- def agg_mean(rows: List[dict[str, Any]]) -> dict[str, float]:
101
- keys = rows[0]["metrics"].keys()
102
- return {k: float(np.mean([r["metrics"][k] for r in rows])) for k in keys}
103
-
104
-
105
- def rq1_correlation(rows, cfg: StatsConfig):
106
- if "human_correct" not in rows[0]:
107
- return {}
108
- retrieval_keys = [k for k in rows[0]["metrics"] if k in {"mrr", "map", "precision@10"}]
109
- gold = [1.0 if r["human_correct"] else 0.0 for r in rows]
110
- out = {}
111
- for k in retrieval_keys:
112
- vec = [r["metrics"][k] for r in rows]
113
- r, (lo, hi), p = corr_ci(vec, gold, method=cfg.correlation_method,
114
- n_boot=cfg.n_boot, ci=cfg.ci)
115
- out[k] = dict(r=r, ci=[lo, hi], p=p)
116
- return out
117
-
118
-
119
- def rq2_faithfulness(rows, cfg: StatsConfig):
120
- if "human_faithful" not in rows[0]:
121
- return {}
122
- faith_keys = [k for k in rows[0]["metrics"] if k.lower().startswith(("faith", "qags", "fact", "ragas"))]
123
- gold = [r["human_faithful"] for r in rows]
124
- out = {}
125
- for k in faith_keys:
126
- vec = [r["metrics"][k] for r in rows]
127
- r, (lo, hi), p = corr_ci(vec, gold, method=cfg.correlation_method,
128
- n_boot=cfg.n_boot, ci=cfg.ci)
129
- out[k] = dict(r=r, ci=[lo, hi], p=p)
130
- return out
131
-
132
-
133
- def rq3_error_propagation(rows):
134
- if "retrieval_error" not in rows[0] or "hallucination" not in rows[0]:
135
- return {}
136
- ret_err = [r["retrieval_error"] for r in rows]
137
- halluc = [r["hallucination"] for r in rows]
138
- cond = conditional_failure_rate(ret_err, halluc)
139
- chi2 = chi2_error_propagation(ret_err, halluc)
140
- return {"conditional": cond, "chi2": chi2}
141
-
142
-
143
- def rq4_robustness(orig_rows, pert_rows):
144
- if pert_rows is None:
145
- return {}
146
- metrics = orig_rows[0]["metrics"].keys()
147
- out = {}
148
- for m in metrics:
149
- d, eff = delta_metric(
150
- [r["metrics"][m] for r in orig_rows],
151
- [r["metrics"][m] for r in pert_rows],
152
- )
153
- out[m] = dict(delta=d, cohen_d=eff)
154
- return out
155
-
156
-
157
- # ─────────────────────────── plotting helpers ───────────────────────────────
158
- def scatter_mrr_vs_correct(rows, path: Path):
159
- x = [r["metrics"].get("mrr", np.nan) for r in rows]
160
- y = [1 if r.get("human_correct") else 0 for r in rows]
161
- plt.figure()
162
- plt.scatter(x, y, alpha=0.5)
163
- plt.xlabel("MRR"); plt.ylabel("Correct (1)")
164
- plt.title("MRR vs. Human Correctness")
165
- plt.tight_layout(); plt.savefig(path); plt.close()
166
-
167
-
168
- # ────────────────────────────────── main ────────────────────────────────────
169
- def main(argv: list[str] | None = None) -> None:
170
- ap = argparse.ArgumentParser()
171
- ap.add_argument("--configs", nargs="+", type=Path, required=True,
172
- help="One or more YAML configs; if one, compared against prior runs.")
173
- ap.add_argument("--datasets", nargs="+", type=Path, required=True)
174
- ap.add_argument("--outdir", type=Path, default=Path("outputs/grid"))
175
- ap.add_argument("--plots", action="store_true")
176
- ap.add_argument("--perturbed-suffix", default="_pert",
177
- help="If dataset perturbed version exists (name+suffix.jsonl) it's used for RQ4.")
178
- args = ap.parse_args(argv)
179
-
180
- init_logging(log_dir=args.outdir / "logs", level="INFO")
181
- log = logging.getLogger("grid")
182
-
183
- for dataset in args.datasets:
184
- log.info("Dataset: %s", dataset.name)
185
- queries = read_jsonl(dataset)
186
- pert_path = dataset.with_stem(dataset.stem + args.perturbed_suffix)
187
- pert_rows = read_jsonl(pert_path) if pert_path.exists() else None
188
-
189
- # discover historical configs to compare against if incremental mode
190
- hist_dirs = (args.outdir / dataset.stem).glob("*") if len(args.configs) == 1 else []
191
- historical = {d.name: read_jsonl(d / "results.jsonl") for d in hist_dirs if d.is_dir()}
192
-
193
- for cfg_yaml in args.configs:
194
- cfg_name = cfg_yaml.stem
195
- log.info(" Config: %s", cfg_name)
196
- cfg = load_pipeline_config(cfg_yaml)
197
- pipe = RAGPipeline(cfg)
198
-
199
- # skip if results already exist
200
- run_dir = args.outdir / dataset.stem / cfg_name
201
- if (run_dir / "results.jsonl").exists():
202
- log.info(" results already present – loading.")
203
- rows = read_jsonl(run_dir / "results.jsonl")
204
- else:
205
- rows = [pipe.run(q["question"]) | q for q in queries]
206
- write_jsonl(run_dir / "results.jsonl", rows)
207
-
208
- # aggregates & RQ1–4
209
- save_yaml(run_dir / "aggregates.yaml", agg_mean(rows))
210
- save_yaml(run_dir / "rq1.yaml", rq1_correlation(rows, cfg.stats))
211
- save_yaml(run_dir / "rq2.yaml", rq2_faithfulness(rows, cfg.stats))
212
- save_yaml(run_dir / "rq3.yaml", rq3_error_propagation(rows))
213
-
214
- if pert_rows:
215
- save_yaml(run_dir / "rq4.yaml", rq4_robustness(rows, pert_rows))
216
-
217
- if args.plots:
218
- scatter_mrr_vs_correct(rows, run_dir / "mrr_vs_correct.png")
219
-
220
- historical[cfg_name] = rows # include current for pairwise tests
221
-
222
- # pairwise Wilcoxon on rag_score
223
- if len(historical) > 1:
224
- pairs = {}
225
- names = list(historical)
226
- for a, b in itertools.combinations(names, 2):
227
- x = [r["metrics"]["rag_score"] for r in historical[a]]
228
- y = [r["metrics"]["rag_score"] for r in historical[b]]
229
- _, p = wilcoxon_signed_rank(x, y)
230
- pairs[f"{a}~{b}"] = p
231
- save_yaml(args.outdir / dataset.stem / "wilcoxon_rag_raw.yaml", pairs)
232
- save_yaml(args.outdir / dataset.stem / "wilcoxon_rag_holm.yaml",
233
- holm_bonferroni(pairs))
234
-
235
- log.info(" Pairwise rag_score significance stored (Holm adjusted).")
236
-
237
-
238
- if __name__ == "__main__":
239
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/test_pipeline_end_to_end.py CHANGED
@@ -34,7 +34,6 @@ def tmp_doc_store(tmp_path_factory):
34
 
35
 
36
  def test_pipeline_with_dense(tmp_doc_store, monkeypatch, tmp_path):
37
- # Monkey-patch HFGenerator so no actual HF download happens
38
  import evaluation.generators.hf_generator as hf_module
39
 
40
  monkeypatch.setattr(hf_module, "HFGenerator", _DummyGenerator)
@@ -46,13 +45,12 @@ def test_pipeline_with_dense(tmp_doc_store, monkeypatch, tmp_path):
46
  faiss_index=tmp_path / "dense.idx",
47
  doc_store=tmp_doc_store,
48
  device="cpu",
49
- model_name="dummy/ignored", # the DummyGenerator bypasses HF
50
  ),
51
  generator=GeneratorConfig(model_name="dummy"),
52
  )
53
  pipeline = RAGPipeline(cfg)
54
 
55
- # Should not raise, and produce no errors
56
  results = pipeline.run_queries([{"question": "Q?", "id": 0}])
57
  assert isinstance(results, list)
58
  assert all("answer" in r for r in results)
 
34
 
35
 
36
  def test_pipeline_with_dense(tmp_doc_store, monkeypatch, tmp_path):
 
37
  import evaluation.generators.hf_generator as hf_module
38
 
39
  monkeypatch.setattr(hf_module, "HFGenerator", _DummyGenerator)
 
45
  faiss_index=tmp_path / "dense.idx",
46
  doc_store=tmp_doc_store,
47
  device="cpu",
48
+ model_name="dummy/ignored",
49
  ),
50
  generator=GeneratorConfig(model_name="dummy"),
51
  )
52
  pipeline = RAGPipeline(cfg)
53
 
 
54
  results = pipeline.run_queries([{"question": "Q?", "id": 0}])
55
  assert isinstance(results, list)
56
  assert all("answer" in r for r in results)