Ryoya Awano commited on
Commit
19fc84f
·
0 Parent(s):

deploy: fix MedLFQA Marginal mode sample matching

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .env.example +16 -0
  3. .gitignore +11 -0
  4. CLAUDE.md +57 -0
  5. README.md +95 -0
  6. conf/config.yaml +40 -0
  7. conf/dataset_config.yaml +17 -0
  8. conf/path_config.yaml +9 -0
  9. data/.DS_Store +0 -0
  10. data/processed/FactScore/fact_score_queries.json +552 -0
  11. data/processed/HotpotQA/hotpot_qa_queries.json +0 -0
  12. data/processed/PopQA/pop_qa_queries.json +0 -0
  13. data/processed/base_schema.json +16 -0
  14. data/processed/document_schema.json +12 -0
  15. data/processed/wiki_schema.json +26 -0
  16. demo/__init__.py +0 -0
  17. demo/app.py +362 -0
  18. demo/constants.py +23 -0
  19. demo/data/samples.json +322 -0
  20. demo/data/thresholds.csv +145 -0
  21. demo/inference_api.py +302 -0
  22. demo/precompute.py +311 -0
  23. docs/context/01_original_architecture.md +404 -0
  24. docs/context/02_design_patterns_analysis.md +264 -0
  25. docs/context/03_demo_app_design.md +327 -0
  26. docs/context/04_environment_setup.md +373 -0
  27. docs/context/05_medrag_corpus_integration.md +196 -0
  28. docs/context/06_test_strategy.md +180 -0
  29. docs/context/07_huggingface_spaces_deployment.md +259 -0
  30. main.py +376 -0
  31. requirements-dev.txt +13 -0
  32. requirements.txt +13 -0
  33. scripts/build_sif.sh +15 -0
  34. scripts/run_demo.sh +28 -0
  35. scripts/run_main.sh +22 -0
  36. scripts/run_precompute.sh +25 -0
  37. scripts/run_tests.sh +37 -0
  38. singularity/response_quality.def +45 -0
  39. src/calibration/__init__.py +0 -0
  40. src/calibration/base_calibration.py +19 -0
  41. src/calibration/conditional_conformal.py +278 -0
  42. src/calibration/conformal.py +422 -0
  43. src/calibration/utils.py +161 -0
  44. src/common/__init__.py +0 -0
  45. src/common/chunker.py +41 -0
  46. src/common/config_manager.py +171 -0
  47. src/common/faiss_manager.py +309 -0
  48. src/common/file_manager.py +128 -0
  49. src/common/llm/__init__.py +0 -0
  50. src/common/llm/llm_agent.py +14 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.env.example ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy this file to .env and fill in the values for your environment.
2
+ # .env is listed in .gitignore and must never be committed.
3
+
4
+ # Absolute path to the data root on the host.
5
+ # Heavy files (FAISS indices, raw data, main.py output) are stored here
6
+ # and bind-mounted into the container.
7
+ DATA_ROOT=/mnt/data/your_username/ResponseQualityAssessment
8
+
9
+ # Directory where the Singularity .sif file is stored.
10
+ SIF_DIR=/mnt/data/your_username/sif
11
+
12
+ # HuggingFace model cache directory.
13
+ HF_HOME=/mnt/data/your_username/hf_cache
14
+
15
+ # OpenAI API key (required for live inference in the demo).
16
+ OPENAI_API_KEY=sk-...
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/raw/WikiDB/enwiki-20230401.db
2
+ logs/
3
+ index_store/
4
+ data/
5
+ # demo/data/ is generated by precompute.py but committed to the repo
6
+ !demo/data/
7
+ !demo/data/*
8
+ *.pyc
9
+ __pycache__/
10
+ launch.json
11
+ .env
CLAUDE.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 開発ガイドライン
2
+
3
+ ## プロジェクトの目的
4
+
5
+ RAG の回答品質評価システム。共形予測(Conformal Prediction)を用いて、LLM 回答のサブクレームごとに統計的保証付きのスコアリングを行う。既存のバッチ処理ロジック(`src/` や `main.py`)を活かしつつ、インタラクティブに操作できるデモアプリ(`demo/`)をアジャイルに開発している。
6
+
7
+ ---
8
+
9
+ ## ディレクトリ構成
10
+
11
+ ```
12
+ src/
13
+ ├── calibration/ # 共形予測キャリブレーション
14
+ ├── common/ # 共通コンポーネント(設定・ファイル・LLM 管理)
15
+ ├── data_processor/ # データセット処理パイプライン
16
+ ├── dataloader/ # HuggingFace データローダー
17
+ ├── rag/ # RAG(SQLite ベース文書 DB)
18
+ ├── subclaim_processor/ # 回答品質評価の中核パイプライン
19
+ └── utils/ # 汎用ヘルパー関数
20
+
21
+ demo/
22
+ ├── app.py # Streamlit エントリーポイント
23
+ ├── inference_api.py # 単一クエリ推論ラッパー(src/ を呼び出す)
24
+ └── precompute.py # 閾値の事前計算スクリプト(オフライン実行用)
25
+ ```
26
+
27
+ ---
28
+
29
+ ## コンテキストと設計
30
+
31
+ 詳細は `docs/context/` を参照。
32
+
33
+ | ドキュメント | 内容 |
34
+ |------------|------|
35
+ | `01_original_architecture.md` | アーキテクチャ全体像・データフロー・クラス設計 |
36
+ | `02_design_patterns_analysis.md` | 設計パターンの評価と改善案 |
37
+ | `03_demo_app_design.md` | デモアプリの技術選定・UI 設計・開発ステップ |
38
+ | `04_environment_setup.md` | 環境構築・実行手順(Singularity、`.env` 設定、スクリプト一覧) |
39
+ | `06_test_strategy.md` | テスト計画と実行フロー(ユニットテスト〜UI テスト) |
40
+
41
+ ロジックを変更した際、アーキテクチャに大きな変更が生じた場合は `docs/context/` 内のドキュメントも合わせて更新してください。
42
+
43
+ ---
44
+
45
+ ## コーディングルール
46
+
47
+ ### src/(コアロジック)
48
+ - `src/` 配下のコアロジックの修正・リファクタリングは許可されています。
49
+ - ただし、既存の `main.py` を使ったバッチ処理が壊れないよう、後方互換性を意識するか、影響範囲を報告してください。
50
+
51
+ ### demo/(デモアプリ)
52
+ - デモ特有の UI ロジックやエンドポイントは `demo/` ディレクトリ内に隔離し、`src/` のコアロジックと密結合させないでください。
53
+ - `demo/` から `src/` モジュールを `import` するのみとし、コアロジックは変更しないでください。
54
+ - キャリブレーション閾値($\hat{q}$)は `demo/precompute.py` で事前計算し、JSON/CSV として静的保持してください。デモ実行時にキャリブレーションループを回さないでください。
55
+
56
+ ### Git
57
+ - 機能単位でキリが良いところで `git commit` を提案してください。
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Conformal RAG Quality Assessment Demo
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.43.2
8
+ app_file: demo/app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Response Quality Assessment for Retrieval-Augmented Generation via Conditional Conformal Factuality
13
+
14
+ This repository contains production-ready code and resources for our research for paper "Response Quality Assessment for Retrieval-Augmented
15
+ Generation via Conditional Conformal Factuality" Accepted by SIGIR 2025
16
+
17
+ Arxiv ver: [https://www.arxiv.org/pdf/2506.20978](https://www.arxiv.org/abs/2506.20978)
18
+
19
+ Poster: https://drive.google.com/file/d/1k7fSa4k07dPJUCvsjU0hE0e486T5UNoy/view?usp=sharing
20
+
21
+ ## Table of Contents
22
+ - [Structure](#Structure)
23
+ - [Data](#data)
24
+ - [Query Data](#query-data)
25
+ - [Wikipedia Extraction](#wikipedia-extraction)
26
+ - [Usage](#usage)
27
+ - [References](#references)
28
+ - [More Information](#more-information)
29
+
30
+ ## Structure
31
+
32
+ ```bash
33
+ .
34
+ ├── conf/ # Configuration file location
35
+ ├── data/
36
+ │ ├── out/ # Final subclaims with scores (follows `subclaims_schema`)
37
+ │ ├── processed/ # Standardized test data (follows `base_schema`)
38
+ │ └── raw/ # Original raw data from source (unstructured)
39
+ ├── index_store/ # Chunked documents and embeddings
40
+ ├── logs/ # Config and logs in format `run_{data}_{run_id}`
41
+ ├── src/
42
+ │ ├── calibration/ # Conformal prediction calibration logic
43
+ │ ├── common/ # Reusable components (e.g., config manager, FAISS vector DB manager)
44
+ │ ├── data_processor/ # Converts raw QA data to standardized format (see `data/processed`)
45
+ │ ├── dataloader/ # Loads data from source datasets (e.g., AkariASAI/PopQA, KILT benchmark)
46
+ │ ├── rag/ # RAG system components for document retrieval
47
+ │ ├── subclaim_processor/# Generates, scores, and annotates subclaims for different datasets
48
+ │ └── utils/ # Miscellaneous utilities
49
+ ```
50
+
51
+
52
+ ## Data
53
+ ### Query Data
54
+ This repository includes the following query datasets:
55
+ - [FactScore](https://github.com/shmsw25/FActScore)
56
+ - [PopQA](https://huggingface.co/datasets/akariasai/PopQA)
57
+ - [HotpotQA](https://huggingface.co/datasets/hotpotqa/hotpot_qa)
58
+ - [MedLFQA] (https://github.com/dmis-lab/OLAPH/tree/main/MedLFQA) or (https://github.com/jjcherian/conformal-safety/tree/main/data/MedLFQAv2)
59
+
60
+ ### Wikipedia Extraction
61
+ We utilize Wikipedia dumps for knowledge retrieval:
62
+ - [enwiki-20230401.db](https://drive.google.com/file/d/1mekls6OGOKLmt7gYtHs0WGf5oTamTNat/view?usp=drive_link)
63
+ This file is not included in this github, you could download it through ming's google drive above (source: https://github.com/shmsw25/FActScore) and put it under
64
+ \data\raw folder in order to generate reference doucument for wiki based queries (popqa and hotpotqa)
65
+
66
+ ## Usage
67
+ Project is build on python version 3.11
68
+ First, set up project env using [requirements.txt](requirements.txt).
69
+ To run the pipeline:
70
+ ```python
71
+ python main.py --config conf/config.yaml --dataset fact_score --query_size 500
72
+ ```
73
+ Only 1 dataset at a time in 1 thread.
74
+ avaliable dataset currently are:
75
+ ["fact_score", "hotpot_qa", "pop_qa", "medlf_qa"]
76
+
77
+ ## Conditional Conformal
78
+ This repo only support conditional conformal in medlf_qa dataset. By default the config in /conf/dataset_config.yaml
79
+ The medlf_qa.is_grouped = true while other are set to false
80
+ The factuality result will be put in different csv files under the result/${datetime}_${run_id} folder naming by each different pre-defined group name
81
+
82
+ ## Start with metadata
83
+ In order to not have any OpenAI token comsumption, one can choose use these metadata already produced to just verify conformal prediction part. The result will be stable
84
+ You can get required metadata here: https://drive.google.com/drive/folders/1aLbHxS6V1ipMH8FpVCxKmr8oMYfqmRgb?usp=drive_link
85
+
86
+
87
+ ## More Information
88
+ For further details, please refer to our Paper (link on top)
89
+ The baseline group conditional conformal (https://arxiv.org/abs/2406.09714) result
90
+ for medlfqav2 is produced by their code: github.com/jjcherian/conformal-safety
91
+ and is not in part of this repo.
92
+
93
+ ## License
94
+
95
+ This project is licensed under the [MIT License](https://opensource.org/license/mit).
conf/config.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ name: "pop_qa" # Options: fact_score, hotpot_qa, pop_qa, medlf_qa
3
+ query_size: 500
4
+ wiki_db_file: "enwiki-20230401.db"
5
+
6
+
7
+ # Index configuration
8
+ index:
9
+ delete_existing: false
10
+ embedding_model: "text-embedding-3-large"
11
+ # Truncation strategies by dataset
12
+ truncation_config: # TODO
13
+ strategy: "fixed_length" # false
14
+ truncate_by: null # "\n"
15
+ chunk_size: 2000
16
+ chunk_overlap: 25
17
+
18
+
19
+ # RAG configuration
20
+ rag:
21
+ retrival_topk: 10
22
+ retrival_threshold: 0.3
23
+ response_model: "gpt-4.1-mini"
24
+ response_temperature: 0.7
25
+ fact_generation_model: "gpt-4.1-mini"
26
+
27
+
28
+ # Prediction configuration
29
+ conformal_prediction:
30
+ aggregation_strategy: "mean"
31
+ scoring_strategy: "product"
32
+ claim_verification_model: "gpt-4.1-mini" # "gpt-4o-mini"
33
+ frequency_score_model: "gpt-4.1-mini"
34
+ split_conformal: true
35
+ conformal_alphas:
36
+ start: 0.05
37
+ end: 0.45
38
+ step: 0.05
39
+ a_value: 1.0
40
+
conf/dataset_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+ fact_score:
3
+ name: "FactScore"
4
+ index_store: "${DATA_ROOT}/index_store/FactScore"
5
+ is_grouped: false
6
+ hotpot_qa:
7
+ name: "HotpotQA"
8
+ index_store: "${DATA_ROOT}/index_store/HotpotQA"
9
+ is_grouped: false
10
+ pop_qa:
11
+ name: "PopQA"
12
+ index_store: "${DATA_ROOT}/index_store/PopQA"
13
+ is_grouped: false
14
+ medlf_qa:
15
+ name: "MedLFQA"
16
+ index_store: "${DATA_ROOT}/index_store/MedLFQA"
17
+ is_grouped: true
conf/path_config.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Path configuration
2
+ # Paths are relative to DATA_ROOT (set in .env).
3
+ # The loader must call os.path.expandvars() after yaml.safe_load().
4
+ paths:
5
+ raw_data_dir: "${DATA_ROOT}/data/raw"
6
+ processed_data_dir: "${DATA_ROOT}/data/processed"
7
+ response_dir: "${DATA_ROOT}/data/out"
8
+ wiki_db_dir: "${DATA_ROOT}/data/raw/WikiDB"
9
+ result_dir: "${DATA_ROOT}/data/result"
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/processed/FactScore/fact_score_queries.json ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "input": "What is Lanny Flaherty's occupation?",
4
+ "output": {
5
+ "answer": "",
6
+ "provenance": [
7
+ {
8
+ "title": "Lanny Flaherty"
9
+ }
10
+ ]
11
+ }
12
+ },
13
+ {
14
+ "input": "What is Marianne McAndrew's occupation?",
15
+ "output": {
16
+ "answer": "",
17
+ "provenance": [
18
+ {
19
+ "title": "Marianne McAndrew"
20
+ }
21
+ ]
22
+ }
23
+ },
24
+ {
25
+ "input": "What is Chaim Malinowitz's occupation?",
26
+ "output": {
27
+ "answer": "",
28
+ "provenance": [
29
+ {
30
+ "title": "Chaim Malinowitz"
31
+ }
32
+ ]
33
+ }
34
+ },
35
+ {
36
+ "input": "What is Doug Sheehan's occupation?",
37
+ "output": {
38
+ "answer": "",
39
+ "provenance": [
40
+ {
41
+ "title": "Doug Sheehan"
42
+ }
43
+ ]
44
+ }
45
+ },
46
+ {
47
+ "input": "What is Gerhard Fischer (inventor)'s occupation?",
48
+ "output": {
49
+ "answer": "",
50
+ "provenance": [
51
+ {
52
+ "title": "Gerhard Fischer (inventor)"
53
+ }
54
+ ]
55
+ }
56
+ },
57
+ {
58
+ "input": "What is Chief Jones's occupation?",
59
+ "output": {
60
+ "answer": "",
61
+ "provenance": [
62
+ {
63
+ "title": "Chief Jones"
64
+ }
65
+ ]
66
+ }
67
+ },
68
+ {
69
+ "input": "What is Focus...'s occupation?",
70
+ "output": {
71
+ "answer": "",
72
+ "provenance": [
73
+ {
74
+ "title": "Focus..."
75
+ }
76
+ ]
77
+ }
78
+ },
79
+ {
80
+ "input": "What is Joey D. Vieira's occupation?",
81
+ "output": {
82
+ "answer": "",
83
+ "provenance": [
84
+ {
85
+ "title": "Joey D. Vieira"
86
+ }
87
+ ]
88
+ }
89
+ },
90
+ {
91
+ "input": "What is Taral Hicks's occupation?",
92
+ "output": {
93
+ "answer": "",
94
+ "provenance": [
95
+ {
96
+ "title": "Taral Hicks"
97
+ }
98
+ ]
99
+ }
100
+ },
101
+ {
102
+ "input": "What is Lees Knowles's occupation?",
103
+ "output": {
104
+ "answer": "",
105
+ "provenance": [
106
+ {
107
+ "title": "Lees Knowles"
108
+ }
109
+ ]
110
+ }
111
+ },
112
+ {
113
+ "input": "What is Carolina Portesi Peroni's occupation?",
114
+ "output": {
115
+ "answer": "",
116
+ "provenance": [
117
+ {
118
+ "title": "Carolina Portesi Peroni"
119
+ }
120
+ ]
121
+ }
122
+ },
123
+ {
124
+ "input": "What is Quintus Sosius Senecio's occupation?",
125
+ "output": {
126
+ "answer": "",
127
+ "provenance": [
128
+ {
129
+ "title": "Quintus Sosius Senecio"
130
+ }
131
+ ]
132
+ }
133
+ },
134
+ {
135
+ "input": "What is Joeri Adams's occupation?",
136
+ "output": {
137
+ "answer": "",
138
+ "provenance": [
139
+ {
140
+ "title": "Joeri Adams"
141
+ }
142
+ ]
143
+ }
144
+ },
145
+ {
146
+ "input": "What is Dominic King's occupation?",
147
+ "output": {
148
+ "answer": "",
149
+ "provenance": [
150
+ {
151
+ "title": "Dominic King"
152
+ }
153
+ ]
154
+ }
155
+ },
156
+ {
157
+ "input": "What is Vincenzo Tusa's occupation?",
158
+ "output": {
159
+ "answer": "",
160
+ "provenance": [
161
+ {
162
+ "title": "Vincenzo Tusa"
163
+ }
164
+ ]
165
+ }
166
+ },
167
+ {
168
+ "input": "What is Jean Daull\u00e9's occupation?",
169
+ "output": {
170
+ "answer": "",
171
+ "provenance": [
172
+ {
173
+ "title": "Jean Daull\u00e9"
174
+ }
175
+ ]
176
+ }
177
+ },
178
+ {
179
+ "input": "What is Richard Burkewood Welbourn's occupation?",
180
+ "output": {
181
+ "answer": "",
182
+ "provenance": [
183
+ {
184
+ "title": "Richard Burkewood Welbourn"
185
+ }
186
+ ]
187
+ }
188
+ },
189
+ {
190
+ "input": "What is Damir Memovi\u0107's occupation?",
191
+ "output": {
192
+ "answer": "",
193
+ "provenance": [
194
+ {
195
+ "title": "Damir Memovi\u0107"
196
+ }
197
+ ]
198
+ }
199
+ },
200
+ {
201
+ "input": "What is Mar\u00eda Elena Medina-Mora Icaza's occupation?",
202
+ "output": {
203
+ "answer": "",
204
+ "provenance": [
205
+ {
206
+ "title": "Mar\u00eda Elena Medina-Mora Icaza"
207
+ }
208
+ ]
209
+ }
210
+ },
211
+ {
212
+ "input": "What is Hesham Nazih's occupation?",
213
+ "output": {
214
+ "answer": "",
215
+ "provenance": [
216
+ {
217
+ "title": "Hesham Nazih"
218
+ }
219
+ ]
220
+ }
221
+ },
222
+ {
223
+ "input": "What is Luis Guillermo Rivera's occupation?",
224
+ "output": {
225
+ "answer": "",
226
+ "provenance": [
227
+ {
228
+ "title": "Luis Guillermo Rivera"
229
+ }
230
+ ]
231
+ }
232
+ },
233
+ {
234
+ "input": "What is Desmond Luke's occupation?",
235
+ "output": {
236
+ "answer": "",
237
+ "provenance": [
238
+ {
239
+ "title": "Desmond Luke"
240
+ }
241
+ ]
242
+ }
243
+ },
244
+ {
245
+ "input": "What is Miguel D\u00edaz (baseball)'s occupation?",
246
+ "output": {
247
+ "answer": "",
248
+ "provenance": [
249
+ {
250
+ "title": "Miguel D\u00edaz (baseball)"
251
+ }
252
+ ]
253
+ }
254
+ },
255
+ {
256
+ "input": "What is Roxana D\u00edaz (athlete)'s occupation?",
257
+ "output": {
258
+ "answer": "",
259
+ "provenance": [
260
+ {
261
+ "title": "Roxana D\u00edaz (athlete)"
262
+ }
263
+ ]
264
+ }
265
+ },
266
+ {
267
+ "input": "What is Fernando (footballer, born 1984)'s occupation?",
268
+ "output": {
269
+ "answer": "",
270
+ "provenance": [
271
+ {
272
+ "title": "Fernando (footballer, born 1984)"
273
+ }
274
+ ]
275
+ }
276
+ },
277
+ {
278
+ "input": "What is Gerardo Fernandez Fe's occupation?",
279
+ "output": {
280
+ "answer": "",
281
+ "provenance": [
282
+ {
283
+ "title": "Gerardo Fernandez Fe"
284
+ }
285
+ ]
286
+ }
287
+ },
288
+ {
289
+ "input": "What is Augusto Mart\u00ednez S\u00e1nchez's occupation?",
290
+ "output": {
291
+ "answer": "",
292
+ "provenance": [
293
+ {
294
+ "title": "Augusto Mart\u00ednez S\u00e1nchez"
295
+ }
296
+ ]
297
+ }
298
+ },
299
+ {
300
+ "input": "What is Sharad Kumar (athlete)'s occupation?",
301
+ "output": {
302
+ "answer": "",
303
+ "provenance": [
304
+ {
305
+ "title": "Sharad Kumar (athlete)"
306
+ }
307
+ ]
308
+ }
309
+ },
310
+ {
311
+ "input": "What is Ko Itakura's occupation?",
312
+ "output": {
313
+ "answer": "",
314
+ "provenance": [
315
+ {
316
+ "title": "Ko Itakura"
317
+ }
318
+ ]
319
+ }
320
+ },
321
+ {
322
+ "input": "What is Besant Ravi's occupation?",
323
+ "output": {
324
+ "answer": "",
325
+ "provenance": [
326
+ {
327
+ "title": "Besant Ravi"
328
+ }
329
+ ]
330
+ }
331
+ },
332
+ {
333
+ "input": "What is Thomas Houghton (rugby league)'s occupation?",
334
+ "output": {
335
+ "answer": "",
336
+ "provenance": [
337
+ {
338
+ "title": "Thomas Houghton (rugby league)"
339
+ }
340
+ ]
341
+ }
342
+ },
343
+ {
344
+ "input": "What is Katsunosuke Hori's occupation?",
345
+ "output": {
346
+ "answer": "",
347
+ "provenance": [
348
+ {
349
+ "title": "Katsunosuke Hori"
350
+ }
351
+ ]
352
+ }
353
+ },
354
+ {
355
+ "input": "What is Ra Jong-yil's occupation?",
356
+ "output": {
357
+ "answer": "",
358
+ "provenance": [
359
+ {
360
+ "title": "Ra Jong-yil"
361
+ }
362
+ ]
363
+ }
364
+ },
365
+ {
366
+ "input": "What is Song Kang's occupation?",
367
+ "output": {
368
+ "answer": "",
369
+ "provenance": [
370
+ {
371
+ "title": "Song Kang"
372
+ }
373
+ ]
374
+ }
375
+ },
376
+ {
377
+ "input": "What is Rin Iwanaga's occupation?",
378
+ "output": {
379
+ "answer": "",
380
+ "provenance": [
381
+ {
382
+ "title": "Rin Iwanaga"
383
+ }
384
+ ]
385
+ }
386
+ },
387
+ {
388
+ "input": "What is Keita Kadokura's occupation?",
389
+ "output": {
390
+ "answer": "",
391
+ "provenance": [
392
+ {
393
+ "title": "Keita Kadokura"
394
+ }
395
+ ]
396
+ }
397
+ },
398
+ {
399
+ "input": "What is Lily Branscombe's occupation?",
400
+ "output": {
401
+ "answer": "",
402
+ "provenance": [
403
+ {
404
+ "title": "Lily Branscombe"
405
+ }
406
+ ]
407
+ }
408
+ },
409
+ {
410
+ "input": "What is Luis N. Rivera-Pag\u00e1n's occupation?",
411
+ "output": {
412
+ "answer": "",
413
+ "provenance": [
414
+ {
415
+ "title": "Luis N. Rivera-Pag\u00e1n"
416
+ }
417
+ ]
418
+ }
419
+ },
420
+ {
421
+ "input": "What is Jesse Foppert's occupation?",
422
+ "output": {
423
+ "answer": "",
424
+ "provenance": [
425
+ {
426
+ "title": "Jesse Foppert"
427
+ }
428
+ ]
429
+ }
430
+ },
431
+ {
432
+ "input": "What is Mike Trivisonno's occupation?",
433
+ "output": {
434
+ "answer": "",
435
+ "provenance": [
436
+ {
437
+ "title": "Mike Trivisonno"
438
+ }
439
+ ]
440
+ }
441
+ },
442
+ {
443
+ "input": "What is Zelma Wilson's occupation?",
444
+ "output": {
445
+ "answer": "",
446
+ "provenance": [
447
+ {
448
+ "title": "Zelma Wilson"
449
+ }
450
+ ]
451
+ }
452
+ },
453
+ {
454
+ "input": "What is Michael Valpy's occupation?",
455
+ "output": {
456
+ "answer": "",
457
+ "provenance": [
458
+ {
459
+ "title": "Michael Valpy"
460
+ }
461
+ ]
462
+ }
463
+ },
464
+ {
465
+ "input": "What is Prince Rivers's occupation?",
466
+ "output": {
467
+ "answer": "",
468
+ "provenance": [
469
+ {
470
+ "title": "Prince Rivers"
471
+ }
472
+ ]
473
+ }
474
+ },
475
+ {
476
+ "input": "What is John Estes's occupation?",
477
+ "output": {
478
+ "answer": "",
479
+ "provenance": [
480
+ {
481
+ "title": "John Estes"
482
+ }
483
+ ]
484
+ }
485
+ },
486
+ {
487
+ "input": "What is Eric Hacker's occupation?",
488
+ "output": {
489
+ "answer": "",
490
+ "provenance": [
491
+ {
492
+ "title": "Eric Hacker"
493
+ }
494
+ ]
495
+ }
496
+ },
497
+ {
498
+ "input": "What is Alma Katsu's occupation?",
499
+ "output": {
500
+ "answer": "",
501
+ "provenance": [
502
+ {
503
+ "title": "Alma Katsu"
504
+ }
505
+ ]
506
+ }
507
+ },
508
+ {
509
+ "input": "What is Carlos Alfonso's occupation?",
510
+ "output": {
511
+ "answer": "",
512
+ "provenance": [
513
+ {
514
+ "title": "Carlos Alfonso"
515
+ }
516
+ ]
517
+ }
518
+ },
519
+ {
520
+ "input": "What is Rennie Fritchie, Baroness Fritchie's occupation?",
521
+ "output": {
522
+ "answer": "",
523
+ "provenance": [
524
+ {
525
+ "title": "Rennie Fritchie, Baroness Fritchie"
526
+ }
527
+ ]
528
+ }
529
+ },
530
+ {
531
+ "input": "What is Daniel Charles's occupation?",
532
+ "output": {
533
+ "answer": "",
534
+ "provenance": [
535
+ {
536
+ "title": "Daniel Charles"
537
+ }
538
+ ]
539
+ }
540
+ },
541
+ {
542
+ "input": "What is Zamfir Arbore's occupation?",
543
+ "output": {
544
+ "answer": "",
545
+ "provenance": [
546
+ {
547
+ "title": "Zamfir Arbore"
548
+ }
549
+ ]
550
+ }
551
+ }
552
+ ]
data/processed/HotpotQA/hotpot_qa_queries.json ADDED
The diff for this file is too large to render. See raw diff
 
data/processed/PopQA/pop_qa_queries.json ADDED
The diff for this file is too large to render. See raw diff
 
data/processed/base_schema.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "$schema": "https://json-schema.org/draft-07/schema",
3
+ "type": "object",
4
+ "properties": {
5
+ "input": { "type": "string" },
6
+ "output": {
7
+ "type": "object",
8
+ "properties": {
9
+ "answer": { "type": "string" }
10
+ },
11
+ "required": ["answer"]
12
+ }
13
+ },
14
+ "required": ["input", "output"]
15
+ }
16
+
data/processed/document_schema.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "$schema": "https://json-schema.org/draft-07/schema",
3
+ "type": "object",
4
+ "unevaluatedProperties": {
5
+ "type": "array",
6
+ "minItems": 1,
7
+ "items": {
8
+ "type": "string"
9
+ }
10
+ }
11
+ }
12
+
data/processed/wiki_schema.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "$schema": "https://json-schema.org/draft-07/schema",
3
+ "allOf": [
4
+ { "$ref": "base_schema.json" },
5
+ {
6
+ "properties": {
7
+ "output": {
8
+ "properties": {
9
+ "provenance": {
10
+ "type": "array",
11
+ "items": {
12
+ "type": "object",
13
+ "properties": {
14
+ "wikipedia_id": { "type": "integer" },
15
+ "title": { "type": ["string", "null"] }
16
+ },
17
+ "required": ["title"]
18
+ }
19
+ }
20
+ }
21
+ }
22
+ }
23
+ }
24
+ ]
25
+ }
26
+
demo/__init__.py ADDED
File without changes
demo/app.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo/app.py
2
+ # Streamlit entry point. Run with: streamlit run demo/app.py
3
+ #
4
+ # Prerequisites:
5
+ # 1. FAISS indices built by main.py
6
+ # 2. demo/data/thresholds.csv and demo/data/samples.json built by precompute.py
7
+
8
+ import json
9
+ import os
10
+ import re
11
+ import time
12
+ import pandas as pd
13
+ import streamlit as st
14
+
15
+ # When SPACES_DEMO=1, live inference (FAISS + LLM pipeline) is disabled.
16
+ # Only precomputed sample queries are available. Set this in HF Spaces Secrets.
17
+ SPACES_DEMO: bool = os.getenv("SPACES_DEMO", "0") == "1"
18
+
19
+ from demo.constants import ALPHA_LEVELS, DEFAULT_SCORING_METHOD
20
+ from demo.inference_api import (
21
+ SubclaimResult,
22
+ FilteredResult,
23
+ build_faiss_manager,
24
+ build_scorer,
25
+ process_query,
26
+ apply_threshold,
27
+ reintegrate_subclaims,
28
+ )
29
+
30
+ # ── Constants ─────────────────────────────────────────────────────────────────
31
+
32
+ DATASET_LABELS: dict[str, str] = {
33
+ "fact_score": "FactScore",
34
+ "hotpot_qa": "HotpotQA",
35
+ "pop_qa": "PopQA",
36
+ "medlf_qa": "MedLFQA",
37
+ }
38
+
39
+ MEDLF_GROUPS: list[str] = [
40
+ "healthsearch_qa",
41
+ "medication_qa",
42
+ "kqa_silver_wogold",
43
+ "kqa_golden",
44
+ "live_qa",
45
+ ]
46
+
47
+ # Slider options in ascending order (95 → highest factuality at the right)
48
+ FACTUALITY_OPTIONS: list[int] = sorted(
49
+ int(round((1 - a) * 100)) for a in ALPHA_LEVELS
50
+ )
51
+
52
+ THRESHOLDS_PATH = "demo/data/thresholds.csv"
53
+ SAMPLES_PATH = "demo/data/samples.json"
54
+
55
+ # ── Cached resources ──────────────────────────────────────────────────────────
56
+
57
+ @st.cache_resource
58
+ def get_faiss_manager(dataset: str):
59
+ return build_faiss_manager(dataset)
60
+
61
+
62
+ @st.cache_resource
63
+ def get_scorer(dataset: str):
64
+ return build_scorer(dataset)
65
+
66
+
67
+ @st.cache_data
68
+ def load_thresholds() -> pd.DataFrame:
69
+ return pd.read_csv(THRESHOLDS_PATH)
70
+
71
+
72
+ @st.cache_data
73
+ def load_samples() -> list[dict]:
74
+ with open(SAMPLES_PATH, encoding="utf-8") as f:
75
+ return json.load(f)["samples"]
76
+
77
+
78
+ # ── Helpers ───────────────────────────────────────────────────────────────────
79
+
80
+ def _lookup_q_hat(
81
+ df: pd.DataFrame, dataset: str, mode: str, group: str, alpha: float
82
+ ) -> float:
83
+ alpha = round(alpha, 2)
84
+ row = df.query(
85
+ "dataset == @dataset and mode == @mode and "
86
+ "scoring_method == @DEFAULT_SCORING_METHOD and "
87
+ "group == @group and alpha == @alpha"
88
+ )
89
+ if row.empty:
90
+ raise ValueError(
91
+ f"Threshold not found for "
92
+ f"({dataset}, {mode}, {DEFAULT_SCORING_METHOD}, {group}, α={alpha}). "
93
+ "Re-run precompute.py."
94
+ )
95
+ return float(row["q_hat"].iloc[0])
96
+
97
+
98
+ def _render_subclaims(result: SubclaimResult, q_hat: float) -> str:
99
+ """Return HTML with keep/remove highlighting for each subclaim."""
100
+ parts = []
101
+ for sc in result["subclaims"]:
102
+ text = sc["text"]
103
+ if sc["score"] >= q_hat:
104
+ parts.append(
105
+ f'<span style="padding:2px 5px; display:inline-block;">'
106
+ f"{text}</span>"
107
+ )
108
+ else:
109
+ parts.append(
110
+ f'<span style="text-decoration:line-through; color:#dc3545; '
111
+ f'padding:2px 5px; display:inline-block;">'
112
+ f"{text}</span>"
113
+ )
114
+ return " ".join(parts)
115
+
116
+
117
+ def _extract_page_content(doc: str) -> str:
118
+ """Extract page_content text from FAISSIndexManager search result string.
119
+
120
+ Strips the source query prefix that is prepended in the format
121
+ "source_text: document_text".
122
+ """
123
+ m = re.search(r"page_content='(.*?)'\s+metadata=\{", doc, re.DOTALL)
124
+ if not m:
125
+ return doc
126
+ content = m.group(1)
127
+ source_m = re.search(r"'source':\s*'(.*?)'", doc)
128
+ if source_m:
129
+ prefix = source_m.group(1) + ": "
130
+ if content.startswith(prefix):
131
+ content = content[len(prefix):]
132
+ return content
133
+
134
+
135
+ def _params_changed(dataset: str, mode: str, group: str) -> bool:
136
+ prev = st.session_state.get("prev_params")
137
+ return prev != (dataset, mode, group)
138
+
139
+
140
+ def _save_params(dataset: str, mode: str, group: str) -> None:
141
+ st.session_state.prev_params = (dataset, mode, group)
142
+
143
+
144
+ # ── App ───────────────────────────────────────────────────────────────────────
145
+
146
+ def main() -> None:
147
+ st.set_page_config(page_title="Conformal RAG Demo", layout="wide")
148
+ st.title("Conformal RAG — 回答品質評価デモ")
149
+
150
+ # Session state defaults
151
+ st.session_state.setdefault("query_processed", False)
152
+ st.session_state.setdefault("result", None)
153
+ st.session_state.setdefault("elapsed_sec", None)
154
+ st.session_state.setdefault("prev_params", None)
155
+ st.session_state.setdefault("integrated_answer", None)
156
+ st.session_state.setdefault("integrated_alpha", None)
157
+
158
+ # Load static data (cached after first call)
159
+ try:
160
+ thresholds_df = load_thresholds()
161
+ samples = load_samples()
162
+ except FileNotFoundError as e:
163
+ st.error(
164
+ f"データファイルが見つかりません: {e}\n\n"
165
+ "`python -m demo.precompute` を実行してください。"
166
+ )
167
+ return
168
+
169
+ # ── Sidebar ──────────────────────────────────────────────────────────────
170
+
171
+ with st.sidebar:
172
+ st.header("設定")
173
+
174
+ # Dataset
175
+ dataset: str = st.selectbox(
176
+ "データセット",
177
+ options=list(DATASET_LABELS.keys()),
178
+ format_func=lambda k: DATASET_LABELS[k],
179
+ )
180
+ is_medlf = dataset == "medlf_qa"
181
+
182
+ # Mode (MedLFQA only)
183
+ mode_raw: str = st.radio(
184
+ "モード",
185
+ options=["marginal", "conditional"],
186
+ format_func=lambda m: (
187
+ "Marginal(全体)" if m == "marginal" else "Conditional(グループ別)"
188
+ ),
189
+ disabled=not is_medlf,
190
+ )
191
+ mode: str = mode_raw if is_medlf else "marginal"
192
+
193
+ # Group (MedLFQA + Conditional only)
194
+ is_conditional = is_medlf and mode == "conditional"
195
+ if is_conditional:
196
+ group: str = st.selectbox("グループ", options=MEDLF_GROUPS)
197
+ else:
198
+ group = "default"
199
+
200
+ # Clear cached result when key params change
201
+ if _params_changed(dataset, mode, group):
202
+ st.session_state.query_processed = False
203
+ st.session_state.result = None
204
+ st.session_state.elapsed_sec = None
205
+ st.session_state.integrated_answer = None
206
+ st.session_state.integrated_alpha = None
207
+ _save_params(dataset, mode, group)
208
+
209
+ st.divider()
210
+
211
+ # Query input
212
+ dataset_samples = [
213
+ s for s in samples
214
+ if s["dataset"] == dataset
215
+ and (not is_conditional or s.get("group") == group)
216
+ ]
217
+ sample_map: dict[str, dict] = {s["query"]: s for s in dataset_samples}
218
+
219
+ query_input: str = st.selectbox(
220
+ "サンプルクエリ", options=list(sample_map.keys())
221
+ )
222
+
223
+ st.divider()
224
+
225
+ run_btn = st.button("推論実行", use_container_width=True, type="primary")
226
+
227
+ # Factuality slider (active only after inference)
228
+ target_factuality_pct: int = st.select_slider(
229
+ "Target Factuality (1−α)",
230
+ options=FACTUALITY_OPTIONS,
231
+ value=90,
232
+ format_func=lambda v: f"{v}%",
233
+ disabled=not st.session_state.query_processed,
234
+ )
235
+
236
+ alpha = round((100 - target_factuality_pct) / 100, 2)
237
+
238
+ # ── Heavy processing (button click only) ─────────────────────────────────
239
+
240
+ if run_btn and query_input:
241
+ st.session_state.integrated_answer = None
242
+ st.session_state.integrated_alpha = None
243
+ # Use precomputed result if available for this exact (mode, group)
244
+ precomputed = sample_map.get(query_input)
245
+ if (
246
+ precomputed is not None
247
+ and (
248
+ mode == "marginal"
249
+ or (precomputed["mode"] == mode and precomputed["group"] == group)
250
+ )
251
+ ):
252
+ st.session_state.result = precomputed
253
+ st.session_state.elapsed_sec = None
254
+ st.session_state.query_processed = True
255
+ st.rerun()
256
+ elif SPACES_DEMO:
257
+ st.error("このデモではサンプルクエリのみ対応しています。")
258
+ else:
259
+ t0 = time.perf_counter()
260
+ with st.spinner("推論中… (30〜60 秒ほどかかります)"):
261
+ faiss_manager = get_faiss_manager(dataset)
262
+ scorer = get_scorer(dataset)
263
+ st.session_state.result = process_query(
264
+ query=query_input,
265
+ dataset=dataset,
266
+ mode=mode,
267
+ group=group,
268
+ faiss_manager=faiss_manager,
269
+ scorer=scorer,
270
+ )
271
+ st.session_state.elapsed_sec = time.perf_counter() - t0
272
+ st.session_state.query_processed = True
273
+ st.rerun()
274
+
275
+ # ── Results panel ─────────────────────────────────────────────────────────
276
+
277
+ if not st.session_state.query_processed or st.session_state.result is None:
278
+ st.info("サイドバーからクエリを選択して「推論実行」��押してください。")
279
+ return
280
+
281
+ result: SubclaimResult = st.session_state.result
282
+
283
+ try:
284
+ q_hat = _lookup_q_hat(thresholds_df, dataset, mode, group, alpha)
285
+ except ValueError as e:
286
+ st.error(str(e))
287
+ return
288
+
289
+ filtered: FilteredResult = apply_threshold(result, q_hat)
290
+
291
+ # Query
292
+ st.subheader("クエリ")
293
+ st.markdown(result["query"])
294
+
295
+ st.divider()
296
+
297
+ # Original answer (left) | Re-integrated answer (right)
298
+ col_orig, col_integ = st.columns(2)
299
+
300
+ with col_orig:
301
+ st.subheader("元の RAG 回答")
302
+ st.markdown(result["rag_answer"])
303
+
304
+ with col_integ:
305
+ st.subheader("フィルタ後の回答")
306
+
307
+ if st.session_state.integrated_answer is not None:
308
+ if st.session_state.integrated_alpha != alpha:
309
+ st.warning(
310
+ f"α={st.session_state.integrated_alpha:.2f} 時点の結果です。"
311
+ "「回答を生成」を再度押して更新してください。"
312
+ )
313
+ st.markdown(st.session_state.integrated_answer)
314
+
315
+ if st.button("回答を生成", type="secondary"):
316
+ removed = [sc["text"].strip() for sc in result["subclaims"] if sc["score"] < q_hat]
317
+ with st.spinner("回答を生成中…"):
318
+ st.session_state.integrated_answer = reintegrate_subclaims(
319
+ rag_answer=result["rag_answer"],
320
+ removed_subclaims=removed,
321
+ )
322
+ st.session_state.integrated_alpha = alpha
323
+ st.rerun()
324
+
325
+ st.divider()
326
+
327
+ # Subclaims with keep/remove highlighting
328
+ st.subheader("Conformal-RAG による回答")
329
+ html = _render_subclaims(result, q_hat)
330
+ st.markdown(html, unsafe_allow_html=True)
331
+
332
+ st.divider()
333
+
334
+ # Retrieved documents
335
+ with st.expander(
336
+ f"検索されたナレッジ({len(result['retrieved_docs'])} 件)",
337
+ expanded=False,
338
+ ):
339
+ idx = 1
340
+ for doc in result["retrieved_docs"]:
341
+ content = _extract_page_content(doc)
342
+ if len(content) < 15:
343
+ continue
344
+ st.markdown(f"**[{idx}]** {content}")
345
+ idx += 1
346
+
347
+ # Metrics
348
+ st.subheader("分析メトリクス")
349
+ cols = st.columns(4 if st.session_state.elapsed_sec is None else 5)
350
+ cols[0].metric("閾値 q̂", f"{q_hat:.4f}")
351
+ cols[1].metric("サブクレーム数", len(result["subclaims"]))
352
+ cols[2].metric("Keep", filtered["keep_count"])
353
+ cols[3].metric(
354
+ "削除率",
355
+ f"{filtered['remove_count'] / max(len(result['subclaims']), 1) * 100:.0f}%",
356
+ )
357
+ if st.session_state.elapsed_sec is not None:
358
+ cols[4].metric("推論時間", f"{st.session_state.elapsed_sec:.1f} 秒")
359
+
360
+
361
+ if __name__ == "__main__":
362
+ main()
demo/constants.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo/constants.py — shared constants; precompute.py and app.py must import from here.
2
+
3
+ # ── Scoring ──────────────────────────────────────────────────────────────────
4
+
5
+ # To switch scoring methods, change this line and re-run precompute.py.
6
+ DEFAULT_SCORING_METHOD: str = "relevance"
7
+
8
+ # Maps clean scoring method names to the keys used in subclaim["scores"]
9
+ # in data files produced by main.py.
10
+ # "relavance" is a typo in the original codebase — preserved for compatibility.
11
+ SCORING_METHOD_TO_DATA_KEY: dict[str, str] = {
12
+ "relevance": "relavance",
13
+ "cosine_similarity": "query_claim_cosine_similarity",
14
+ "min_log_prob": "min_log_prob",
15
+ }
16
+
17
+ # ── Conformal prediction ─────────────────────────────────────────────────────
18
+
19
+ # alpha = 1 - target_factuality.
20
+ # precompute.py computes thresholds for exactly these values;
21
+ # app.py passes this list to st.select_slider.
22
+ # → Target Factuality: 95%, 90%, 85%, 80%, 75%, 70%, 65%, 60%
23
+ ALPHA_LEVELS: list[float] = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40]
demo/data/samples.json ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "generated_at": "2026-03-29T17:28:52.797800+00:00",
4
+ "model": "gpt-4.1-mini",
5
+ "scoring_method": "relevance"
6
+ },
7
+ "samples": [
8
+ {
9
+ "query": "Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?",
10
+ "dataset": "medlf_qa",
11
+ "mode": "conditional",
12
+ "group": "live_qa",
13
+ "scoring_method": "relevance",
14
+ "rag_answer": "Yes, you can stop using the menopausal hormone therapy patch after 4.5 months, but it is generally recommended that attempts to taper or discontinue the medication be made at intervals of 3 to 6 months. It is important to consider your symptoms and discuss with your healthcare provider the best approach for stopping the therapy safely.",
15
+ "retrieved_docs": [
16
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: For the treatment of vasomotor symptoms, treatment should be initiated with the 6' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=21 fileposition=21 score=0.8116",
17
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: Attempts to taper or discontinue the medication should be made at 3- to 6-month intervals' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=26 fileposition=26 score=0.7919",
18
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: Six (6' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=16 fileposition=16 score=0.7711",
19
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: The dose should be adjusted as necessary to control symptoms' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=24 fileposition=24 score=0.7696",
20
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: 5, 9' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=17 fileposition=17 score=0.7678",
21
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: This system is applied to the skin once weekly for the treatment of vasomotor symptoms related to menopause.' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=28 fileposition=28 score=0.7627",
22
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: 375, 12' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=18 fileposition=18 score=0.7588",
23
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: 5 cm2 (0' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=22 fileposition=22 score=0.7534",
24
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: 5, 15, 18' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=19 fileposition=19 score=0.7519",
25
+ "page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: Clinical responses (relief of symptoms) at the lowest effective dose should be the guide for establishing administration of the Climara system, especially in women with an intact uterus' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=25 fileposition=25 score=0.7484"
26
+ ],
27
+ "subclaims": [
28
+ {
29
+ "id": 0,
30
+ "text": "You can stop using the menopausal hormone therapy patch after 4.5 months",
31
+ "score": 0.5595795678160917
32
+ },
33
+ {
34
+ "id": 1,
35
+ "text": " It is generally recommended to attempt tapering or discontinuing the medication at intervals of 3 to 6 months",
36
+ "score": 0.31861125529509804
37
+ },
38
+ {
39
+ "id": 2,
40
+ "text": " It is important to consider your symptoms when stopping menopausal hormone therapy",
41
+ "score": 0.3732023054183166
42
+ },
43
+ {
44
+ "id": 3,
45
+ "text": " You should discuss with your healthcare provider the best approach for safely stopping the therapy",
46
+ "score": 0.2703604207878165
47
+ }
48
+ ]
49
+ },
50
+ {
51
+ "query": "Would you recommend something a bit stronger than meloxicam?",
52
+ "dataset": "medlf_qa",
53
+ "mode": "conditional",
54
+ "group": "kqa_golden",
55
+ "scoring_method": "relevance",
56
+ "rag_answer": "If you are looking for something stronger than meloxicam, options include combining drugs from different classes or considering opioid medications. Opioids such as tramadol (Conzip) and oxycodone (Oxyado, Oxycontin) are more potent pain relievers but come with significant risks, including addiction, toxicity, and adverse cognitive and respiratory effects. It's important to consult with a physician before changing your treatment regimen to carefully weigh the benefits and risks.",
57
+ "retrieved_docs": [
58
+ "page_content='Would you recommend something a bit stronger than meloxicam?: For a stronger pain relief effect, combining two drugs from different classes or medications from the opioid group is possible' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=41 fileposition=41 score=0.7034",
59
+ "page_content='Would you recommend something a bit stronger than meloxicam?: NSAIDs are generally more potent but carry a risk of harming the kidneys and other side effects' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=38 fileposition=38 score=0.6924",
60
+ "page_content='Would you recommend something a bit stronger than meloxicam?: In case of acute and severe pain, NSAIDs like Meloxicam are often recommended' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=40 fileposition=40 score=0.6871",
61
+ "page_content='Would you recommend something a bit stronger than meloxicam?: Opioids such as Tramadol (Conzip) and Oxycodone (Oxyado, Oxycontin) are the most potent but carry a significant risk of addiction, toxicity, and cognitive and respiratory adverse effects' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=39 fileposition=39 score=0.6863",
62
+ "page_content='Would you recommend something a bit stronger than meloxicam?: It is essential to consult with a physician before changing treatment regimens due to the risk of addiction and potential side effects' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=42 fileposition=42 score=0.6812",
63
+ "page_content='Would you recommend something a bit stronger than meloxicam?: There are generally three main types of drugs used for pain management: acetaminophen (paracetamol), NSAIDs (like Meloxicam), and opioids' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=36 fileposition=36 score=0.6810",
64
+ "page_content='Would you recommend something a bit stronger than meloxicam?: NSAIDs like Meloxicam carry a risk of harming the kidneys and other side effects' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=46 fileposition=46 score=0.6799",
65
+ "page_content='Would you recommend something a bit stronger than meloxicam?: Opioids such as Tramadol (Conzip) and Oxycodone (Oxyado, Oxycontin) are the most potent pain management drugs' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=47 fileposition=47 score=0.6730",
66
+ "page_content='Would you recommend something a bit stronger than meloxicam?: NSAIDs like Meloxicam are generally more potent than acetaminophen' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=45 fileposition=45 score=0.6688",
67
+ "page_content='Would you recommend something a bit stronger than meloxicam?: Acetaminophen (paracetamol) is the least potent but has the safest side effect profile' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=37 fileposition=37 score=0.6409"
68
+ ],
69
+ "subclaims": [
70
+ {
71
+ "id": 0,
72
+ "text": "If you are looking for something stronger than meloxicam, options include combining drugs from different classes",
73
+ "score": 0.38080484418115046
74
+ },
75
+ {
76
+ "id": 1,
77
+ "text": " Opioids such as tramadol (Conzip) and oxycodone (Oxyado, Oxycontin) are more potent pain relievers",
78
+ "score": 0.2872582910028152
79
+ },
80
+ {
81
+ "id": 2,
82
+ "text": " Opioids come with significant risks including addiction, toxicity, and adverse cognitive effects",
83
+ "score": 0.19291751764168302
84
+ },
85
+ {
86
+ "id": 3,
87
+ "text": " Opioids also have adverse respiratory effects",
88
+ "score": 0.1868863743184725
89
+ },
90
+ {
91
+ "id": 4,
92
+ "text": " It is important to consult with a physician before changing your treatment regimen",
93
+ "score": 0.1655506771407161
94
+ },
95
+ {
96
+ "id": 5,
97
+ "text": " Consulting a physician helps to carefully weigh the benefits and risks of the treatment",
98
+ "score": 0.17235453684696078
99
+ }
100
+ ]
101
+ },
102
+ {
103
+ "query": "how soon does losartan afffect blood pressure",
104
+ "dataset": "medlf_qa",
105
+ "mode": "conditional",
106
+ "group": "medication_qa",
107
+ "scoring_method": "relevance",
108
+ "rag_answer": "The clinical effect of losartan on blood pressure typically becomes apparent within two weeks. Peak effects are observed around 6 hours after taking the medication, with these peak effects being moderately larger than the trough (lowest) effects.",
109
+ "retrieved_docs": [
110
+ "page_content='how soon does losartan afffect blood pressure: Hypertension … Peak (6 hour) effects were uniformly, but moderately, larger than trough effects,' metadata={'source': 'how soon does losartan afffect blood pressure', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=51 fileposition=51 score=0.7362",
111
+ "page_content='how soon does losartan afffect blood pressure: The peak effects of losartan were uniformly, but moderately, larger than trough effects, with the trough-to-peak ratio for systolic and diastolic responses between 50-95%.' metadata={'source': 'how soon does losartan afffect blood pressure', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=52 fileposition=52 score=0.7016",
112
+ "page_content='how soon does losartan afffect blood pressure: Clinical effect of losartan typically becomes apparent within two weeks.' metadata={'source': 'how soon does losartan afffect blood pressure', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=53 fileposition=53 score=0.6849"
113
+ ],
114
+ "subclaims": [
115
+ {
116
+ "id": 0,
117
+ "text": "The clinical effect of losartan on blood pressure typically becomes apparent within two weeks",
118
+ "score": 0.46106826938943024
119
+ },
120
+ {
121
+ "id": 1,
122
+ "text": " Peak effects of losartan are observed around 6 hours after taking the medication",
123
+ "score": 0.42004030890343347
124
+ },
125
+ {
126
+ "id": 2,
127
+ "text": " The peak effects of losartan are moderately larger than the trough (lowest) effects",
128
+ "score": 0.4264968324332668
129
+ }
130
+ ]
131
+ },
132
+ {
133
+ "query": "Are there any signs I should look out for if Lyme develops and I should follow up?",
134
+ "dataset": "medlf_qa",
135
+ "mode": "conditional",
136
+ "group": "kqa_silver_wogold",
137
+ "scoring_method": "relevance",
138
+ "rag_answer": "Yes, there are several signs to watch for if Lyme disease develops after a tick bite, and you should follow up with your healthcare provider if you notice any of these symptoms:\n\n- Early symptoms (within a week to a few weeks after the bite) may include:\n - Fever\n - Fatigue\n - Headache\n - Muscle and joint aches\n - Swollen lymph nodes\n - A characteristic skin rash called erythema migrans, which often looks like a bull’s-eye on the skin\n\n- If left untreated, symptoms can progress to more severe issues such as:\n - Severe fatigue\n - Stiff, aching neck\n - Severe headaches\n - Irregularities in heartbeat\n - Changes in vision\n - Inflammatory arthritis, often affecting the knee\n\nIf you experience any of these signs or symptoms after suspected or known exposure to ticks, it’s important to schedule an appointment with your healthcare provider for evaluation and possible treatment. Regular check-ins can be helpful to monitor for any developments.",
139
+ "retrieved_docs": [
140
+ "page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: Regular check-ins with a healthcare provider are beneficial for managing and observing the potential developments of Lyme disease following suspected or known tick exposure.' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=68 fileposition=68 score=0.7716",
141
+ "page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: If you exhibit any of these signs or symptoms after suspected or known exposure to ticks, schedule an appointment with your healthcare provider' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=66 fileposition=66 score=0.7560",
142
+ "page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: If such an infection develops, early signs of Lyme disease may occur within a week to a few weeks of a tick bite' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=62 fileposition=62 score=0.7273",
143
+ "page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: If left untreated, symptoms can progress to severe fatigue, a stiff, aching neck, and severe headaches' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=64 fileposition=64 score=0.7108",
144
+ "page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: One may experience irregularities in heartbeat, changes in vision, or an inflammatory arthritis often affecting the knee as well' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=65 fileposition=65 score=0.6972",
145
+ "page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: These early symptoms can include: fever, fatigue, headache, muscle and joint aches, swollen lymph nodes, and a characteristic skin rash called erythema migrans, which often looks like a bull’s-eye on the skin' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=63 fileposition=63 score=0.6907",
146
+ "page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: Lyme disease is a bacterial infection that can be transmitted to humans through the bite of infected black-legged ticks' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=61 fileposition=61 score=0.6902",
147
+ "page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: Lyme disease is a bacterial infection that is transmitted to humans through the bite of infected black-legged ticks.' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=67 fileposition=67 score=0.6879"
148
+ ],
149
+ "subclaims": [
150
+ {
151
+ "id": 0,
152
+ "text": "Lyme disease can develop after a tick bite",
153
+ "score": 0.36037872147634314
154
+ },
155
+ {
156
+ "id": 1,
157
+ "text": " Early symptoms of Lyme disease may appear within a week to a few weeks after the tick bite",
158
+ "score": 0.3590103639552498
159
+ },
160
+ {
161
+ "id": 2,
162
+ "text": " Early symptoms of Lyme disease include fever",
163
+ "score": 0.330641627741826
164
+ },
165
+ {
166
+ "id": 3,
167
+ "text": " Early symptoms of Lyme disease include fatigue",
168
+ "score": 0.3326410987371023
169
+ },
170
+ {
171
+ "id": 4,
172
+ "text": " Early symptoms of Lyme disease include headache",
173
+ "score": 0.3409522381113656
174
+ },
175
+ {
176
+ "id": 5,
177
+ "text": " Early symptoms of Lyme disease include muscle and joint aches",
178
+ "score": 0.31654988182829147
179
+ },
180
+ {
181
+ "id": 6,
182
+ "text": " Early symptoms of Lyme disease include swollen lymph nodes",
183
+ "score": 0.3134207503110557
184
+ },
185
+ {
186
+ "id": 7,
187
+ "text": " Early symptoms of Lyme disease include a characteristic skin rash called erythema migrans",
188
+ "score": 0.31260172708098244
189
+ },
190
+ {
191
+ "id": 8,
192
+ "text": " Erythema migrans often looks like a bull’s-eye on the skin",
193
+ "score": 0.2383296001897222
194
+ },
195
+ {
196
+ "id": 9,
197
+ "text": " If Lyme disease is left untreated, symptoms can progress to more severe issues",
198
+ "score": 0.3353740959350072
199
+ },
200
+ {
201
+ "id": 10,
202
+ "text": " Severe issues from untreated Lyme disease include severe fatigue",
203
+ "score": 0.3008999281508631
204
+ },
205
+ {
206
+ "id": 11,
207
+ "text": " Severe issues from untreated Lyme disease include stiff, aching neck",
208
+ "score": 0.32575460132098455
209
+ },
210
+ {
211
+ "id": 12,
212
+ "text": " Severe issues from untreated Lyme disease include severe headaches",
213
+ "score": 0.31177354749866565
214
+ },
215
+ {
216
+ "id": 13,
217
+ "text": " Severe issues from untreated Lyme disease include irregularities in heartbeat",
218
+ "score": 0.31990244889782615
219
+ },
220
+ {
221
+ "id": 14,
222
+ "text": " Severe issues from untreated Lyme disease include changes in vision",
223
+ "score": 0.30950684605158113
224
+ },
225
+ {
226
+ "id": 15,
227
+ "text": " Severe issues from untreated Lyme disease include inflammatory arthritis",
228
+ "score": 0.3049160752333055
229
+ },
230
+ {
231
+ "id": 16,
232
+ "text": " Inflammatory arthritis from Lyme disease often affects the knee",
233
+ "score": 0.2874386307287703
234
+ },
235
+ {
236
+ "id": 17,
237
+ "text": " If you experience signs or symptoms after tick exposure, you should schedule an appointment with your healthcare provider",
238
+ "score": 0.3763904855114536
239
+ },
240
+ {
241
+ "id": 18,
242
+ "text": " Regular check-ins with a healthcare provider can help monitor for developments in Lyme disease",
243
+ "score": 0.3957411848526053
244
+ }
245
+ ]
246
+ },
247
+ {
248
+ "query": "How do adults usually get botulism?",
249
+ "dataset": "medlf_qa",
250
+ "mode": "conditional",
251
+ "group": "healthsearch_qa",
252
+ "scoring_method": "relevance",
253
+ "rag_answer": "Adults usually get botulism in one of three main ways: \n\n1. **Foodborne botulism**: This is the most common form in adults and occurs when they eat foods contaminated with the botulinum toxin. This often happens with home-canned, home-preserved, or fermented foods prepared under conditions that encourage bacterial growth.\n\n2. **Wound botulism**: This occurs when the bacteria enter a wound, which is more common in people who inject drugs.\n\n3. **Inhalation botulism**: This is very rare and can happen if the pure toxin is released into the air, such as during laboratory accidents or if used as a biological weapon.\n\nOverall, foodborne botulism is the most frequent way adults contract the illness.",
254
+ "retrieved_docs": [
255
+ "page_content='How do adults usually get botulism?: In adults, botulism can be contracted in a few different ways: through food, wound botulism, and inhalation botulism' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=78 fileposition=78 score=0.6993",
256
+ "page_content='How do adults usually get botulism?: Foodborne botulism is the most common form in adults and occurs when adults eat foods that are contaminated with the botulism toxin' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=79 fileposition=79 score=0.6891",
257
+ "page_content='How do adults usually get botulism?: Wound botulism happens when the bacteria enter a wound, usually in people who inject drugs' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=81 fileposition=81 score=0.6884",
258
+ "page_content='How do adults usually get botulism?: Botulism is a rare but serious illness caused by a toxin that attacks the body's nerves' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=77 fileposition=77 score=0.6836",
259
+ "page_content='How do adults usually get botulism?: This typically happens with home-canned, home-preserved, or fermented foods that have been prepared in conditions that encourage bacterial growth' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=80 fileposition=80 score=0.6832",
260
+ "page_content='How do adults usually get botulism?: In very rare instances, inhalation botulism can occur when the pure toxin is released into the air, such as during laboratory accidents or if used as a biological weapon.' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=83 fileposition=83 score=0.6773",
261
+ "page_content='How do adults usually get botulism?: Inhalation botulism is very rare and can occur when the pure toxin is released into the air as it can in laboratory accidents or if used as a biological weapon' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=82 fileposition=82 score=0.6543"
262
+ ],
263
+ "subclaims": [
264
+ {
265
+ "id": 0,
266
+ "text": "Adults usually get botulism in one of three main ways",
267
+ "score": 0.42031683919779217
268
+ },
269
+ {
270
+ "id": 1,
271
+ "text": " Foodborne botulism is the most common form in adults",
272
+ "score": 0.39594970883087005
273
+ },
274
+ {
275
+ "id": 2,
276
+ "text": " Foodborne botulism occurs when adults eat foods contaminated with the botulinum toxin",
277
+ "score": 0.3942247114653798
278
+ },
279
+ {
280
+ "id": 3,
281
+ "text": " Foodborne botulism often happens with home-canned, home-preserved, or fermented foods prepared under conditions that encourage bacterial growth",
282
+ "score": 0.3146562102397841
283
+ },
284
+ {
285
+ "id": 4,
286
+ "text": " Wound botulism occurs when the bacteria enter a wound",
287
+ "score": 0.2937099827086362
288
+ },
289
+ {
290
+ "id": 5,
291
+ "text": " Wound botulism is more common in people who inject drugs",
292
+ "score": 0.29189200145131067
293
+ },
294
+ {
295
+ "id": 6,
296
+ "text": " Inhalation botulism is very rare",
297
+ "score": 0.3344457211496001
298
+ },
299
+ {
300
+ "id": 7,
301
+ "text": " Inhalation botulism can happen if the pure toxin is released into the air",
302
+ "score": 0.3194226107069425
303
+ },
304
+ {
305
+ "id": 8,
306
+ "text": " Inhalation botulism can occur during laboratory accidents",
307
+ "score": 0.3357236089778982
308
+ },
309
+ {
310
+ "id": 9,
311
+ "text": " Inhalation botulism can occur if used as a biological weapon",
312
+ "score": 0.31917884005382047
313
+ },
314
+ {
315
+ "id": 10,
316
+ "text": " Foodborne botulism is the most frequent way adults contract botulism",
317
+ "score": 0.3974304925095534
318
+ }
319
+ ]
320
+ }
321
+ ]
322
+ }
demo/data/thresholds.csv ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset,mode,scoring_method,group,alpha,q_hat
2
+ medlf_qa,marginal,relevance,default,0.05,0.26235458225908953
3
+ medlf_qa,marginal,relevance,default,0.1,0.26235458225908953
4
+ medlf_qa,marginal,relevance,default,0.15,0.26235458225908953
5
+ medlf_qa,marginal,relevance,default,0.2,0.26235458225908953
6
+ medlf_qa,marginal,relevance,default,0.25,0.26235458225908953
7
+ medlf_qa,marginal,relevance,default,0.3,0.26235458225908953
8
+ medlf_qa,marginal,relevance,default,0.35,-1
9
+ medlf_qa,marginal,relevance,default,0.4,-1
10
+ medlf_qa,conditional,relevance,live_qa,0.05,0.26235458225908953
11
+ medlf_qa,conditional,relevance,live_qa,0.1,0.26235458225908953
12
+ medlf_qa,conditional,relevance,live_qa,0.15,0.26235458225908953
13
+ medlf_qa,conditional,relevance,live_qa,0.2,0.26235458225908953
14
+ medlf_qa,conditional,relevance,live_qa,0.25,0.26235458225908953
15
+ medlf_qa,conditional,relevance,live_qa,0.3,0.26235458225908953
16
+ medlf_qa,conditional,relevance,live_qa,0.35,0.26235458225908953
17
+ medlf_qa,conditional,relevance,live_qa,0.4,0.26235458225908953
18
+ medlf_qa,conditional,relevance,kqa_golden,0.05,-1
19
+ medlf_qa,conditional,relevance,kqa_golden,0.1,-1
20
+ medlf_qa,conditional,relevance,kqa_golden,0.15,-1
21
+ medlf_qa,conditional,relevance,kqa_golden,0.2,-1
22
+ medlf_qa,conditional,relevance,kqa_golden,0.25,-1
23
+ medlf_qa,conditional,relevance,kqa_golden,0.3,-1
24
+ medlf_qa,conditional,relevance,kqa_golden,0.35,-1
25
+ medlf_qa,conditional,relevance,kqa_golden,0.4,-1
26
+ medlf_qa,conditional,relevance,medication_qa,0.05,-1
27
+ medlf_qa,conditional,relevance,medication_qa,0.1,-1
28
+ medlf_qa,conditional,relevance,medication_qa,0.15,-1
29
+ medlf_qa,conditional,relevance,medication_qa,0.2,-1
30
+ medlf_qa,conditional,relevance,medication_qa,0.25,-1
31
+ medlf_qa,conditional,relevance,medication_qa,0.3,-1
32
+ medlf_qa,conditional,relevance,medication_qa,0.35,-1
33
+ medlf_qa,conditional,relevance,medication_qa,0.4,-1
34
+ medlf_qa,conditional,relevance,kqa_silver_wogold,0.05,-1
35
+ medlf_qa,conditional,relevance,kqa_silver_wogold,0.1,-1
36
+ medlf_qa,conditional,relevance,kqa_silver_wogold,0.15,-1
37
+ medlf_qa,conditional,relevance,kqa_silver_wogold,0.2,-1
38
+ medlf_qa,conditional,relevance,kqa_silver_wogold,0.25,-1
39
+ medlf_qa,conditional,relevance,kqa_silver_wogold,0.3,-1
40
+ medlf_qa,conditional,relevance,kqa_silver_wogold,0.35,-1
41
+ medlf_qa,conditional,relevance,kqa_silver_wogold,0.4,-1
42
+ medlf_qa,conditional,relevance,healthsearch_qa,0.05,-1
43
+ medlf_qa,conditional,relevance,healthsearch_qa,0.1,-1
44
+ medlf_qa,conditional,relevance,healthsearch_qa,0.15,-1
45
+ medlf_qa,conditional,relevance,healthsearch_qa,0.2,-1
46
+ medlf_qa,conditional,relevance,healthsearch_qa,0.25,-1
47
+ medlf_qa,conditional,relevance,healthsearch_qa,0.3,-1
48
+ medlf_qa,conditional,relevance,healthsearch_qa,0.35,-1
49
+ medlf_qa,conditional,relevance,healthsearch_qa,0.4,-1
50
+ medlf_qa,marginal,cosine_similarity,default,0.05,0.3992308077305861
51
+ medlf_qa,marginal,cosine_similarity,default,0.1,0.3992308077305861
52
+ medlf_qa,marginal,cosine_similarity,default,0.15,0.3992308077305861
53
+ medlf_qa,marginal,cosine_similarity,default,0.2,0.3992308077305861
54
+ medlf_qa,marginal,cosine_similarity,default,0.25,0.3992308077305861
55
+ medlf_qa,marginal,cosine_similarity,default,0.3,0.3992308077305861
56
+ medlf_qa,marginal,cosine_similarity,default,0.35,-1
57
+ medlf_qa,marginal,cosine_similarity,default,0.4,-1
58
+ medlf_qa,conditional,cosine_similarity,live_qa,0.05,0.3992308077305861
59
+ medlf_qa,conditional,cosine_similarity,live_qa,0.1,0.3992308077305861
60
+ medlf_qa,conditional,cosine_similarity,live_qa,0.15,0.3992308077305861
61
+ medlf_qa,conditional,cosine_similarity,live_qa,0.2,0.3992308077305861
62
+ medlf_qa,conditional,cosine_similarity,live_qa,0.25,0.3992308077305861
63
+ medlf_qa,conditional,cosine_similarity,live_qa,0.3,0.3992308077305861
64
+ medlf_qa,conditional,cosine_similarity,live_qa,0.35,0.3992308077305861
65
+ medlf_qa,conditional,cosine_similarity,live_qa,0.4,0.3992308077305861
66
+ medlf_qa,conditional,cosine_similarity,kqa_golden,0.05,-1
67
+ medlf_qa,conditional,cosine_similarity,kqa_golden,0.1,-1
68
+ medlf_qa,conditional,cosine_similarity,kqa_golden,0.15,-1
69
+ medlf_qa,conditional,cosine_similarity,kqa_golden,0.2,-1
70
+ medlf_qa,conditional,cosine_similarity,kqa_golden,0.25,-1
71
+ medlf_qa,conditional,cosine_similarity,kqa_golden,0.3,-1
72
+ medlf_qa,conditional,cosine_similarity,kqa_golden,0.35,-1
73
+ medlf_qa,conditional,cosine_similarity,kqa_golden,0.4,-1
74
+ medlf_qa,conditional,cosine_similarity,medication_qa,0.05,-1
75
+ medlf_qa,conditional,cosine_similarity,medication_qa,0.1,-1
76
+ medlf_qa,conditional,cosine_similarity,medication_qa,0.15,-1
77
+ medlf_qa,conditional,cosine_similarity,medication_qa,0.2,-1
78
+ medlf_qa,conditional,cosine_similarity,medication_qa,0.25,-1
79
+ medlf_qa,conditional,cosine_similarity,medication_qa,0.3,-1
80
+ medlf_qa,conditional,cosine_similarity,medication_qa,0.35,-1
81
+ medlf_qa,conditional,cosine_similarity,medication_qa,0.4,-1
82
+ medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.05,-1
83
+ medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.1,-1
84
+ medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.15,-1
85
+ medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.2,-1
86
+ medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.25,-1
87
+ medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.3,-1
88
+ medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.35,-1
89
+ medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.4,-1
90
+ medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.05,-1
91
+ medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.1,-1
92
+ medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.15,-1
93
+ medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.2,-1
94
+ medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.25,-1
95
+ medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.3,-1
96
+ medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.35,-1
97
+ medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.4,-1
98
+ medlf_qa,marginal,min_log_prob,default,0.05,1.002567365930604
99
+ medlf_qa,marginal,min_log_prob,default,0.1,1.002567365930604
100
+ medlf_qa,marginal,min_log_prob,default,0.15,1.002567365930604
101
+ medlf_qa,marginal,min_log_prob,default,0.2,1.002567365930604
102
+ medlf_qa,marginal,min_log_prob,default,0.25,1.002567365930604
103
+ medlf_qa,marginal,min_log_prob,default,0.3,1.002567365930604
104
+ medlf_qa,marginal,min_log_prob,default,0.35,-1
105
+ medlf_qa,marginal,min_log_prob,default,0.4,-1
106
+ medlf_qa,conditional,min_log_prob,live_qa,0.05,1.002567365930604
107
+ medlf_qa,conditional,min_log_prob,live_qa,0.1,1.002567365930604
108
+ medlf_qa,conditional,min_log_prob,live_qa,0.15,1.002567365930604
109
+ medlf_qa,conditional,min_log_prob,live_qa,0.2,1.002567365930604
110
+ medlf_qa,conditional,min_log_prob,live_qa,0.25,1.002567365930604
111
+ medlf_qa,conditional,min_log_prob,live_qa,0.3,1.002567365930604
112
+ medlf_qa,conditional,min_log_prob,live_qa,0.35,1.002567365930604
113
+ medlf_qa,conditional,min_log_prob,live_qa,0.4,1.002567365930604
114
+ medlf_qa,conditional,min_log_prob,kqa_golden,0.05,-1
115
+ medlf_qa,conditional,min_log_prob,kqa_golden,0.1,-1
116
+ medlf_qa,conditional,min_log_prob,kqa_golden,0.15,-1
117
+ medlf_qa,conditional,min_log_prob,kqa_golden,0.2,-1
118
+ medlf_qa,conditional,min_log_prob,kqa_golden,0.25,-1
119
+ medlf_qa,conditional,min_log_prob,kqa_golden,0.3,-1
120
+ medlf_qa,conditional,min_log_prob,kqa_golden,0.35,-1
121
+ medlf_qa,conditional,min_log_prob,kqa_golden,0.4,-1
122
+ medlf_qa,conditional,min_log_prob,medication_qa,0.05,-1
123
+ medlf_qa,conditional,min_log_prob,medication_qa,0.1,-1
124
+ medlf_qa,conditional,min_log_prob,medication_qa,0.15,-1
125
+ medlf_qa,conditional,min_log_prob,medication_qa,0.2,-1
126
+ medlf_qa,conditional,min_log_prob,medication_qa,0.25,-1
127
+ medlf_qa,conditional,min_log_prob,medication_qa,0.3,-1
128
+ medlf_qa,conditional,min_log_prob,medication_qa,0.35,-1
129
+ medlf_qa,conditional,min_log_prob,medication_qa,0.4,-1
130
+ medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.05,-1
131
+ medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.1,-1
132
+ medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.15,-1
133
+ medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.2,-1
134
+ medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.25,-1
135
+ medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.3,-1
136
+ medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.35,-1
137
+ medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.4,-1
138
+ medlf_qa,conditional,min_log_prob,healthsearch_qa,0.05,-1
139
+ medlf_qa,conditional,min_log_prob,healthsearch_qa,0.1,-1
140
+ medlf_qa,conditional,min_log_prob,healthsearch_qa,0.15,-1
141
+ medlf_qa,conditional,min_log_prob,healthsearch_qa,0.2,-1
142
+ medlf_qa,conditional,min_log_prob,healthsearch_qa,0.25,-1
143
+ medlf_qa,conditional,min_log_prob,healthsearch_qa,0.3,-1
144
+ medlf_qa,conditional,min_log_prob,healthsearch_qa,0.35,-1
145
+ medlf_qa,conditional,min_log_prob,healthsearch_qa,0.4,-1
demo/inference_api.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo/inference_api.py
2
+ # Single-query inference wrapper around src/ core logic.
3
+ # Does NOT modify src/ — imports only.
4
+ #
5
+ # Heavy resources (FAISSIndexManager, SubclaimScorer) are constructed here
6
+ # via build_faiss_manager() / build_scorer() and cached in app.py with
7
+ # @st.cache_resource, so they survive Streamlit re-runs.
8
+
9
+ import os
10
+ import re
11
+ import glob
12
+ import yaml
13
+ from typing import TYPE_CHECKING, TypedDict
14
+
15
+ if TYPE_CHECKING:
16
+ # FAISSIndexManager and SubclaimScorer transitively import torch/transformers
17
+ # via langchain_text_splitters (file_manager.py). To avoid loading these
18
+ # heavy packages at module import time (e.g. in HF Spaces where FAISS is
19
+ # unused), we guard them here and do the real import inside each build_*()
20
+ # function instead. String annotations ("FAISSIndexManager") ensure the
21
+ # type hints remain valid at runtime without triggering the import.
22
+ from src.common.faiss_manager import FAISSIndexManager
23
+ from src.subclaim_processor.scorer.subclaim_scorer import SubclaimScorer
24
+
25
+
26
+ def _expand_env_vars(obj):
27
+ """Recursively expand environment variables in string values."""
28
+ if isinstance(obj, dict):
29
+ return {k: _expand_env_vars(v) for k, v in obj.items()}
30
+ if isinstance(obj, list):
31
+ return [_expand_env_vars(v) for v in obj]
32
+ if isinstance(obj, str):
33
+ return os.path.expandvars(obj)
34
+ return obj
35
+
36
+ from src.common.llm.openai_rag_agent import OpenAIRAGAgent
37
+ from src.common.llm.openai_atomicfact_generator import OpenAIAtomicFactGenerator
38
+ from src.common.llm.openai_llm_agent import OpenAILLMAgent
39
+ from demo.constants import DEFAULT_SCORING_METHOD
40
+
41
+ # ── Config loading (module-level cache) ──────────────────────────────────────
42
+
43
+ _main_config: dict | None = None
44
+ _dataset_config: dict | None = None
45
+
46
+
47
+ def _load_main_config() -> dict:
48
+ global _main_config
49
+ if _main_config is None:
50
+ with open("conf/config.yaml") as f:
51
+ _main_config = yaml.safe_load(f)
52
+ return _main_config
53
+
54
+
55
+ def _load_dataset_config() -> dict:
56
+ global _dataset_config
57
+ if _dataset_config is None:
58
+ with open("conf/dataset_config.yaml") as f:
59
+ _dataset_config = _expand_env_vars(yaml.safe_load(f))
60
+ return _dataset_config
61
+
62
+
63
+ # ── FAISS index path resolution ───────────────────────────────────────────────
64
+
65
+ def _get_index_paths(dataset: str) -> tuple[str, str]:
66
+ """Resolve index_path and indice2fm_path for a dataset.
67
+
68
+ Scans the index_store directory for files created by main.py
69
+ (index_{query_size}.faiss) and picks the one with the largest query_size.
70
+ """
71
+ dataset_cfg = _load_dataset_config()
72
+ if dataset not in dataset_cfg["datasets"]:
73
+ raise ValueError(f"Unknown dataset: '{dataset}'. "
74
+ f"Available: {list(dataset_cfg['datasets'].keys())}")
75
+ index_store = dataset_cfg["datasets"][dataset]["index_store"]
76
+
77
+ faiss_files = glob.glob(f"{index_store}/index_*.faiss")
78
+ if not faiss_files:
79
+ raise FileNotFoundError(
80
+ f"No FAISS index found in '{index_store}'. "
81
+ "Run main.py first to build the index."
82
+ )
83
+
84
+ def _query_size(path: str) -> int:
85
+ m = re.search(r"index_(\d+)\.faiss$", path)
86
+ return int(m.group(1)) if m else 0
87
+
88
+ index_path = max(faiss_files, key=_query_size)
89
+ indice2fm_path = re.sub(r"index_(\d+)\.faiss$", r"indice2fm_\1.json", index_path)
90
+ return index_path, indice2fm_path
91
+
92
+
93
+ # ── Resource builders (call from app.py wrapped with @st.cache_resource) ─────
94
+
95
+ def build_faiss_manager(dataset: str) -> "FAISSIndexManager":
96
+ """Build a FAISSIndexManager for the given dataset.
97
+
98
+ Intended to be wrapped with @st.cache_resource in app.py.
99
+ FAISSIndexManager is imported here (not at module level) to avoid pulling
100
+ torch/transformers via langchain_text_splitters at startup.
101
+ """
102
+ from src.common.faiss_manager import FAISSIndexManager # noqa: PLC0415
103
+ config = _load_main_config()
104
+ index_path, indice2fm_path = _get_index_paths(dataset)
105
+ return FAISSIndexManager(
106
+ index_truncation_config=config["index"]["truncation_config"],
107
+ index_path=index_path,
108
+ indice2fm_path=indice2fm_path,
109
+ )
110
+
111
+
112
+ def build_scorer(dataset: str) -> "SubclaimScorer":
113
+ """Build a SubclaimScorer for the given dataset.
114
+
115
+ SubclaimScorer loads its own FAISS index internally, so wrap this with
116
+ @st.cache_resource in app.py. Only needed for "relevance" and
117
+ "cosine_similarity" scoring methods.
118
+ SubclaimScorer is imported here (not at module level) for the same reason
119
+ as FAISSIndexManager — to avoid pulling torch/transformers at startup.
120
+ """
121
+ from src.subclaim_processor.scorer.subclaim_scorer import SubclaimScorer # noqa: PLC0415
122
+ config = _load_main_config()
123
+ index_path, indice2fm_path = _get_index_paths(dataset)
124
+ return SubclaimScorer(
125
+ index_truncation_config=config["index"]["truncation_config"],
126
+ embedding_model=config["index"]["embedding_model"],
127
+ index_path=index_path,
128
+ indice2fm_path=indice2fm_path,
129
+ frequency_score_model=config["conformal_prediction"]["frequency_score_model"],
130
+ )
131
+
132
+
133
+ # ── TypedDict definitions ─────────────────────────────────────────────────────
134
+
135
+ class Subclaim(TypedDict):
136
+ id: int
137
+ text: str
138
+ score: float # semantics depend on SubclaimResult.scoring_method
139
+
140
+
141
+ class SubclaimResult(TypedDict):
142
+ query: str
143
+ dataset: str
144
+ mode: str # "marginal" or "conditional"
145
+ group: str # group name, or "default" for marginal / non-grouped datasets
146
+ scoring_method: str # e.g. "relevance"; used as key for threshold lookup in thresholds.csv
147
+ rag_answer: str
148
+ retrieved_docs: list[str]
149
+ subclaims: list[Subclaim]
150
+
151
+
152
+ class FilteredResult(TypedDict):
153
+ subclaims: list[Subclaim]
154
+ q_hat: float
155
+ keep_count: int
156
+ remove_count: int
157
+
158
+
159
+ # ── Score computation ─────────────────────────────────────────────────────────
160
+
161
+ def _compute_score(
162
+ scoring_method: str,
163
+ subclaim_text: str,
164
+ subclaim_probs: list[tuple[str, float]], # (token, probability) — probability = exp(log_prob)
165
+ query: str,
166
+ retrieved_docs: list[str],
167
+ scorer: "SubclaimScorer | None",
168
+ ) -> float:
169
+ if scoring_method == "relevance":
170
+ if scorer is None:
171
+ raise ValueError("scorer is required for scoring_method='relevance'")
172
+ config = _load_main_config()
173
+ cp_config = config["conformal_prediction"]
174
+ return float(scorer.score(
175
+ claim=subclaim_text,
176
+ retrieved_docs=retrieved_docs,
177
+ aggregation_strategy=cp_config["aggregation_strategy"],
178
+ scoring_strategy=cp_config["scoring_strategy"],
179
+ ))
180
+
181
+ elif scoring_method == "cosine_similarity":
182
+ if scorer is None:
183
+ raise ValueError("scorer is required for scoring_method='cosine_similarity'")
184
+ return float(scorer.cosine_similarity(subclaim_text, query))
185
+
186
+ elif scoring_method == "min_log_prob":
187
+ # min(probability) = least confident token; higher value = more confident subclaim
188
+ probs = [p for _, p in subclaim_probs]
189
+ return min(probs) if probs else 0.0
190
+
191
+ else:
192
+ raise ValueError(
193
+ f"Unknown scoring method: '{scoring_method}'. "
194
+ "Supported: 'relevance', 'cosine_similarity', 'min_log_prob'"
195
+ )
196
+
197
+
198
+ # ── Main inference functions ──────────────────────────────────────────────────
199
+
200
+ def process_query(
201
+ query: str,
202
+ dataset: str,
203
+ mode: str,
204
+ group: str,
205
+ faiss_manager: "FAISSIndexManager",
206
+ scorer: "SubclaimScorer | None" = None,
207
+ scoring_method: str = DEFAULT_SCORING_METHOD,
208
+ ) -> SubclaimResult:
209
+ """Run the full inference pipeline for a single query.
210
+
211
+ Heavy: makes OpenAI API calls. Trigger only on button click and cache the
212
+ result in st.session_state.result.
213
+
214
+ scorer is required when scoring_method is "relevance" or "cosine_similarity".
215
+ Pass a @st.cache_resource-wrapped instance from app.py.
216
+ """
217
+ config = _load_main_config()
218
+ rag_cfg = config["rag"]
219
+
220
+ # 1. FAISS search
221
+ retrieved_docs: list[str] = faiss_manager.search_faiss_index(
222
+ query,
223
+ top_k=rag_cfg["retrival_topk"],
224
+ threshold=rag_cfg["retrival_threshold"],
225
+ )
226
+
227
+ # 2. RAG answer generation
228
+ rag_agent = OpenAIRAGAgent(faiss_manager, model=rag_cfg["response_model"])
229
+ response = rag_agent.answer(
230
+ query,
231
+ retrieved_docs,
232
+ temperature=rag_cfg["response_temperature"],
233
+ n_samples=1,
234
+ )
235
+ rag_answer: str = response.choices[0].message.content
236
+
237
+ # 3. Subclaim decomposition (returns zip iterator of (text, [(token, prob)]))
238
+ generator = OpenAIAtomicFactGenerator(model=rag_cfg["fact_generation_model"])
239
+ subclaims_with_probs = list(generator.get_facts_from_text(rag_answer))
240
+
241
+ # 4. Score each subclaim
242
+ subclaims: list[Subclaim] = []
243
+ for i, (text, token_probs) in enumerate(subclaims_with_probs):
244
+ score = _compute_score(
245
+ scoring_method=scoring_method,
246
+ subclaim_text=text,
247
+ subclaim_probs=token_probs,
248
+ query=query,
249
+ retrieved_docs=retrieved_docs,
250
+ scorer=scorer,
251
+ )
252
+ subclaims.append(Subclaim(id=i, text=text, score=score))
253
+
254
+ return SubclaimResult(
255
+ query=query,
256
+ dataset=dataset,
257
+ mode=mode,
258
+ group=group,
259
+ scoring_method=scoring_method,
260
+ rag_answer=rag_answer,
261
+ retrieved_docs=retrieved_docs,
262
+ subclaims=subclaims,
263
+ )
264
+
265
+
266
+ _REINTEGRATE_INSTRUCTION = (
267
+ "You are given an original answer and a list of statements to remove from it. "
268
+ "Rewrite the answer by removing information corresponding to the listed statements "
269
+ "while keeping the remaining text fluent and natural. "
270
+ "Return only the revised answer without any explanation."
271
+ )
272
+
273
+
274
+ def reintegrate_subclaims(
275
+ rag_answer: str,
276
+ removed_subclaims: list[str],
277
+ model: str = "gpt-4o-mini",
278
+ ) -> str:
279
+ """Rewrite rag_answer by removing the specified subclaims via LLM."""
280
+ if not removed_subclaims:
281
+ return rag_answer
282
+ removed_text = "\n".join(f"- {t}" for t in removed_subclaims)
283
+ question = f"Original answer:\n{rag_answer}\n\nStatements to remove:\n{removed_text}"
284
+ agent = OpenAILLMAgent(instruction=_REINTEGRATE_INSTRUCTION, model=model)
285
+ return agent.answer(question)
286
+
287
+
288
+ def apply_threshold(result: SubclaimResult, q_hat: float) -> FilteredResult:
289
+ """Apply a conformal threshold to a cached SubclaimResult.
290
+
291
+ Light: no API calls. Safe to run on every slider move.
292
+ Subclaims with score >= q_hat are kept.
293
+ """
294
+ subclaims = result["subclaims"]
295
+ keep_count = sum(1 for sc in subclaims if sc["score"] >= q_hat)
296
+ remove_count = len(subclaims) - keep_count
297
+ return FilteredResult(
298
+ subclaims=subclaims,
299
+ q_hat=q_hat,
300
+ keep_count=keep_count,
301
+ remove_count=remove_count,
302
+ )
demo/precompute.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo/precompute.py
2
+ # Offline script — run once before launching the demo.
3
+ #
4
+ # Produces:
5
+ # demo/data/thresholds.csv — conformal thresholds for every
6
+ # (dataset, mode, scoring_method, group, alpha)
7
+ # demo/data/samples.json — pre-computed SubclaimResult objects for
8
+ # sample queries (no API calls at demo runtime)
9
+ #
10
+ # Usage:
11
+ # python -m demo.precompute
12
+ #
13
+ # Prerequisites:
14
+ # Run main.py for each dataset first to generate FAISS indices and the
15
+ # subclaim data files in data/out/.
16
+
17
+ import csv
18
+ import glob
19
+ import json
20
+ import os
21
+ import yaml
22
+ from collections import defaultdict
23
+ from datetime import datetime, timezone
24
+
25
+
26
+ def _expand_env_vars(obj):
27
+ """Recursively expand environment variables in string values."""
28
+ if isinstance(obj, dict):
29
+ return {k: _expand_env_vars(v) for k, v in obj.items()}
30
+ if isinstance(obj, list):
31
+ return [_expand_env_vars(v) for v in obj]
32
+ if isinstance(obj, str):
33
+ return os.path.expandvars(obj)
34
+ return obj
35
+
36
+ from src.calibration.utils import compute_threshold, load_subclaim_data
37
+ from demo.constants import (
38
+ ALPHA_LEVELS,
39
+ DEFAULT_SCORING_METHOD,
40
+ SCORING_METHOD_TO_DATA_KEY,
41
+ )
42
+
43
+ # ── Config ────────────────────────────────────────────────────────────────────
44
+
45
+ _main_cfg = yaml.safe_load(open("conf/config.yaml"))
46
+ _path_cfg = _expand_env_vars(yaml.safe_load(open("conf/path_config.yaml")))
47
+ _dataset_cfg = _expand_env_vars(yaml.safe_load(open("conf/dataset_config.yaml")))
48
+
49
+ A_VALUE: float = _main_cfg["conformal_prediction"]["a_value"]
50
+ RESPONSE_MODEL: str = _main_cfg["rag"]["response_model"]
51
+
52
+ OUTPUT_THRESHOLDS = "demo/data/thresholds.csv"
53
+ OUTPUT_SAMPLES = "demo/data/samples.json"
54
+
55
+ # Number of entries held out per group (or per dataset for non-grouped) that
56
+ # are reserved exclusively for samples.json and excluded from threshold
57
+ # calibration. Conformal guarantees require calibration and evaluation data
58
+ # to be disjoint, so these entries must never appear in compute_thresholds().
59
+ HOLDOUT_PER_GROUP: int = 1
60
+
61
+ # ── Helpers ───────────────────────────────────────────────────────────────────
62
+
63
+ def _split_calibration_holdout(
64
+ data: list[dict], is_grouped: bool
65
+ ) -> tuple[list[dict], list[dict]]:
66
+ """Split entries into (calibration, holdout) sets.
67
+
68
+ holdout entries are reserved for samples.json and excluded from threshold
69
+ calibration to prevent data leakage. When a group has only HOLDOUT_PER_GROUP
70
+ or fewer entries it is kept entirely in calibration and the holdout for that
71
+ group is empty (a warning is printed).
72
+
73
+ Returns (calib_data, holdout_data).
74
+ """
75
+ if not is_grouped:
76
+ if len(data) > HOLDOUT_PER_GROUP:
77
+ return data[:-HOLDOUT_PER_GROUP], data[-HOLDOUT_PER_GROUP:]
78
+ print(f" WARNING: only {len(data)} entries — too few to split; all used for calibration, holdout is empty")
79
+ return data, []
80
+
81
+ grouped = _group_data(data)
82
+ calib: list[dict] = []
83
+ holdout: list[dict] = []
84
+ for grp, grp_entries in grouped.items():
85
+ if len(grp_entries) > HOLDOUT_PER_GROUP:
86
+ calib.extend(grp_entries[:-HOLDOUT_PER_GROUP])
87
+ holdout.extend(grp_entries[-HOLDOUT_PER_GROUP:])
88
+ else:
89
+ print(f" WARNING: group '{grp}' has only {len(grp_entries)} entries — too few to split; all used for calibration")
90
+ calib.extend(grp_entries)
91
+ return calib, holdout
92
+
93
+
94
+ def _find_calibration_file(dataset: str) -> str:
95
+ """Return the path to the latest subclaim data file for a dataset.
96
+
97
+ Looks for files matching the naming convention used by main.py:
98
+ data/out/{DatasetName}/{dataset}_{query_size}_subclaims_with_scores_{model}.json
99
+ """
100
+ full_name = _dataset_cfg["datasets"][dataset]["name"]
101
+ response_dir = os.path.join(_path_cfg["paths"]["response_dir"], full_name)
102
+ pattern = os.path.join(response_dir, f"{dataset}_*_subclaims_with_scores_*.json")
103
+ files = glob.glob(pattern)
104
+ if not files:
105
+ raise FileNotFoundError(
106
+ f"No subclaim data file found for dataset '{dataset}' "
107
+ f"(searched: {pattern}). Run main.py first."
108
+ )
109
+ return max(files, key=os.path.getmtime)
110
+
111
+
112
+ def _group_data(data: list[dict]) -> dict[str, list[dict]]:
113
+ """Group entries by their first group label."""
114
+ groups: dict[str, list[dict]] = defaultdict(list)
115
+ for entry in data:
116
+ grp = entry["groups"][0] if entry.get("groups") else "default"
117
+ groups[grp].append(entry)
118
+ return dict(groups)
119
+
120
+
121
+ def _has_score(entry: dict, data_key: str) -> bool:
122
+ return all(
123
+ data_key in sc.get("scores", {})
124
+ for sc in entry.get("subclaims", [])
125
+ )
126
+
127
+
128
+ # ── A. Threshold computation ──────────────────────────────────────────────────
129
+
130
+ def compute_thresholds() -> None:
131
+ """Compute conformal thresholds and write demo/data/thresholds.csv.
132
+
133
+ Iterates over all datasets, scoring methods, modes, groups, and alpha
134
+ values defined in constants.py. Each row corresponds to one lookup key
135
+ used by apply_threshold() at demo runtime.
136
+
137
+ CSV schema: dataset, mode, scoring_method, group, alpha, q_hat
138
+ """
139
+ os.makedirs(os.path.dirname(OUTPUT_THRESHOLDS), exist_ok=True)
140
+ rows: list[dict] = []
141
+
142
+ for dataset, ds_info in _dataset_cfg["datasets"].items():
143
+ is_grouped = ds_info.get("is_grouped", False)
144
+ print(f"\n[{dataset}] Loading calibration data…")
145
+
146
+ try:
147
+ calib_path = _find_calibration_file(dataset)
148
+ except FileNotFoundError as e:
149
+ print(f" SKIP: {e}")
150
+ continue
151
+
152
+ data = load_subclaim_data(calib_path)
153
+ calib_data, holdout_data = _split_calibration_holdout(data, is_grouped)
154
+ print(f" Loaded {len(data)} entries from {calib_path} (calib={len(calib_data)}, holdout={len(holdout_data)})")
155
+
156
+ for method_name, data_key in SCORING_METHOD_TO_DATA_KEY.items():
157
+ valid = [e for e in calib_data if _has_score(e, data_key)]
158
+ if not valid:
159
+ print(f" SKIP scoring_method='{method_name}': no entries with key '{data_key}'")
160
+ continue
161
+
162
+ # ── Marginal mode ────────────────────────────────────────────────
163
+ for alpha in ALPHA_LEVELS:
164
+ q_hat = compute_threshold(alpha, valid, A_VALUE, data_key)
165
+ rows.append({
166
+ "dataset": dataset,
167
+ "mode": "marginal",
168
+ "scoring_method": method_name,
169
+ "group": "default",
170
+ "alpha": alpha,
171
+ "q_hat": q_hat,
172
+ })
173
+
174
+ print(f" marginal / {method_name}: {len(ALPHA_LEVELS)} thresholds computed")
175
+
176
+ # ── Conditional mode (grouped datasets only) ─────────────────────
177
+ if is_grouped:
178
+ grouped = _group_data(valid)
179
+ for grp, grp_data in grouped.items():
180
+ for alpha in ALPHA_LEVELS:
181
+ q_hat = compute_threshold(alpha, grp_data, A_VALUE, data_key)
182
+ rows.append({
183
+ "dataset": dataset,
184
+ "mode": "conditional",
185
+ "scoring_method": method_name,
186
+ "group": grp,
187
+ "alpha": alpha,
188
+ "q_hat": q_hat,
189
+ })
190
+ print(f" conditional / {method_name}: {len(grouped)} groups × {len(ALPHA_LEVELS)} thresholds computed")
191
+
192
+ # Write CSV
193
+ fieldnames = ["dataset", "mode", "scoring_method", "group", "alpha", "q_hat"]
194
+ with open(OUTPUT_THRESHOLDS, "w", newline="") as f:
195
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
196
+ writer.writeheader()
197
+ writer.writerows(rows)
198
+
199
+ print(f"\nThresholds saved → {OUTPUT_THRESHOLDS} ({len(rows)} rows)")
200
+
201
+
202
+ # ── B. Sample query precomputation ────────────────────────────────────────────
203
+
204
+ def _entry_to_subclaim_result(entry: dict, dataset: str, is_grouped: bool) -> dict:
205
+ """Convert an existing subclaim data entry to SubclaimResult format.
206
+
207
+ Uses DEFAULT_SCORING_METHOD and the corresponding data key so that the
208
+ stored score is directly comparable with the thresholds in thresholds.csv.
209
+ """
210
+ data_key = SCORING_METHOD_TO_DATA_KEY[DEFAULT_SCORING_METHOD]
211
+ group = entry["groups"][0] if is_grouped and entry.get("groups") else "default"
212
+ mode = "conditional" if is_grouped else "marginal"
213
+
214
+ subclaims = [
215
+ {
216
+ "id": i,
217
+ "text": sc["subclaim"],
218
+ # Add noise so the score is in the same space as q_hat.
219
+ # q_hat is calibrated using (score + noise) in calibration/utils.py.
220
+ "score": float(sc["scores"].get(data_key, 0.0))
221
+ + float(sc["scores"].get("noise", 0.0)),
222
+ }
223
+ for i, sc in enumerate(entry.get("subclaims", []))
224
+ ]
225
+
226
+ return {
227
+ "query": entry["query"],
228
+ "dataset": dataset,
229
+ "mode": mode,
230
+ "group": group,
231
+ "scoring_method": DEFAULT_SCORING_METHOD,
232
+ "rag_answer": entry["response"],
233
+ "retrieved_docs": entry["retrieved_docs"],
234
+ "subclaims": subclaims,
235
+ }
236
+
237
+
238
+ def _pick_entries(data: list[dict], is_grouped: bool) -> list[dict]:
239
+ """Pick entries with non-empty subclaims from the holdout set.
240
+
241
+ data is expected to be the holdout portion returned by
242
+ _split_calibration_holdout(), so at most HOLDOUT_PER_GROUP entries per
243
+ group are available.
244
+ """
245
+ has_subclaims = [e for e in data if e.get("subclaims")]
246
+
247
+ if not is_grouped:
248
+ return has_subclaims[:HOLDOUT_PER_GROUP]
249
+
250
+ grouped = _group_data(has_subclaims)
251
+ picked = []
252
+ for grp_entries in grouped.values():
253
+ picked.extend(grp_entries[:HOLDOUT_PER_GROUP])
254
+ return picked
255
+
256
+
257
+ def compute_samples() -> None:
258
+ """Precompute SubclaimResult objects for sample queries.
259
+
260
+ Converts existing subclaim data (produced by main.py) to SubclaimResult
261
+ format without making any API calls. Results are saved with metadata so
262
+ it is easy to detect when re-generation is needed (e.g. after a model
263
+ change).
264
+ """
265
+ os.makedirs(os.path.dirname(OUTPUT_SAMPLES), exist_ok=True)
266
+ samples: list[dict] = []
267
+
268
+ for dataset, ds_info in _dataset_cfg["datasets"].items():
269
+ is_grouped = ds_info.get("is_grouped", False)
270
+ print(f"\n[{dataset}] Picking sample queries…")
271
+
272
+ try:
273
+ calib_path = _find_calibration_file(dataset)
274
+ except FileNotFoundError as e:
275
+ print(f" SKIP: {e}")
276
+ continue
277
+
278
+ data = load_subclaim_data(calib_path)
279
+ _, holdout_data = _split_calibration_holdout(data, is_grouped)
280
+ picked = _pick_entries(holdout_data, is_grouped)
281
+
282
+ for entry in picked:
283
+ result = _entry_to_subclaim_result(entry, dataset, is_grouped)
284
+ samples.append(result)
285
+ print(f" + query={result['query'][:60]!r} group={result['group']}")
286
+
287
+ output = {
288
+ "metadata": {
289
+ "generated_at": datetime.now(timezone.utc).isoformat(),
290
+ "model": RESPONSE_MODEL,
291
+ "scoring_method": DEFAULT_SCORING_METHOD,
292
+ },
293
+ "samples": samples,
294
+ }
295
+
296
+ with open(OUTPUT_SAMPLES, "w", encoding="utf-8") as f:
297
+ json.dump(output, f, indent=2, ensure_ascii=False)
298
+
299
+ print(f"\nSamples saved → {OUTPUT_SAMPLES} ({len(samples)} entries)")
300
+
301
+
302
+ # ── Entry point ───────────────────────────────────────────────────────────────
303
+
304
+ if __name__ == "__main__":
305
+ print("=== Step A: Computing conformal thresholds ===")
306
+ compute_thresholds()
307
+
308
+ print("\n=== Step B: Precomputing sample queries ===")
309
+ compute_samples()
310
+
311
+ print("\nDone. Run `streamlit run demo/app.py` to launch the demo.")
docs/context/01_original_architecture.md ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ResponseQualityAssessment リポジトリ解析
2
+
3
+ > 対象コード: `ResponseQualityAssessment/`(conformal-rag-demo サブモジュール)
4
+ > 論文: "Response Quality Assessment for Retrieval-Augmented Generation via Conditional Conformal Factuality"(SIGIR 2025 想定)
5
+
6
+ ---
7
+
8
+ ## 1. システム概要
9
+
10
+ RAG(Retrieval-Augmented Generation)システムの回答品質を評価するパイプラインの実装。**Conformal Prediction** を用いて、LLM の回答をサブクレームに分解し、各サブクレームの事実性スコアを計算・校正することで、統計的保証付きの品質評価を実現する。
11
+
12
+ ### 主な研究貢献
13
+
14
+ - **Split Conformal Prediction**: キャリブレーションデータから閾値を算出し、テストデータへ適用
15
+ - **Group Conditional Conformal**: グループ別(例: 医療カテゴリ)の閾値を独立に算出
16
+ - 複数のスコアリング手法(類似度・頻度・対数確率など)の横断比較
17
+
18
+ ---
19
+
20
+ ## 2. ディレクトリ構成
21
+
22
+ ```
23
+ ResponseQualityAssessment/
24
+ ├── conf/
25
+ │ ├── config.yaml # 実行全体の設定
26
+ │ ├── dataset_config.yaml # データセット別設定
27
+ │ └── path_config.yaml # ファイルパス設定
28
+ ├── data/
29
+ │ ├── raw/ # 生データ(FactScore, HotpotQA, PopQA, MedLFQA, WikiDB)
30
+ │ ├── processed/ # 標準化済みクエリ・文書データ
31
+ │ ├── out/ # サブクレーム+スコア出力
32
+ │ └── result/ # 最終結果・可視化
33
+ ├── index_store/ # FAISS インデックスとマッピング
34
+ ├── logs/ # 実行ログ
35
+ ├── src/
36
+ │ ├── calibration/ # 共形予測キャリブレーション
37
+ │ ├── common/ # 共通コンポーネント(設定・ファイル・LLM 管理)
38
+ │ ├── data_processor/ # データセット処理パイプライン
39
+ │ ├── dataloader/ # HuggingFace データローダー
40
+ │ ├── rag/ # RAG(SQLite ベース文書 DB)
41
+ │ ├── subclaim_processor/ # 回答品質評価の中核パイプライン
42
+ │ └── utils/ # 汎用ユーティリティ
43
+ ├── main.py # エントリーポイント
44
+ └── requirements.txt # 依存ライブラリ
45
+ ```
46
+
47
+ ---
48
+
49
+ ## 3. 使用ライブラリ
50
+
51
+ | ライブラリ | 用途 |
52
+ |---|---|
53
+ | `openai` | LLM API(GPT-4o-mini による応答生成・サブクレーム抽出・アノテーション)、埋め込みモデル |
54
+ | `faiss-cpu` | ベクトル近傍探索(FAISS `IndexFlatIP`、L2 正規化後の内積=コサイン類似度) |
55
+ | `sentence-transformers` / `transformers` | 埋め込みモデル(補助的)|
56
+ | `datasets` | HuggingFace からの QA データセット取得 |
57
+ | `langchain` | LLM オーケストレーション |
58
+ | `numpy` | 数値計算(共形予測の分位点計算、スコア演算)|
59
+ | `torch` | Transformer モデルのバックエンド |
60
+ | `PyPDF2` | PDF 文書のテキスト抽出 |
61
+ | `matplotlib` | キャリブレーション結果の可視化 |
62
+ | `jsonschema` | 各フェーズの出力データのスキーマ検証 |
63
+ | `python-dotenv` | `.env` からの API キー読み込み |
64
+ | `flask` / `flask-cors` | Web API(コードに存在するが現状未使用)|
65
+
66
+ **Python バージョン**: 3.11
67
+
68
+ ---
69
+
70
+ ## 4. アーキテクチャ設計
71
+
72
+ ### 4.1. 設計パターン
73
+
74
+ | パターン | 適用箇所 |
75
+ |---|---|
76
+ | **Strategy パターン** | スコアリング戦略(`ProductScoreStrategy`)、集約戦略(`MeanAggregation` / `MaxAggregation`)、チャンキング戦略 |
77
+ | **Template Method パターン** | `RawDataProcessor`(抽象基底)← データセット別実装、`ICalibration` ← 共形予測実装 |
78
+ | **Factory / Dispatcher パターン** | `QueryProcessor` がデータセット名に基づき各プロセッサへ委譲 |
79
+ | **Manager パターン** | `ConfigManager`、`FAISSIndexManager`、`FileManager`、`OpenAIManager` による各リソースのライフサイクル管理 |
80
+ | **Pipeline アーキテクチャ** | `main.py` で各ステージを順次呼び出す直列パイプライン |
81
+ | **Dependency Injection** | コンストラクタ経由でコンポーネントを注入(`faiss_manager`、`scorer` など)|
82
+
83
+ ### 4.2. コンポーネント相関図
84
+
85
+ ```
86
+ main.py
87
+ ├─ ConfigManager … YAML 読み込み・ログ設定
88
+ ├─ DataLoader … HuggingFace / Wikipedia SQLite DB
89
+ ├─ QueryProcessor … データセット別に標準化
90
+ ├─ FAISSIndexManager … ベクトルインデックス作成・検索
91
+ │ └─ OpenAIManager … text-embedding-3-large で埋め込み生成
92
+ ├─ process_subclaims() … サブクレーム処理オーケスト���ーター
93
+ │ └─ SubclaimProcessor
94
+ │ ├─ OpenAIRAGAgent … RAG 応答生成
95
+ │ ├─ OpenAIAtomicFactGenerator … サブクレーム抽出(logprobs 付き)
96
+ │ ├─ OpenAIClaimVerification … 事実性アノテーション(S/I/U/N)
97
+ │ └─ SubclaimScorer … 7 種スコア計算
98
+ └─ SplitConformalCalibration / GroupConditionalConformal
99
+ └─ calibration/utils … r_a スコア・分位点閾値計算
100
+ ```
101
+
102
+ ---
103
+
104
+ ## 5. 処理パイプライン詳細
105
+
106
+ ### Step 1: データ読み込み・標準化
107
+
108
+ | 処理 | 実装クラス |
109
+ |---|---|
110
+ | HuggingFace からデータセット取得 | `DataLoader.load_qa_data()` |
111
+ | Wikipedia SQLite DB 構築 | `DataLoader.create_wiki_db()` |
112
+ | データセット別正規化 | `FactScoreProcessor`, `HotpotQAProcessor`, `PopQAProcessor`, `MedLFQAProcessor` |
113
+
114
+ #### データセット別 外部コーパス対応表
115
+
116
+ FAISS インデックスの構築・検索に用いる外部コーパスはデータセットによって異なる。
117
+
118
+ | データセット | 外部コーパス | 取得元 / 形式 |
119
+ |---|---|---|
120
+ | **FActScore** | Wikipedia(2023-04-01 ダンプ) | bz2 ダンプ → SQLite DB(`enwiki-20230401.db`)|
121
+ | **HotpotQA** | Wikipedia(2023-04-01 ダンプ) | 同上(HuggingFace `kilt_tasks/hotpotqa` と組み合わせ)|
122
+ | **PopQA** | Wikipedia(2023-04-01 ダンプ) | 同上(HuggingFace `akariasai/PopQA` と組み合わせ)|
123
+ | **MedLFQA** | MedLFQAv2 の QA データセット自体 | GitHub `jjcherian/conformal-safety` の JSONL ファイル群(`healthsearch_qa`, `kqa_golden`, `kqa_silver_wogold`, `live_qa`, `medication_qa`)|
124
+
125
+ FActScore・HotpotQA・PopQA は共通の Wikipedia SQLite DB(`DocDB`)から文書を取得する。MedLFQA のみ Wikipedia を使わず、各 JSONL ファイルに格納された `retrieved_passages` フィールドの文書をそのままコーパスとして使用する。
126
+
127
+ 出力形式(JSON):
128
+ ```json
129
+ {"input": "クエリ文字列", "output": {"answer": "正解", "provenance": [...]}}
130
+ ```
131
+
132
+ ### Step 2: FAISS インデックス構築
133
+
134
+ #### 構築フロー
135
+
136
+ 全データセット共通で、以下の 2 フェーズに分かれる。
137
+
138
+ **フェーズ A: 文書収集(データセット別)**
139
+
140
+ 各プロセッサが「どこから文書を取ってくるか」が異なり、最終的に同一 JSON 形式で保存される。
141
+
142
+ | データセット | 文書取得元 | 1クエリあたり文書数 |
143
+ |---|---|---|
144
+ | **FActScore** | Wikipedia SQLite DB(provenance タイトル 1 件でタイトル引き) | 1 記事 |
145
+ | **PopQA** | Wikipedia SQLite DB(`s_wiki_title` 1 件、重複排除あり) | 1 記事 |
146
+ | **HotpotQA** | Wikipedia SQLite DB(provenance の複数タイトル、重複排除あり) | 2〜3 記事 |
147
+ | **MedLFQA** | JSONL の `Free_form_answer` を文単位に分割 + `Nice_to_have` リスト | 10〜20 文(短文) |
148
+
149
+ Wikipedia 3 データセットは Wikipedia 全文をそのまま FAISS に入れるのではなく、各クエリの provenance タイトルに対応する記事のみを SQLite から引き当てて保存する。MedLFQA は Wikipedia を使わず、JSONL ファイルに格納済みの文書を直接使用する。
150
+
151
+ **フェーズ B: 埋め込み・インデックス化(全データセット共通、[main.py:211-275](main.py#L211-L275))**
152
+
153
+ 1. 文書を **Fixed-Length チャンカー**(デフォルト: 2000 語、25 語オーバーラップ)で分割
154
+ 2. OpenAI `text-embedding-3-large`(次元数: 3072)で埋め込みベクトルを生成
155
+ 3. L2 正規化 → `IndexFlatIP`(内積 ≒ コサイン類似度)に追加
156
+ 4. インデックスバイナリと `indice2fm` マッピング(FAISS インデックス ID → ファイル位置)を保存
157
+
158
+ 各データセット独立したインデックスとして `index_store/{Dataset}/` に保存される。インデックスは `.gitignore` によりリポジトリ管理外であり、**実行のたびに再構築が必要**。
159
+
160
+ #### text-embedding-3-large の構築コスト見積もり(`query_size: 500`、$0.13/1M トークン)
161
+
162
+ | データセット | ユニーク文書数の目安 | チャンク/文書 | 総トークン数 | 推定コスト |
163
+ |---|---|---|---|---|
164
+ | **FActScore** | ~500 記事 | ~2 | ~2.7M | ~$0.35 |
165
+ | **PopQA** | ~500 記事 | ~2 | ~2.7M | ~$0.35 |
166
+ | **HotpotQA** | ~1,000 記事(2〜3 provenance × 重複排除) | ~2 | ~5.4M | ~$0.70 |
167
+ | **MedLFQA** | ~500 クエリ × ~15 文 | チャンク分割なし(短文) | ~0.2M | ~$0.03 |
168
+ | **合計(4 データセット)** | | | ~11M | **~$1.4** |
169
+
170
+ これは**1 回限りのインデックス構築コスト**。推論時のクエリ埋め込み・サブクレーム埋め込みは数千トークン程度で無視できる。Wikipedia 記事長のばらつきにより±2 倍程度の誤差あり。
171
+
172
+ ### Step 3: 応答生成
173
+
174
+ `SubclaimProcessor.generate_responses()`:
175
+ 1. クエリに対し FAISS から上位 k 件(デフォルト: 10)の文書を検索(閾値: 0.3 以上)
176
+ 2. 検索文書をコンテキストとして GPT-4o-mini に渡し、初期回答を生成
177
+
178
+ ### Step 4: サブクレーム抽出
179
+
180
+ `OpenAIAtomicFactGenerator.get_facts_from_text()`:
181
+ - LLM に回答テキストを渡し、セミコロン区切りの原子的事実に分解
182
+ - `logprobs=True, top_logprobs=1` でトークンごとの対数確率を取得
183
+ - 各サブクレームの対数確率リストを保持
184
+
185
+ ### Step 5: サブクレームスコアリング
186
+
187
+ `SubclaimProcessor.score_subclaim()` で 7 種のスコアを計算:
188
+
189
+ | スコア名 | 計算方法 | 意味 |
190
+ |---|---|---|
191
+ | `relavance`(relevance) | FAISS 検索スコア × サブクレーム-文書コサイン類似度 の積、文書間で集約 | 検索文書との総合関連度 |
192
+ | `query_claim_cosine_similarity` | クエリ埋め込み ↔ サブクレーム埋め込み のコサイン類似度 | クエリとの意味的整合性 |
193
+ | `doc_claim_cosine_similarity` | 全検索文書とサブクレームのコサイン類似度の最大値 | 文書との意味的整合性 |
194
+ | `frequency` | 温度 1.0 で 5 回サンプリングし、同内容のサブクレームが出現した割合 | LLM 自体の一貫性(自己信頼度)|
195
+ | `min_log_prob` | サブクレーム中のトークン対数確率の最小値 | LLM の生成確信度 |
196
+ | `random` | Uniform(0, 1) | ベースライン |
197
+ | `ordinal` | `i / サブクレーム数` | 応答内の出現順序(ベースライン)|
198
+
199
+ > 全スコアに N(0, 0.001) のガウスノイズを付加(安定化のため)
200
+
201
+ 集約戦略(文書間集約):
202
+ - `MeanAggregation`: 文書スコアの平均
203
+ - `MaxAggregation`: 文書スコアの最大値
204
+
205
+ スコアリング戦略:
206
+ - `ProductScoreStrategy`: FAISS スコア × コサイン類似度 の積(現状唯一の実装)
207
+
208
+ ### Step 6: アノテーション
209
+
210
+ `OpenAIClaimVerification.annotate()`:
211
+ - サブクレームを検索文書・正解と照合し、GPT-4o-mini が 4 段階ラベルを付与
212
+
213
+ | ラベル | 意味 |
214
+ |---|---|
215
+ | `S` (Supported) | 事実として支持される |
216
+ | `I` (Irrelevant) | 質問と無関係 |
217
+ | `U` (Unverifiable) | 検証不可能 |
218
+ | `N` (Nonfactual) | 事実に反する |
219
+
220
+ ### Step 7: 共形キャリブレーション
221
+
222
+ #### Split Conformal Prediction
223
+
224
+ ```
225
+ For each alpha in [0.05, 0.10, ..., 0.40]:
226
+ Repeat 1000 runs:
227
+ 1. データをシャッフルし 50/50 でキャリブ/テストに分割
228
+ 2. キャリブデータで各サブクレームの r_a スコアを計算:
229
+ r_a(x) = "サブクレームを低スコア順に除去したとき、
230
+ 残存サブクレームの正確率が初めて a 以上になる最小閾値"
231
+ 3. 閾値 q̂ = ceil((n+1)*(1-alpha))/n 分位点(共形予測の保証付き計算)
232
+ 4. テストデータで閾値 q̂ を適用し、除去率・正確率を計算
233
+ ```
234
+
235
+ **キャリブレーションのデータ量について**
236
+
237
+ キャリブレーションに使うデータは `query_size` で処理したクエリ全体を 50/50 分割したものであり、専用のキャリブレーションセットは別途用意しない。デフォルト(`query_size: 500`)の場合、1 ランあたりキャリブレーション 250 件・テスト 250 件となる。各データセットの総件数(HotpotQA ~5,600 件、PopQA ~14,267 件など)に対して 500 件はサンプリング上限であり、`query_size` を増やすほどキャリブレーション精度は上がるが LLM 呼び出しコストも線形に増加する。1000 回のランダムシャッフルを繰り返すことで、少ないサンプル数でも統計的ばらつきを吸収している。
238
+
239
+ #### Group Conditional Conformal
240
+
241
+ `GroupConditionalConformal`: グループ(MedLFQA のカテゴリなど)ごとに独立してキャリブレーションを実行し、グループ別閾値を算出。
242
+
243
+ ### Step 8: 結果出力
244
+
245
+ - **CSV**: alpha × confidence_method ごとの除去率・正確率
246
+ - **PNG プロット**: 共形除去カーブ(除去率 vs. 事実性保証)、事実的正確率カーブ
247
+ - **JSON**: 全サブクレーム(スコア・アノテーション込み)
248
+ - **config/**: 再現性のため実行設定の YAML コピー
249
+
250
+ ---
251
+
252
+ ## 6. データフロー
253
+
254
+ ```
255
+ 生データ(HuggingFace / bz2 Wikipedia)
256
+ ↓ DataLoader
257
+ SQLite Wikipedia DB + 標準化クエリ JSON
258
+ ↓ FileManager + OpenAIManager
259
+ チャンク埋め込み行列(N × 3072)
260
+ ↓ FAISSIndexManager
261
+ FAISS IndexFlatIP + indice2fm マッピング
262
+ ↓ SubclaimProcessor.generate_responses()
263
+ {query, gld_ans, retrieved_docs, response, groups}
264
+ ↓ OpenAIAtomicFactGenerator
265
+ {subclaim_text, log_prob_list}(サブクレーム + 対数確率)
266
+ ↓ SubclaimScorer (7 種スコア)
267
+ {scores: {relavance, query_claim_cosine, ..., ordinal}}
268
+ ↓ OpenAIClaimVerification
269
+ {annotations: {gpt: "S"|"I"|"U"|"N"}}
270
+ ↓ SplitConformalCalibration / GroupConditionalConformal
271
+ 除去率・正確率の統計(1000 runs × alpha × method)
272
+
273
+ CSV / PNG / config YAML
274
+ ```
275
+
276
+ ---
277
+
278
+ ## 7. 設定ファイル
279
+
280
+ ### `conf/config.yaml`(主要パラメータ)
281
+
282
+ ```yaml
283
+ dataset:
284
+ name: "pop_qa" # fact_score | hotpot_qa | pop_qa | medlf_qa
285
+ query_size: 500 # 処理クエリ数(-1 = 全件)
286
+
287
+ index:
288
+ embedding_model: "text-embedding-3-large"
289
+ truncation_config:
290
+ strategy: "fixed_length"
291
+ chunk_size: 2000 # チャンクあたり単語数
292
+ chunk_overlap: 25
293
+
294
+ rag:
295
+ retrival_topk: 10 # 検索上位件数
296
+ retrival_threshold: 0.3 # コサイン類似度の下限閾値
297
+ response_model: "gpt-4o-mini"
298
+
299
+ conformal_prediction:
300
+ aggregation_strategy: "mean" # mean | max
301
+ scoring_strategy: "product"
302
+ split_conformal: true
303
+ conformal_alphas:
304
+ start: 0.05
305
+ end: 0.45
306
+ step: 0.05
307
+ a_value: 1.0 # 目標正確率(0〜1)
308
+ ```
309
+
310
+ ### `conf/dataset_config.yaml`
311
+
312
+ ```yaml
313
+ datasets:
314
+ fact_score:
315
+ is_grouped: false # グループ条件付き共形を無効
316
+ medlf_qa:
317
+ is_grouped: true # グループ条件付き共形を有効
318
+ ```
319
+
320
+ ---
321
+
322
+ ## 8. 主要クラス・関数一覧
323
+
324
+ ### 設定・管理系
325
+
326
+ | クラス / 関数 | 役割 |
327
+ |---|---|
328
+ | `ConfigManager` | YAML 設定の読み込み・保存・更新、ロギング設定 |
329
+ | `FileManager` | 文書(PDF/テキスト)処理、チャンキング、埋め込みキャッシュ |
330
+ | `FAISSIndexManager` | FAISS インデックスの作成・読み込み・検索、ファイル位置マッピング |
331
+ | `OpenAIManager` | OpenAI API ラッパー(埋め込み・アシスタント・スレッド)|
332
+
333
+ ### データ処理系
334
+
335
+ | クラス / 関数 | 役割 |
336
+ |---|---|
337
+ | `DataLoader.load_qa_data()` | 4 データセットを HuggingFace から取得 |
338
+ | `DataLoader.create_wiki_db()` | Wikipedia bz2 ダンプから SQLite DB 構築 |
339
+ | `QueryProcessor.get_queries()` | 標準スキーマへの変換・サンプリング |
340
+ | `DocDB` | SQLite ベースの文書検索(FActScore 実装を参照)|
341
+
342
+ ### RAG・スコアリング系
343
+
344
+ | クラス / 関数 | 役割 |
345
+ |---|---|
346
+ | `OpenAIRAGAgent.answer()` | クエリ + 検索文書 → LLM 回答生成 |
347
+ | `OpenAIAtomicFactGenerator.get_facts_from_text()` | テキスト → 原子的事実リスト(logprobs 付き)|
348
+ | `OpenAIClaimVerification.annotate()` | サブクレームの S/I/U/N ラベル付け |
349
+ | `SubclaimScorer.score()` | コサイン類似度ベースの関連度スコア計算 |
350
+ | `SubclaimScorer.frequency_score()` | LLM 複数サンプリングによる一貫性スコア |
351
+
352
+ ### キャリブレーション系
353
+
354
+ | クラス / 関数 | 役割 |
355
+ |---|---|
356
+ | `get_r_score(entry, method, a)` | `r_a` スコア(閾値を下げたとき正確率が `a` を超える最小値)の計算 |
357
+ | `compute_threshold(alpha, data, a, method)` | 共形予測の分位点閾値 `q̂` の計算 |
358
+ | `SplitConformalCalibration.plot_conformal_removal()` | 除去率 vs. 事実性保証のカーブ生成 |
359
+ | `GroupConditionalConformal` | グループ別閾値による条件付き共形予測 |
360
+
361
+ ---
362
+
363
+ ## 9. テスト・品質保証
364
+
365
+ 専用のテストスイートは存在しない。品質保証は以下の方法で実施:
366
+
367
+ - **JSON スキーマ検証**: `jsonschema` による各ステージ出力の構造検証(`base_schema.json`, `wiki_schema.json`, `subclaims_schema.json`)
368
+ - **インデックス整合性チェック**: `FAISSIndexManager.is_indice_align()` によるインデックスとマッピングの同期確認
369
+ - **バッチ処理の冪等性**: サブクレーム処理は既存の処理済みエントリをスキップ(再実行安全)
370
+ - **設定ログ**: 実行時の設定 YAML を `config/` ディレクトリにコピーして再現性を担保
371
+
372
+ ---
373
+
374
+ ## 10. 既知の制約・拡張ポイント
375
+
376
+ | 項目 | 現状 | 拡張の方向性 |
377
+ |---|---|---|
378
+ | スコアリング戦略 | `ProductScoreStrategy` のみ実装 | Strategy パターンにより追加容易 |
379
+ | テキストチャンキング | Fixed-Length のみ(Recursive は未実装) | `FixedLengthChunker` と同インターフェースで追加可 |
380
+ | 共形予測の種別 | Split / Group Conditional の 2 種 | オンライン共形予測などへの拡張余地あり |
381
+ | 並列処理 | なし(完全逐次処理) | バッチ API や非同期処理で高速化可 |
382
+ | Web API | `flask` が依存関係に含まれるが未使用 | デモ API サーバー化の際に活用可 |
383
+ | FAISS コーパス範囲 | 各クエリの provenance に対応する文書のみインデックス化(正解文書が事前判明している前提) | データセット外の任意クエリに対応するには Wikipedia 全文などコーパス全体をインデックス化する必要がある |
384
+ | 全コーパス検索時の精度 | `IndexFlatIP` は完全探索のため数学的な top-k 精度は落ちない。ただし埋め込み空間に無関係文書が増えるこ���で真に関連する文書が top-k から押し出される可能性がある。FAISS スコアはスコアリングにも使われるため最終的な品質評価精度にも直結する | ドメイン特化の fine-tuned 埋め込みモデルへの切り替えや BM25 との hybrid retrieval が対策として有効 |
385
+
386
+ ---
387
+
388
+ ## 11. 実行方法
389
+
390
+ ```bash
391
+ # 基本実行
392
+ python main.py --config conf/config.yaml
393
+
394
+ # データセット・クエリ数を上書き指定
395
+ python main.py --dataset hotpot_qa --query_size 200
396
+
397
+ # カスタム実行 ID を付与
398
+ python main.py --run_id my_experiment_01
399
+ ```
400
+
401
+ 環境変数(`.env`):
402
+ ```
403
+ OPENAI_API_KEY=sk-...
404
+ ```
docs/context/02_design_patterns_analysis.md ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 設計パターン詳細分析
2
+
3
+ ## 概要
4
+
5
+ 本ドキュメントでは、現在の `src/` に適用されている設計パターンをコードベースに基づいて詳細に解説し、問題点と改善案を記録する。
6
+
7
+ ---
8
+
9
+ ## 1. Strategy パターン
10
+
11
+ ### 目的
12
+ アルゴリズムをオブジェクトとして差し替え可能にする。
13
+
14
+ ### 適用箇所
15
+ ```
16
+ src/subclaim_processor/strategies/
17
+ ├── aggregation/
18
+ │ ├── base.py → AggregationStrategy (ABC)
19
+ │ ├── mean.py → MeanAggregation
20
+ │ └── max.py → MaxAggregation
21
+ └── scoring/
22
+ ├── base.py → ScoringStrategy (ABC)
23
+ └── product.py → ProductScoreStrategy
24
+ ```
25
+
26
+ ### 実際の動作(`subclaim_scorer.py:46-68`)
27
+ ```python
28
+ AGGREGATION_STRATEGIES: Dict[str, Callable] = {
29
+ "max": MaxAggregation,
30
+ "mean": MeanAggregation,
31
+ }
32
+ SCORING_STRATEGIES: Dict[str, Callable] = {
33
+ "product": ProductScoreStrategy,
34
+ }
35
+
36
+ def score(self, claim, retrieved_docs, aggregation_strategy, scoring_strategy):
37
+ agg_func = AGGREGATION_STRATEGIES[aggregation_strategy]() # "mean" or "max"
38
+ scoring_func = SCORING_STRATEGIES[scoring_strategy]() # "product"
39
+
40
+ for doc in retrieved_docs:
41
+ score = scoring_func.compute_score(...) # 差し替え可能
42
+
43
+ return agg_func.aggregate(doc_scores) # 差し替え可能
44
+ ```
45
+
46
+ ### 問題点
47
+ Strategy の選択が文字列キーの辞書で行われており、型安全性がない。`"product"` を typo しても実行時まで気づけない。
48
+
49
+ ### 改善案
50
+ ```python
51
+ # Enum で型安全にする
52
+ from enum import Enum
53
+
54
+ class AggregationMethod(Enum):
55
+ MEAN = "mean"
56
+ MAX = "max"
57
+
58
+ # または Protocol を使う(Python 3.8+)
59
+ from typing import Protocol
60
+
61
+ class AggregationStrategy(Protocol):
62
+ def aggregate(self, scores: list[float]) -> float: ...
63
+ ```
64
+
65
+ ---
66
+
67
+ ## 2. Template Method パターン
68
+
69
+ ### 目的
70
+ 処理の骨格を親クラスで定め、詳細を子クラスに委ねる。
71
+
72
+ ### 適用箇所
73
+ ```
74
+ src/data_processor/raw_data_processor.py → IRawDataProcessor, DatasetProcessor (ABC)
75
+ src/calibration/base_calibration.py → ICalibration (ABC)
76
+ src/common/llm/llm_agent.py → LLMAgent (ABC)
77
+ ```
78
+
79
+ ### 現在の実装(`raw_data_processor.py`)
80
+ ```python
81
+ class IRawDataProcessor(ABC):
82
+ @abstractmethod
83
+ def get_queries(self, input_file, output_file): pass
84
+
85
+ @abstractmethod
86
+ def get_documents(self, query_file, output_file): pass
87
+
88
+ class DatasetProcessor(ABC):
89
+ @abstractmethod
90
+ def process_queries(self, input_file, **kwargs) -> list: pass
91
+
92
+ @abstractmethod
93
+ def process_documents(self, query_file, db, **kwargs) -> dict: pass
94
+ ```
95
+
96
+ ### 問題点
97
+ 抽象クラスが2つある(`IRawDataProcessor` と `DatasetProcessor`)。
98
+
99
+ | クラス | 役割 |
100
+ |--------|------|
101
+ | `IRawDataProcessor` | ファイルI/O込みのインターフェース |
102
+ | `DatasetProcessor` | 純粋な処理ロジックのインターフェース |
103
+
104
+ `QueryProcessor` は `IRawDataProcessor` を継承しつつ、内部で `DatasetProcessor` を利用するというダブル構造になっている。
105
+
106
+ 真の Template Method パターンなら、親クラスに「骨格メソッド」が実装されているはずだが、現在の実装では `QueryProcessor.get_queries()` が骨格を担っている。
107
+
108
+ ---
109
+
110
+ ## 3. Factory / Dispatcher パターン
111
+
112
+ ### 目的
113
+ 入力に基づいて適切なオブジェクトを生成・委譲する。
114
+
115
+ ### 実際のコード(`query_processor.py:26-31`)
116
+ ```python
117
+ self.processors = {
118
+ "fact_score": FactScoreProcessor(),
119
+ "hotpot_qa": HotpotQAProcessor(),
120
+ "pop_qa": PopQAProcessor(),
121
+ "medlf_qa": MedLFQAProcessor(),
122
+ }
123
+
124
+ # 使用時
125
+ processor = self.processors.get(dataset)
126
+ ```
127
+
128
+ ### 問題点
129
+ これは「Factory パターン」ではなく「Registry パターン」に近い。
130
+
131
+ | パターン | 特徴 |
132
+ |---------|------|
133
+ | Factory(正) | 必要時にオブジェクトを生成する |
134
+ | Registry(現在) | 起動時に全オブジェクトを一括生成して辞書に保持 |
135
+
136
+ `QueryProcessor()` を作った時点で4データセット全プロセッサがインスタンス化される。実害は少ないが、設計の意図とずれている。
137
+
138
+ ---
139
+
140
+ ## 4. Manager パターン
141
+
142
+ ### 目的
143
+ リソースのライフサイクル(生成・使用・破棄)を一元管理する。
144
+
145
+ ### 各 Manager の評価
146
+
147
+ | Manager | 管理対象 | 実際の責任数 | 評価 |
148
+ |---------|---------|------------|------|
149
+ | `ConfigManager` | YAMLファイル | 設定読み込み + ログセットアップ | △ 2責任 |
150
+ | `FAISSIndexManager` | FAISSインデックス | インデックス管理 + 検索 + 応答生成 | ✗ 3責任 |
151
+ | `FileManager` | ドキュメント | ファイル読み込み + チャンキング | △ 2責任 |
152
+ | `OpenAIManager` | Embeddings API | API呼び出しのみ | ✓ 適切 |
153
+
154
+ ### 最大の問題:`FAISSIndexManager`(`faiss_manager.py:245-286`)
155
+ ```python
156
+ class FAISSIndexManager:
157
+ def upsert_file_to_faiss(...) # インデックス管理 ✓
158
+ def search_faiss_index(...) # 検索 ✓
159
+ def generate_response_from_context(...) # ← LLM応答生成(無関係)
160
+ def parse_result(...) # ← 文字列パース(雑多)
161
+ ```
162
+
163
+ `generate_response_from_context` はインデックス管理とは無関係なのに `FAISSIndexManager` に含まれている。この責任は `OpenAIRAGAgent` に移すべき。
164
+
165
+ ---
166
+
167
+ ## 5. Pipeline アーキテクチャ
168
+
169
+ ### 目的
170
+ 処理を段階的に分けて順次実行する。
171
+
172
+ ### 現在のパイプライン
173
+ ```
174
+ main.py
175
+ 1. DataLoader → 生データ取得
176
+ 2. QueryProcessor → データ正規化
177
+ 3. FAISSIndexManager → インデックス構築
178
+ 4. SubclaimProcessor → 応答生成 → サブクレーム抽出 → スコアリング → アノテーション
179
+ 5. SplitConformalCalibration → 統計的キャリブレーション
180
+ ```
181
+
182
+ ### 問題点:ステージ間の文字列シリアライズ(`faiss_manager.py:196-216`)
183
+ ```python
184
+ # 検索結果を文字列として返す
185
+ results.append(
186
+ f"{text} indice={idx} fileposition={relative_idx} score={dist:.4f}"
187
+ # TODO reformat this ← 本人コメントあり
188
+ )
189
+ ```
190
+
191
+ その後 `parse_result()` で正規表現パースし直している:
192
+ ```python
193
+ pattern = re.compile(
194
+ r"page_content='(.*?)'\smetadata=(\{.*?\})\sindice=(\d+)\sfileposition=(\d+)\sscore=([\d.]+)",
195
+ re.DOTALL,
196
+ )
197
+ ```
198
+
199
+ ### 改善案
200
+ dataclass でステージ間のデータ型を定義すれば `parse_result()` が不要になる:
201
+ ```python
202
+ @dataclass
203
+ class SearchResult:
204
+ text: str
205
+ indice: int
206
+ fileposition: int
207
+ score: float
208
+ ```
209
+
210
+ ---
211
+
212
+ ## 6. Dependency Injection(DI)
213
+
214
+ ### 目的
215
+ 依存オブジェクトを外から注入してテスト容易性を高める。
216
+
217
+ ### 現在の実装(`subclaim_scorer.py:29-44`)
218
+ ```python
219
+ class SubclaimScorer(IDocumentScorer):
220
+ def __init__(self, index_truncation_config, embedding_model, index_path, ...):
221
+ self.faiss_manager = FAISSIndexManager(...) # ← 内部で直接生成
222
+ self.gen = OpenAIAtomicFactGenerator() # ← 内部で直接生成
223
+ self.openai_client = OpenAI() # ← 内部で直接生成
224
+ ```
225
+
226
+ ### 問題点
227
+ これは DI ではない。本当の DI は依存オブジェクトを外部から受け取る:
228
+ ```python
229
+ # DI(正)
230
+ class SubclaimScorer:
231
+ def __init__(
232
+ self,
233
+ faiss_manager: FAISSIndexManager,
234
+ fact_generator: OpenAIAtomicFactGenerator,
235
+ openai_client: OpenAI,
236
+ ):
237
+ self.faiss_manager = faiss_manager
238
+ ```
239
+
240
+ 現在の設計では `SubclaimScorer` の単体テストに本物の OpenAI API と FAISS インデックスが必要になる。
241
+
242
+ ---
243
+
244
+ ## 設計全体の評価サマリー
245
+
246
+ | パターン | 意図の正確さ | 実装の品質 | 改善優先度 |
247
+ |---------|------------|----------|----------|
248
+ | Strategy | ✓ 正しい | △ 文字列キーで型安全でない | 低(機能はしている) |
249
+ | Template Method | △ 混乱あり | △ 2つの抽象クラスが混在 | 中 |
250
+ | Factory/Dispatcher | △ Registry に近い | △ 起動時に全数生成 | 低 |
251
+ | Manager | △ 責任過多 | ✗ FAISSがLLM応答生成を担当 | 高 |
252
+ | Pipeline | ✓ 正しい | ✗ 文字列シリアライズ+正規表現パース | 高 |
253
+ | Dependency Injection | ✗ 名前だけ | ✗ 内部生成で外部注入不可 | 中 |
254
+
255
+ ### 優先度「高」の改善項目(構造変更不要)
256
+
257
+ 1. **`FAISSIndexManager.generate_response_from_context` の移動**
258
+ - 移動先:`OpenAIRAGAgent`(すでに存在する)
259
+ - 効果:Manager の単一責任原則回復
260
+
261
+ 2. **パイプライン間データ型の構造化**
262
+ - `SearchResult` dataclass の導入
263
+ - `parse_result()` の正規表現解析の廃止
264
+ - 効果:バグ発生箇所の削減、デバッグ容易性の向上
docs/context/03_demo_app_design.md ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # デモアプリケーション設計
2
+
3
+ ## 1. 技術選定
4
+
5
+ 既存のバックエンド資産(Python 3.11, OpenAI, FAISS)を活かしつつ、デモ用途に最適化する。
6
+
7
+ | 層 | 選定技術 | 理由 |
8
+ |----|---------|------|
9
+ | フロントエンド・UI | **Streamlit** | Python のみで完結し、既存推論コードを直接インポート可能。Flask/FastAPI で別途 API サーバーを立てる手間が不要。テキストのハイライト(Markdown/HTML 描画)やスライダーも標準サポート |
10
+ | バックエンド・推論 | 既存 `src/` モジュールを再利用 | - |
11
+ | LLM API | `gpt-4o-mini` | デモの応答速度とコストのバランスが良い |
12
+ | ベクトル検索 | FAISS + `text-embedding-3-large` | そのまま利用 |
13
+ | データストア | 事前計算済み閾値を CSV として静的保持 | デモ実行時にキャリブレーションループを回すのは非現実的なため。alpha を float 列として保持するため型変換問題が発生しない |
14
+
15
+ ---
16
+
17
+ ## 2. UI / UX 設計
18
+
19
+ ユーザーが「Target Factuality ($1-\alpha$) を操作すると、生成された回答のどの部分が削られるか」を体感できる構成にする。
20
+
21
+ ### 画面レイアウト
22
+
23
+ #### サイドバー(コントロールパネル)
24
+
25
+ | コントロール | 内容 | 備考 |
26
+ |------------|------|------|
27
+ | データセット選択 | FactScore / HotpotQA / PopQA / MedLFQA から選択 | - |
28
+ | モード選択 | Marginal(全体)か Conditional(グループ別)か | **MedLFQA 選択時のみ有効**。他のデータセットでは非活性(`disabled=True`)または非表示にする |
29
+ | グループ選択 | healthsearch_qa / medication_qa / kqa_silver_wogold / kqa_golden / live_qa から選択 | **MedLFQA + Conditional モード選択時のみ有効**。それ以外は非表示 |
30
+ | 質問の入力 | サンプルクエリのプルダウン選択のみ | FAISS インデックスは各クエリの provenance に対応する文書のみで構成されているため、任意のフリーテキスト入力は検索精度・スコアリング精度が保証されない。デモではサンプルクエリに限定する |
31
+ | 推論実行ボタン | `st.button("推論実行")` | クリックで重い処理を起動 |
32
+ | Factuality スライダー | Target Factuality ($1-\alpha$) を 60% 〜 95% の範囲で調整 | 推論結果がキャッシュされている場合のみ有効 |
33
+
34
+ #### メインパネル(結果表示)
35
+
36
+ **クエリ**
37
+ - 選択されたサンプルクエリの全文を表示
38
+
39
+ ---
40
+
41
+ **元の RAG 回答 | フィルタ後の回答**(2カラム)
42
+
43
+ 左カラム: 元の RAG 回答
44
+ - `OpenAIRAGAgent` が生成した回答をそのまま表示
45
+
46
+ 右カラム: フィルタ後の回答
47
+ - 「回答を生成」ボタン(`st.button`)を押すと LLM が回答を再生成(`reintegrate_subclaims()` 呼び出し)
48
+ - 生成された回答をボタンの上に表示
49
+ - スライダー($\alpha$)を変更した後に古い結果が残っている場合は警告を表示
50
+
51
+ ---
52
+
53
+ **Conformal-RAG による回答(インタラクティブ表示)**
54
+ - スライダーの値に対応する閾値($\hat{q}$)を事前計算済み CSV から取得し、キャッシュ済みスコアと比較して即時再描画
55
+ - スコア($R(c)$)が閾値($\hat{q}$)以上のサブクレーム → 装飾なし(そのまま表示)
56
+ - スコアが閾値未満のサブクレーム → **取り消し線・赤字**で表示
57
+
58
+ ---
59
+
60
+ **検索されたナレッジ(Retrieved Documents)**
61
+ - `st.expander` でアコーディオン表示(デフォルト折りたたみ)
62
+ - FAISS 検索結果の `page_content` を抽出(`_extract_page_content()` でソースプレフィックスを除去)
63
+
64
+ **分析メトリクス**
65
+ - 算出された閾値($\hat{q}$)
66
+ - 元のサブクレーム数
67
+ - Keep 件数
68
+ - 削除率(Removal Rate)
69
+ - 推論にかかった時間(ライブ推論時のみ表示。サンプルクエリ使用時は非表示)
70
+
71
+ ---
72
+
73
+ ## 3. バックエンド設計変更(デモ向け改修)
74
+
75
+ 現在のリポジトリは「バッチ評価用(論文の実験用)」設計のため、インタラクティブなデモ向けに以下のリファクタリングが必要。
76
+
77
+ > **方針(CLAUDE.md より)**: デモ特有の UI ロジックやエンドポイントは `demo/` ディレクトリ内に隔離し、`src/` のコアロジックと密結合させない。`main.py` を使ったバッチ処理が壊れないよう後方互換性を維持する。
78
+
79
+ ### ディレクトリ構成
80
+
81
+ ```
82
+ demo/
83
+ ├── app.py # Streamlit エントリーポイント
84
+ ├── inference_api.py # 単一クエリ推論ラッパー(src/ を呼び出す)
85
+ ├── precompute.py # キャリブレーション閾値・サンプルクエリの事前計算(オフライン実行用)
86
+ ├── constants.py # ALPHA_LEVELS など demo/ 内の共有定数
87
+ └── data/
88
+ ├── thresholds.csv # 事前計算済みキャリブレーション閾値マスタ
89
+ └── samples.json # 事前計算済みサンプルクエリ推論結果
90
+ ```
91
+
92
+ `src/` のコアロジックは変更せず、`demo/` 側から `import` して利用する。
93
+
94
+ ### A. 事前計算(オフライン)フェーズ — `precompute.py`
95
+
96
+ `precompute.py` は以下 **2つの責務** を持つ。どちらもオフラインで実行し、デモ起動時には読み込むだけにする。
97
+
98
+ #### A-1. キャリブレーション閾値の計算
99
+
100
+ 各データセット・スコアリング手法・$\alpha$ 値・グループごとの閾値($\hat{q}$)を計算して静的ファイルに保存する。
101
+
102
+ > **データ分割(ホールドアウト)**: `main.py` の出力データはキャリブレーション用とサンプル用に分割して使用する(`precompute.py` の `HOLDOUT_PER_GROUP` 定数で制御)。Conformal Prediction の理論的 coverage 保証はキャリブレーションに使っていないデータに対してのみ成立するため、同一データを両方に使ってはならない。`compute_thresholds()` はキャリブレーション側のみ、`compute_samples()` はホールドアウト側のみを参照する。
103
+
104
+ - 対象データセット: `fact_score`(FactScore)、`hotpot_qa`(HotpotQA)、`pop_qa`(PopQA)、`medlf_qa`(MedLFQA)
105
+ - MedLFQA の Conditional モードのグループ: **HealthsearchQA / MedicationQA / K-QA_Silver / K-QA_Golden / LiveQA**(全グループを網羅すること)
106
+ - スコアリング手法: デフォルト `relevance`。手法名は定数として管理し、変更可能にする
107
+ - 保存フォーマットは **CSV(縦持ち)**、出力先: `demo/data/thresholds.csv`
108
+ - スキーマ:
109
+
110
+ | 列名 | 型 | 説明 | 取りうる値 |
111
+ |------|----|------|-----------|
112
+ | `dataset` | `str` | データセット識別子 | `"fact_score"`, `"hotpot_qa"`, `"pop_qa"`, `"medlf_qa"` |
113
+ | `mode` | `str` | キャリブレーションモード | `"marginal"`, `"conditional"` |
114
+ | `scoring_method` | `str` | スコアリング手法名 | `"relevance"`, `"cosine_similarity"`, `"min_log_prob"` など。手法を追加する際は `precompute.py` を再実行して行を追加するだけでよい(スキーマ変更不要) |
115
+ | `group` | `str` | グループ名。Marginal モードまたはグループなしデータセットは `"default"` | `"HealthsearchQA"`, `"MedicationQA"`, `"K-QA_Silver"`, `"K-QA_Golden"`, `"LiveQA"`, `"default"` |
116
+ | `alpha` | `float` | 誤り率($\alpha = 1 - \text{target factuality}$) | `ALPHA_LEVELS`(後述)と完全一致させること |
117
+ | `q_hat` | `float` | 共形予測の閾値 | — |
118
+
119
+ - `alpha` は浮動小数点演算の誤差(例: `1 - 0.9 = 0.09999...`)を避けるため、CSV 書き込み時・lookup 時の両方で `round(alpha, 2)` を適用すること
120
+ - lookup 例: `df.query("dataset==@d and mode==@m and scoring_method==@s and group==@g and alpha==@a")["q_hat"].iloc[0]`
121
+
122
+ #### A-2. サンプルクエリの推論結果の事前計算
123
+
124
+ UI のプルダウンで選択できるサンプルクエリについて、`inference_api.process_query()` を実行済みの `SubclaimResult` を静的ファイルとして保存しておく。これによりプルダウン選択時に API 呼び出しなしで即座に表示できる。
125
+
126
+ - サンプルクエリ結果の保存ファイルには**メタデータ(生成日時・使用モデル名・スコアリング手法)を付与**し、モデルやプロンプト変更後に再生成が必要かどうか判断できるようにする
127
+ - MedLFQA のサンプルクエリは全グループ(HealthsearchQA / MedicationQA / K-QA_Silver / K-QA_Golden / LiveQA)から少なくとも 1 件ずつ含めること
128
+
129
+ 保存フォーマット(`demo/data/samples.json`):
130
+
131
+ ```json
132
+ {
133
+ "metadata": {
134
+ "generated_at": "2026-03-21T12:00:00Z",
135
+ "model": "gpt-4o-mini",
136
+ "scoring_method": "relevance"
137
+ },
138
+ "samples": [
139
+ {
140
+ "query": "What is the first-line treatment for hypertension?",
141
+ "dataset": "medlf_qa",
142
+ "mode": "conditional",
143
+ "group": "MedicationQA",
144
+ "scoring_method": "relevance",
145
+ "rag_answer": "...",
146
+ "retrieved_docs": ["...", "..."],
147
+ "subclaims": [
148
+ {"id": 0, "text": "...", "score": 0.85},
149
+ {"id": 1, "text": "...", "score": 0.42}
150
+ ]
151
+ }
152
+ ]
153
+ }
154
+ ```
155
+
156
+ - `samples` 配列の各要素は `SubclaimResult` と同一スキーマ
157
+ - プルダウン表示ラベルは `query` フィールドから生成する
158
+
159
+ ### B. 状態管理(`st.session_state`)
160
+
161
+ Streamlit はウィジェット操作のたびにスクリプト全体を再実行するため、重い処理の結果を `st.session_state` に保持して再計算を回避する。
162
+
163
+ #### B-1. セッションキー一覧
164
+
165
+ | キー | 型 | 内容 |
166
+ |------|----|------|
167
+ | `query_processed` | `bool` | クエリが処理済みかどうか |
168
+ | `result` | `SubclaimResult \| None` | `process_query()` の戻り値(サブクレーム・スコア・検索結果・RAG 回答をまとめて保持) |
169
+ | `elapsed_sec` | `float \| None` | ライブ推論の所要時間。サンプルクエリ使用時は `None`(推論時間メトリクスの表示制御に使用) |
170
+ | `prev_params` | `tuple \| None` | 直前の `(dataset, mode, group)` の組。変更検出してセッションをリセットするために使用 |
171
+ | `integrated_answer` | `str \| None` | `reintegrate_subclaims()` が生成した再統合済み回答。パラメータ変更や新規推論時に `None` にリセット |
172
+ | `integrated_alpha` | `float \| None` | `integrated_answer` 生成時の $\alpha$ 値。現在のスライダー値と異なれば警告を表示するために保持 |
173
+
174
+ #### B-2. データ型定義
175
+
176
+ `demo/inference_api.py` で定義する型。`app.py` との境界を明確にするため `TypedDict` で宣言する。
177
+
178
+ ```python
179
+ from typing import TypedDict
180
+ from demo.constants import DEFAULT_SCORING_METHOD # constants.py を単一の定義元とする
181
+
182
+ class Subclaim(TypedDict):
183
+ id: int
184
+ text: str # サブクレームのテキスト
185
+ score: float # 計算されたスコア値(手法名は SubclaimResult.scoring_method を参照)
186
+
187
+ class SubclaimResult(TypedDict):
188
+ query: str
189
+ dataset: str # 例: "medlf_qa"
190
+ mode: str # 例: "marginal" / "conditional"。thresholds.csv の mode 列に対応
191
+ group: str # 例: "MedicationQA"。Marginal モードまたは非グループデータセットは "default"
192
+ scoring_method: str # 例: "relevance"。apply_threshold() での閾値ルックアップキーに使用
193
+ rag_answer: str
194
+ retrieved_docs: list[str]
195
+ subclaims: list[Subclaim]
196
+
197
+ class FilteredResult(TypedDict):
198
+ subclaims: list[Subclaim]
199
+ q_hat: float
200
+ keep_count: int
201
+ remove_count: int
202
+ ```
203
+
204
+ - `Subclaim.score` は手法を問わず汎用名とし、何のスコアかは `SubclaimResult.scoring_method` で示す
205
+ - `SubclaimResult.group` は CSV の `group` 列と対応。Marginal モードや非グループデータセットでは `"default"` を使い、`None` は使わない
206
+ - `status`(Keep/Remove)は UI 側でスコアと `q_hat` を比較して動的に判定するため、バックエンドからは返さない
207
+
208
+ #### 共有定数 — `demo/constants.py`
209
+
210
+ `precompute.py`(CSV 生成)と `app.py`(スライダー)の両方が参照する定数をここに一元管理する。スライダーの選択値と CSV の `alpha` 列が一致しないと lookup が失敗するため、両者は必ずこのファイルを `import` して使う。
211
+
212
+ ```python
213
+ # demo/constants.py
214
+
215
+ # スコアリング手法のデフォルト値。変更時はここのみ修正し、precompute.py を再実行する
216
+ DEFAULT_SCORING_METHOD = "relevance"
217
+
218
+ # クリーンな手法名 → main.py の出力データ内のキー名 のマッピング
219
+ # "relavance" は元コードベースのタイポ。データファイルとの互換性のため保持
220
+ SCORING_METHOD_TO_DATA_KEY: dict[str, str] = {
221
+ "relevance": "relavance",
222
+ "cosine_similarity": "query_claim_cosine_similarity",
223
+ "min_log_prob": "min_log_prob",
224
+ }
225
+
226
+ # Target Factuality スライダーの刻み(alpha = 1 - target_factuality)
227
+ # precompute.py はこのリストの alpha 値のみ計算する
228
+ # app.py の st.select_slider はこのリストをそのまま使う
229
+ ALPHA_LEVELS: list[float] = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40]
230
+ # → Target Factuality: 95%, 90%, 85%, 80%, 75%, 70%, 65%, 60%
231
+ ```
232
+
233
+ ---
234
+
235
+ ### C. 推論パイプライン — 重い処理と軽い処理の分離
236
+
237
+ **重い処理① — 「推論実行」ボタンのクリック時のみ起動**
238
+
239
+ ```
240
+ 1. クエリのベクトル化 & FAISS 検索 ← src/common/faiss_manager.py
241
+ 2. RAG による初期回答(ŷ)の生成 ← src/common/llm/openai_rag_agent.py
242
+ 3. OpenAI API によるサブクレームへの分解 ← src/common/llm/openai_atomicfact_generator.py
243
+ 4. サブクレームごとのスコア計算 ← src/subclaim_processor/scorer/
244
+ 5. 結果を st.session_state に格納 → st.rerun() でスライダーを即時有効化
245
+ ```
246
+
247
+ - `st.spinner` でユーザーに待機(30〜60 秒)を明示する
248
+ - アノテーション(S/I/U/N ラベル付け)はデモの表示に不要なため**スキップ**する
249
+ - frequency スコア(`src/subclaim_processor/scorer/subclaim_scorer.py` の `frequency_score()`)は1サブクレームあたり LLM を 5 回呼び出すため、**デモでは除外する**。表示するスコアは relevance・cosine similarity・min_log_prob に限定する
250
+
251
+ **重い処理② — 「回答を生成」ボタンのクリック時のみ起動**
252
+
253
+ ```
254
+ 1. 閾値未満のサブクレームを removed_subclaims として収集
255
+ 2. reintegrate_subclaims(rag_answer, removed_subclaims) を呼び出し
256
+ → OpenAILLMAgent 経由で元の回答から除去対象サブク��ームを削除した自然文を生成
257
+ 3. 結果を st.session_state.integrated_answer に格納 → st.rerun()
258
+ ```
259
+
260
+ **軽い処理(リアルタイム実行)— スライダー操作のたびに実行**
261
+
262
+ ```
263
+ 1. スライダーの alpha 値に対応する q̂ を事前計算済み CSV から取得
264
+ 2. st.session_state.result のスコアと q̂ を比較
265
+ 3. Keep/Remove の HTML タグを生成して st.markdown で再描画
266
+ <span style="padding:2px 5px;">サブクレーム</span> ← Keep(装飾なし)
267
+ <span style="text-decoration:line-through; color:#dc3545;">サブクレーム</span> ← Remove
268
+ ```
269
+
270
+ API 呼び出しは一切行わない。
271
+
272
+ ### D. リソースキャッシュ(`@st.cache_resource`)
273
+
274
+ FAISS インデックスはデータセットごとに数百 MB あるため、データセット切り替えのたびにロードし直すと数十秒かかる。`@st.cache_resource` でデータセット名をキーにキャッシュし、初回のみロードする。
275
+
276
+ ```python
277
+ @st.cache_resource
278
+ def get_faiss_manager(dataset: str) -> FAISSIndexManager:
279
+ return build_faiss_manager(dataset)
280
+
281
+ @st.cache_resource
282
+ def get_scorer(dataset: str) -> SubclaimScorer:
283
+ return build_scorer(dataset)
284
+ ```
285
+
286
+ `thresholds.csv` と `samples.json` の静的ファイルは `@st.cache_data` でキャッシュする(`@st.cache_resource` はオブジェクト向け。DataFrameや辞書は `@st.cache_data`)。
287
+
288
+ > **実装上の注意(HF Spaces 対応)**: `FAISSIndexManager` / `SubclaimScorer` は
289
+ > `langchain_text_splitters` を経由して `torch` / `transformers` 等の重い依存を引き込む。
290
+ > HF Spaces など FAISS 不要な環境での起動コストを避けるため、これらのクラスは
291
+ > `inference_api.py` のモジュールレベルではなく `build_faiss_manager()` / `build_scorer()`
292
+ > 関数内で遅延 import している。型ヒントは `TYPE_CHECKING` ガードで維持。
293
+
294
+ ---
295
+
296
+ ## 4. 開発ステップ
297
+
298
+ ### ① `demo/constants.py`(共有定数)
299
+
300
+ 他モジュールが依存するため最初に作成する。`ALPHA_LEVELS` と `DEFAULT_SCORING_METHOD` のみ定義。
301
+
302
+ ### ② `demo/inference_api.py`(推論ラッパー)
303
+
304
+ `precompute.py` がこれを呼び出すため、先に実装する。
305
+
306
+ - `src/` モジュールを `import` するのみで、コアロジックは変更しない
307
+ - インターフェース:
308
+ - `process_query(query: str, dataset: str, mode: str, group: str) -> SubclaimResult`(重い処理①)
309
+ - `reintegrate_subclaims(rag_answer: str, removed_subclaims: list[str], model: str) -> str`(重い処理②)
310
+ - `apply_threshold(result: SubclaimResult, q_hat: float) -> FilteredResult`(軽い処理)
311
+ - 使用するスコアリング手法は `DEFAULT_SCORING_METHOD` 定数で管理し、ハードコーディングしない
312
+
313
+ ### ③ `demo/precompute.py` + データ生成(オフライン)
314
+
315
+ `inference_api` が完成してから実行する。
316
+
317
+ - 既存の `main.py` を実行し、デモ用データセットの FAISS インデックスを作成(未作成の場合)
318
+ - `demo/precompute.py` を実装・実行し、以下を生成:
319
+ 1. $\alpha$ ごとのキャリブレーション閾値マスタ(CSV: `demo/data/thresholds.csv`)
320
+ 2. サンプルクエリ分の `SubclaimResult` キャッシュ(`demo/data/samples.json`、生成日時・モデル名のメタデータ付き)
321
+
322
+ ### ④ `demo/app.py`(Streamlit UI)
323
+
324
+ 上記 3 つが揃ってから実装する。
325
+
326
+ - ボタン・スライダー・`st.session_state` を連動させる
327
+ - MedLFQA 以外のデータセット選択時は Conditional モードおよびグループ選択を非活性にする
docs/context/04_environment_setup.md ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 環境構築・実行手順
2
+
3
+ 実行環境は **Singularity コンテナ経由** と **ホスト直接実行** の2通りをサポートする。
4
+
5
+ ---
6
+
7
+ ## 1. 前提条件
8
+
9
+ | ツール | バージョン要件 | 用途 |
10
+ |--------|-------------|------|
11
+ | Git | 任意 | リポジトリのクローン |
12
+ | Singularity (SingularityCE) | 3.x 以上 | コンテナ実行・SIF ビルド(コンテナ経由の場合) |
13
+ | `fakeroot` | — | SIF ビルド時に必要(`singularity build --fakeroot`) |
14
+ | Python 3.11 + `requirements-dev.txt` | — | ホスト直接実行の場合のみ |
15
+
16
+ > **推奨**: HPC 環境など Python 環境の管理が難しい場合は Singularity コンテナ経由を使う。ローカル開発では直接実行でも動作する。
17
+
18
+ ---
19
+
20
+ ## 2. ホスト側ディレクトリ構成
21
+
22
+ ホストの任意のディレクトリ(`DATA_ROOT` の親)に以下の構造を作成する。
23
+ 重いファイル(FAISS インデックス、生データ、HuggingFace キャッシュ)はすべてここに集約し、Singularity の `--bind` でコンテナ内にマウントする。
24
+
25
+ ```
26
+ /mnt/data/<username>/
27
+ ├── sif/
28
+ │ └── response_quality.sif ← SIF_DIR
29
+ ├── hf_cache/ ← HF_HOME
30
+ └── ResponseQualityAssessment/ ← DATA_ROOT
31
+ ├── data/
32
+ │ ├── raw/
33
+ │ │ └── WikiDB/ ← Wikipedia SQLite DB
34
+ │ ├── processed/
35
+ │ ├── out/ ← main.py 出力(サブクレーム JSON)
36
+ │ └── result/
37
+ └── index_store/
38
+ ├── FactScore/ ← FAISS インデックス
39
+ ├── HotpotQA/
40
+ ├── PopQA/
41
+ └── MedLFQA/
42
+ ```
43
+
44
+ > `demo/data/`(thresholds.csv, samples.json)はリポジトリ内に含まれるため `DATA_ROOT` 側への配置・バインドは不要。
45
+
46
+ ---
47
+
48
+ ## 3. `.env` の設定
49
+
50
+ `.env.example` をコピーして `.env` を作成し、環境に合わせて値を設定する。
51
+
52
+ ```bash
53
+ cp .env.example .env
54
+ ```
55
+
56
+ `.env` の内容(`.gitignore` 対象):
57
+
58
+ ```bash
59
+ # Machine-specific absolute paths
60
+ DATA_ROOT=/mnt/data/<username>/ResponseQualityAssessment
61
+ SIF_DIR=/mnt/data/<username>/sif
62
+ HF_HOME=/mnt/data/<username>/hf_cache
63
+
64
+ # API keys
65
+ OPENAI_API_KEY=sk-...
66
+ ```
67
+
68
+ `conf/path_config.yaml` と `conf/dataset_config.yaml` の各パスは `${DATA_ROOT}` を起点として定義されており、実行時にこの環境変数を展開して解釈する。
69
+
70
+ > - **ホスト直接実行時**: `.env` の `DATA_ROOT`(例: `/mnt/data/<username>/ResponseQualityAssessment`)をそのまま使用する。
71
+ > - **コンテナ経由実行時**: `run_*.sh` スクリプトがホストの `data/` と `index_store/` を `REPO_ROOT` 以下にバインドマウントするため、コンテナ内での `DATA_ROOT` は `REPO_ROOT` に上書きされる。スクリプト内の `--env DATA_ROOT="${REPO_ROOT}"` がこれを担っている。
72
+
73
+ ---
74
+
75
+ ## 4. Singularity コンテナのビルド(初回のみ)
76
+
77
+ ```bash
78
+ bash scripts/build_sif.sh
79
+ ```
80
+
81
+ - `singularity/response_quality.def`(`docker://python:3.11-slim` ベース)から SIF をビルドする
82
+ - 出力先: `${SIF_DIR}/response_quality.sif`
83
+ - ビルドには数分かかる。`fakeroot` が必要
84
+ - `torch` はコンテナ内では CPU 専用 wheel(`--index-url https://download.pytorch.org/whl/cpu`)でインストールされる。ローカル開発で `requirements-dev.txt` から直接インストールする場合は同オプションを手動で指定すること
85
+
86
+ **requirements ファイルの構成:**
87
+
88
+ | ファイル | 用途 |
89
+ |----------|------|
90
+ | `requirements.txt` | HF Spaces 向け最小セット(`torch` 等の重い依存を除く)。Spaces はこのファイルを自動で読む |
91
+ | `requirements-dev.txt` | ローカル・Singularity 向け全依存(`-r requirements.txt` で共通部分を継承) |
92
+
93
+ ホスト直接実行時のインストール:
94
+ ```bash
95
+ pip install -r requirements-dev.txt
96
+ ```
97
+
98
+ ---
99
+
100
+ ## 5. データ準備(初回のみ)
101
+
102
+ | 方法 | 用途 | API 消費 |
103
+ |--------|------|---------|
104
+ | **A. 著者提供データを使う** | 生データ・FAISS インデックスを再利用して API 呼び出しを削減 | Embeddings 不要(FAISS スキップ)、Chat は必要 |
105
+ | **B. 自前でフルパイプライン実行** | 新モデルや新データセットで結果を再現 | あり(Chat + Embeddings) |
106
+
107
+ ---
108
+
109
+ ### A. 著者提供データを使う(FAISS 構築スキップ)
110
+
111
+ 著者提供の Google Drive フォルダには以下が含まれる(**注**: 事前計算済みサブクレーム JSON は含まれていない):
112
+
113
+ - `data.zip` — 生クエリデータ(`HotpotQA/raw_hotpot_qa.json`、`PopQA/raw_pop_qa.json`、`FactScore/factscore_names.txt`、`MedLFQA/*.jsonl`)
114
+ - `index_store.zip` — 事前構築済み FAISS インデックス(各データセット × 各 query_size 分の `.faiss` と `indice2fm_*.json`)
115
+
116
+ FAISS インデックスが揃っていれば Embeddings API(インデックス構築)はスキップされる。ただし、RAG 回答生成・サブクレーム抽出・スコアリング・アノテーションには Chat API が必要。
117
+
118
+ **1. Wikipedia DB をダウンロードして配置する**
119
+
120
+ 全データセット共通で必要(QueryProcessor が常に参照する)。
121
+
122
+ > [enwiki-20230401.db(Google Drive)](https://drive.google.com/file/d/1mekls6OGOKLmt7gYtHs0WGf5oTamTNat/view?usp=drive_link)
123
+
124
+ ```
125
+ DATA_ROOT/data/raw/WikiDB/enwiki-20230401.db
126
+ ```
127
+
128
+ **2. `medlf_qa` を使う場合は生データを手動取得する**
129
+
130
+ `fact_score` / `hotpot_qa` / `pop_qa` は DataLoader が HuggingFace から自動取得するため不要。
131
+
132
+ ```bash
133
+ source .env
134
+ # 以下どちらか一方
135
+ git clone https://github.com/dmis-lab/OLAPH.git /tmp/olaph
136
+ cp -r /tmp/olaph/MedLFQA ${DATA_ROOT}/data/raw/MedLFQA
137
+ # または
138
+ git clone https://github.com/jjcherian/conformal-safety.git /tmp/cs
139
+ cp -r /tmp/cs/data/MedLFQAv2 ${DATA_ROOT}/data/raw/MedLFQA
140
+ ```
141
+
142
+ **3. 著者提供データをダウンロードして展開する**
143
+
144
+ > https://drive.google.com/drive/folders/1aLbHxS6V1ipMH8FpVCxKmr8oMYfqmRgb
145
+
146
+ ```bash
147
+ source .env
148
+ # gdown でフォルダをダウンロード("conformal prediction with RAG/" という名前で保存される)
149
+ gdown --folder https://drive.google.com/drive/folders/1aLbHxS6V1ipMH8FpVCxKmr8oMYfqmRgb \
150
+ -O "${DATA_ROOT}/data/out/"
151
+
152
+ # 生データを data/raw/ に展開
153
+ unzip "${DATA_ROOT}/data/out/conformal prediction with RAG/data.zip" \
154
+ -d "${DATA_ROOT}/data/raw/"
155
+
156
+ # FAISS インデックスを DATA_ROOT 直下に展開(index_store/ が作られる)
157
+ unzip "${DATA_ROOT}/data/out/conformal prediction with RAG/index_store.zip" \
158
+ -d "${DATA_ROOT}/"
159
+ ```
160
+
161
+ 展開後の構造:
162
+
163
+ ```
164
+ DATA_ROOT/
165
+ ├── data/
166
+ │ └── raw/
167
+ │ ├── HotpotQA/raw_hotpot_qa.json
168
+ │ ├── PopQA/raw_pop_qa.json
169
+ │ ├── FactScore/factscore_names.txt
170
+ │ └── MedLFQA/*.jsonl
171
+ └── index_store/
172
+ ├── FactScore/index_500.faiss, indice2fm_500.json, ...
173
+ ├── HotpotQA/index_500.faiss, indice2fm_500.json
174
+ └── MedLFQA/index_10.faiss, index_500.faiss, indice2fm_*.json, ...
175
+ ```
176
+
177
+ **4. 実行する**
178
+
179
+ FAISS インデックスが存在する場合、インデックス構築(Embeddings API)はスキップされる。RAG 推論・スコアリング・アノテーションには Chat API が必要。
180
+
181
+ ```bash
182
+ # コンテナ経由
183
+ bash scripts/run_main.sh --dataset medlf_qa
184
+
185
+ # ホスト直接実行
186
+ source .env && python main.py --dataset medlf_qa
187
+ ```
188
+
189
+ ---
190
+
191
+ ### B. 自前でフルパイプライン実行(API あり)
192
+
193
+ **1. Wikipedia DB をダウンロードして配置する**
194
+
195
+ 上記 A の手順 1 と同じ。
196
+
197
+ **2. `medlf_qa` を使う場合は生データを手動取得する**
198
+
199
+ 上記 A の手順 2 と同じ。`fact_score` / `hotpot_qa` / `pop_qa` は不要。
200
+
201
+ **3. 実行する**
202
+
203
+ データセットごとに1つずつ実行する(並列不可)。
204
+
205
+ ```bash
206
+ # コンテナ経由
207
+ bash scripts/run_main.sh --dataset medlf_qa
208
+ # 利用可能: fact_score, hotpot_qa, pop_qa, medlf_qa
209
+
210
+ # ホスト直接実行
211
+ source .env && python main.py --dataset medlf_qa
212
+ ```
213
+
214
+ ---
215
+
216
+ 実行後に生成されるファイル(`DATA_ROOT` 以下):
217
+
218
+ | パス | 内容 |
219
+ |------|------|
220
+ | `data/out/{DatasetName}/` | サブクレームスコア付き JSON |
221
+ | `index_store/{DatasetName}/` | FAISS インデックス (`index_*.faiss`, `indice2fm_*.json`) |
222
+
223
+ ---
224
+
225
+ ### CLI オプション一覧
226
+
227
+ | オプション | デフォルト | 説明 |
228
+ |---|---|---|
229
+ | `--dataset` | config 参照 | データセット名(`fact_score` / `hotpot_qa` / `pop_qa` / `medlf_qa`) |
230
+ | `--query_size` | 500 | 処理するクエリ件数 |
231
+ | `--run_id` | タイムスタンプ | ログ・結果ディレクトリの識別子 |
232
+ | `--lite` | false | `frequency`・`doc_claim_cosine_similarity` スコアと conformal グラフ生成をスキップ |
233
+ | `--log_level` | `INFO` | ログレベル(`DEBUG` / `INFO` / `WARNING` / `ERROR`) |
234
+
235
+ ### API コールのログ確認
236
+
237
+ `--log_level DEBUG` を付けると、OpenAI API の個別コール(purpose・model・トークン数)が `[API:Chat]` / `[API:Embeddings]` プレフィックスで記録される。
238
+
239
+ ```bash
240
+ # API ログを有効にして実行(コンテナ経由)
241
+ bash scripts/run_main.sh --dataset medlf_qa --query_size 10 --lite --log_level DEBUG
242
+
243
+ # ログから API コールのみ抽出
244
+ grep "\[API:" logs/<run_id>/run_<run_id>.log
245
+
246
+ # purpose 別の集計
247
+ grep "\[API:" logs/<run_id>/run_<run_id>.log \
248
+ | sed 's/.*\[API:\([^]]*\)\] purpose=\([^ ]*\).*/\1 \2/' \
249
+ | sort | uniq -c | sort -rn
250
+ ```
251
+
252
+ 出力例:
253
+ ```
254
+ 103 Embeddings cosine_similarity(claim)
255
+ 103 Embeddings cosine_similarity(query)
256
+ 103 Embeddings relevance_score
257
+ 103 Chat claim_verification
258
+ 10 Chat subclaim_extraction
259
+ 10 Chat rag_response
260
+ ```
261
+
262
+ ---
263
+
264
+ ## 6. 事前計算(データ更新・モデル変更時)
265
+
266
+ `demo/data/thresholds.csv` と `demo/data/samples.json` を(再)生成する。
267
+ モデル変更やデータセット追加後に実行する。
268
+
269
+ **コンテナ経由:**
270
+ ```bash
271
+ bash scripts/run_precompute.sh
272
+ ```
273
+
274
+ **ホスト直接実行:**
275
+ ```bash
276
+ source .env
277
+ python -m demo.precompute
278
+ ```
279
+
280
+ 出力はリポジトリ内の `demo/data/` に書き込まれる。
281
+ 生成後は `git add demo/data/ && git commit` でコミットする。
282
+
283
+ ---
284
+
285
+ ## 7. デモの起動
286
+
287
+ **コンテナ経由:**
288
+ ```bash
289
+ bash scripts/run_demo.sh # デフォルト: ポート 8502
290
+ bash scripts/run_demo.sh --port 8503 # ポートを指定する場合
291
+ ```
292
+
293
+ **ホスト直接実行:**
294
+ ```bash
295
+ source .env
296
+ streamlit run demo/app.py --server.port 8502
297
+ ```
298
+
299
+ - デフォルトポート: **8502**(8501 はサーバー上の別アプリが使用中のため)
300
+ - ポートが使用中の場合は `--port` オプションで別ポートを指定する
301
+
302
+ **Singularity バインド構成(スクリプト内部):**
303
+
304
+ ```bash
305
+ singularity run \
306
+ --bind "${DATA_ROOT}/data:${REPO_ROOT}/data" \
307
+ --bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store" \
308
+ --bind "${HF_HOME}:${HF_HOME}" \
309
+ --env HF_HOME="${HF_HOME}" \
310
+ --env OPENAI_API_KEY="${OPENAI_API_KEY}" \
311
+ --env DATA_ROOT="${REPO_ROOT}" \
312
+ --env PYTHONPATH="${REPO_ROOT}" \
313
+ --pwd "${REPO_ROOT}" \
314
+ "${SIF_DIR}/response_quality.sif" \
315
+ streamlit run demo/app.py --server.port 8502
316
+ ```
317
+
318
+ (`REPO_ROOT` はスクリプト自身の場所から自動解決される。コンテナ内では `DATA_ROOT` を `REPO_ROOT` に上書きすることで、バインドマウント先のパスと整合させている。)
319
+
320
+ `demo/data/` はリポジトリ内に存在するため、バインド不要。
321
+
322
+ ---
323
+
324
+ ## 8. トラブルシューティング
325
+
326
+ ### `.env` が読み込まれない
327
+
328
+ 各スクリプトはリポジトリルートの `.env` を自動的に探して読み込む(`REPO_ROOT` をスクリプトの場所から導出するため、どのディレクトリから実行しても動作する)。`.env` ファイル自体が存在するか確認する。
329
+
330
+ ```bash
331
+ ls /path/to/ResponseQualityAssessment/.env
332
+ ```
333
+
334
+ ### `DATA_ROOT` 以下のディレクトリが存在しない
335
+
336
+ `data/out/` や `index_store/` は `main.py` 実行時に自動作成されるが、
337
+ `DATA_ROOT` 自体と `DATA_ROOT/data/`, `DATA_ROOT/index_store/` は事前に作成が必要。
338
+
339
+ ```bash
340
+ source .env
341
+ mkdir -p ${DATA_ROOT}/data/raw/WikiDB \
342
+ ${DATA_ROOT}/data/processed \
343
+ ${DATA_ROOT}/data/out \
344
+ ${DATA_ROOT}/data/result \
345
+ ${DATA_ROOT}/index_store \
346
+ ${HF_HOME} \
347
+ ${SIF_DIR}
348
+ ```
349
+
350
+ ### FAISS インデックスが見つからない(デモ起動時にエラー)
351
+
352
+ `scripts/run_main.sh` が未実行、または対象データセットが処理されていない。
353
+ §5「データ準備」を実行する。
354
+
355
+ ### `demo/data/thresholds.csv` が見つからない
356
+
357
+ `scripts/run_precompute.sh` が未実行。§6「事前計算」を実行する。
358
+ または、リポジトリの `demo/data/` にコミット済みファイルがあるか確認する(`git pull` で最新化)。
359
+
360
+ ### `subclaims_schema.json` が見つからない(コンテナ経由実行時)
361
+
362
+ コンテナ実行時は `DATA_ROOT/data/` が `REPO_ROOT/data/` にバインドマウントされるため、リポジトリ内の `data/out/subclaims_schema.json` が隠れてしまう。初回実行前に手動でコピーする。
363
+
364
+ ```bash
365
+ source .env
366
+ cp data/out/subclaims_schema.json ${DATA_ROOT}/data/out/subclaims_schema.json
367
+ ```
368
+
369
+ ### OpenAI API エラー
370
+
371
+ `.env` の `OPENAI_API_KEY` が正しく設定されているか確認する。
372
+ `#` で始まる行はコメントとして無視されるため、`OPENAI_API_KEY=sk-...` のように `#` なしで記述する。
373
+ 推論実行ボタンを押した場合のみ API が呼ばれる。サンプルクエリはボタンを押さずにプルダウン選択するだけのため API 不要。
docs/context/05_medrag_corpus_integration.md ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MedRAGコーパス導入・システム拡張要件書
2
+
3
+ ## 1. プロジェクト概要
4
+
5
+ ### 1.1. 目的
6
+
7
+ 現状のRAG回答品質評価システム(Conformal-RAG)は、MedLFQAデータセット評価において、データセット内にあらかじめ用意された参照文書(JSONL)から検索インデックスを構築している。本プロジェクトでは、この閉じた検証環境を脱却し、外部の大規模医療特化コーパス(MedRAG)を統合することで、実運用に近いノイズ環境下でのRAGファクチュアリティ(事実性)評価基盤を構築する。
8
+
9
+ ### 1.2. 課題と解決方針
10
+
11
+ 大規模コーパスを導入すると、検索時のディストラクター(意味的に類似するが無関係なノイズ文書)が激増し、コンフォーマル予測のスコアリングが破綻するリスクがある。これに対し、以下の3点をシステム要件として定義する。
12
+
13
+ - 医療特化の埋め込みモデルへの切り替え
14
+ - ハイブリッド検索・リランカーの導入
15
+ - チャンク戦略の最適化
16
+
17
+ ---
18
+
19
+ ## 2. システム構成・データパイプライン要件
20
+
21
+ ### 2.1. コーパスデータベースの構築
22
+
23
+ - **要件**: MedRAGが提供する生データ(PubMedアブストラクト、StatPearls等)をシステムに取り込む。
24
+ - **MedRAGコーパスのデータ形式**: MedRAGは以下のコーパスをHuggingFace(`MedRAG/pubmed` 等)でチャンク済みデータセットとして公開している。スキーマは全コーパス共通の4フィールド。
25
+
26
+ | コーパス | スニペット数 | 平均トークン数 | チャンク戦略 | 備考 |
27
+ |---------|------------|------------|------------|------|
28
+ | PubMed | 約2,390万件 | **296 tokens** | 論文1件=1スニペット(title + abstract) | HuggingFaceで配布済み |
29
+ | StatPearls | 約30万件 | 119 tokens | 段落単位 | Privacy Policyによりコンテンツ非公開。自前処理必要 |
30
+ | Textbooks | 約13万件 | 182 tokens | `RecursiveCharacterTextSplitter`(上限1,000文字) | HuggingFaceで配布済み |
31
+
32
+ ```json
33
+ {"id": "pubmed23n0001_0", "title": "...", "content": "...", "contents": "title + content の結合(BM25用)"}
34
+ ```
35
+
36
+ - **実装方針**: HuggingFaceから配布済みのParquet/JSONL形式データをロードし、`DocDB.build_db()`(`src/rag/retrieval.py`)で`title`・`content`フィールドをSQLiteに格納する。`DataLoader.create_wiki_db()` と同様の仕組みを活用できる。
37
+
38
+ ### 2.2. チャンキング戦略の方針
39
+
40
+ - **現状の課題**: `FixedLengthChunker`(`src/common/chunker.py`)がデフォルトで2000語(オーバーラップ25語)と大きく設定されているため、情報の境界が曖昧になりやすい。設定は `conf/config.yaml` の `index.truncation_config` で管理されている。
41
+ - **コーパス別の対応方針**:
42
+ - **PubMed**: チャンク済みデータ(平均296トークン)をそのまま使用。MedCPT Article Encoderの最大入力長512トークンの範囲内に収まっており、**追加のチャンキングは不要**。MedCPT使用時は `[title, content]` をリスト形式で渡す必要がある(文字列連結不可)。
43
+ - **StatPearls / Textbooks**: 自前チャンキングが必要な場合は `RecursiveCharacterChunker` を `FileManager` に追加し、BPEトークンベースで256〜512トークン程度を目標とする。`conf/config.yaml` の `strategy` キーで選択可能にする。
44
+
45
+ ---
46
+
47
+ ## 3. 検索アルゴリズム要件(Retrieval & Embedding)
48
+
49
+ ### 3.1. 医療特化Embeddingモデルの導入
50
+
51
+ - **現状の課題**: `conf/config.yaml` の `index.embedding_model` に `text-embedding-3-large`(汎用モデル)が固定で指定されており、`OpenAIManager.create_openai_embeddings()` を通じて埋め込みを生成している。
52
+ - **要件**: 医療専門用語や略語を正確にベクトル空間へマッピングする。
53
+ - **実装方針**: `OpenAIManager` に依存しないローカルモデル用のEmbeddingクラスを新設し(例: `LocalEmbeddingManager`)、以下のモデルをロードして埋め込みを生成できるようにする。`conf/config.yaml` の `embedding_model` キーで切り替え可能な設計とする。
54
+ - **`MedCPT`**(`ncbi/MedCPT-Query-Encoder` / `ncbi/MedCPT-Article-Encoder`): NCBIが255MのPubMed検索ログでContrastive学習したbi-encoderであり、RAGの密ベクトル検索に直接適用できる。MedRAGベンチマークでも高い性能を示しており、第一候補。
55
+ - **`MedEmbed`**(`abhinand/MedEmbed`シリーズ): 汎用モデルを医療データでContrastive学習させたモデル。MTEBの医療タスクで競争力のある性能を示す代替選択肢。
56
+ - > **注意**: `PubMedBERT` や `BioLinkBERT` はContrastive学習を経ていないMLMであり、fine-tuneなしでのbi-encoderとしての利用は非推奨。
57
+
58
+ ### 3.2. ��イブリッド検索(Hybrid Search)の実装
59
+
60
+ - **現状の課題**: `FAISSIndexManager`(`src/common/faiss_manager.py`)による密ベクトル(Dense)検索のみが実装されており、薬品名・疾患名などの固有表現の完全一致を取りこぼすケースがある。
61
+ - **要件**: 固有表現の完全一致を取りこぼさない検索機構の構築。
62
+ - **実装方針**: `FAISSIndexManager` に加えて疎ベクトル(Sparse)検索エンジンを導入し、両者のスコアをReciprocal Rank Fusion(RRF)で統合して上位文書を取得する仕組みを追加する。
63
+ - **Sparse検索の選択肢**: BM25は実装が容易で安定したベースラインとなる。ただし医療・科学文書では専門用語の語彙ミスマッチが起きやすいため、意味的な語彙展開が可能な **SPLADE** も代替として検討する。
64
+ - **スコア統合**: RRFは訓練データなしで使えるロバストな手法であり第一選択。ラベル付きデータが入手できる場合は重み付きRRFや凸結合(Convex Combination)でさらなる精度改善が見込める。
65
+
66
+ ### 3.3. リランカー(Re-ranker)の導入
67
+
68
+ - **要件**: 大規模コーパス特有のノイズ文書(ディストラクター)を排除する。
69
+ - **実装方針**: `SubclaimProcessor.generate_responses()`(`src/subclaim_processor/subclaim_processor.py`)において、ハイブリッド検索で上位100件程度を粗く取得したのち、Cross-Encoderで精緻な関連度再計算を行い、最終的に上位10件に絞り込む処理を追加する。
70
+ - **推奨モデル(セルフホスト)**: `BAAI/bge-reranker-v2-m3`(多言語・汎用)または `BAAI/bge-reranker-v2.5-gemma2-lightweight`(2025年時点のSOTA、トークン圧縮機能付き)
71
+ - **推奨モデル(商用API)**: Cohere Rerank(現行v4)。精度は高いがセルフホスト不可のため、運用コストと可搬性を考慮して選択する。
72
+
73
+ ---
74
+
75
+ ## 4. 評価・スコアリング要件
76
+
77
+ ### 4.1. スコアリング戦略(Strategy)の改修
78
+
79
+ - **現状の課題**: `ProductScoreStrategy`(`src/subclaim_processor/strategies/scoring/product.py`)がFAISSの検索スコアとサブクレームのコサイン類似度の積でスコアを算出している。ハイブリッド検索やリランカーを導入するとベーススコアのスケールが変化するため、この計算式が破綻する。
80
+ - **要件**: 検索手法に依存しない普遍的なサブクレーム関連度スコアの定義。
81
+ - **実装方針**: `ScoringStrategy`(`src/subclaim_processor/strategies/scoring/base.py`)を継承した新たな戦略クラスを追加し、Cross-Encoderが出力する「クエリ+サブクレーム」と「検索文書」の推論ロジットを直接スコアとして採用する。`conf/config.yaml` の `conformal_prediction.scoring_strategy` キーで切り替え可能にする。
82
+
83
+ ---
84
+
85
+ ## 5. 非機能要件(パフォーマンス・拡張性)
86
+
87
+ ### 5.1. 非同期処理・バッチ処理の導入
88
+
89
+ - **現状の課題**: RAG応答生成からサブクレーム抽出・アノテーションに至るまで、パイプライン全体が逐次処理で実装されており、大量クエリ評価時のスループットが低い。OpenAI APIのレートリミット(429エラー)も問題になりうる。
90
+ - **要件**: 大規模コーパスおよび大量クエリ評価における実行時間の大幅な短縮とAPIレートリミット回避。
91
+ - **実装方針**: 以下のいずれか、または組み合わせを採用する。
92
+ - `asyncio` を用いた非同期パイプラインへの書き換え(リアルタイム性が求められる場合)
93
+ - OpenAI Batch APIを用いたアノテーションの一括処理(コスト重視の場合、レイテンシは最大24時間)
94
+
95
+ ### 5.2. インデックス管理の堅牢化
96
+
97
+ - **現状の課題**: `FAISSIndexManager` は `IndexFlatIP`(内積)を使用しており、ファイルパスとFAISSインデックス番号のリストを紐づけるマッピング `indice2fm`(`{file_path: [indices]}`、`index_store/indice2fm.json` に保存)をオンメモリで管理している。数百万規模のドキュメントではメモリ不足が発生しうる。
98
+ - **要件**: 大規模インデックスにおけるメモリ効率の確保と、インデックス更新の容易性。
99
+ - **実装方針**: インデックス構築をバッチ処理化する。必要に応じてディスクベースのベクトルDB(Milvus、Qdrant等)への移行を見据え、`FAISSIndexManager` のインターフェースを抽象化しておく。
100
+
101
+ ---
102
+
103
+ ## 6. ハイパーパラメータ調整・実験計画
104
+
105
+ ### 6.1. 調整対象ハイパーパラメータの一覧
106
+
107
+ パイプラインを構成する4つのステージごとにハイパーパラメータを整理する。各パラメータの現行値は `conf/config.yaml` に基づく。
108
+
109
+ #### ステージA: チャンキング(StatPearls / Textbooks のみ)
110
+
111
+ | パラメータ | 現行値 | 探索範囲 | 備考 |
112
+ |-----------|------|--------|------|
113
+ | `chunk_size` | 2000語 | 128 / 256 / **512** / 1024 トークン | PubMedは296トークン固定のため対象外 |
114
+ | `chunk_overlap` | 25語 | 0 / 32 / **64** / 128 トークン | chunkサイズの10〜20%が目安 |
115
+
116
+ #### ステージB: 検索(Retrieval)
117
+
118
+ | パラメータ | 現行値 | 探索範囲 | 備考 |
119
+ |-----------|------|--------|------|
120
+ | `retrival_topk` | 10 | 10 / 20 / **50** / 100 | リランカー導入後は候補プール用に拡大 |
121
+ | `retrival_threshold` | 0.3 | 0.1 / **0.2** / 0.3 | Denseスコア下限。低すぎるとノイズ増加 |
122
+ | BM25 `k1` | —(未実装) | 0.5 / **1.2** / 2.0 | 語頻度の飽和速度。医療文書は1.2前後が標準 |
123
+ | BM25 `b` | —(未実装) | 0.25 / **0.75** | 文書長正規化。長いアブストラクトには0.75推奨 |
124
+ | RRF `k` | —(未実装) | 10 / **60** | 標準値60。訓練データなしで使える頑健な設定 |
125
+ | RRF Dense重み | —(未実装) | 0.3 / **0.5** / 0.7 | Dense寄りにするほど意味検索優先 |
126
+
127
+ #### ステージC: リランキング
128
+
129
+ | パラメータ | 現行値 | 探索範囲 | 備考 |
130
+ |-----------|------|--------|------|
131
+ | リランカー候補プール数 | —(未実装) | 50 / **100** / 200 | 大きいほど再現率↑、遅延↑ |
132
+ | リランカー後 `topk` | —(未実装) | 5 / **10** / 20 | 最終的にLLMに渡す文書数 |
133
+
134
+ #### ステージD: スコアリング・集約(Conformal Prediction)
135
+
136
+ | パラメータ | 現行値 | 探索範囲 | 備考 |
137
+ |-----------|------|--------|------|
138
+ | `scoring_strategy` | `product` | `product` / `cross_encoder` | リランカー導入後は `cross_encoder` に切り替えを検討 |
139
+ | `aggregation_strategy` | `mean` | `mean` / `max` | `max`は最もサポートされた証拠を採用 |
140
+ | `a_value` | 1.0 | 0.8 / 0.9 / **1.0** | 「ファクチュアル」と見なす正解保持率の閾値 |
141
+
142
+ ---
143
+
144
+ ### 6.2. 評価指標
145
+
146
+ 実験には以下の指標を使用する。`SplitConformalCalibration`(`src/calibration/conformal.py`)が既に以下の計算を実装済み。
147
+
148
+ | 指標 | 定義 | 担当ステージ |
149
+ |-----|------|------------|
150
+ | **Recall@k** | 正解文書がtop-k内に含まれる割合 | B, C |
151
+ | **MRR@k** | Mean Reciprocal Rank(正解文書の順位逆数の平均) | B, C |
152
+ | **Coverage rate** | 実経験的ファクチュアリティ ≥ 1−α を満たすクエリの割合(≥ 1−α であれば保証成立) | D |
153
+ | **Efficiency** | 各αにおける平均サブクレーム除去率(高いほど効率的にノイズを除去できている) | D |
154
+
155
+ ---
156
+
157
+ ### 6.3. 実験フロー
158
+
159
+ 各ステージで最良の設定を固定したうえで次ステージの探索を行う、逐次的な**グリーディサーチ**を採用する。ステージ間の依存関係があるため、並列探索は同一ステージ内のみで行う。
160
+
161
+ ```
162
+ Phase 1: チャンキング探索(StatPearls / Textbooks のみ)
163
+ └─ 評価指標: Recall@10 on MedLFQA
164
+ └─ 固定: Embedding=MedCPT, Dense-only 検索
165
+ └─ 探索: chunk_size × chunk_overlap の格子探索(8条件)
166
+
167
+ Phase 2: ハイブリッド検索の探索
168
+ └─ 評価指標: Recall@10, MRR@10
169
+ └─ 固定: Phase 1 のベストチャンク設定
170
+ └─ 探索ステップ:
171
+ 2a. Dense-only ベースライン(現行)
172
+ 2b. BM25 k1 × b の格子探索(6条件)、Denseとの RRF 融合
173
+ 2c. SPLADE vs BM25 の比較(GPU使用可能な場合のみ)
174
+ 2d. RRF Dense重み の探索(0.3 / 0.5 / 0.7)
175
+
176
+ Phase 3: リランカーの探索
177
+ └─ 評価指標: Recall@10 after rerank, NDCG@10
178
+ └─ 固定: Phase 2 のベストハイブリッド設定
179
+ └─ 探索: 候補プール数 × 最終topk の格子探索(6条件)
180
+ └─ モデル比較: bge-reranker-v2-m3 vs bge-reranker-v2.5-gemma2-lightweight
181
+
182
+ Phase 4: Conformal Prediction の探索
183
+ └─ 評価指標: Coverage rate, Efficiency(fraction_removed)
184
+ └─ 固定: Phase 3 のベスト検索設定
185
+ └─ 探索: scoring_strategy × aggregation_strategy × a_value(12条件)
186
+ └─ alphaは 0.05〜0.45 を既定のステップ幅でスイープ(conformal.py の既存実装を使用)
187
+ ```
188
+
189
+ ---
190
+
191
+ ### 6.4. 実験管理上の注意点
192
+
193
+ - **PubMedのチャンク設定は固定**: PubMedスニペットは平均296トークンで配布されており、MedCPTの512トークン上限に収まる。Phase 1 の探索対象はStatPearls / Textbooks のみとする。
194
+ - **MedCPTの入力形式**: `[title, content]` をリスト要素として渡す必要があり、他のモデルとの入力フォーマットが異なる。`LocalEmbeddingManager` 実装時に対応すること。
195
+ - **キャッシュの活用**: `conformal.py` はスコア計算結果を `.npy` ファイルにキャッシュする仕組みを持つ��チャンクや検索設定を変えた場合はキャッシュを削除してから再実行すること。
196
+ - **Phase間の依存**: 各フェーズのベスト設定を `conf/config.yaml` に反映してから次フェーズを実行する。バージョン管理のため、フェーズごとに `git tag` を付けることを推奨する。
docs/context/06_test_strategy.md ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # テスト計画と実行フロー
2
+
3
+ OpenAI API への依存度に基づいてステップを分離し、API キーなしでも中核ロジックを早期に検証できる構成にする。
4
+
5
+ ## 全体像
6
+
7
+ | ステップ | 目的 | API 使用 | 使用データ | 対象コンポーネント |
8
+ | :--- | :--- | :---: | :--- | :--- |
9
+ | **1. ユニットテスト** | コアロジック・純粋関数の単体テスト | **不要** | 合成データ(手作り) | `src/calibration/`、`demo/inference_api.py` |
10
+ | **2. スモークテスト** | 実データの生成とパイプラインの疎通確認 | **必要** | 実データ(極小スケール) | `main.py`(データ取得〜スコア計算) |
11
+ | **3. 統合テスト** | 事前計算ロジックの動作確認 | **不要** | ステップ 2 の出力 JSON | `demo/precompute.py` |
12
+ | **4. UI テスト** | Streamlit UI と状態管理の動作確認 | **不要** | ステップ 3 の出力 CSV/JSON | `demo/app.py`、`st.session_state` |
13
+
14
+ ---
15
+
16
+ ## 1. ユニットテスト(API 不要)
17
+
18
+ 論文の中核である「Conformal Guarantee(事実性の統計的保証)」を計算するロジックと、デモの表示切り替えロジックを最優先でテストする。`pytest` で即座に実行できる。
19
+
20
+ 実装ファイル:
21
+ - `tests/test_calibration.py` — キャリブレーション計算関数
22
+ - `tests/test_precompute.py` — `demo/precompute.py` のデータ変換・サンプリング関数
23
+
24
+ **準備するデータ:** ダミーのサブクレームリスト(スコア `relevance: 0.8`、ラベル `S` などを持たせた数十件の Python 辞書)
25
+
26
+ **テスト対象と確認事項:**
27
+
28
+ 1. **`src/calibration/utils.py` — `get_r_score`**
29
+ - ダミーのサブクレーム群を与え、目標正確率 $a$ を満たすための最小スコア $r_a$ が数学的に正しく算出されるか。
30
+ - 全サブクレームが正解の場合に `-1`(常に安全)が返るか。
31
+
32
+ 2. **`src/calibration/utils.py` — `compute_threshold`**
33
+ - 特定の $\alpha$(例: `0.10`)を与えたとき、`ceil((n+1)*(1-alpha))` 分位点ベースの閾値 $\hat{q}$ が正しく計算されるか。
34
+
35
+ 3. **`src/calibration/utils.py` — `split_group`**
36
+ - グループ別に calibration/test が指定比率(デフォルト 50/50)で分割されるか。
37
+ - グループをまたいだデータ汚染(calibration と test のオーバーラップ)がないか。
38
+
39
+ 4. **`src/calibration/conformal.py` — `_evaluate_conformal_correctness`**
40
+ - 閾値を適用したとき、スコアが閾値以上のサブクレームのみ retain され、`correctly_retained` と `fraction_removed` が正しく計算されるか。
41
+ - サブクレームが全件 remove された場合(`retained_cnt == 0`)に `correctly_retained_percentage = 1` のフォールバックが効くか。
42
+
43
+ 5. **`demo/inference_api.py` — `apply_threshold`**
44
+ - モックの `SubclaimResult` と任意の $\hat{q}$ を渡し、スコアが $\hat{q}$ 以上のサブクレームが `keep_count` として正しく集計され、意図した `FilteredResult` が返るか。
45
+ - > **ノイズについて**: `calibration/utils.py` の `get_r_score` は `score + noise` の空間でスコアを集計し、`q_hat` もその空間で計算される。`noise` は Conformal Prediction のタイブレーカーとして `N(0, 0.001)` で付与される微小乱数(`subclaim_processor.py:150`)。`precompute.py` はサンプルクエリのスコアを `score + noise` で保存するため、`apply_threshold` の `score >= q_hat` 比較は理論的に整合している。ライブ推論(`process_query`)では noise を生成しないため、同一保証は成立しない点に注意。
46
+
47
+ 6. **`demo/precompute.py` — `_entry_to_subclaim_result`**(リグレッションテスト)
48
+ - `subclaims[i]["score"]` が `score + noise` の値になっているか(ノイズ加算の回帰防止)。
49
+ - `noise` キーが欠落した古いデータ形式でもクラッシュしないか(デフォルト `0.0`)。
50
+ - `is_grouped=False` のとき `mode="marginal"`、`group="default"` が設定されるか。
51
+ - `is_grouped=True` のとき `mode="conditional"`、`group` が `entry["groups"][0]` から取得されるか。
52
+
53
+ 7. **`demo/precompute.py` — `_group_data`、`_has_score`、`_pick_entries`**
54
+ - エントリが正しくグループ別に分類されるか。
55
+ - サブクレームに指定スコアキーがない場合に `_has_score` が `False` を返すか。
56
+ - `_pick_entries` がサブクレーム空のエントリをスキップし、指定件数(`N_SAMPLES_PER_DATASET` / `N_SAMPLES_PER_GROUP`)を返すか。
57
+
58
+ ---
59
+
60
+ ## 2. スモークテスト(API 必要)
61
+
62
+ 以降のテストには本物のデータ構造が必要なため、既存のバッチ処理を最小規模で走らせる。既存の `src/` ロジックが壊れていないことも同時に確認する。
63
+
64
+ **実行手順:**
65
+
66
+ ```bash
67
+ # フル��行(全スコア)
68
+ python main.py --dataset medlf_qa --query_size 10
69
+
70
+ # スモークテスト推奨(frequency・doc_claim_cosine_similarity をスキップ)
71
+ python main.py --dataset medlf_qa --query_size 10 --lite
72
+ ```
73
+
74
+ - データセットは `medlf_qa` を使用する(グループ構造を持つ最も複雑なケースのため)。
75
+ - `--query_size 10` で引数上書きするため、`conf/config.yaml` の編集は不要。
76
+ - `2-3` だと 50/50 split(`runs=1000`)に対してデータが少なすぎ、キャリブレーションが意味をなさない。統合テストの動作確認には最低 10 件程度が必要。
77
+
78
+ **API 呼び出し数の目安(`--query_size 10`、平均サブクレーム ~5/クエリ → 計 ~50 サブクレーム):**
79
+
80
+ | ステップ | API | 呼び出し数 |
81
+ |---------|-----|---------|
82
+ | FAISS インデックス構築 | Embeddings | **0**(`index_10.faiss` 既存の場合スキップ) |
83
+ | RAG 回答生成 | Chat (gpt-4.1-mini) | **10** |
84
+ | サブクレーム抽出 | Chat (gpt-4.1-mini) | **10** |
85
+ | `relavance` スコア(claim 埋め込み) | Embeddings | **~50** |
86
+ | `query_claim_cosine_similarity` | Embeddings | **~100**(2 コール × 50) |
87
+ | `doc_claim_cosine_similarity` | Embeddings | **~1,000**(2 × top_k=10 × 50) |
88
+ | `frequency` 代替回答生成(n=5 バッチ) | Chat (gpt-4.1-mini) | **~50 バッチ** |
89
+ | `frequency` カウントプロンプト(× 5) | Chat (gpt-4.1-mini) | **~250** |
90
+ | アノテーション(claim verification) | Chat (gpt-4.1-mini) | **~50** |
91
+ | **合計 Chat** | | **~370 コール** |
92
+ | **合計 Embeddings** | | **~1,150 コール** |
93
+
94
+ `index_store.zip`(著者提供)に `MedLFQA/index_10.faiss` が含まれているため、それを配置すれば Embeddings API によるインデックス構築はスキップできる。
95
+
96
+ `--lite` フラグを付けると以下のスコアをスキップし、API コストを大幅に削減できる。
97
+
98
+ | スキップされるスコア | 理由 |
99
+ |---|---|
100
+ | `frequency` | Chat ~300 コール(総コストの ~80%)、デモ未使用 |
101
+ | `doc_claim_cosine_similarity` | Embeddings ~1,000 コール(~87%)、デモ未使用 |
102
+
103
+ またキャリブレーション・グラフ生成(split conformal / group conditional conformal)もスキップされる。
104
+
105
+ **確認事項:**
106
+
107
+ 1. **E2E パイプラインの疎通**: データ取得・FAISS インデックス生成・LLM 推論・スコア計算がエラーなく完走するか。
108
+ 2. **出力ファイルの検証**: `data/out/` に正しい JSON スキーマのファイルが生成されているか(`subclaims[].scores` に `relavance`・`noise` 等のキーが存在するか)。`--lite` 使用時は `frequency`・`doc_claim_cosine_similarity` キーは生成されず、`relavance`・`noise`・`query_claim_cosine_similarity`・`min_log_prob` は常に計算される。
109
+
110
+ ---
111
+
112
+ ## 3. 統合テスト(API 不要)
113
+
114
+ ステップ 2 で得られた少量の本物データ(`data/out/` 内の JSON)を入力として、デモ用の事前計算スクリプトが正しく動作するかをテストする。
115
+
116
+ **実行手順:**
117
+
118
+ ```bash
119
+ python demo/precompute.py
120
+ ```
121
+
122
+ **確認事項:**
123
+
124
+ 1. **`demo/data/thresholds.csv` の生成**
125
+ - `ALPHA_LEVELS`(例: `0.05, 0.10, ...`)に対応する行が全件生成されているか。
126
+ - 浮動小数点の丸め誤差が発生していないか(`alpha` 列が `0.09999...` ではなく `0.1` になっているか)。
127
+ - キャリブレーションに使われるのがホールドアウト分を除いたデータのみであるか(ログの `calib=N, holdout=M` で確認)。
128
+ - > **注意**: `query_size: 10` 相当の少量データを使っている場合、$\hat{q}$ の値自体は統計的に意味が薄い。ここではスキーマと生成フローの正常動作の確認が目的。本番品質の閾値は `query_size: 500` で再実行する。
129
+
130
+ 2. **`demo/data/samples.json` の生成**
131
+ - メタデータ(生成日時・モデル名・スコアリング手法)が正しく付与されているか。
132
+ - MedLFQA の場合、全グループ(HealthsearchQA / MedicationQA / K-QA_Silver / K-QA_Golden / LiveQA)から少なくとも 1 件ずつ含まれているか。
133
+ - `subclaims[].score` フィールドが `FilteredResult` の lookup キーと一致するか(`constants.py` の `SCORING_METHOD_TO_DATA_KEY` のマッピングが正しいか)。
134
+
135
+ ---
136
+
137
+ ## 4. UI テスト(API 不要)
138
+
139
+ 生成された静的ファイル(CSV/JSON)を用いて、ブラウザ上での UI コンポーネントと状態遷移をテストする。Streamlit の `AppTest` フレームワークを使えば自動テストも可能。
140
+
141
+ **実装ファイル:** `tests/test_app_ui.py`(16 テスト、API 不要)
142
+
143
+ **実行手順:**
144
+
145
+ ```bash
146
+ # コンテナ経由
147
+ bash scripts/run_tests.sh tests/test_app_ui.py -v
148
+
149
+ # ホスト直接実行
150
+ source .env && python -m pytest tests/test_app_ui.py -v
151
+ ```
152
+
153
+ **前提条件:** `demo/data/thresholds.csv` と `demo/data/samples.json` が存在すること(ステップ 3 の実行済み状態)。
154
+
155
+ **確認事項:**
156
+
157
+ 1. **条件付き UI のレンダリング**(`TestConditionalUIRendering`、4 テスト)
158
+ - 「MedLFQA」選択時のみ「Marginal / Conditional」ラジオボタンが活性化するか。
159
+ - 「Conditional」選択時のみグループ選択プルダウンが表示されるか。
160
+ - 検証方法: `at.radio[0].disabled`、`len(at.selectbox)` の件数(Marginal=2, Conditional=3)。
161
+
162
+ 2. **セッションリセットの動作**(`TestSessionStateReset`、4 テスト)
163
+ - データセット・モード・グループを切り替えたとき、`prev_params` の変化検出により `st.session_state` がリセットされ、前回の推論結果が表示されたままにならないか。
164
+ - 変更なしの再実行では `query_processed` と `result` が保持されるか。
165
+
166
+ 3. **キャッシュの動作(`@st.cache_resource`)**(`TestCacheBehavior`、2 テスト)
167
+ - `get_faiss_manager` / `get_scorer` が `dataset` 引数を持ち、データセットごとにキャッシュが分離される設計になっているかをシグネチャで静的検証する。
168
+ - > **注意**: AppTest 環境ではキャッシュ実体の動作検証が難しいため、Streamlit がキャッシュキーに使う関数シグネチャの存在を確認する形にとどめている。
169
+
170
+ 4. **スライダーのインタラクティビティ(最重要)**(`TestSliderInteractivity`、6 テスト)
171
+ - Target Factuality スライダーを動かしたとき、「推論実行」の重い処理が再トリガーされず、`st.session_state` のスコアと CSV の $\hat{q}$ を比較するハイライト再描画(軽い処理)のみが即座に実行されるか。
172
+ - 推論前は disabled、推論後は enabled になるか。
173
+ - Keep / 削除率メトリクスの値が `apply_threshold` の計算結果と一致するか。
174
+ - 検証方法: `process_query` を `MagicMock` で差し替え、スライダー操作後に `assert_not_called()` で確認。
175
+
176
+ **テスト設計上の注意点:**
177
+
178
+ - `samples.json` には `medlf_qa`(conditional)のサンプルのみ含まれるため、結果パネルの描画テストはすべて `medlf_qa / marginal / default` で行い、セッション状態を手動注入する。
179
+ - `st.metric` の値は AppTest では文字列として返るため、数値との比較は `== "1"` のように文字列で行う。
180
+ - `get_faiss_manager` / `get_scorer` / `process_query` は `unittest.mock.patch` で差し替え、API・FAISS・重い推論をすべてモックアウトする。
docs/context/07_huggingface_spaces_deployment.md ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces デプロイ設計
2
+
3
+ ## 1. 目標とスコープ
4
+
5
+ デモアプリ(`demo/app.py`)を Hugging Face Spaces(公開)で動かす。
6
+
7
+ **スコープ: サンプルクエリ専用モード**
8
+ - `demo/data/samples.json` に事前計算済みのクエリのみ動作する
9
+ - 「推論実行」ボタンは FAISS・Wikipedia DB を使わずに事前計算済み結果を即時表示
10
+ - 「回答を生成」ボタン(LLM 再統合)は OpenAI API を呼ぶため、API キーが必要
11
+ - ライブ推論(任意クエリ入力)は行わない
12
+
13
+ **Spaces の制約**
14
+ - Singularity は使えない(Docker / Python 環境)
15
+ - リポジトリに大きなバイナリファイル(FAISS インデックス、SQLite DB)は置けない
16
+ - 環境変数(Secrets)で API キーを設定できる
17
+
18
+ ---
19
+
20
+ ## 2. 現行コードの移植可否分析
21
+
22
+ | 処理 | Spaces での動作 | 対応方針 |
23
+ |------|---------------|---------|
24
+ | `samples.json` / `thresholds.csv` の読み込み | ○ そのまま動く(リポジトリ内) | 変更不要 |
25
+ | 「推論実行」→ 事前計算済み結果の表示 | ○ API 呼び出しなし | 変更不要 |
26
+ | 「回答を生成」(`reintegrate_subclaims`) | ○ OpenAI API のみ | Secrets に `OPENAI_API_KEY` を設定 |
27
+ | `build_faiss_manager()` / `build_scorer()` | △ ファイルが存在しないためエラー | ライブ推論パスに到達しなければ呼ばれない(後述) |
28
+ | `@st.cache_resource` でのリソース初期化 | △ 呼ばれた時点でエラー | 環境フラグで無効化 |
29
+ | `.env` の読み込み | △ Spaces では `.env` ファイルは使わない | Spaces Secrets → 環境変数として自動注入 |
30
+
31
+ ### ライブ推論パスが呼ばれる条件
32
+
33
+ 現在のコードでは、`build_faiss_manager()` / `build_scorer()` はライブ推論パス内でしか呼ばれない。
34
+
35
+ ```python
36
+ # app.py 内の live inference ブランチ(sample_map にないクエリのみ)
37
+ faiss_manager = get_faiss_manager(dataset) ← ここでエラーになる
38
+ scorer = get_scorer(dataset)
39
+ ```
40
+
41
+ サンプルクエリのプルダウンには `samples.json` 内のクエリしか表示されないため、
42
+ 通常操作では **このパスには到達しない**。
43
+
44
+ ただし、予期しないエラー時のフォールバックや将来の機能追加を考慮し、
45
+ **環境変数フラグ `SPACES_DEMO=1`** でライブ推論パスを明示的に無効化する。
46
+
47
+ ---
48
+
49
+ ## 3. 必要な変更
50
+
51
+ ### 3-1. `app.py` の変更(最小限)
52
+
53
+ `SPACES_DEMO=1` が設定されている場合、ライブ推論パスをブロックする。
54
+
55
+ ```python
56
+ import os
57
+
58
+ SPACES_DEMO = os.getenv("SPACES_DEMO", "0") == "1"
59
+ ```
60
+
61
+ 「推論実行」ボタン処理内:
62
+
63
+ ```python
64
+ if run_btn and query_input:
65
+ precomputed = sample_map.get(query_input)
66
+ if precomputed is not None and ...:
67
+ # 事前計算済み → そのまま表示(変更なし)
68
+ ...
69
+ elif SPACES_DEMO:
70
+ st.error("このデモではサンプルクエリのみ対応しています。")
71
+ else:
72
+ # ライブ推論(ローカル環境のみ)
73
+ ...
74
+ ```
75
+
76
+ ### 3-2. Spaces 用設定ファイル
77
+
78
+ Spaces は **リポジトリの `README.md`(frontmatter)** で設定を宣言する。
79
+
80
+ ```yaml
81
+ ---
82
+ title: Response Quality Assessment Demo
83
+ emoji: 📊
84
+ colorFrom: blue
85
+ colorTo: green
86
+ sdk: streamlit
87
+ sdk_version: 1.43.2
88
+ app_file: demo/app.py
89
+ pinned: false
90
+ ---
91
+ ```
92
+
93
+ `app_file` に `demo/app.py` を指定することでリポジトリ構造を変えずに済む。
94
+
95
+ ### 3-3. `requirements.txt` の整理(実装済み)
96
+
97
+ Spaces はリポジトリルートの `requirements.txt` を自動で `pip install` する。
98
+ リポジトリでは以下の構成を採用している:
99
+
100
+ | ファイル | 用途 |
101
+ |----------|------|
102
+ | `requirements.txt` | HF Spaces 向け最小セット。Spaces はこのファイルを自動で読む |
103
+ | `requirements-dev.txt` | ローカル・Singularity 向け全依存。`-r requirements.txt` で共通部分を継承 |
104
+
105
+ **`requirements.txt`(Spaces 向け最小セット、§4-1 調査結果に基づく):**
106
+
107
+ ```
108
+ openai>=2.0
109
+ python-dotenv>=1.0
110
+ numpy>=1.24
111
+ pandas>=2.0
112
+ pyyaml>=6.0
113
+ streamlit>=1.43
114
+ ```
115
+
116
+ `faiss-cpu` / `torch` / `transformers` / `sentence-transformers` / `langchain-core` /
117
+ `scikit-learn` はいずれも不要(遅延 import により Spaces モードでは読み込まれない)。
118
+
119
+ ローカル開発時は `pip install -r requirements-dev.txt` を使う。
120
+
121
+ ### 3-4. Secrets の設定
122
+
123
+ Spaces の **Settings > Variables and secrets** に以下を追加する:
124
+
125
+ | キー | 値 | 用途 |
126
+ |------|----|------|
127
+ | `OPENAI_API_KEY` | `sk-...` | 「回答を生成」ボタン(`reintegrate_subclaims`) |
128
+ | `SPACES_DEMO` | `1` | ライブ推論パスを無効化 |
129
+
130
+ `.env` ファイルは Spaces 環境では使わない(`python-dotenv` の `load_dotenv()` は
131
+ 環境変数が既にセットされていれば上書きしないため、ローカルとの互換性は保たれる)。
132
+
133
+ ---
134
+
135
+ ## 4. 事前調査結果
136
+
137
+ ### 4-1. `src/` の import チェーンの影響調査(済)
138
+
139
+ `sys.modules` を import 前後で比較し、重い依存パッケージの混入を確認した。
140
+
141
+ **原因**: `src.common.faiss_manager` → `src.common.file_manager` → `langchain_text_splitters`
142
+ が `torch` / `transformers` / `sentence_transformers` 等を連鎖的に引き込む。
143
+ `src.subclaim_processor.scorer.subclaim_scorer` も同様。
144
+
145
+ **対処**: `FAISSIndexManager` / `SubclaimScorer` の import を `build_faiss_manager()` /
146
+ `build_scorer()` 関数内に移動(遅延 import)。型ヒントは `TYPE_CHECKING` ガードで維持。
147
+ これにより、サンプルクエリ専用モードでは import 時に重い依存が読み込まれない。
148
+
149
+ **Spaces 向け最小パッケージセット**(`build_*()` が呼ばれない前提):
150
+
151
+ ```
152
+ openai
153
+ python-dotenv
154
+ numpy
155
+ pandas
156
+ pyyaml
157
+ langchain-core
158
+ streamlit>=1.43
159
+ ```
160
+
161
+ `faiss-cpu` / `torch` / `transformers` / `sentence-transformers` / `scipy` /
162
+ `scikit-learn` は不要。
163
+
164
+ ### 4-2. `DATA_ROOT` / config パスの解決(済)
165
+
166
+ `_load_main_config()` / `_load_dataset_config()` は関数呼び出し時に初めて実行される
167
+ (モジュールレベルでは実行されない)ため、`DATA_ROOT` 未設定でも import は成功する。
168
+ サンプルクエリ専用モードではこれらの関数は呼ばれないため問題なし。
169
+
170
+ ---
171
+
172
+ ## 5. デプロイ手順
173
+
174
+ 1. **HF Space の作成**
175
+ - `huggingface.co/spaces/<username>/<space-name>` を新規作成(SDK: Streamlit)
176
+
177
+ 2. **`spaces` remote を追加**
178
+ ```bash
179
+ git remote add spaces https://huggingface.co/spaces/EQUES/Response-Quality-Assessment
180
+ ```
181
+ 認証はリモート URL にトークンを埋め込む方法を使う(`.git/config` は `.gitignore` 対象外だが
182
+ `git` の管理ファイルであり GitHub にはプッシュされない):
183
+ ```bash
184
+ git remote set-url spaces https://<username>:<hf_token>@huggingface.co/spaces/EQUES/Response-Quality-Assessment
185
+ ```
186
+
187
+ 3. **Orphan ブランチで push**
188
+
189
+ HF Spaces はプッシュ時に **全コミット履歴** をスキャンし、10 MiB 超のファイルを拒否する。
190
+ `feature/hf-spaces` の祖先コミットに大きなデータファイル(`data/out/Medlfqav2/` 等)が含まれるため、
191
+ 履歴なしの Orphan ブランチを一時作成してから push する:
192
+
193
+ ```bash
194
+ git checkout --orphan spaces-deploy
195
+ git add -A
196
+ git rm --cached data/out/ data/raw/ -r --ignore-unmatch
197
+ git commit -m "deploy: initial push to HF Spaces"
198
+ git push --force spaces spaces-deploy:main
199
+ git checkout feature/hf-spaces
200
+ git branch -D spaces-deploy
201
+ ```
202
+
203
+ > **なぜ Orphan か**: 通常の push では過去コミットごと送られる。Orphan ブランチは親コミットが
204
+ > 存在しない「起点」なので、HF Spaces には現在のスナップショット 1 コミットだけが届く。
205
+ > ローカルと GitHub の履歴には一切影響しない。
206
+
207
+ 4. **Secrets の設定**
208
+ - Space の **Settings > Variables and secrets** に以下を追加:
209
+
210
+ | キー | 値 |
211
+ |------|----|
212
+ | `OPENAI_API_KEY` | `sk-...` |
213
+ | `SPACES_DEMO` | `1` |
214
+
215
+ 5. **動作確認**
216
+ - サンプルクエリの表示・スライダー操作
217
+ - 「回答を生成」ボタン(OpenAI API 呼び出し)
218
+
219
+ ### 再デプロイ(コード変更後)
220
+
221
+ コードを変更したら同じ Orphan 手順を繰り返す。`spaces` remote は設定済みのため手順 2 は不要。
222
+
223
+ ---
224
+
225
+ ## 6. 既知の問題・注意点
226
+
227
+ ### MedLFQA Marginal モードのサンプル一致
228
+
229
+ `precompute.py` は grouped データセット(medlf_qa)のサンプルを全て `mode="conditional"` で
230
+ 生成するため、`samples.json` に `mode="marginal"` エントリが存在しない。
231
+
232
+ `app.py` では Marginal モード選択時にモード不一致でサンプルが見つからない問題を回避するため、
233
+ Marginal モードではサンプルの `mode`/`group` フィールドを無視してマッチさせる:
234
+
235
+ ```python
236
+ if (
237
+ precomputed is not None
238
+ and (
239
+ mode == "marginal"
240
+ or (precomputed["mode"] == mode and precomputed["group"] == group)
241
+ )
242
+ ):
243
+ ```
244
+
245
+ サブクレームスコア自体はモードに依存しないため、閾値(`_lookup_q_hat` が marginal 用 q_hat を
246
+ 参照)との比較は正しく行われる。
247
+
248
+ ---
249
+
250
+ ## 7. 開発ステップ
251
+
252
+ | ステップ | 状態 | 担当ファイル |
253
+ |----------|------|------------|
254
+ | ① import 影響調査 | 完了 | — |
255
+ | ② requirements 整理 | 完了 | `requirements.txt`, `requirements-dev.txt` |
256
+ | ③ `app.py` にフラグ追加 | 完了 | `demo/app.py` |
257
+ | ④ `README.md` の作成 | 完了 | `README.md` |
258
+ | ⑤ Spaces へ push | 完了 | — |
259
+ | ⑥ Marginal モードのバグ修正 | 完了 | `demo/app.py` |
main.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import logging
5
+ import yaml
6
+ from pathlib import Path
7
+
8
+ from src.common.config_manager import ConfigManager
9
+ from src.dataloader.dataloader import DataLoader
10
+ from src.data_processor.query_processor import QueryProcessor
11
+ from src.common.file_manager import FileManager
12
+ from src.common.faiss_manager import FAISSIndexManager
13
+ from src.subclaim_processor.scorer.subclaim_scorer import SubclaimScorer
14
+ from src.subclaim_processor.subclaim_processor import process_subclaims
15
+ from src.calibration.conformal import SplitConformalCalibration
16
+ from src.calibration.conditional_conformal import GroupConditionalConformal
17
+
18
+
19
+ def parse_args(dataset_aliases):
20
+ """Parse command line arguments"""
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--config",
24
+ type=str,
25
+ default="conf/config.yaml",
26
+ help="Path to configuration file",
27
+ )
28
+ parser.add_argument(
29
+ "--dataset",
30
+ type=str,
31
+ help="Override dataset name from config",
32
+ choices=dataset_aliases,
33
+ )
34
+ parser.add_argument(
35
+ "--query_size", type=int, default=500, help="Override query size from config"
36
+ )
37
+ parser.add_argument("--run_id", type=str, help="Custom run identifier")
38
+ parser.add_argument(
39
+ "--lite",
40
+ action="store_true",
41
+ help="Skip frequency and doc_claim_cosine_similarity scoring (demo-unused) and conformal prediction plots",
42
+ )
43
+ parser.add_argument(
44
+ "--log_level",
45
+ type=str,
46
+ default="INFO",
47
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
48
+ help="Logging level. Use DEBUG to record individual API calls (default: INFO)",
49
+ )
50
+ return parser.parse_args()
51
+
52
+
53
+ def main():
54
+ avaliable_datasets = []
55
+ with open("conf/dataset_config.yaml", 'r') as f:
56
+ dataset_config = yaml.safe_load(f)
57
+ avaliable_datasets = list(dataset_config["datasets"].keys())
58
+ # Parse arguments
59
+ args = parse_args(avaliable_datasets)
60
+
61
+ # Initialize config manager
62
+ config_manager = ConfigManager(
63
+ config_path=args.config,
64
+ path_config_path="conf/path_config.yaml",
65
+ dataset_config_path="conf/dataset_config.yaml",
66
+ run_id=args.run_id,
67
+ )
68
+
69
+ dataset_aliases = list(dataset_config["datasets"].keys())
70
+
71
+ # Setup logging
72
+ log_file, run_id = config_manager.setup_logging(
73
+ log_level=getattr(logging, args.log_level)
74
+ )
75
+
76
+ # Update config with command line arguments if provided
77
+ if args.dataset or args.query_size:
78
+ updates = {"dataset": {}}
79
+ if args.dataset:
80
+ updates["dataset"]["name"] = args.dataset
81
+ if args.query_size:
82
+ updates["dataset"]["query_size"] = args.query_size
83
+ config_manager.update_config(updates)
84
+
85
+ lite = args.lite
86
+
87
+ # Save updated config
88
+ config_file = config_manager.save_config()
89
+ logging.info(f"Configuration saved to: {config_file}")
90
+
91
+ # Log important config values
92
+ config_manager.log_config()
93
+
94
+ # Get the config
95
+ config = config_manager.config
96
+ path_config = config_manager.path_config
97
+ dataset_config = config_manager.dataset_config
98
+
99
+ ####################################### Data and Folder Set up ############################################
100
+ dataset_name = config["dataset"]["name"]
101
+ query_size = config["dataset"]["query_size"]
102
+ wiki_db_file = config["dataset"]["wiki_db_file"]
103
+
104
+ delete_existing_index = config["index"]["delete_existing"]
105
+ embedding_model = config["index"]["embedding_model"]
106
+ index_truncation_config = config["index"]["truncation_config"]
107
+ truncation_strategy = index_truncation_config["strategy"]
108
+ truncate_by = index_truncation_config["truncate_by"]
109
+
110
+ response_model = config["rag"]["response_model"]
111
+ frequency_score_model = config["conformal_prediction"]["frequency_score_model"]
112
+
113
+ alpha_config = config["conformal_prediction"]["conformal_alphas"]
114
+ conformal_alphas = np.arange(
115
+ alpha_config["start"], alpha_config["end"], alpha_config["step"]
116
+ )
117
+ a_value = config["conformal_prediction"]["a_value"]
118
+
119
+ dataset_custom_config = dataset_config["datasets"].get(dataset_name)
120
+ if not dataset_custom_config:
121
+ raise ValueError(f"Unknown dataset: {dataset_name}")
122
+ full_dataset_name = dataset_custom_config["name"]
123
+ index_store_dir = dataset_custom_config["index_store"]
124
+ group_conditional_conformal = dataset_custom_config.get("is_grouped", False)
125
+
126
+ raw_data_dir = os.path.join(path_config["paths"]["raw_data_dir"], full_dataset_name)
127
+ processed_data_dir = os.path.join(
128
+ path_config["paths"]["processed_data_dir"], full_dataset_name
129
+ )
130
+ response_dir = os.path.join(path_config["paths"]["response_dir"], full_dataset_name)
131
+ wiki_db_path = os.path.join(path_config["paths"]["wiki_db_dir"], wiki_db_file)
132
+ result_dir = os.path.join(
133
+ path_config["paths"]["result_dir"], full_dataset_name, run_id
134
+ )
135
+
136
+ # set up directories
137
+ for dir_path in [raw_data_dir, processed_data_dir, response_dir, result_dir]:
138
+ os.makedirs(dir_path, exist_ok=True)
139
+ logging.info(f"Directory ensured: {dir_path}")
140
+
141
+ # Determine raw data file path
142
+ if dataset_name == "medlf_qa":
143
+ input_file = os.path.join(path_config["paths"]["raw_data_dir"], "MedLFQA")
144
+ raw_data_path = input_file
145
+ else:
146
+ raw_data_file = f"raw_{dataset_name}.json"
147
+ raw_data_path = os.path.join(raw_data_dir, raw_data_file)
148
+
149
+ logging.info(f"Raw data path: {raw_data_path}")
150
+
151
+ # Load data if needed
152
+ if not os.path.exists(raw_data_path):
153
+ logging.info(f"Raw data not found. Loading data for {dataset_name}")
154
+ data_loader = DataLoader(dataset_name)
155
+ data_loader.load_qa_data(output_path=raw_data_path)
156
+ logging.info(f"Data loaded and saved to {raw_data_path}")
157
+
158
+ # create wiki db if needed
159
+ if not os.path.exists(wiki_db_path) or not os.path.isfile(wiki_db_path):
160
+ wiki_source = os.path.join(
161
+ path_config["paths"]["wiki_db_dir"],
162
+ "enwiki-20171001-pages-meta-current-withlinks-abstracts",
163
+ )
164
+ if not os.path.exists(wiki_source):
165
+ raise FileNotFoundError(f"Wiki source data not found at {wiki_source}")
166
+ logging.info(f"Wiki DB not found. Creating from source {wiki_source}")
167
+ data_loader = DataLoader(dataset_name)
168
+ data_loader.create_wiki_db(source_path=wiki_source, output_path=wiki_db_path)
169
+ logging.info(f"Wiki DB created at {wiki_db_path}")
170
+
171
+ # Process queries and documents
172
+ input_file = raw_data_path
173
+ if dataset_name == "medlf_qa":
174
+ input_file = os.path.join(path_config["paths"]["raw_data_dir"], "MedLFQA")
175
+
176
+ query_output_file = f"{dataset_name}_queries.json"
177
+ document_output_file = f"{dataset_name}_documents.txt"
178
+
179
+ subclaims_path = os.path.join(
180
+ response_dir,
181
+ f"{dataset_name}_{query_size}_subclaims_with_scores_{response_model}.json",
182
+ )
183
+ CP_result_fig_path = os.path.join(
184
+ result_dir, f"{dataset_name}_{query_size}_a={a_value:.2f}_CP_removal.png"
185
+ )
186
+ GCP_result_fig_path = os.path.join(
187
+ result_dir, f"{dataset_name}_{query_size}_a={a_value:.2f}_GCP_removal.png"
188
+ )
189
+ factual_result_fig_path = os.path.join(
190
+ result_dir,
191
+ f"{dataset_name}_{query_size}_a={a_value:.2f}_factual_correctness.png",
192
+ )
193
+ group_factual_result_fig_path = os.path.join(
194
+ result_dir,
195
+ f"group_{dataset_name}_{query_size}_a={a_value:.2f}_factual_correctness.png",
196
+ )
197
+ result_path = os.path.join(
198
+ result_dir, f"{dataset_name}_{query_size}_a={a_value:.2f}.csv"
199
+ )
200
+ group_result_path = os.path.join(
201
+ result_dir, f"group_{dataset_name}_{query_size}_a={a_value:.2f}.csv"
202
+ )
203
+ ####################################### End of Data and Folder Set up ######################################
204
+
205
+ # Create QueryProcessor
206
+ logging.info("Initializing QueryProcessor")
207
+ query_processor = QueryProcessor(db_path=wiki_db_path, query_size=query_size)
208
+
209
+ # Create queries data
210
+ logging.info("Processing queries")
211
+ queries, query_path = query_processor.get_queries(
212
+ dataset=dataset_name,
213
+ input_file=input_file,
214
+ output_dir=processed_data_dir,
215
+ output_file=query_output_file,
216
+ )
217
+ logging.info(f"Query size: {len(queries)}")
218
+
219
+ # Create documents data
220
+ logging.info("Processing documents")
221
+ document_path = query_processor.get_documents(
222
+ query_dir=query_path,
223
+ output_dir=processed_data_dir,
224
+ output_file=document_output_file,
225
+ )
226
+ logging.info(f"Documents saved to {document_path}")
227
+
228
+ # Index creation and retrieval
229
+ os.makedirs(index_store_dir, exist_ok=True)
230
+ index_file_path = os.path.join(index_store_dir, f"index_{query_size}.faiss")
231
+ indice2fm_path = os.path.join(index_store_dir, f"indice2fm_{query_size}.json")
232
+
233
+ logging.info(f"Setting up FAISS index manager")
234
+ faiss_manager = FAISSIndexManager(
235
+ index_truncation_config=index_truncation_config,
236
+ index_path=index_file_path,
237
+ indice2fm_path=indice2fm_path,
238
+ )
239
+
240
+ if delete_existing_index:
241
+ logging.info("Deleting existing index as requested")
242
+ faiss_manager.delete_index()
243
+
244
+ # Create index if it does not exist
245
+ document_file = FileManager(
246
+ document_path, index_truncation_config=index_truncation_config
247
+ )
248
+
249
+ logging.info(
250
+ f"Using truncation strategy: {truncation_strategy}, truncate_by: {truncate_by}"
251
+ )
252
+
253
+ # If Index doesn't exist yet
254
+ if not os.path.exists(index_file_path):
255
+ try:
256
+ logging.info(f"Creating new index with document '{document_path}'")
257
+ faiss_manager.upsert_file_to_faiss(
258
+ document_file,
259
+ truncation_strategy=truncation_strategy,
260
+ truncate_by=truncate_by,
261
+ )
262
+ logging.info("Index created successfully")
263
+ except Exception as e:
264
+ error_msg = f"Failed to create new index: {str(e)}"
265
+ logging.error(error_msg)
266
+ raise RuntimeError(error_msg)
267
+
268
+ # If Index exists but current document isn't indexed
269
+ elif document_path not in faiss_manager.indice2fm:
270
+ # Verify index integrity
271
+ logging.info("Checking index integrity")
272
+ if not faiss_manager.is_indice_align():
273
+ error_msg = "Index corruption detected: index and indice2fm are not aligned"
274
+ logging.error(error_msg)
275
+ raise ValueError(error_msg)
276
+
277
+ try:
278
+ logging.info(f"Adding document '{document_path}' to existing index")
279
+ faiss_manager.upsert_file_to_faiss(
280
+ document_file,
281
+ truncation_strategy=truncation_strategy,
282
+ truncate_by=truncate_by,
283
+ )
284
+ logging.info("Document added to index successfully")
285
+ except Exception as e:
286
+ error_msg = f"Failed to add document to index: {str(e)}"
287
+ logging.error(error_msg)
288
+ raise RuntimeError(error_msg)
289
+
290
+ # Case 3: Document is already indexed
291
+ else:
292
+ logging.info(f"Document '{document_path}' is already indexed")
293
+
294
+ # generate subclaims with scores
295
+ logging.info(f"Initializing SubclaimScorer with embedding model {embedding_model}")
296
+ scorer = SubclaimScorer(
297
+ index_truncation_config=index_truncation_config,
298
+ embedding_model=embedding_model,
299
+ index_path=index_file_path,
300
+ indice2fm_path=indice2fm_path,
301
+ frequency_score_model=frequency_score_model,
302
+ )
303
+
304
+ logging.info(f"Processing subclaims and generating scores")
305
+ subclaim_with_annotation_data = process_subclaims(
306
+ query_path=query_path,
307
+ subclaims_path=subclaims_path,
308
+ faiss_manager=faiss_manager,
309
+ scorer=scorer,
310
+ config=config,
311
+ lite=lite,
312
+ )
313
+ logging.info(f"Subclaims processed and saved to {subclaims_path}")
314
+
315
+ # calibration and conformal prediction results
316
+ if not lite and config["conformal_prediction"]["split_conformal"]:
317
+ logging.info("Running split conformal prediction")
318
+ conformal = SplitConformalCalibration(dataset_name=dataset_name)
319
+ logging.info(
320
+ f"Plotting conformal removal with alphas: {conformal_alphas}, a={a_value}"
321
+ )
322
+ conformal.plot_conformal_removal(
323
+ data=subclaim_with_annotation_data,
324
+ alphas=conformal_alphas,
325
+ a=a_value,
326
+ fig_filename=CP_result_fig_path,
327
+ csv_filename=result_path,
328
+ )
329
+ logging.info(f"CP removal plot saved to {CP_result_fig_path}")
330
+
331
+ logging.info("Plotting factual removal")
332
+ conformal.plot_factual_removal(
333
+ data=subclaim_with_annotation_data,
334
+ alphas=conformal_alphas,
335
+ a=a_value,
336
+ fig_filename=factual_result_fig_path,
337
+ csv_filename=result_path,
338
+ )
339
+ logging.info(f"Factual removal plot saved to {factual_result_fig_path}")
340
+ logging.info(f"Results saved to {result_path}")
341
+
342
+ if not lite and group_conditional_conformal:
343
+ logging.info("Running group conditional conformal prediction")
344
+ conformal = GroupConditionalConformal(dataset_name=dataset_name, result_dir=result_dir)
345
+ logging.info(
346
+ f"Plotting conformal removal with alphas: {conformal_alphas}, a={a_value}"
347
+ )
348
+ conformal.plot_conformal_removal(
349
+ data=subclaim_with_annotation_data,
350
+ alphas=conformal_alphas,
351
+ a=a_value,
352
+ fig_filename=GCP_result_fig_path,
353
+ csv_filename=group_result_path,
354
+ )
355
+ logging.info(f"CP removal plot saved to {GCP_result_fig_path}")
356
+
357
+ logging.info("Plotting factual removal")
358
+ conformal.plot_factual_removal(
359
+ data=subclaim_with_annotation_data,
360
+ alphas=conformal_alphas,
361
+ a=a_value,
362
+ fig_filename=group_factual_result_fig_path,
363
+ csv_filename=group_result_path,
364
+ )
365
+ logging.info(f"Factual removal plot saved to {factual_result_fig_path}")
366
+ logging.info(f"Results saved to {result_path}")
367
+
368
+ # Copy config and log files to result directory for reproducibility
369
+ result_run_dir = config_manager.copy_run_artifacts(result_dir)
370
+ logging.info(
371
+ f"Run completed successfully. Results and logs saved to {result_run_dir}"
372
+ )
373
+
374
+
375
+ if __name__ == "__main__":
376
+ main()
requirements-dev.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -r requirements.txt
2
+ faiss-cpu
3
+ torch
4
+ transformers
5
+ sentence-transformers
6
+ huggingface-hub
7
+ datasets
8
+ langchain
9
+ scikit-learn
10
+ PyPDF2
11
+ jsonschema
12
+ matplotlib
13
+ tqdm
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal requirements for HF Spaces (SPACES_DEMO=1, sample queries only).
2
+ # FAISS / torch / transformers are NOT needed because FAISSIndexManager and
3
+ # SubclaimScorer are lazily imported and never called in this mode.
4
+ #
5
+ # Usage: rename or symlink this file to requirements.txt in the Space repo,
6
+ # or specify it via the Spaces build configuration.
7
+
8
+ openai>=2.0
9
+ python-dotenv>=1.0
10
+ numpy>=1.24
11
+ pandas>=2.0
12
+ pyyaml>=6.0
13
+ streamlit>=1.43
scripts/build_sif.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Build the Singularity SIF image from singularity/response_quality.def.
3
+ # Usage: bash scripts/build_sif.sh
4
+ set -euo pipefail
5
+
6
+ REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
7
+ source "${REPO_ROOT}/.env"
8
+
9
+ mkdir -p "${SIF_DIR}"
10
+
11
+ singularity build --fakeroot --force \
12
+ "${SIF_DIR}/response_quality.sif" \
13
+ "${REPO_ROOT}/singularity/response_quality.def"
14
+
15
+ echo "Built: ${SIF_DIR}/response_quality.sif"
scripts/run_demo.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Launch the Streamlit demo inside the Singularity container.
3
+ # Usage: bash scripts/run_demo.sh [--port PORT]
4
+ # Access: http://localhost:<PORT> (default: 8502)
5
+ set -euo pipefail
6
+
7
+ REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
8
+ source "${REPO_ROOT}/.env"
9
+
10
+ PORT=8502
11
+ while [[ $# -gt 0 ]]; do
12
+ case "$1" in
13
+ --port) PORT="$2"; shift 2 ;;
14
+ *) echo "Unknown option: $1" >&2; exit 1 ;;
15
+ esac
16
+ done
17
+
18
+ singularity run \
19
+ --bind "${DATA_ROOT}/data:${REPO_ROOT}/data" \
20
+ --bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store" \
21
+ --bind "${HF_HOME}:${HF_HOME}" \
22
+ --env HF_HOME="${HF_HOME}" \
23
+ --env OPENAI_API_KEY="${OPENAI_API_KEY}" \
24
+ --env DATA_ROOT="${REPO_ROOT}" \
25
+ --env PYTHONPATH="${REPO_ROOT}" \
26
+ --pwd "${REPO_ROOT}" \
27
+ "${SIF_DIR}/response_quality.sif" \
28
+ streamlit run demo/app.py --server.port "${PORT}"
scripts/run_main.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Run main.py inside the Singularity container.
3
+ # All arguments are forwarded to main.py as-is.
4
+ #
5
+ # Usage:
6
+ # bash scripts/run_main.sh --dataset hotpot_qa
7
+ # bash scripts/run_main.sh --dataset hotpot_qa --query_size 200
8
+ set -euo pipefail
9
+
10
+ REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
11
+ source "${REPO_ROOT}/.env"
12
+
13
+ singularity exec \
14
+ --bind "${DATA_ROOT}/data:${REPO_ROOT}/data" \
15
+ --bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store" \
16
+ --bind "${HF_HOME}:${HF_HOME}" \
17
+ --env HF_HOME="${HF_HOME}" \
18
+ --env OPENAI_API_KEY="${OPENAI_API_KEY}" \
19
+ --env DATA_ROOT="${REPO_ROOT}" \
20
+ --pwd "${REPO_ROOT}" \
21
+ "${SIF_DIR}/response_quality.sif" \
22
+ python main.py "$@"
scripts/run_precompute.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Run demo/precompute.py inside the Singularity container.
3
+ # Usage: bash scripts/run_precompute.sh
4
+ set -euo pipefail
5
+
6
+ REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
7
+ source "${REPO_ROOT}/.env"
8
+
9
+ singularity exec \
10
+ --bind "${DATA_ROOT}/data:${REPO_ROOT}/data" \
11
+ --bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store" \
12
+ --bind "${HF_HOME}:${HF_HOME}" \
13
+ --env HF_HOME="${HF_HOME}" \
14
+ --env OPENAI_API_KEY="${OPENAI_API_KEY}" \
15
+ --env DATA_ROOT="${REPO_ROOT}" \
16
+ --pwd "${REPO_ROOT}" \
17
+ "${SIF_DIR}/response_quality.sif" \
18
+ python -m demo.precompute
19
+
20
+ echo ""
21
+ echo "Generated:"
22
+ echo " ${REPO_ROOT}/demo/data/thresholds.csv"
23
+ echo " ${REPO_ROOT}/demo/data/samples.json"
24
+ echo ""
25
+ echo "Commit these files: git add demo/data/ && git commit"
scripts/run_tests.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Run pytest inside the Singularity container.
3
+ # All arguments are forwarded to pytest as-is.
4
+ #
5
+ # Prerequisites: .env must define SIF_DIR (and optionally DATA_ROOT, HF_HOME).
6
+ # See docs/context/04_environment_setup.md and .env.example.
7
+ #
8
+ # Usage:
9
+ # bash scripts/run_tests.sh # run all tests
10
+ # bash scripts/run_tests.sh tests/test_calibration.py -v
11
+ set -euo pipefail
12
+
13
+ REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
14
+ source "${REPO_ROOT}/.env"
15
+
16
+ if [[ -z "${SIF_DIR:-}" ]]; then
17
+ echo "ERROR: SIF_DIR is not set. Fill in .env (see .env.example)." >&2
18
+ exit 1
19
+ fi
20
+
21
+ # Build the bind list only for paths that are set.
22
+ BINDS=()
23
+ if [[ -n "${DATA_ROOT:-}" ]]; then
24
+ BINDS+=(--bind "${DATA_ROOT}/data:${REPO_ROOT}/data")
25
+ BINDS+=(--bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store")
26
+ fi
27
+ if [[ -n "${HF_HOME:-}" ]]; then
28
+ BINDS+=(--bind "${HF_HOME}:${HF_HOME}")
29
+ fi
30
+
31
+ singularity exec \
32
+ "${BINDS[@]}" \
33
+ --env HF_HOME="${HF_HOME:-}" \
34
+ --env DATA_ROOT="${REPO_ROOT}" \
35
+ --pwd "${REPO_ROOT}" \
36
+ "${SIF_DIR}/response_quality.sif" \
37
+ python -m pytest "$@"
singularity/response_quality.def ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Bootstrap: docker
2
+ From: python:3.11-slim
3
+
4
+ %post
5
+ apt-get update && apt-get install -y --no-install-recommends \
6
+ build-essential \
7
+ git \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Install torch (CPU-only) explicitly before other packages to avoid
11
+ # pulling in the full CUDA build as a transitive dependency.
12
+ pip install --no-cache-dir \
13
+ torch --index-url https://download.pytorch.org/whl/cpu
14
+
15
+ pip install --no-cache-dir \
16
+ streamlit \
17
+ openai \
18
+ faiss-cpu \
19
+ numpy \
20
+ pandas \
21
+ pyyaml \
22
+ python-dotenv \
23
+ transformers \
24
+ sentence-transformers \
25
+ huggingface-hub \
26
+ datasets \
27
+ langchain \
28
+ scikit-learn \
29
+ PyPDF2 \
30
+ jsonschema \
31
+ matplotlib \
32
+ tqdm \
33
+ langchain-text-splitters \
34
+ pytest
35
+
36
+ %environment
37
+ export PYTHONUNBUFFERED=1
38
+ export PYTHONDONTWRITEBYTECODE=1
39
+
40
+ %runscript
41
+ exec "$@"
42
+
43
+ %labels
44
+ Author ryoya.awano
45
+ Description "Response Quality Assessment demo (Conformal RAG)"
src/calibration/__init__.py ADDED
File without changes
src/calibration/base_calibration.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class ICalibration(ABC):
5
+ """
6
+ Interface for calibration methods.
7
+ """
8
+
9
+ @abstractmethod
10
+ def plot_conformal_removal(
11
+ self, data, alphas, a, fig_filename, csv_filename, plot_group_results=False
12
+ ):
13
+ pass
14
+
15
+ @abstractmethod
16
+ def plot_factual_removal(
17
+ self, data, alphas, a, fig_filename, csv_filename, plot_group_results=False
18
+ ):
19
+ pass
src/calibration/conditional_conformal.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import random
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from tqdm import tqdm
7
+ from collections import defaultdict
8
+
9
+
10
+ from src.calibration.conformal import SplitConformalCalibration
11
+ from src.calibration.utils import compute_threshold
12
+ from src.calibration.utils import append_result_to_csv
13
+ from src.calibration.utils import split_group
14
+
15
+
16
+ CORRECT_ANNOTATIONS = ["S"]
17
+
18
+
19
+ class GroupConditionalConformal(SplitConformalCalibration):
20
+ def __init__(self, dataset_name: str, result_dir: str, runs: int = 1000):
21
+ super().__init__(dataset_name, runs)
22
+ self.result_dir = result_dir
23
+
24
+
25
+ def compute_conformal_results(
26
+ self, data: list, alphas: np.ndarray, a: float, plot_group_results: bool = False
27
+ ):
28
+
29
+ results = {}
30
+ for confidence_method in self.confidence_method:
31
+ results[confidence_method] = {}
32
+ for alpha in tqdm(
33
+ alphas, desc=f"Computing conformal results for {confidence_method}"
34
+ ):
35
+ # TODO add grouping
36
+ groups = None
37
+
38
+
39
+ thresholds_result = []
40
+ correctness_list = []
41
+ fraction_removed_list = []
42
+ test_data = []
43
+ for _ in range(self.runs):
44
+ random.shuffle(data)
45
+ calibration_data, test_data = split_group(data)
46
+ groups = list(calibration_data.keys())
47
+
48
+ assert (
49
+ len(calibration_data) != 0
50
+ ), "Calibration data should not be empty"
51
+ assert len(test_data) != 0, "Test data should not be empty"
52
+ thresholds = self._compute_threshold_by_group(
53
+ alpha, calibration_data, a, confidence_method, groups=groups
54
+ )
55
+
56
+ correctness, fraction_removed = (
57
+ self._evaluate_conformal_correctness(
58
+ test_data, thresholds, a, confidence_method
59
+ )
60
+ )
61
+ thresholds_result.append(thresholds)
62
+ correctness_list.append(correctness)
63
+ fraction_removed_list.append(fraction_removed)
64
+
65
+ results[confidence_method][alpha] = {
66
+ "threshold": thresholds_result,
67
+ "correctness": correctness_list,
68
+ "fraction_removed": fraction_removed_list,
69
+ }
70
+
71
+ return results
72
+
73
+ def _compute_threshold_by_group(
74
+ self,
75
+ alpha: float,
76
+ calibration_data: list,
77
+ a: float,
78
+ confidence_method: str,
79
+ groups: list | None = None,
80
+ ):
81
+ thresholds = {}
82
+ for group in groups:
83
+ group_data = calibration_data[group]
84
+ thresholds[group] = compute_threshold(
85
+ alpha, group_data, a, confidence_method
86
+ )
87
+ return thresholds
88
+
89
+ def _evaluate_conformal_correctness(
90
+ self, data: list, thresholds: dict, a: float, confidence_method: str
91
+ ):
92
+ """
93
+ Evaluates the performance of a conformal prediction model on test data.
94
+ Parameters:
95
+ data (list): A list of dictionaries, where each dictionary represents an entry with subclaims.
96
+ threshold (float): The similarity score threshold to determine if a subclaim is correctly retained.
97
+ a (float): The threshold for the correctly retained percentage to consider an entry as correctly retained.
98
+ Returns:
99
+ tuple: A tuple containing two lists:
100
+ - correctly_retained (float): Percentage of data that are correctly retained.
101
+ - fraction_removed (float): Percentage of subclaims removed for each entry.
102
+ """
103
+
104
+ correctly_retained = []
105
+
106
+ fraction_removed = []
107
+
108
+ for entry in data:
109
+ removal_count = 0
110
+ retained_cnt = 0
111
+ correctly_retained_count = 0
112
+ threshold = thresholds[entry["groups"][0]] # Get threshold for the group
113
+ if threshold is None:
114
+ raise ValueError(
115
+ f"Threshold for group {entry['groups'][0]} is None. Check your calibration data."
116
+ )
117
+ for subclaim in entry["subclaims"]:
118
+ # Find similarity score
119
+ score = subclaim["scores"][confidence_method]
120
+ noise = subclaim["scores"]["noise"]
121
+ if score + noise >= threshold:
122
+ retained_cnt += 1
123
+ if (
124
+ subclaim.get("annotations", {}).get("gpt", "")
125
+ in CORRECT_ANNOTATIONS
126
+ ):
127
+ correctly_retained_count += 1
128
+
129
+ else:
130
+ removal_count += 1
131
+
132
+ total_subclaims = len(entry["subclaims"])
133
+
134
+ # Calculate fraction of removed subclaims
135
+ entry_removal_rate = (
136
+ 0 if total_subclaims == 0 else removal_count / total_subclaims
137
+ )
138
+ fraction_removed.append(
139
+ entry_removal_rate
140
+ ) # e.g. fraction_removed = [0.2, 0.5, 0.6, 0.2, 0.7] - one element per data entry
141
+
142
+ # Calculate correctly retained rate
143
+ correctly_retained_percentage = (
144
+ correctly_retained_count / retained_cnt if retained_cnt > 0 else 1
145
+ )
146
+ correctly_retained.append(correctly_retained_percentage >= a)
147
+
148
+ return np.mean(correctly_retained), np.mean(fraction_removed)
149
+
150
+ def compute_factual_results(self, data, alphas, a, calibrate_range=0.5):
151
+ overall_results = {}
152
+ per_group_results = defaultdict(lambda: defaultdict(dict))
153
+
154
+ for method in self.confidence_method:
155
+ overall_results[method] = {}
156
+ for alpha in tqdm(
157
+ alphas, desc=f"Computing factual results for {method}"
158
+ ):
159
+ # trackers for this (method, alpha)
160
+ overall_correctness = []
161
+ thresholds_per_group = defaultdict(list)
162
+ correctness_per_group = defaultdict(list)
163
+
164
+ for _ in range(self.runs):
165
+ random.shuffle(data)
166
+ calibration_data, test_data = split_group(data, calibrate_range)
167
+ groups = list(calibration_data.keys())
168
+
169
+ # assert on nonempty
170
+ assert calibration_data, "No calibration groups"
171
+ assert len(test_data) != 0, "Test data should not be empty"
172
+
173
+ # compute all thresholds at once
174
+ thresholds = self._compute_threshold_by_group(
175
+ alpha, calibration_data, a, method, groups=groups
176
+ )
177
+
178
+ fraction_correct = self._evaluate_factual_correctness(
179
+ test_data, thresholds, a, method
180
+ )
181
+ for group in groups:
182
+ thresholds_per_group[group].append(thresholds[group])
183
+ correctness_per_group[group].append(
184
+ fraction_correct[group]
185
+ )
186
+ overall_correctness.append(fraction_correct["overall"])
187
+
188
+ # package overall
189
+ overall_results[method][alpha] = {
190
+ "threshold": thresholds_per_group,
191
+ "correctness": overall_correctness,
192
+ "factuality": 1 - alpha,
193
+ }
194
+
195
+ # package per‐group
196
+ for grp in thresholds_per_group:
197
+ per_group_results[grp][method][alpha] = {
198
+ "threshold": thresholds_per_group[grp],
199
+ "correctness": correctness_per_group[grp],
200
+ "factuality": 1 - alpha,
201
+ }
202
+
203
+ # now write your CSVs
204
+ for grp, grp_results in per_group_results.items():
205
+ csv_name = os.path.join(self.result_dir, f"{self.dataset_name}_{grp}_factual_correctness.csv")
206
+ self._write_csv_header(csv_name, alphas)
207
+ for method, res in grp_results.items():
208
+ lvl, corr, err = self.process_factual_correctness_results(res)
209
+ append_result_to_csv(
210
+ csv_filename=csv_name,
211
+ label=f"{method}_factual_correctness",
212
+ y=corr,
213
+ yerr=err,
214
+ )
215
+
216
+ return overall_results
217
+
218
+
219
+ def _evaluate_factual_correctness(
220
+ self,
221
+ data: list,
222
+ thresholds: dict,
223
+ a: float,
224
+ confidence_method: str,
225
+ ):
226
+ """
227
+ Evaluates the factual correctness of subclaims within the provided data,
228
+ using a per‑group threshold but computing overall accuracy over all entries.
229
+
230
+ Args:
231
+ data (list): A list of dicts, each with "groups" and "subclaims".
232
+ thresholds (dict): Mapping from group_name -> threshold float.
233
+ a (float): The accuracy level to compare the correctly retained percentage against.
234
+ confidence_method (str): Which score key to use for similarity.
235
+
236
+ Returns:
237
+ dict: { "overall": float, "<group1>": float, "<group2>": float, … }
238
+ """
239
+ per_group_percentages = defaultdict(list)
240
+ total_pass = 0
241
+ total_entries = 0
242
+
243
+ for entry in data:
244
+ group_name = entry["groups"][0]
245
+ retained_cnt = 0
246
+ correctly_retained_cnt = 0
247
+
248
+ # count retained & correctly retained
249
+ for sub in entry["subclaims"]:
250
+ score = sub["scores"][confidence_method]
251
+ noise = sub["scores"]["noise"]
252
+ if score + noise >= thresholds[group_name]:
253
+ retained_cnt += 1
254
+ if sub.get("annotations", {}).get("gpt", "") in CORRECT_ANNOTATIONS:
255
+ correctly_retained_cnt += 1
256
+
257
+ # pct for this entry (1.0 if nothing retained)
258
+ pct = (correctly_retained_cnt / retained_cnt) if retained_cnt > 0 else 1.0
259
+ per_group_percentages[group_name].append(pct)
260
+
261
+ # update global pass/fail
262
+ total_entries += 1
263
+ if pct >= a:
264
+ total_pass += 1
265
+
266
+ # compute per‑group correctness
267
+ per_group_correctness = {
268
+ grp: sum(1 for pct in pct_list if pct >= a) / len(pct_list)
269
+ for grp, pct_list in per_group_percentages.items()
270
+ }
271
+
272
+ # compute overall exactly as “fraction of all entries passing”
273
+ overall = total_pass / total_entries if total_entries > 0 else 0.0
274
+
275
+ # assemble result
276
+ result = {"overall": overall}
277
+ result.update(per_group_correctness)
278
+ return result
src/calibration/conformal.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import csv
3
+ import random
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from tqdm import tqdm
7
+
8
+ from src.calibration.base_calibration import ICalibration
9
+ from src.calibration.utils import compute_threshold
10
+ from src.calibration.utils import append_result_to_csv
11
+
12
+
13
+ CORRECT_ANNOTATIONS = ["S"]
14
+
15
+
16
+ class SplitConformalCalibration(ICalibration):
17
+ """
18
+ Implementation of standard conformal calibration.
19
+ """
20
+
21
+ def __init__(self, dataset_name: str, runs: int = 1000):
22
+ self.dataset_name = dataset_name
23
+ self.confidence_method = [
24
+ "relavance",
25
+ "frequency",
26
+ "query_claim_cosine_similarity",
27
+ "doc_claim_cosine_similarity",
28
+ "min_log_prob",
29
+ "random",
30
+ "ordinal",
31
+ ]
32
+ self.runs = runs
33
+
34
+ def plot_conformal_removal(
35
+ self, data, alphas, a, fig_filename, csv_filename
36
+ ):
37
+
38
+ # compute the correctness and fraction removed for each alpha
39
+
40
+ cache_filename = f"{os.path.splitext(os.path.abspath(csv_filename))[0]}_conformal_removal_cache.npy"
41
+ if not os.path.exists(cache_filename):
42
+ results = self.compute_conformal_results(
43
+ data, alphas, a
44
+ )
45
+ print(f"Caching results to {cache_filename}")
46
+ np.save(cache_filename, results)
47
+
48
+ else:
49
+ print(f"Loading cached results from {cache_filename}")
50
+ results = np.load(cache_filename, allow_pickle=True).item()
51
+
52
+ ax = None
53
+ for confidence_method, result in results.items():
54
+ correctness, fraction_removed, yerr = (
55
+ self.process_conformal_removal_results(result)
56
+ )
57
+
58
+ # write the results to csv file
59
+ self._write_csv_header(csv_filename, alphas)
60
+ append_result_to_csv(
61
+ csv_filename=csv_filename,
62
+ label=f"{confidence_method}_conformal_removal_rate",
63
+ y=fraction_removed,
64
+ yerr=yerr,
65
+ )
66
+
67
+ # plot the results
68
+ print(f"Producing conformal plot for {confidence_method}")
69
+ ax = self.plot_conformal_removal_rate_by_alpha(
70
+ correctness,
71
+ fraction_removed,
72
+ yerr,
73
+ a,
74
+ confidence_method,
75
+ fig_filename,
76
+ ax,
77
+ )
78
+ print(f"Conformal plot saved to {fig_filename}")
79
+
80
+ def compute_conformal_results(
81
+ self, data: list, alphas: np.ndarray, a: float
82
+ ):
83
+
84
+ results = {}
85
+ for confidence_method in self.confidence_method:
86
+ results[confidence_method] = {}
87
+ for alpha in tqdm(
88
+ alphas, desc=f"Computing conformal results for {confidence_method}"
89
+ ):
90
+
91
+ thresholds = []
92
+ correctness_list = []
93
+ fraction_removed_list = []
94
+ for _ in range(self.runs):
95
+ random.shuffle(data)
96
+ split_index = len(data) // 2
97
+ calibration_data = data[:split_index]
98
+ test_data = data[split_index:]
99
+
100
+ assert (
101
+ len(calibration_data) != 0
102
+ ), "Calibration data should not be empty"
103
+ assert len(test_data) != 0, "Test data should not be empty"
104
+
105
+ threshold = compute_threshold(alpha, calibration_data, a, confidence_method)
106
+
107
+ correctness, fraction_removed = (
108
+ self._evaluate_conformal_correctness(
109
+ test_data, threshold, a, confidence_method
110
+ )
111
+ )
112
+ thresholds.append(threshold)
113
+ correctness_list.append(correctness)
114
+ fraction_removed_list.append(fraction_removed)
115
+
116
+ results[confidence_method][alpha] = {
117
+ "threshold": thresholds,
118
+ "correctness": correctness_list,
119
+ "fraction_removed": fraction_removed_list,
120
+ }
121
+
122
+ return results
123
+
124
+ def process_conformal_removal_results(self, results: dict):
125
+ """
126
+ x: list of average correctness
127
+ y: list of average fraction removed
128
+ yerr: list of standard error of fraction removed
129
+ """
130
+ x, y, yerr = [], [], []
131
+ for alpha, results_for_alpha in results.items():
132
+ x_per_alpha = np.mean(
133
+ results_for_alpha["correctness"]
134
+ ) # correct retainment percentage at a specific alpha value, averaging over 1000 times of shuffled data
135
+ y_per_alpha = np.mean(
136
+ results_for_alpha["fraction_removed"]
137
+ ) # removal percentage at a specific alpha value, averaging, averaging over 1000 times of shuffled data
138
+ x.append(x_per_alpha)
139
+ y.append(y_per_alpha)
140
+ yerr.append(
141
+ (
142
+ np.std(results_for_alpha["fraction_removed"])
143
+ * 1.96
144
+ / np.sqrt(len(results_for_alpha["fraction_removed"]))
145
+ )
146
+ )
147
+
148
+ return x, y, yerr
149
+
150
+ def plot_conformal_removal_rate_by_alpha(
151
+ self, x, y, yerr, a, confidence_method, fig_filename, ax=None
152
+ ):
153
+ if not ax:
154
+ fig, ax = plt.subplots(figsize=(8, 6), dpi=800)
155
+ ax.set_title(
156
+ f"Conformal Plots for {self.dataset_name} Datasets (a={a})", fontsize=20
157
+ )
158
+ x_label = (
159
+ f"Fraction achieving avg factuality >= {a}"
160
+ if a != 1
161
+ else "Fraction of factual outputs"
162
+ )
163
+ ax.set_xlabel(x_label, fontsize=16)
164
+ ax.set_ylabel("Average percent removed", fontsize=16)
165
+ else:
166
+ fig = ax.figure
167
+
168
+ # Plot the data
169
+ ax.errorbar(x, y, yerr=yerr, label=confidence_method, linewidth=2)
170
+
171
+ # set the legend
172
+ ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98), fontsize=10)
173
+
174
+ # Save the figure
175
+ fig.savefig(fig_filename, bbox_inches="tight")
176
+
177
+ return ax # Return the ax for further modifications if needed
178
+
179
+ def _write_csv_header(self, csv_filename, alphas):
180
+ target_factuality = [f"{(1-x):.2f}" for x in alphas][::-1]
181
+ header = ["target_factuality"] + target_factuality
182
+
183
+ # Ensure the directory exists
184
+ os.makedirs(os.path.dirname(csv_filename), exist_ok=True)
185
+
186
+ if not os.path.exists(csv_filename):
187
+ with open(csv_filename, mode="w", newline="") as file:
188
+ csv.writer(file).writerow(header)
189
+
190
+ def _evaluate_conformal_correctness(
191
+ self, data: list, threshold: float, a: float, confidence_method: str
192
+ ):
193
+ """
194
+ Evaluates the performance of a conformal prediction model on test data.
195
+ Parameters:
196
+ data (list): A list of dictionaries, where each dictionary represents an entry with subclaims.
197
+ threshold (float): The similarity score threshold to determine if a subclaim is correctly retained.
198
+ a (float): The threshold for the correctly retained percentage to consider an entry as correctly retained.
199
+ Returns:
200
+ tuple: A tuple containing two lists:
201
+ - correctly_retained (float): Percentage of data that are correctly retained.
202
+ - fraction_removed (float): Percentage of subclaims removed for each entry.
203
+ """
204
+
205
+ correctly_retained = []
206
+ fraction_removed = []
207
+
208
+ for entry in data:
209
+ removal_count = 0
210
+ retained_cnt = 0
211
+ correctly_retained_count = 0
212
+
213
+ for subclaim in entry["subclaims"]:
214
+ # Find similarity score
215
+ score = subclaim["scores"][confidence_method]
216
+ noise = subclaim["scores"]["noise"]
217
+ if score + noise >= threshold:
218
+ retained_cnt += 1
219
+ if (
220
+ subclaim.get("annotations", {}).get("gpt", "")
221
+ in CORRECT_ANNOTATIONS
222
+ ):
223
+ correctly_retained_count += 1
224
+
225
+ else:
226
+ removal_count += 1
227
+
228
+ total_subclaims = len(entry["subclaims"])
229
+
230
+ # Calculate fraction of removed subclaims
231
+ entry_removal_rate = (
232
+ 0 if total_subclaims == 0 else removal_count / total_subclaims
233
+ )
234
+ fraction_removed.append(
235
+ entry_removal_rate
236
+ ) # e.g. fraction_removed = [0.2, 0.5, 0.6, 0.2, 0.7] - one element per data entry
237
+
238
+ # Calculate correctly retained rate
239
+ correctly_retained_percentage = (
240
+ correctly_retained_count / retained_cnt if retained_cnt > 0 else 1
241
+ )
242
+ correctly_retained.append(correctly_retained_percentage >= a)
243
+
244
+ return np.mean(correctly_retained), np.mean(fraction_removed)
245
+
246
+ def plot_factual_removal(
247
+ self, data, alphas, a, fig_filename, csv_filename, plot_group_results=False
248
+ ):
249
+ x_values = np.linspace(1 - alphas[-1] - 0.05, 1 - alphas[0] + 0.03, 100)
250
+ fig, ax = plt.subplots(figsize=(8, 6), dpi=800)
251
+ ax.plot(
252
+ x_values,
253
+ x_values,
254
+ "--",
255
+ color="gray",
256
+ linewidth=2,
257
+ label="Conformal guarantee lower bounds",
258
+ )
259
+
260
+ cache_filename = f"{os.path.splitext(os.path.abspath(csv_filename))[0]}_factual_correctness_cache.npy"
261
+ if not os.path.exists(cache_filename):
262
+ results = self.compute_factual_results(data, alphas, a)
263
+ print(f"Caching results to {cache_filename}")
264
+ np.save(cache_filename, results)
265
+
266
+ else:
267
+ print(f"Loading cached results from {cache_filename}")
268
+ results = np.load(cache_filename, allow_pickle=True).item()
269
+
270
+ for confidence_method, result in results.items():
271
+ conf_level, corretness, yerr = self.process_factual_correctness_results(
272
+ result
273
+ )
274
+
275
+ append_result_to_csv(
276
+ csv_filename=csv_filename,
277
+ label=f"{confidence_method}_factual_correctness",
278
+ y=corretness,
279
+ yerr=yerr,
280
+ )
281
+
282
+ print(
283
+ f"Producing factual removal plot for {confidence_method}: {fig_filename}"
284
+ )
285
+ ax = self.plot_factual_removal_rate_by_alpha(
286
+ conf_level, corretness, a, confidence_method, fig_filename, ax
287
+ )
288
+ print(f"Conformal plot saved to {fig_filename}")
289
+
290
+ if plot_group_results:
291
+ # self.plot_factual_group_results(results, csv_filename, x)
292
+ raise NotImplementedError("Not implemented")
293
+
294
+ def compute_factual_results(self, data, alphas, a):
295
+ results = {}
296
+ for confidence_method in self.confidence_method:
297
+ results[confidence_method] = {}
298
+ for alpha in tqdm(
299
+ alphas, desc=f"Computing factual results for {confidence_method}"
300
+ ):
301
+ thresholds = []
302
+ correctness = []
303
+ for _ in range(self.runs):
304
+ random.shuffle(data)
305
+ split_index = len(data) // 2
306
+ calibration_data = data[:split_index]
307
+ test_data = data[split_index:]
308
+
309
+ assert (
310
+ len(calibration_data) != 0
311
+ ), "Calibration data should not be empty"
312
+ assert len(test_data) != 0, "Test data should not be empty"
313
+
314
+ threshold = compute_threshold(alpha, calibration_data, a, confidence_method)
315
+ fraction_correct = self._evaluate_factual_correctness(
316
+ test_data, threshold, a, confidence_method
317
+ )
318
+ thresholds.append(threshold)
319
+ correctness.append(fraction_correct)
320
+
321
+ results[confidence_method][alpha] = {
322
+ "threshold": thresholds,
323
+ "correctness": correctness,
324
+ "factuality": 1 - alpha,
325
+ }
326
+ return results
327
+
328
+ def process_factual_correctness_results(self, results: dict):
329
+ """
330
+ x: confidence level
331
+ y: list of average factual correctness
332
+ yerr: list of standard error of factual correctness
333
+ """
334
+ x, y, yerr = [], [], []
335
+ for alpha, results_for_alpha in results.items():
336
+
337
+ x.append(1 - alpha)
338
+ y.append(np.mean(results_for_alpha["correctness"]))
339
+ yerr.append(
340
+ (
341
+ np.std(results_for_alpha["correctness"])
342
+ * 1.96
343
+ / np.sqrt(len(results_for_alpha["correctness"]))
344
+ )
345
+ )
346
+
347
+ return x, y, yerr
348
+
349
+ def plot_factual_removal_rate_by_alpha(
350
+ self, x, y, a, confidence_method, fig_filename, ax=None
351
+ ):
352
+ if not ax:
353
+ fig, ax = plt.subplots(figsize=(8, 6), dpi=800)
354
+ else:
355
+ fig = ax.figure # Get the figure from the provided ax
356
+
357
+ ax.set_xlabel(f"Target factuality (1 - {chr(945)})", fontsize=16)
358
+ ax.set_ylabel("Empirical factuality", fontsize=16)
359
+ ax.set_title(
360
+ f"Factual correctness for {self.dataset_name} Datasets (a={a})", fontsize=20
361
+ )
362
+
363
+ # Plot the data
364
+ ax.plot(x, y, label=confidence_method, linewidth=2)
365
+
366
+ # Set legend
367
+ ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98), fontsize=10)
368
+
369
+ # Save the figure
370
+ fig.savefig(fig_filename, bbox_inches="tight", dpi=800)
371
+
372
+ return ax # Return the ax for further modifications if needed
373
+
374
+ def _evaluate_factual_correctness(
375
+ self, data: list, threshold: float, a: float, confidence_method: str
376
+ ):
377
+ """
378
+ Evaluates the factual correctness of subclaims within the provided data.
379
+ This function processes a list of data entries, each containing subclaims with similarity scores.
380
+ It calculates the percentage of correctly retained subclaims based on a given threshold and
381
+ compares it to a specified accuracy level `a`.
382
+ Args:
383
+ data (list): A list of dictionaries, where each dictionary represents an entry containing subclaims.
384
+ threshold (float): The similarity score threshold above which subclaims are considered retained.
385
+ a (float): The accuracy level to compare the correctly retained percentage against.
386
+ Returns:
387
+ float: The percentage of entries in the data that satisfy the correct level of accuracy `a`.
388
+ """
389
+
390
+ correctly_retained = []
391
+ # Process each item in the list
392
+ for entry in data:
393
+ # Extract subclaims from each item
394
+ retained_cnt = 0
395
+ correctly_retained_count = 0
396
+ for subclaim in entry["subclaims"]:
397
+
398
+ # Extract the score and noise
399
+ score = subclaim["scores"][confidence_method]
400
+ noise = subclaim["scores"]["noise"]
401
+
402
+ # Add the subclaim to the collection if similarity score is above threshold
403
+ if score + noise >= threshold:
404
+ retained_cnt += 1
405
+ if (
406
+ subclaim.get("annotations", {}).get("gpt", "")
407
+ in CORRECT_ANNOTATIONS
408
+ ):
409
+ correctly_retained_count += 1
410
+
411
+ # Calculate correctly retained rate
412
+ correctly_retained_percentage = (
413
+ correctly_retained_count / retained_cnt if retained_cnt > 0 else 1
414
+ )
415
+ correctly_retained.append(correctly_retained_percentage)
416
+
417
+ correctness_list = [
418
+ correctly_retained_percentage >= a
419
+ for correctly_retained_percentage in correctly_retained
420
+ ]
421
+ # percentage of test data satisfying correct level of a
422
+ return sum(correctness_list) / len(correctness_list)
src/calibration/utils.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import csv
3
+ import numpy as np
4
+ from math import ceil
5
+ from collections import defaultdict
6
+
7
+ CORRECT_ANNOTATIONS = ["Y", "S"]
8
+
9
+
10
+ def load_subclaim_data(file_path):
11
+ """Load calibration data from a JSON file"""
12
+ with open(file_path, "r", encoding="utf-8") as file:
13
+ return json.load(file)
14
+
15
+
16
+ def append_result_to_csv(csv_filename, label, y, yerr):
17
+ """Append calibration results to CSV file"""
18
+ formatted_results = [f"{y:.4f} ± {yerr:.4f}" for y, yerr in zip(y, yerr)]
19
+ formatted_results.reverse()
20
+ row = [label] + formatted_results
21
+ with open(csv_filename, mode="a", newline="") as file:
22
+ writer = csv.writer(file)
23
+ writer.writerow(row)
24
+
25
+
26
+ def _get_accepted_subclaims(entry, threshold, confidence_method):
27
+ """Helper function to get accepted subclaims based on threshold"""
28
+ return [
29
+ subclaim
30
+ for subclaim in entry["subclaims"]
31
+ if subclaim["scores"][confidence_method] + subclaim["scores"]["noise"]
32
+ >= threshold
33
+ ]
34
+
35
+
36
+ def _calculate_entailed_fraction(subclaims):
37
+ """Helper function to calculate fraction of entailed/correct subclaims"""
38
+ if not subclaims:
39
+ return 1.0
40
+ return np.mean(
41
+ [
42
+ subclaim["annotations"]["gpt"] in CORRECT_ANNOTATIONS
43
+ for subclaim in subclaims
44
+ ]
45
+ )
46
+
47
+
48
+ def get_r_score(entry: list, confidence_method: str, a: float):
49
+ """
50
+ Compute the r_a score for each data entry when confidence_method is used as the sub-claim scoring function.
51
+
52
+ This function calculates the minimum threshold at which the fraction of correct subclaims
53
+ falls below the required threshold 'a'. The r_a score represents the confidence score
54
+ at which the model's reliability drops below the acceptable level.
55
+
56
+ The algorithm works by:
57
+ 1. First checking if the score was already calculated and cached
58
+ 2. Sorting all subclaim scores in descending order
59
+ 3. Testing each score as a potential threshold
60
+ 4. For each threshold, accepting only subclaims with scores >= threshold
61
+ 5. Calculating the fraction of correct subclaims among the accepted ones
62
+ 6. Returning the first threshold where this fraction falls below 'a'
63
+ 7. Returning -1 if all possible thresholds maintain accuracy above 'a'
64
+
65
+ Args:
66
+ entry: Dictionary containing claims data
67
+ confidence_method: Method used for scoring subclaims
68
+ a: Required fraction correct threshold
69
+
70
+ Returns:
71
+ float: r_a score for the entry
72
+ """
73
+ r_score_key = f"r_score_{a}_{confidence_method}"
74
+ if r_score_key in entry:
75
+ return entry[r_score_key]
76
+ #add a cache in entry to remember it's r_score
77
+
78
+ scores = [
79
+ subclaim["scores"][confidence_method] + subclaim["scores"]["noise"]
80
+ for subclaim in entry["subclaims"]
81
+ ]
82
+ threshold_set = sorted(scores, reverse=True)
83
+
84
+ for threshold in threshold_set:
85
+ accepted_subclaims = _get_accepted_subclaims(
86
+ entry, threshold, confidence_method
87
+ )
88
+ entailed_fraction = _calculate_entailed_fraction(accepted_subclaims)
89
+
90
+ if entailed_fraction < a:
91
+ entry[r_score_key] = threshold
92
+ return threshold
93
+
94
+ entry[r_score_key] = -1
95
+ return -1
96
+
97
+
98
+ def compute_threshold(alpha, calibration_data, a, confidence_method):
99
+ """
100
+ Computes the quantile/threshold from conformal prediction.
101
+ # alpha: float in (0, 1)
102
+ # calibration_data: calibration data
103
+ # a: as in paper, required fraction correct, section 4.1
104
+ # confidence_method: string
105
+ """
106
+ # Compute r score for each example.
107
+ r_scores = [get_r_score(entry, confidence_method, a) for entry in calibration_data]
108
+
109
+ # Compute threshold for conformal prection. The quantile is ceil((n+1)*(1-alpha))/n, and
110
+ # We map this to the index by dropping the division by n and subtracting one (for zero-index).
111
+ quantile_target_index = min(ceil((len(r_scores) + 1) * (1 - alpha)), len(r_scores))
112
+ threshold = sorted(r_scores)[quantile_target_index - 1]
113
+ return threshold
114
+
115
+
116
+ # Make sure the split calibrate_range ratio are all same not just in overall level but in group level
117
+ # not return data in list but in a map with each group name as key
118
+ def split_group(data, calibrate_range=0.5):
119
+ group_data = defaultdict(list)
120
+ calibration_data = defaultdict(list)
121
+ test_data = []
122
+
123
+ for entry in data:
124
+ group = entry["groups"][0] # Use first group as default
125
+ group_data[group].append(entry)
126
+
127
+ for group, group_entries in group_data.items():
128
+ split_index = ceil(len(group_entries) * calibrate_range)
129
+ calibration_data[group].extend(group_entries[:split_index])
130
+ test_data.extend(group_entries[split_index:])
131
+
132
+ return calibration_data, test_data
133
+
134
+ # Analyze Functions #
135
+
136
+ def percentage_highest_not_S(data, key="relavance"):
137
+ count_total = 0
138
+ count_not_S = 0
139
+
140
+ for item in data:
141
+ subclaims = item.get("subclaims", [])
142
+ if not subclaims:
143
+ continue
144
+
145
+ # Sort subclaims by (score[key] + score[noise]), descending
146
+ subclaims_sorted = sorted(
147
+ subclaims,
148
+ key=lambda sc: sc["scores"].get(key, 0) + sc["scores"].get("noise", 0),
149
+ reverse=True
150
+ )
151
+
152
+ top_annotation = subclaims_sorted[0].get("annotations", {}).get("gpt", None)
153
+
154
+ count_total += 1
155
+ if top_annotation != "S":
156
+ count_not_S += 1
157
+
158
+ if count_total == 0:
159
+ return 0.0 # Avoid division by zero
160
+
161
+ return (count_not_S / count_total) * 100
src/common/__init__.py ADDED
File without changes
src/common/chunker.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Any
3
+
4
+ class BaseChunker(ABC):
5
+ """
6
+ Abstract base class for all chunking strategies.
7
+ """
8
+
9
+ def __init__(self, document, chunk_size: int, overlap_size: int = 0):
10
+ self.document = document
11
+ self.chunk_size = chunk_size
12
+ self.overlap_size = overlap_size
13
+
14
+ @abstractmethod
15
+ def create_chunks(self) -> list[dict[str, Any]]:
16
+ """
17
+ Abstract method to be implemented by subclasses for chunking text.
18
+ """
19
+ pass
20
+
21
+ class FixedLengthChunker(BaseChunker):
22
+ """
23
+ Chunker that splits text into overlapping fixed-size chunks of words.
24
+ """
25
+
26
+ def create_chunks(self) -> list[str]:
27
+
28
+ chunks: list[str] = []
29
+
30
+ text = self.document
31
+ words = text.split()
32
+ start = 0
33
+ chunk_num = 0
34
+
35
+ while start < len(words):
36
+ end = start + self.chunk_size
37
+ chunks.append(" ".join(words[start:end]))
38
+ start += self.chunk_size - self.overlap_size
39
+ chunk_num += 1
40
+
41
+ return chunks, len(words)
src/common/config_manager.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import logging
4
+ import datetime
5
+ import json
6
+ import shutil
7
+ from pathlib import Path
8
+
9
+ class ConfigManager:
10
+ """Utility class to manage configuration loading, saving and logging"""
11
+
12
+ def __init__(self, config_path=None, path_config_path=None, dataset_config_path=None, run_id=None):
13
+ """
14
+ Initialize the ConfigManager with a config file path
15
+
16
+ Args:
17
+ config_path (str): Path to the YAML config file
18
+ run_id (str): Optional identifier for the run
19
+ """
20
+ self.config = {}
21
+ self.run_id = run_id or datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
22
+ self.log_dir = f"logs/{self.run_id}"
23
+
24
+ if config_path:
25
+ self.config = self.load_config(config_path)
26
+ if path_config_path:
27
+ self.path_config = self.load_config(path_config_path)
28
+ if dataset_config_path:
29
+ self.dataset_config = self.load_config(dataset_config_path)
30
+
31
+ def load_config(self, config_path):
32
+ """
33
+ Load configuration from a YAML file
34
+
35
+ Args:
36
+ config_path (str): Path to the YAML config file
37
+
38
+ Returns:
39
+ dict: The loaded configuration
40
+ """
41
+ with open(config_path, 'r') as f:
42
+ return self._expand_env_vars(yaml.safe_load(f))
43
+
44
+ def _expand_env_vars(self, obj):
45
+ """Recursively expand environment variables in string values."""
46
+ if isinstance(obj, dict):
47
+ return {k: self._expand_env_vars(v) for k, v in obj.items()}
48
+ if isinstance(obj, list):
49
+ return [self._expand_env_vars(v) for v in obj]
50
+ if isinstance(obj, str):
51
+ return os.path.expandvars(obj)
52
+ return obj
53
+
54
+ def save_config(self, output_path=None):
55
+ """
56
+ Save the current configuration to a YAML file
57
+
58
+ Args:
59
+ output_path (str): Path to save the config file, defaults to log directory
60
+
61
+ Returns:
62
+ str: Path to the saved config file
63
+ """
64
+ if output_path is None:
65
+ os.makedirs(self.log_dir, exist_ok=True)
66
+ output_path = os.path.join(self.log_dir, f"config_{self.run_id}.yaml")
67
+
68
+ # Ensure directory exists
69
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
70
+
71
+ with open(output_path, 'w') as f:
72
+ yaml.dump(self.config, f, default_flow_style=False)
73
+
74
+ return output_path
75
+
76
+ def setup_logging(self, log_level=logging.INFO):
77
+ """
78
+ Setup logging configuration
79
+
80
+ Args:
81
+ log_level: Logging level
82
+
83
+ Returns:
84
+ str: Path to the log file
85
+ """
86
+ os.makedirs(self.log_dir, exist_ok=True)
87
+ log_file = os.path.join(self.log_dir, f"run_{self.run_id}.log")
88
+
89
+ logging.basicConfig(
90
+ level=log_level,
91
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
92
+ handlers=[
93
+ logging.FileHandler(log_file),
94
+ logging.StreamHandler()
95
+ ]
96
+ )
97
+
98
+ # Disable httpx logs
99
+ logging.getLogger("httpx").setLevel(logging.WARNING)
100
+
101
+ # Log some initial information
102
+ logging.info(f"Starting run with ID: {self.run_id}")
103
+ logging.info(f"Log file: {log_file}")
104
+
105
+ return log_file, self.run_id
106
+
107
+ def log_config(self):
108
+ """Log the important parts of the configuration"""
109
+ if not self.config:
110
+ logging.warning("No configuration loaded to log")
111
+ return
112
+
113
+ logging.info("=== Run Configuration ===")
114
+
115
+ # Log dataset info
116
+ if 'dataset' in self.config:
117
+ logging.info(f"Dataset: {self.config['dataset']['name']}")
118
+ logging.info(f"Query size: {self.config['dataset']['query_size']}")
119
+
120
+ # Log index info
121
+ if 'index' in self.config:
122
+ logging.info(f"Embedding model: {self.config['index']['embedding_model']}")
123
+ logging.info(f"Delete existing index: {self.config['index']['delete_existing']}")
124
+
125
+ logging.info("========================")
126
+
127
+ def update_config(self, updates):
128
+ """
129
+ Update the configuration with new values
130
+
131
+ Args:
132
+ updates (dict): Dictionary containing updates to apply
133
+
134
+ Returns:
135
+ dict: The updated configuration
136
+ """
137
+ # This is a simple implementation that only handles top-level keys
138
+ for key, value in updates.items():
139
+ if isinstance(value, dict) and key in self.config and isinstance(self.config[key], dict):
140
+ self.config[key].update(value)
141
+ else:
142
+ self.config[key] = value
143
+
144
+ return self.config
145
+
146
+ def copy_run_artifacts(self, result_dir):
147
+ """
148
+ Copy config and logs to a results directory for reproducibility
149
+
150
+ Args:
151
+ result_dir (str): Path to the results directory
152
+
153
+ Returns:
154
+ str: Path to the result run directory
155
+ """
156
+ result_run_dir = os.path.join(result_dir, "config")
157
+ os.makedirs(result_run_dir, exist_ok=True)
158
+
159
+ # Get the latest config and log files
160
+ config_files = sorted(Path(self.log_dir).glob("config_*.yaml"))
161
+ # log_files = sorted(Path(self.log_dir).glob("run_*.log"))
162
+
163
+ if config_files:
164
+ latest_config = str(config_files[-1])
165
+ shutil.copy2(latest_config, os.path.join(result_run_dir, "config.yaml"))
166
+
167
+ # if log_files:
168
+ # latest_log = str(log_files[-1])
169
+ # shutil.copy2(latest_log, os.path.join(result_run_dir, "run.log"))
170
+
171
+ return result_run_dir
src/common/faiss_manager.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ import ast
5
+ import faiss
6
+ from typing import Union, Optional
7
+ from dotenv import load_dotenv
8
+ import numpy as np
9
+ from sklearn.preprocessing import normalize
10
+ from src.common.file_manager import FileManager
11
+ from src.common.llm.openai_manager import OpenAIManager
12
+
13
+
14
+ class FAISSIndexManager:
15
+ def __init__(
16
+ self,
17
+ index_truncation_config,
18
+ dimension=3072,
19
+ index_path="index_store/index.faiss",
20
+ indice2fm_path="index_store/indice2fm.json",
21
+ ):
22
+
23
+ dotenv_path = os.path.join(os.getcwd(), ".env")
24
+ load_dotenv(dotenv_path)
25
+ self.openaiManager = OpenAIManager()
26
+ self.dimension = dimension
27
+ self.index = faiss.IndexFlatIP(dimension)
28
+ self.file_managers = []
29
+ self.indice2fm = (
30
+ {}
31
+ ) # Mapping from file texts tracking from file_path to faiss index indices, guarantee indice in asc order
32
+ self.index_path = index_path
33
+ self.indice2fm_path = indice2fm_path
34
+
35
+ # initialize index and indice2fm from saved files
36
+ if os.path.exists(index_path):
37
+ self.index = faiss.read_index(index_path)
38
+ print(f"Loaded FAISS index from {index_path}")
39
+
40
+ if os.path.exists(indice2fm_path):
41
+ with open(indice2fm_path, "r") as file:
42
+ self.indice2fm = json.load(file)
43
+ for file_path, _ in self.indice2fm.items():
44
+ self.file_managers.append(
45
+ FileManager(
46
+ file_path=file_path,
47
+ index_truncation_config=index_truncation_config,
48
+ )
49
+ )
50
+
51
+ def is_indice_align(self):
52
+ last_index_id = self.index.ntotal - 1
53
+ return last_index_id == max(max(values) for values in self.indice2fm.values())
54
+
55
+ def save_index(self, index_path, indice2fm_path):
56
+ if self.index:
57
+ os.makedirs(os.path.dirname(index_path), exist_ok=True)
58
+ faiss.write_index(self.index, index_path)
59
+ # also save file_path to indice mapping, self.indice2fm should be updated before calling this function
60
+ with open(indice2fm_path, mode="w") as file:
61
+ json.dump(self.indice2fm, file, indent=4)
62
+
63
+ def delete_index(self):
64
+ self.index.reset()
65
+ self.indice2fm = {}
66
+ if os.path.exists(self.index_path):
67
+ os.remove(self.index_path)
68
+ if os.path.exists(self.indice2fm_path):
69
+ os.remove(self.indice2fm_path)
70
+ print("FAISS index deleted.")
71
+
72
+ def upsert_file_to_faiss(
73
+ self,
74
+ file_manager,
75
+ model="text-embedding-3-large",
76
+ truncation_strategy: Optional[Union[str, bool]] = "fixed_length",
77
+ truncate_by: Optional[str] = "\n",
78
+ ):
79
+ if not file_manager.file_path in [
80
+ file_manager.file_path for file_manager in self.file_managers
81
+ ]:
82
+ self.file_managers.append(file_manager)
83
+ else:
84
+ print(f"File '{file_manager.file_path}' already exists in the FAISS index.")
85
+ return
86
+
87
+ # Process the file if necessary
88
+ # TODO: check if file_manager.texts will in any case be empty, if not, remove the below block
89
+ if not file_manager.texts:
90
+ print("Processing documents...")
91
+ file_manager.process_document(
92
+ truncation_strategy=truncation_strategy, truncate_by=truncate_by
93
+ )
94
+ print("Documents processing done.")
95
+
96
+ # Generate embeddings and append to index if not already present
97
+ if not file_manager.file_path in self.indice2fm:
98
+ print("Creating embedding for the document...")
99
+ embeddings = self.openaiManager.create_openai_embeddings(
100
+ file_manager.texts, model=model
101
+ )
102
+
103
+ # Normalize embeddings
104
+ embeddings_np = self.normalize_embeddings(embeddings)
105
+ start_index = self.index.ntotal
106
+ # Add embeddings to FAISS index
107
+ self.index.add(embeddings_np)
108
+ end_index = self.index.ntotal
109
+ added_indices = list(range(start_index, end_index))
110
+
111
+ # Update the self.indice2fm dictionary
112
+ self.indice2fm[file_manager.file_path] = added_indices
113
+ self.save_index(
114
+ index_path=self.index_path, indice2fm_path=self.indice2fm_path
115
+ )
116
+ print(
117
+ f"Embeddings from file '{file_manager.file_path}' added to FAISS index between indice {start_index} to {end_index}."
118
+ )
119
+ else:
120
+ print(f"File '{file_manager.file_path}' already exists in the FAISS index.")
121
+
122
+ def normalize_embeddings(self, embeddings):
123
+ if np.isnan(embeddings).any() or np.isinf(embeddings).any():
124
+ raise ValueError("Embeddings contain NaNs or Infs.")
125
+ embeddings_np = np.array(embeddings).astype("float32")
126
+ #faiss normalize give error zsh: segmentation fault python faiss manager at some edge case in hotpotqa
127
+ #faiss.normalize_L2(embeddings_np)
128
+ embeddings_normalized = normalize(embeddings_np, norm='l2', axis=1)
129
+ return embeddings_normalized
130
+
131
+ def search_faiss_index(
132
+ self,
133
+ query,
134
+ top_k=10,
135
+ threshold=0.5,
136
+ truncation_strategy: Optional[Union[str, bool]] = "fixed_length",
137
+ truncate_by: Optional[str] = "\n",
138
+ ):
139
+ if self.index.ntotal == 0:
140
+ return []
141
+
142
+ # Create a normalized embedding for the query
143
+ query_embedding = self.normalize_embeddings(
144
+ [
145
+ self.openaiManager.client.embeddings.create(
146
+ input=[query], model="text-embedding-3-large"
147
+ )
148
+ .data[0]
149
+ .embedding
150
+ ]
151
+ )[0].reshape(1, -1)
152
+
153
+ # Perform the search
154
+ similarity, indices = self.index.search(query_embedding, top_k)
155
+ filtered_results = [
156
+ (idx, similar)
157
+ for idx, similar in zip(indices[0], similarity[0])
158
+ if similar >= threshold
159
+ ]
160
+ results = []
161
+
162
+ # Reverse map indices to file paths and text
163
+ for idx, dist in filtered_results:
164
+ file_path_found = None
165
+ relative_idx = None
166
+
167
+ # Find the file_path and relative index using self.indice2fm
168
+ for file_path, indice_list in self.indice2fm.items():
169
+ if idx in indice_list:
170
+ file_path_found = file_path
171
+ relative_idx = indice_list.index(idx)
172
+ break
173
+
174
+ if file_path_found is not None and relative_idx is not None:
175
+ # Find the corresponding file_manager
176
+ file_manager = next(
177
+ (
178
+ fm
179
+ for fm in self.file_managers
180
+ if fm.file_path == file_path_found
181
+ ),
182
+ None,
183
+ )
184
+
185
+ if file_manager:
186
+ # Process the file if necessary
187
+ file_manager.process_document(
188
+ truncation_strategy=truncation_strategy, truncate_by=truncate_by
189
+ )
190
+ try:
191
+ # Get the text from the file_manager
192
+ text = file_manager.texts[relative_idx][
193
+ 1
194
+ ] # Assuming (index, text) tuples in file_manager.texts
195
+ results.append(
196
+ f"{text} indice={idx} fileposition={relative_idx} score={dist:.4f}"
197
+ # TODO reformat this
198
+ # {
199
+ # "text": text,
200
+ # "indice": idx,
201
+ # "fileposition": relative_idx,
202
+ # "score": round(dist, 4),
203
+ # }
204
+ )
205
+ except:
206
+ print(
207
+ f"Error while retriving id={relative_idx} from file manager. Skipping over id={relative_idx}."
208
+ )
209
+
210
+ else:
211
+ results.append(
212
+ f"File manager not found for '{file_path_found}' score={dist:.4f}"
213
+ )
214
+ else:
215
+ # TODO reformat this
216
+ results.append(f"Index not mapped, score={dist:.4f}")
217
+
218
+ return results
219
+
220
+ def parse_result(self, result):
221
+ """
222
+ Parse the result from the search and return the page content, metadata, indice, and score.
223
+ """
224
+ # Parse the input
225
+ parsed_item = None
226
+ pattern = re.compile(
227
+ r"page_content='(.*?)'\smetadata=(\{.*?\})\sindice=(\d+)\sfileposition=(\d+)\sscore=([\d.]+)",
228
+ re.DOTALL,
229
+ )
230
+ matches = pattern.findall(result)
231
+ # assume only 1 row with matched pattern will be feed in each time, only remain last item
232
+ for match in matches:
233
+ page_content, metadata, indice, fileposition, score = match
234
+ # Convert metadata string to a dictionary
235
+ metadata_dict = ast.literal_eval(metadata)
236
+ parsed_item = {
237
+ "page_content": page_content.strip(),
238
+ "metadata": metadata_dict,
239
+ "indice": int(indice),
240
+ "fileposition": int(fileposition),
241
+ "score": float(score),
242
+ }
243
+ return parsed_item
244
+
245
+ def generate_response_from_context(self, query, retrieved_docs, model="gpt-4o"):
246
+ if not retrieved_docs:
247
+ return "No relevant documents found in the FAISS index."
248
+
249
+ # Process retrieved documents into a clean context
250
+ formatted_docs = []
251
+ for doc in retrieved_docs:
252
+ try:
253
+ # Split the document string into page_content and metadata
254
+ doc_parts = doc.split("metadata=")
255
+ page_content = doc_parts[0].replace("page_content=", "").strip()
256
+ metadata = (
257
+ doc_parts[1].strip() if len(doc_parts) > 1 else "Unknown source"
258
+ )
259
+
260
+ # Format each document clearly
261
+ formatted_doc = f"Content: {page_content}\nSource: {metadata}"
262
+ formatted_docs.append(formatted_doc)
263
+ except Exception as e:
264
+ formatted_docs.append(f"Error processing document: {e}")
265
+
266
+ # Combine the formatted documents into a single context
267
+ context = "\n\n---\n\n".join(formatted_docs)
268
+
269
+ # Construct the prompt for the OpenAI API
270
+ messages = [
271
+ {
272
+ "role": "system",
273
+ "content": "You are a helpful assistant that answers questions based on provided context.",
274
+ },
275
+ {"role": "user", "content": query},
276
+ {
277
+ "role": "assistant",
278
+ "content": f"The following context was retrieved from the database:\n\n{context}",
279
+ },
280
+ ]
281
+
282
+ # Generate response using OpenAI Chat API
283
+ response = self.openaiManager.client.chat.completions.create(
284
+ model=model, messages=messages, max_tokens=4096, temperature=0.7
285
+ )
286
+ return response.choices[0].message.content
287
+
288
+
289
+ def main():
290
+ # Example Usage
291
+ file_path1 = os.path.join(os.getcwd(), "documents", "2024_Corrective_RAGv2.pdf")
292
+ file_manager1 = FileManager(file_path1)
293
+ manager = FAISSIndexManager(dimension=3072)
294
+ manager.upsert_file_to_faiss(file_manager1)
295
+
296
+ file_path2 = os.path.join(os.getcwd(), "documents", "2023_Iterative_RGen.pdf")
297
+ file_manager2 = FileManager(file_path2)
298
+ manager.upsert_file_to_faiss(file_manager2)
299
+
300
+ query = "tell me about corrective rag system."
301
+ retrieved_docs = manager.search_faiss_index(query, top_k=10, threshold=0.1)
302
+ print(retrieved_docs)
303
+ response = manager.generate_response_from_context(query, retrieved_docs)
304
+ print(response)
305
+
306
+
307
+ if __name__ == "__main__":
308
+ print("Running faiss_manager.py")
309
+ main()
src/common/file_manager.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from typing import Union, Optional
4
+ from PyPDF2 import PdfReader
5
+ from langchain_core.documents import Document
6
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
7
+ from src.common.chunker import FixedLengthChunker
8
+
9
+
10
+ class FileManager:
11
+ def __init__(self, file_path: str, index_truncation_config: dict):
12
+ self.file_path = file_path
13
+ self.chunk_size = index_truncation_config["chunk_size"]
14
+ self.chunk_overlap = index_truncation_config["chunk_overlap"]
15
+ self.texts = []
16
+ directory = os.path.dirname(file_path)
17
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
18
+ self.texts_file = os.path.join(directory, f"{base_name}_texts.json")
19
+ self.text_splitter = RecursiveCharacterTextSplitter(
20
+ chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
21
+ ) # TODO
22
+
23
+ # Load texts from file if it exists
24
+ if os.path.exists(self.texts_file):
25
+ with open(self.texts_file, "r", encoding="utf-8-sig") as f:
26
+ self.texts = json.load(f)
27
+ print(f"Loaded texts from file: {self.texts_file}")
28
+
29
+ def load_pdf_document(self):
30
+ pdf_reader = PdfReader(self.file_path)
31
+ documents = []
32
+
33
+ for page_num, page in enumerate(pdf_reader.pages):
34
+ page_text = page.extract_text()
35
+ if page_text: # Ensure the page has text
36
+ document = Document(
37
+ metadata={"source": self.file_path, "page": page_num},
38
+ page_content=page_text,
39
+ )
40
+ documents.append(document)
41
+
42
+ return documents
43
+
44
+ def dump_documents(self, texts):
45
+ if texts and not os.path.exists(self.texts_file):
46
+ with open(self.texts_file, "w") as f:
47
+ json.dump(texts, f)
48
+ print(f"Associated texts saved to file: {self.texts_file}")
49
+ else:
50
+ raise FileExistsError(
51
+ f"File {self.texts_file} already exists. Please remove it before saving."
52
+ )
53
+
54
+ def process_pdf(self):
55
+ data = self.load_pdf_document()
56
+
57
+ documents = self.text_splitter.split_documents(data)
58
+ self.texts = [(i, str(doc)) for i, doc in enumerate(documents)]
59
+ self.dump_documents(self.texts)
60
+
61
+ def process_document(
62
+ self,
63
+ truncation_strategy: Optional[Union[str, bool]] = "fixed_length",
64
+ chunk_size: int = 2000,
65
+ overlap_size: int = 25,
66
+ truncate_by: Optional[str] = "\n",
67
+ ):
68
+ """
69
+ Process document according to the specified strategy.
70
+ Either truncation_strategy or truncate_by must be provided, but not both.
71
+ """
72
+ if truncation_strategy is None and truncate_by is None:
73
+ raise ValueError(
74
+ "Either truncation_strategy or truncate_by must be provided"
75
+ )
76
+
77
+ if self.texts:
78
+ return
79
+
80
+ chunks = []
81
+
82
+ with open(self.file_path, "r", encoding="utf-8") as f:
83
+ data = json.load(f)
84
+
85
+ for title, texts in data.items():
86
+ if not truncation_strategy and not truncate_by:
87
+ chunks.append(self.create_document(title, texts, self.file_path))
88
+ print(f"{title} - No text splitting. Chunk size: {len(texts)}")
89
+ elif truncation_strategy == "fixed_length":
90
+ chunk_list = []
91
+ for text in texts:
92
+ fixed_length_chunks, texts_word_cnt = FixedLengthChunker(
93
+ text, chunk_size, overlap_size
94
+ ).create_chunks()
95
+ chunk_list.extend(fixed_length_chunks)
96
+ print(
97
+ f"Document '{title}' is splitted into {len(chunk_list)} chunk(s) by length of {chunk_size} words. Initial text size: {texts_word_cnt}."
98
+ )
99
+ for text in chunk_list:
100
+ if text.strip():
101
+ chunks.append(self.create_document(title, text, self.file_path))
102
+ elif truncation_strategy == "recursive": # Fixed typo in strategy name
103
+ raise NotImplementedError(
104
+ "Recursive truncation is currently not supported"
105
+ )
106
+ else:
107
+ # print("splitting by specific char")
108
+ if isinstance(texts, str):
109
+ if truncate_by in texts:
110
+ split_texts = texts.split(truncate_by)
111
+ else:
112
+ split_texts = [texts]
113
+ elif isinstance(texts, list):
114
+ split_texts = texts
115
+
116
+ for text in split_texts:
117
+ if text.strip():
118
+ chunks.append(self.create_document(title, text, self.file_path))
119
+
120
+ self.texts = [(i, str(doc)) for i, doc in enumerate(chunks)]
121
+ self.dump_documents(self.texts)
122
+
123
+ def create_document(self, title, text, file_path):
124
+ """Create a document with the given title and text."""
125
+ return Document(
126
+ page_content=f"{title}: {text}",
127
+ metadata={"source": title, "file_path": file_path},
128
+ )
src/common/llm/__init__.py ADDED
File without changes
src/common/llm/llm_agent.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ class LLMAgent(ABC):
4
+ @abstractmethod
5
+ def answer(self, question) -> str:
6
+ pass
7
+
8
+ @abstractmethod
9
+ def preProcess(self, query):
10
+ pass
11
+
12
+ @abstractmethod
13
+ def postProcess(self, response):
14
+ pass