Spaces:
Sleeping
Sleeping
Ryoya Awano commited on
Commit ·
19fc84f
0
Parent(s):
deploy: fix MedLFQA Marginal mode sample matching
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .DS_Store +0 -0
- .env.example +16 -0
- .gitignore +11 -0
- CLAUDE.md +57 -0
- README.md +95 -0
- conf/config.yaml +40 -0
- conf/dataset_config.yaml +17 -0
- conf/path_config.yaml +9 -0
- data/.DS_Store +0 -0
- data/processed/FactScore/fact_score_queries.json +552 -0
- data/processed/HotpotQA/hotpot_qa_queries.json +0 -0
- data/processed/PopQA/pop_qa_queries.json +0 -0
- data/processed/base_schema.json +16 -0
- data/processed/document_schema.json +12 -0
- data/processed/wiki_schema.json +26 -0
- demo/__init__.py +0 -0
- demo/app.py +362 -0
- demo/constants.py +23 -0
- demo/data/samples.json +322 -0
- demo/data/thresholds.csv +145 -0
- demo/inference_api.py +302 -0
- demo/precompute.py +311 -0
- docs/context/01_original_architecture.md +404 -0
- docs/context/02_design_patterns_analysis.md +264 -0
- docs/context/03_demo_app_design.md +327 -0
- docs/context/04_environment_setup.md +373 -0
- docs/context/05_medrag_corpus_integration.md +196 -0
- docs/context/06_test_strategy.md +180 -0
- docs/context/07_huggingface_spaces_deployment.md +259 -0
- main.py +376 -0
- requirements-dev.txt +13 -0
- requirements.txt +13 -0
- scripts/build_sif.sh +15 -0
- scripts/run_demo.sh +28 -0
- scripts/run_main.sh +22 -0
- scripts/run_precompute.sh +25 -0
- scripts/run_tests.sh +37 -0
- singularity/response_quality.def +45 -0
- src/calibration/__init__.py +0 -0
- src/calibration/base_calibration.py +19 -0
- src/calibration/conditional_conformal.py +278 -0
- src/calibration/conformal.py +422 -0
- src/calibration/utils.py +161 -0
- src/common/__init__.py +0 -0
- src/common/chunker.py +41 -0
- src/common/config_manager.py +171 -0
- src/common/faiss_manager.py +309 -0
- src/common/file_manager.py +128 -0
- src/common/llm/__init__.py +0 -0
- src/common/llm/llm_agent.py +14 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.env.example
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy this file to .env and fill in the values for your environment.
|
| 2 |
+
# .env is listed in .gitignore and must never be committed.
|
| 3 |
+
|
| 4 |
+
# Absolute path to the data root on the host.
|
| 5 |
+
# Heavy files (FAISS indices, raw data, main.py output) are stored here
|
| 6 |
+
# and bind-mounted into the container.
|
| 7 |
+
DATA_ROOT=/mnt/data/your_username/ResponseQualityAssessment
|
| 8 |
+
|
| 9 |
+
# Directory where the Singularity .sif file is stored.
|
| 10 |
+
SIF_DIR=/mnt/data/your_username/sif
|
| 11 |
+
|
| 12 |
+
# HuggingFace model cache directory.
|
| 13 |
+
HF_HOME=/mnt/data/your_username/hf_cache
|
| 14 |
+
|
| 15 |
+
# OpenAI API key (required for live inference in the demo).
|
| 16 |
+
OPENAI_API_KEY=sk-...
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/raw/WikiDB/enwiki-20230401.db
|
| 2 |
+
logs/
|
| 3 |
+
index_store/
|
| 4 |
+
data/
|
| 5 |
+
# demo/data/ is generated by precompute.py but committed to the repo
|
| 6 |
+
!demo/data/
|
| 7 |
+
!demo/data/*
|
| 8 |
+
*.pyc
|
| 9 |
+
__pycache__/
|
| 10 |
+
launch.json
|
| 11 |
+
.env
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 開発ガイドライン
|
| 2 |
+
|
| 3 |
+
## プロジェクトの目的
|
| 4 |
+
|
| 5 |
+
RAG の回答品質評価システム。共形予測(Conformal Prediction)を用いて、LLM 回答のサブクレームごとに統計的保証付きのスコアリングを行う。既存のバッチ処理ロジック(`src/` や `main.py`)を活かしつつ、インタラクティブに操作できるデモアプリ(`demo/`)をアジャイルに開発している。
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## ディレクトリ構成
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
src/
|
| 13 |
+
├── calibration/ # 共形予測キャリブレーション
|
| 14 |
+
├── common/ # 共通コンポーネント(設定・ファイル・LLM 管理)
|
| 15 |
+
├── data_processor/ # データセット処理パイプライン
|
| 16 |
+
├── dataloader/ # HuggingFace データローダー
|
| 17 |
+
├── rag/ # RAG(SQLite ベース文書 DB)
|
| 18 |
+
├── subclaim_processor/ # 回答品質評価の中核パイプライン
|
| 19 |
+
└── utils/ # 汎用ヘルパー関数
|
| 20 |
+
|
| 21 |
+
demo/
|
| 22 |
+
├── app.py # Streamlit エントリーポイント
|
| 23 |
+
├── inference_api.py # 単一クエリ推論ラッパー(src/ を呼び出す)
|
| 24 |
+
└── precompute.py # 閾値の事前計算スクリプト(オフライン実行用)
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## コンテキストと設計
|
| 30 |
+
|
| 31 |
+
詳細は `docs/context/` を参照。
|
| 32 |
+
|
| 33 |
+
| ドキュメント | 内容 |
|
| 34 |
+
|------------|------|
|
| 35 |
+
| `01_original_architecture.md` | アーキテクチャ全体像・データフロー・クラス設計 |
|
| 36 |
+
| `02_design_patterns_analysis.md` | 設計パターンの評価と改善案 |
|
| 37 |
+
| `03_demo_app_design.md` | デモアプリの技術選定・UI 設計・開発ステップ |
|
| 38 |
+
| `04_environment_setup.md` | 環境構築・実行手順(Singularity、`.env` 設定、スクリプト一覧) |
|
| 39 |
+
| `06_test_strategy.md` | テスト計画と実行フロー(ユニットテスト〜UI テスト) |
|
| 40 |
+
|
| 41 |
+
ロジックを変更した際、アーキテクチャに大きな変更が生じた場合は `docs/context/` 内のドキュメントも合わせて更新してください。
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## コーディングルール
|
| 46 |
+
|
| 47 |
+
### src/(コアロジック)
|
| 48 |
+
- `src/` 配下のコアロジックの修正・リファクタリングは許可されています。
|
| 49 |
+
- ただし、既存の `main.py` を使ったバッチ処理が壊れないよう、後方互換性を意識するか、影響範囲を報告してください。
|
| 50 |
+
|
| 51 |
+
### demo/(デモアプリ)
|
| 52 |
+
- デモ特有の UI ロジックやエンドポイントは `demo/` ディレクトリ内に隔離し、`src/` のコアロジックと密結合させないでください。
|
| 53 |
+
- `demo/` から `src/` モジュールを `import` するのみとし、コアロジックは変更しないでください。
|
| 54 |
+
- キャリブレーション閾値($\hat{q}$)は `demo/precompute.py` で事前計算し、JSON/CSV として静的保持してください。デモ実行時にキャリブレーションループを回さないでください。
|
| 55 |
+
|
| 56 |
+
### Git
|
| 57 |
+
- 機能単位でキリが良いところで `git commit` を提案してください。
|
README.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Conformal RAG Quality Assessment Demo
|
| 3 |
+
emoji: 📊
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.43.2
|
| 8 |
+
app_file: demo/app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Response Quality Assessment for Retrieval-Augmented Generation via Conditional Conformal Factuality
|
| 13 |
+
|
| 14 |
+
This repository contains production-ready code and resources for our research for paper "Response Quality Assessment for Retrieval-Augmented
|
| 15 |
+
Generation via Conditional Conformal Factuality" Accepted by SIGIR 2025
|
| 16 |
+
|
| 17 |
+
Arxiv ver: [https://www.arxiv.org/pdf/2506.20978](https://www.arxiv.org/abs/2506.20978)
|
| 18 |
+
|
| 19 |
+
Poster: https://drive.google.com/file/d/1k7fSa4k07dPJUCvsjU0hE0e486T5UNoy/view?usp=sharing
|
| 20 |
+
|
| 21 |
+
## Table of Contents
|
| 22 |
+
- [Structure](#Structure)
|
| 23 |
+
- [Data](#data)
|
| 24 |
+
- [Query Data](#query-data)
|
| 25 |
+
- [Wikipedia Extraction](#wikipedia-extraction)
|
| 26 |
+
- [Usage](#usage)
|
| 27 |
+
- [References](#references)
|
| 28 |
+
- [More Information](#more-information)
|
| 29 |
+
|
| 30 |
+
## Structure
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
.
|
| 34 |
+
├── conf/ # Configuration file location
|
| 35 |
+
├── data/
|
| 36 |
+
│ ├── out/ # Final subclaims with scores (follows `subclaims_schema`)
|
| 37 |
+
│ ├── processed/ # Standardized test data (follows `base_schema`)
|
| 38 |
+
│ └── raw/ # Original raw data from source (unstructured)
|
| 39 |
+
├── index_store/ # Chunked documents and embeddings
|
| 40 |
+
├── logs/ # Config and logs in format `run_{data}_{run_id}`
|
| 41 |
+
├── src/
|
| 42 |
+
│ ├── calibration/ # Conformal prediction calibration logic
|
| 43 |
+
│ ├── common/ # Reusable components (e.g., config manager, FAISS vector DB manager)
|
| 44 |
+
│ ├── data_processor/ # Converts raw QA data to standardized format (see `data/processed`)
|
| 45 |
+
│ ├── dataloader/ # Loads data from source datasets (e.g., AkariASAI/PopQA, KILT benchmark)
|
| 46 |
+
│ ├── rag/ # RAG system components for document retrieval
|
| 47 |
+
│ ├── subclaim_processor/# Generates, scores, and annotates subclaims for different datasets
|
| 48 |
+
│ └── utils/ # Miscellaneous utilities
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
## Data
|
| 53 |
+
### Query Data
|
| 54 |
+
This repository includes the following query datasets:
|
| 55 |
+
- [FactScore](https://github.com/shmsw25/FActScore)
|
| 56 |
+
- [PopQA](https://huggingface.co/datasets/akariasai/PopQA)
|
| 57 |
+
- [HotpotQA](https://huggingface.co/datasets/hotpotqa/hotpot_qa)
|
| 58 |
+
- [MedLFQA] (https://github.com/dmis-lab/OLAPH/tree/main/MedLFQA) or (https://github.com/jjcherian/conformal-safety/tree/main/data/MedLFQAv2)
|
| 59 |
+
|
| 60 |
+
### Wikipedia Extraction
|
| 61 |
+
We utilize Wikipedia dumps for knowledge retrieval:
|
| 62 |
+
- [enwiki-20230401.db](https://drive.google.com/file/d/1mekls6OGOKLmt7gYtHs0WGf5oTamTNat/view?usp=drive_link)
|
| 63 |
+
This file is not included in this github, you could download it through ming's google drive above (source: https://github.com/shmsw25/FActScore) and put it under
|
| 64 |
+
\data\raw folder in order to generate reference doucument for wiki based queries (popqa and hotpotqa)
|
| 65 |
+
|
| 66 |
+
## Usage
|
| 67 |
+
Project is build on python version 3.11
|
| 68 |
+
First, set up project env using [requirements.txt](requirements.txt).
|
| 69 |
+
To run the pipeline:
|
| 70 |
+
```python
|
| 71 |
+
python main.py --config conf/config.yaml --dataset fact_score --query_size 500
|
| 72 |
+
```
|
| 73 |
+
Only 1 dataset at a time in 1 thread.
|
| 74 |
+
avaliable dataset currently are:
|
| 75 |
+
["fact_score", "hotpot_qa", "pop_qa", "medlf_qa"]
|
| 76 |
+
|
| 77 |
+
## Conditional Conformal
|
| 78 |
+
This repo only support conditional conformal in medlf_qa dataset. By default the config in /conf/dataset_config.yaml
|
| 79 |
+
The medlf_qa.is_grouped = true while other are set to false
|
| 80 |
+
The factuality result will be put in different csv files under the result/${datetime}_${run_id} folder naming by each different pre-defined group name
|
| 81 |
+
|
| 82 |
+
## Start with metadata
|
| 83 |
+
In order to not have any OpenAI token comsumption, one can choose use these metadata already produced to just verify conformal prediction part. The result will be stable
|
| 84 |
+
You can get required metadata here: https://drive.google.com/drive/folders/1aLbHxS6V1ipMH8FpVCxKmr8oMYfqmRgb?usp=drive_link
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
## More Information
|
| 88 |
+
For further details, please refer to our Paper (link on top)
|
| 89 |
+
The baseline group conditional conformal (https://arxiv.org/abs/2406.09714) result
|
| 90 |
+
for medlfqav2 is produced by their code: github.com/jjcherian/conformal-safety
|
| 91 |
+
and is not in part of this repo.
|
| 92 |
+
|
| 93 |
+
## License
|
| 94 |
+
|
| 95 |
+
This project is licensed under the [MIT License](https://opensource.org/license/mit).
|
conf/config.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset:
|
| 2 |
+
name: "pop_qa" # Options: fact_score, hotpot_qa, pop_qa, medlf_qa
|
| 3 |
+
query_size: 500
|
| 4 |
+
wiki_db_file: "enwiki-20230401.db"
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Index configuration
|
| 8 |
+
index:
|
| 9 |
+
delete_existing: false
|
| 10 |
+
embedding_model: "text-embedding-3-large"
|
| 11 |
+
# Truncation strategies by dataset
|
| 12 |
+
truncation_config: # TODO
|
| 13 |
+
strategy: "fixed_length" # false
|
| 14 |
+
truncate_by: null # "\n"
|
| 15 |
+
chunk_size: 2000
|
| 16 |
+
chunk_overlap: 25
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# RAG configuration
|
| 20 |
+
rag:
|
| 21 |
+
retrival_topk: 10
|
| 22 |
+
retrival_threshold: 0.3
|
| 23 |
+
response_model: "gpt-4.1-mini"
|
| 24 |
+
response_temperature: 0.7
|
| 25 |
+
fact_generation_model: "gpt-4.1-mini"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Prediction configuration
|
| 29 |
+
conformal_prediction:
|
| 30 |
+
aggregation_strategy: "mean"
|
| 31 |
+
scoring_strategy: "product"
|
| 32 |
+
claim_verification_model: "gpt-4.1-mini" # "gpt-4o-mini"
|
| 33 |
+
frequency_score_model: "gpt-4.1-mini"
|
| 34 |
+
split_conformal: true
|
| 35 |
+
conformal_alphas:
|
| 36 |
+
start: 0.05
|
| 37 |
+
end: 0.45
|
| 38 |
+
step: 0.05
|
| 39 |
+
a_value: 1.0
|
| 40 |
+
|
conf/dataset_config.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
datasets:
|
| 2 |
+
fact_score:
|
| 3 |
+
name: "FactScore"
|
| 4 |
+
index_store: "${DATA_ROOT}/index_store/FactScore"
|
| 5 |
+
is_grouped: false
|
| 6 |
+
hotpot_qa:
|
| 7 |
+
name: "HotpotQA"
|
| 8 |
+
index_store: "${DATA_ROOT}/index_store/HotpotQA"
|
| 9 |
+
is_grouped: false
|
| 10 |
+
pop_qa:
|
| 11 |
+
name: "PopQA"
|
| 12 |
+
index_store: "${DATA_ROOT}/index_store/PopQA"
|
| 13 |
+
is_grouped: false
|
| 14 |
+
medlf_qa:
|
| 15 |
+
name: "MedLFQA"
|
| 16 |
+
index_store: "${DATA_ROOT}/index_store/MedLFQA"
|
| 17 |
+
is_grouped: true
|
conf/path_config.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Path configuration
|
| 2 |
+
# Paths are relative to DATA_ROOT (set in .env).
|
| 3 |
+
# The loader must call os.path.expandvars() after yaml.safe_load().
|
| 4 |
+
paths:
|
| 5 |
+
raw_data_dir: "${DATA_ROOT}/data/raw"
|
| 6 |
+
processed_data_dir: "${DATA_ROOT}/data/processed"
|
| 7 |
+
response_dir: "${DATA_ROOT}/data/out"
|
| 8 |
+
wiki_db_dir: "${DATA_ROOT}/data/raw/WikiDB"
|
| 9 |
+
result_dir: "${DATA_ROOT}/data/result"
|
data/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
data/processed/FactScore/fact_score_queries.json
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"input": "What is Lanny Flaherty's occupation?",
|
| 4 |
+
"output": {
|
| 5 |
+
"answer": "",
|
| 6 |
+
"provenance": [
|
| 7 |
+
{
|
| 8 |
+
"title": "Lanny Flaherty"
|
| 9 |
+
}
|
| 10 |
+
]
|
| 11 |
+
}
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"input": "What is Marianne McAndrew's occupation?",
|
| 15 |
+
"output": {
|
| 16 |
+
"answer": "",
|
| 17 |
+
"provenance": [
|
| 18 |
+
{
|
| 19 |
+
"title": "Marianne McAndrew"
|
| 20 |
+
}
|
| 21 |
+
]
|
| 22 |
+
}
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"input": "What is Chaim Malinowitz's occupation?",
|
| 26 |
+
"output": {
|
| 27 |
+
"answer": "",
|
| 28 |
+
"provenance": [
|
| 29 |
+
{
|
| 30 |
+
"title": "Chaim Malinowitz"
|
| 31 |
+
}
|
| 32 |
+
]
|
| 33 |
+
}
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"input": "What is Doug Sheehan's occupation?",
|
| 37 |
+
"output": {
|
| 38 |
+
"answer": "",
|
| 39 |
+
"provenance": [
|
| 40 |
+
{
|
| 41 |
+
"title": "Doug Sheehan"
|
| 42 |
+
}
|
| 43 |
+
]
|
| 44 |
+
}
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"input": "What is Gerhard Fischer (inventor)'s occupation?",
|
| 48 |
+
"output": {
|
| 49 |
+
"answer": "",
|
| 50 |
+
"provenance": [
|
| 51 |
+
{
|
| 52 |
+
"title": "Gerhard Fischer (inventor)"
|
| 53 |
+
}
|
| 54 |
+
]
|
| 55 |
+
}
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"input": "What is Chief Jones's occupation?",
|
| 59 |
+
"output": {
|
| 60 |
+
"answer": "",
|
| 61 |
+
"provenance": [
|
| 62 |
+
{
|
| 63 |
+
"title": "Chief Jones"
|
| 64 |
+
}
|
| 65 |
+
]
|
| 66 |
+
}
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"input": "What is Focus...'s occupation?",
|
| 70 |
+
"output": {
|
| 71 |
+
"answer": "",
|
| 72 |
+
"provenance": [
|
| 73 |
+
{
|
| 74 |
+
"title": "Focus..."
|
| 75 |
+
}
|
| 76 |
+
]
|
| 77 |
+
}
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"input": "What is Joey D. Vieira's occupation?",
|
| 81 |
+
"output": {
|
| 82 |
+
"answer": "",
|
| 83 |
+
"provenance": [
|
| 84 |
+
{
|
| 85 |
+
"title": "Joey D. Vieira"
|
| 86 |
+
}
|
| 87 |
+
]
|
| 88 |
+
}
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"input": "What is Taral Hicks's occupation?",
|
| 92 |
+
"output": {
|
| 93 |
+
"answer": "",
|
| 94 |
+
"provenance": [
|
| 95 |
+
{
|
| 96 |
+
"title": "Taral Hicks"
|
| 97 |
+
}
|
| 98 |
+
]
|
| 99 |
+
}
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"input": "What is Lees Knowles's occupation?",
|
| 103 |
+
"output": {
|
| 104 |
+
"answer": "",
|
| 105 |
+
"provenance": [
|
| 106 |
+
{
|
| 107 |
+
"title": "Lees Knowles"
|
| 108 |
+
}
|
| 109 |
+
]
|
| 110 |
+
}
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"input": "What is Carolina Portesi Peroni's occupation?",
|
| 114 |
+
"output": {
|
| 115 |
+
"answer": "",
|
| 116 |
+
"provenance": [
|
| 117 |
+
{
|
| 118 |
+
"title": "Carolina Portesi Peroni"
|
| 119 |
+
}
|
| 120 |
+
]
|
| 121 |
+
}
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"input": "What is Quintus Sosius Senecio's occupation?",
|
| 125 |
+
"output": {
|
| 126 |
+
"answer": "",
|
| 127 |
+
"provenance": [
|
| 128 |
+
{
|
| 129 |
+
"title": "Quintus Sosius Senecio"
|
| 130 |
+
}
|
| 131 |
+
]
|
| 132 |
+
}
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"input": "What is Joeri Adams's occupation?",
|
| 136 |
+
"output": {
|
| 137 |
+
"answer": "",
|
| 138 |
+
"provenance": [
|
| 139 |
+
{
|
| 140 |
+
"title": "Joeri Adams"
|
| 141 |
+
}
|
| 142 |
+
]
|
| 143 |
+
}
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"input": "What is Dominic King's occupation?",
|
| 147 |
+
"output": {
|
| 148 |
+
"answer": "",
|
| 149 |
+
"provenance": [
|
| 150 |
+
{
|
| 151 |
+
"title": "Dominic King"
|
| 152 |
+
}
|
| 153 |
+
]
|
| 154 |
+
}
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"input": "What is Vincenzo Tusa's occupation?",
|
| 158 |
+
"output": {
|
| 159 |
+
"answer": "",
|
| 160 |
+
"provenance": [
|
| 161 |
+
{
|
| 162 |
+
"title": "Vincenzo Tusa"
|
| 163 |
+
}
|
| 164 |
+
]
|
| 165 |
+
}
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"input": "What is Jean Daull\u00e9's occupation?",
|
| 169 |
+
"output": {
|
| 170 |
+
"answer": "",
|
| 171 |
+
"provenance": [
|
| 172 |
+
{
|
| 173 |
+
"title": "Jean Daull\u00e9"
|
| 174 |
+
}
|
| 175 |
+
]
|
| 176 |
+
}
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"input": "What is Richard Burkewood Welbourn's occupation?",
|
| 180 |
+
"output": {
|
| 181 |
+
"answer": "",
|
| 182 |
+
"provenance": [
|
| 183 |
+
{
|
| 184 |
+
"title": "Richard Burkewood Welbourn"
|
| 185 |
+
}
|
| 186 |
+
]
|
| 187 |
+
}
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"input": "What is Damir Memovi\u0107's occupation?",
|
| 191 |
+
"output": {
|
| 192 |
+
"answer": "",
|
| 193 |
+
"provenance": [
|
| 194 |
+
{
|
| 195 |
+
"title": "Damir Memovi\u0107"
|
| 196 |
+
}
|
| 197 |
+
]
|
| 198 |
+
}
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"input": "What is Mar\u00eda Elena Medina-Mora Icaza's occupation?",
|
| 202 |
+
"output": {
|
| 203 |
+
"answer": "",
|
| 204 |
+
"provenance": [
|
| 205 |
+
{
|
| 206 |
+
"title": "Mar\u00eda Elena Medina-Mora Icaza"
|
| 207 |
+
}
|
| 208 |
+
]
|
| 209 |
+
}
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"input": "What is Hesham Nazih's occupation?",
|
| 213 |
+
"output": {
|
| 214 |
+
"answer": "",
|
| 215 |
+
"provenance": [
|
| 216 |
+
{
|
| 217 |
+
"title": "Hesham Nazih"
|
| 218 |
+
}
|
| 219 |
+
]
|
| 220 |
+
}
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"input": "What is Luis Guillermo Rivera's occupation?",
|
| 224 |
+
"output": {
|
| 225 |
+
"answer": "",
|
| 226 |
+
"provenance": [
|
| 227 |
+
{
|
| 228 |
+
"title": "Luis Guillermo Rivera"
|
| 229 |
+
}
|
| 230 |
+
]
|
| 231 |
+
}
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"input": "What is Desmond Luke's occupation?",
|
| 235 |
+
"output": {
|
| 236 |
+
"answer": "",
|
| 237 |
+
"provenance": [
|
| 238 |
+
{
|
| 239 |
+
"title": "Desmond Luke"
|
| 240 |
+
}
|
| 241 |
+
]
|
| 242 |
+
}
|
| 243 |
+
},
|
| 244 |
+
{
|
| 245 |
+
"input": "What is Miguel D\u00edaz (baseball)'s occupation?",
|
| 246 |
+
"output": {
|
| 247 |
+
"answer": "",
|
| 248 |
+
"provenance": [
|
| 249 |
+
{
|
| 250 |
+
"title": "Miguel D\u00edaz (baseball)"
|
| 251 |
+
}
|
| 252 |
+
]
|
| 253 |
+
}
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"input": "What is Roxana D\u00edaz (athlete)'s occupation?",
|
| 257 |
+
"output": {
|
| 258 |
+
"answer": "",
|
| 259 |
+
"provenance": [
|
| 260 |
+
{
|
| 261 |
+
"title": "Roxana D\u00edaz (athlete)"
|
| 262 |
+
}
|
| 263 |
+
]
|
| 264 |
+
}
|
| 265 |
+
},
|
| 266 |
+
{
|
| 267 |
+
"input": "What is Fernando (footballer, born 1984)'s occupation?",
|
| 268 |
+
"output": {
|
| 269 |
+
"answer": "",
|
| 270 |
+
"provenance": [
|
| 271 |
+
{
|
| 272 |
+
"title": "Fernando (footballer, born 1984)"
|
| 273 |
+
}
|
| 274 |
+
]
|
| 275 |
+
}
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"input": "What is Gerardo Fernandez Fe's occupation?",
|
| 279 |
+
"output": {
|
| 280 |
+
"answer": "",
|
| 281 |
+
"provenance": [
|
| 282 |
+
{
|
| 283 |
+
"title": "Gerardo Fernandez Fe"
|
| 284 |
+
}
|
| 285 |
+
]
|
| 286 |
+
}
|
| 287 |
+
},
|
| 288 |
+
{
|
| 289 |
+
"input": "What is Augusto Mart\u00ednez S\u00e1nchez's occupation?",
|
| 290 |
+
"output": {
|
| 291 |
+
"answer": "",
|
| 292 |
+
"provenance": [
|
| 293 |
+
{
|
| 294 |
+
"title": "Augusto Mart\u00ednez S\u00e1nchez"
|
| 295 |
+
}
|
| 296 |
+
]
|
| 297 |
+
}
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"input": "What is Sharad Kumar (athlete)'s occupation?",
|
| 301 |
+
"output": {
|
| 302 |
+
"answer": "",
|
| 303 |
+
"provenance": [
|
| 304 |
+
{
|
| 305 |
+
"title": "Sharad Kumar (athlete)"
|
| 306 |
+
}
|
| 307 |
+
]
|
| 308 |
+
}
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"input": "What is Ko Itakura's occupation?",
|
| 312 |
+
"output": {
|
| 313 |
+
"answer": "",
|
| 314 |
+
"provenance": [
|
| 315 |
+
{
|
| 316 |
+
"title": "Ko Itakura"
|
| 317 |
+
}
|
| 318 |
+
]
|
| 319 |
+
}
|
| 320 |
+
},
|
| 321 |
+
{
|
| 322 |
+
"input": "What is Besant Ravi's occupation?",
|
| 323 |
+
"output": {
|
| 324 |
+
"answer": "",
|
| 325 |
+
"provenance": [
|
| 326 |
+
{
|
| 327 |
+
"title": "Besant Ravi"
|
| 328 |
+
}
|
| 329 |
+
]
|
| 330 |
+
}
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"input": "What is Thomas Houghton (rugby league)'s occupation?",
|
| 334 |
+
"output": {
|
| 335 |
+
"answer": "",
|
| 336 |
+
"provenance": [
|
| 337 |
+
{
|
| 338 |
+
"title": "Thomas Houghton (rugby league)"
|
| 339 |
+
}
|
| 340 |
+
]
|
| 341 |
+
}
|
| 342 |
+
},
|
| 343 |
+
{
|
| 344 |
+
"input": "What is Katsunosuke Hori's occupation?",
|
| 345 |
+
"output": {
|
| 346 |
+
"answer": "",
|
| 347 |
+
"provenance": [
|
| 348 |
+
{
|
| 349 |
+
"title": "Katsunosuke Hori"
|
| 350 |
+
}
|
| 351 |
+
]
|
| 352 |
+
}
|
| 353 |
+
},
|
| 354 |
+
{
|
| 355 |
+
"input": "What is Ra Jong-yil's occupation?",
|
| 356 |
+
"output": {
|
| 357 |
+
"answer": "",
|
| 358 |
+
"provenance": [
|
| 359 |
+
{
|
| 360 |
+
"title": "Ra Jong-yil"
|
| 361 |
+
}
|
| 362 |
+
]
|
| 363 |
+
}
|
| 364 |
+
},
|
| 365 |
+
{
|
| 366 |
+
"input": "What is Song Kang's occupation?",
|
| 367 |
+
"output": {
|
| 368 |
+
"answer": "",
|
| 369 |
+
"provenance": [
|
| 370 |
+
{
|
| 371 |
+
"title": "Song Kang"
|
| 372 |
+
}
|
| 373 |
+
]
|
| 374 |
+
}
|
| 375 |
+
},
|
| 376 |
+
{
|
| 377 |
+
"input": "What is Rin Iwanaga's occupation?",
|
| 378 |
+
"output": {
|
| 379 |
+
"answer": "",
|
| 380 |
+
"provenance": [
|
| 381 |
+
{
|
| 382 |
+
"title": "Rin Iwanaga"
|
| 383 |
+
}
|
| 384 |
+
]
|
| 385 |
+
}
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"input": "What is Keita Kadokura's occupation?",
|
| 389 |
+
"output": {
|
| 390 |
+
"answer": "",
|
| 391 |
+
"provenance": [
|
| 392 |
+
{
|
| 393 |
+
"title": "Keita Kadokura"
|
| 394 |
+
}
|
| 395 |
+
]
|
| 396 |
+
}
|
| 397 |
+
},
|
| 398 |
+
{
|
| 399 |
+
"input": "What is Lily Branscombe's occupation?",
|
| 400 |
+
"output": {
|
| 401 |
+
"answer": "",
|
| 402 |
+
"provenance": [
|
| 403 |
+
{
|
| 404 |
+
"title": "Lily Branscombe"
|
| 405 |
+
}
|
| 406 |
+
]
|
| 407 |
+
}
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"input": "What is Luis N. Rivera-Pag\u00e1n's occupation?",
|
| 411 |
+
"output": {
|
| 412 |
+
"answer": "",
|
| 413 |
+
"provenance": [
|
| 414 |
+
{
|
| 415 |
+
"title": "Luis N. Rivera-Pag\u00e1n"
|
| 416 |
+
}
|
| 417 |
+
]
|
| 418 |
+
}
|
| 419 |
+
},
|
| 420 |
+
{
|
| 421 |
+
"input": "What is Jesse Foppert's occupation?",
|
| 422 |
+
"output": {
|
| 423 |
+
"answer": "",
|
| 424 |
+
"provenance": [
|
| 425 |
+
{
|
| 426 |
+
"title": "Jesse Foppert"
|
| 427 |
+
}
|
| 428 |
+
]
|
| 429 |
+
}
|
| 430 |
+
},
|
| 431 |
+
{
|
| 432 |
+
"input": "What is Mike Trivisonno's occupation?",
|
| 433 |
+
"output": {
|
| 434 |
+
"answer": "",
|
| 435 |
+
"provenance": [
|
| 436 |
+
{
|
| 437 |
+
"title": "Mike Trivisonno"
|
| 438 |
+
}
|
| 439 |
+
]
|
| 440 |
+
}
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
"input": "What is Zelma Wilson's occupation?",
|
| 444 |
+
"output": {
|
| 445 |
+
"answer": "",
|
| 446 |
+
"provenance": [
|
| 447 |
+
{
|
| 448 |
+
"title": "Zelma Wilson"
|
| 449 |
+
}
|
| 450 |
+
]
|
| 451 |
+
}
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"input": "What is Michael Valpy's occupation?",
|
| 455 |
+
"output": {
|
| 456 |
+
"answer": "",
|
| 457 |
+
"provenance": [
|
| 458 |
+
{
|
| 459 |
+
"title": "Michael Valpy"
|
| 460 |
+
}
|
| 461 |
+
]
|
| 462 |
+
}
|
| 463 |
+
},
|
| 464 |
+
{
|
| 465 |
+
"input": "What is Prince Rivers's occupation?",
|
| 466 |
+
"output": {
|
| 467 |
+
"answer": "",
|
| 468 |
+
"provenance": [
|
| 469 |
+
{
|
| 470 |
+
"title": "Prince Rivers"
|
| 471 |
+
}
|
| 472 |
+
]
|
| 473 |
+
}
|
| 474 |
+
},
|
| 475 |
+
{
|
| 476 |
+
"input": "What is John Estes's occupation?",
|
| 477 |
+
"output": {
|
| 478 |
+
"answer": "",
|
| 479 |
+
"provenance": [
|
| 480 |
+
{
|
| 481 |
+
"title": "John Estes"
|
| 482 |
+
}
|
| 483 |
+
]
|
| 484 |
+
}
|
| 485 |
+
},
|
| 486 |
+
{
|
| 487 |
+
"input": "What is Eric Hacker's occupation?",
|
| 488 |
+
"output": {
|
| 489 |
+
"answer": "",
|
| 490 |
+
"provenance": [
|
| 491 |
+
{
|
| 492 |
+
"title": "Eric Hacker"
|
| 493 |
+
}
|
| 494 |
+
]
|
| 495 |
+
}
|
| 496 |
+
},
|
| 497 |
+
{
|
| 498 |
+
"input": "What is Alma Katsu's occupation?",
|
| 499 |
+
"output": {
|
| 500 |
+
"answer": "",
|
| 501 |
+
"provenance": [
|
| 502 |
+
{
|
| 503 |
+
"title": "Alma Katsu"
|
| 504 |
+
}
|
| 505 |
+
]
|
| 506 |
+
}
|
| 507 |
+
},
|
| 508 |
+
{
|
| 509 |
+
"input": "What is Carlos Alfonso's occupation?",
|
| 510 |
+
"output": {
|
| 511 |
+
"answer": "",
|
| 512 |
+
"provenance": [
|
| 513 |
+
{
|
| 514 |
+
"title": "Carlos Alfonso"
|
| 515 |
+
}
|
| 516 |
+
]
|
| 517 |
+
}
|
| 518 |
+
},
|
| 519 |
+
{
|
| 520 |
+
"input": "What is Rennie Fritchie, Baroness Fritchie's occupation?",
|
| 521 |
+
"output": {
|
| 522 |
+
"answer": "",
|
| 523 |
+
"provenance": [
|
| 524 |
+
{
|
| 525 |
+
"title": "Rennie Fritchie, Baroness Fritchie"
|
| 526 |
+
}
|
| 527 |
+
]
|
| 528 |
+
}
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"input": "What is Daniel Charles's occupation?",
|
| 532 |
+
"output": {
|
| 533 |
+
"answer": "",
|
| 534 |
+
"provenance": [
|
| 535 |
+
{
|
| 536 |
+
"title": "Daniel Charles"
|
| 537 |
+
}
|
| 538 |
+
]
|
| 539 |
+
}
|
| 540 |
+
},
|
| 541 |
+
{
|
| 542 |
+
"input": "What is Zamfir Arbore's occupation?",
|
| 543 |
+
"output": {
|
| 544 |
+
"answer": "",
|
| 545 |
+
"provenance": [
|
| 546 |
+
{
|
| 547 |
+
"title": "Zamfir Arbore"
|
| 548 |
+
}
|
| 549 |
+
]
|
| 550 |
+
}
|
| 551 |
+
}
|
| 552 |
+
]
|
data/processed/HotpotQA/hotpot_qa_queries.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/processed/PopQA/pop_qa_queries.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/processed/base_schema.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"$schema": "https://json-schema.org/draft-07/schema",
|
| 3 |
+
"type": "object",
|
| 4 |
+
"properties": {
|
| 5 |
+
"input": { "type": "string" },
|
| 6 |
+
"output": {
|
| 7 |
+
"type": "object",
|
| 8 |
+
"properties": {
|
| 9 |
+
"answer": { "type": "string" }
|
| 10 |
+
},
|
| 11 |
+
"required": ["answer"]
|
| 12 |
+
}
|
| 13 |
+
},
|
| 14 |
+
"required": ["input", "output"]
|
| 15 |
+
}
|
| 16 |
+
|
data/processed/document_schema.json
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"$schema": "https://json-schema.org/draft-07/schema",
|
| 3 |
+
"type": "object",
|
| 4 |
+
"unevaluatedProperties": {
|
| 5 |
+
"type": "array",
|
| 6 |
+
"minItems": 1,
|
| 7 |
+
"items": {
|
| 8 |
+
"type": "string"
|
| 9 |
+
}
|
| 10 |
+
}
|
| 11 |
+
}
|
| 12 |
+
|
data/processed/wiki_schema.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"$schema": "https://json-schema.org/draft-07/schema",
|
| 3 |
+
"allOf": [
|
| 4 |
+
{ "$ref": "base_schema.json" },
|
| 5 |
+
{
|
| 6 |
+
"properties": {
|
| 7 |
+
"output": {
|
| 8 |
+
"properties": {
|
| 9 |
+
"provenance": {
|
| 10 |
+
"type": "array",
|
| 11 |
+
"items": {
|
| 12 |
+
"type": "object",
|
| 13 |
+
"properties": {
|
| 14 |
+
"wikipedia_id": { "type": "integer" },
|
| 15 |
+
"title": { "type": ["string", "null"] }
|
| 16 |
+
},
|
| 17 |
+
"required": ["title"]
|
| 18 |
+
}
|
| 19 |
+
}
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
|
demo/__init__.py
ADDED
|
File without changes
|
demo/app.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo/app.py
|
| 2 |
+
# Streamlit entry point. Run with: streamlit run demo/app.py
|
| 3 |
+
#
|
| 4 |
+
# Prerequisites:
|
| 5 |
+
# 1. FAISS indices built by main.py
|
| 6 |
+
# 2. demo/data/thresholds.csv and demo/data/samples.json built by precompute.py
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import time
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import streamlit as st
|
| 14 |
+
|
| 15 |
+
# When SPACES_DEMO=1, live inference (FAISS + LLM pipeline) is disabled.
|
| 16 |
+
# Only precomputed sample queries are available. Set this in HF Spaces Secrets.
|
| 17 |
+
SPACES_DEMO: bool = os.getenv("SPACES_DEMO", "0") == "1"
|
| 18 |
+
|
| 19 |
+
from demo.constants import ALPHA_LEVELS, DEFAULT_SCORING_METHOD
|
| 20 |
+
from demo.inference_api import (
|
| 21 |
+
SubclaimResult,
|
| 22 |
+
FilteredResult,
|
| 23 |
+
build_faiss_manager,
|
| 24 |
+
build_scorer,
|
| 25 |
+
process_query,
|
| 26 |
+
apply_threshold,
|
| 27 |
+
reintegrate_subclaims,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
| 31 |
+
|
| 32 |
+
DATASET_LABELS: dict[str, str] = {
|
| 33 |
+
"fact_score": "FactScore",
|
| 34 |
+
"hotpot_qa": "HotpotQA",
|
| 35 |
+
"pop_qa": "PopQA",
|
| 36 |
+
"medlf_qa": "MedLFQA",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
MEDLF_GROUPS: list[str] = [
|
| 40 |
+
"healthsearch_qa",
|
| 41 |
+
"medication_qa",
|
| 42 |
+
"kqa_silver_wogold",
|
| 43 |
+
"kqa_golden",
|
| 44 |
+
"live_qa",
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# Slider options in ascending order (95 → highest factuality at the right)
|
| 48 |
+
FACTUALITY_OPTIONS: list[int] = sorted(
|
| 49 |
+
int(round((1 - a) * 100)) for a in ALPHA_LEVELS
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
THRESHOLDS_PATH = "demo/data/thresholds.csv"
|
| 53 |
+
SAMPLES_PATH = "demo/data/samples.json"
|
| 54 |
+
|
| 55 |
+
# ── Cached resources ──────────────────────────────────────────────────────────
|
| 56 |
+
|
| 57 |
+
@st.cache_resource
|
| 58 |
+
def get_faiss_manager(dataset: str):
|
| 59 |
+
return build_faiss_manager(dataset)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@st.cache_resource
|
| 63 |
+
def get_scorer(dataset: str):
|
| 64 |
+
return build_scorer(dataset)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@st.cache_data
|
| 68 |
+
def load_thresholds() -> pd.DataFrame:
|
| 69 |
+
return pd.read_csv(THRESHOLDS_PATH)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@st.cache_data
|
| 73 |
+
def load_samples() -> list[dict]:
|
| 74 |
+
with open(SAMPLES_PATH, encoding="utf-8") as f:
|
| 75 |
+
return json.load(f)["samples"]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ── Helpers ───────────────────────────────────────────────────────────────────
|
| 79 |
+
|
| 80 |
+
def _lookup_q_hat(
|
| 81 |
+
df: pd.DataFrame, dataset: str, mode: str, group: str, alpha: float
|
| 82 |
+
) -> float:
|
| 83 |
+
alpha = round(alpha, 2)
|
| 84 |
+
row = df.query(
|
| 85 |
+
"dataset == @dataset and mode == @mode and "
|
| 86 |
+
"scoring_method == @DEFAULT_SCORING_METHOD and "
|
| 87 |
+
"group == @group and alpha == @alpha"
|
| 88 |
+
)
|
| 89 |
+
if row.empty:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"Threshold not found for "
|
| 92 |
+
f"({dataset}, {mode}, {DEFAULT_SCORING_METHOD}, {group}, α={alpha}). "
|
| 93 |
+
"Re-run precompute.py."
|
| 94 |
+
)
|
| 95 |
+
return float(row["q_hat"].iloc[0])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _render_subclaims(result: SubclaimResult, q_hat: float) -> str:
|
| 99 |
+
"""Return HTML with keep/remove highlighting for each subclaim."""
|
| 100 |
+
parts = []
|
| 101 |
+
for sc in result["subclaims"]:
|
| 102 |
+
text = sc["text"]
|
| 103 |
+
if sc["score"] >= q_hat:
|
| 104 |
+
parts.append(
|
| 105 |
+
f'<span style="padding:2px 5px; display:inline-block;">'
|
| 106 |
+
f"{text}</span>"
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
parts.append(
|
| 110 |
+
f'<span style="text-decoration:line-through; color:#dc3545; '
|
| 111 |
+
f'padding:2px 5px; display:inline-block;">'
|
| 112 |
+
f"{text}</span>"
|
| 113 |
+
)
|
| 114 |
+
return " ".join(parts)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _extract_page_content(doc: str) -> str:
|
| 118 |
+
"""Extract page_content text from FAISSIndexManager search result string.
|
| 119 |
+
|
| 120 |
+
Strips the source query prefix that is prepended in the format
|
| 121 |
+
"source_text: document_text".
|
| 122 |
+
"""
|
| 123 |
+
m = re.search(r"page_content='(.*?)'\s+metadata=\{", doc, re.DOTALL)
|
| 124 |
+
if not m:
|
| 125 |
+
return doc
|
| 126 |
+
content = m.group(1)
|
| 127 |
+
source_m = re.search(r"'source':\s*'(.*?)'", doc)
|
| 128 |
+
if source_m:
|
| 129 |
+
prefix = source_m.group(1) + ": "
|
| 130 |
+
if content.startswith(prefix):
|
| 131 |
+
content = content[len(prefix):]
|
| 132 |
+
return content
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _params_changed(dataset: str, mode: str, group: str) -> bool:
|
| 136 |
+
prev = st.session_state.get("prev_params")
|
| 137 |
+
return prev != (dataset, mode, group)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _save_params(dataset: str, mode: str, group: str) -> None:
|
| 141 |
+
st.session_state.prev_params = (dataset, mode, group)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ── App ───────────────────────────────────────────────────────────────────────
|
| 145 |
+
|
| 146 |
+
def main() -> None:
|
| 147 |
+
st.set_page_config(page_title="Conformal RAG Demo", layout="wide")
|
| 148 |
+
st.title("Conformal RAG — 回答品質評価デモ")
|
| 149 |
+
|
| 150 |
+
# Session state defaults
|
| 151 |
+
st.session_state.setdefault("query_processed", False)
|
| 152 |
+
st.session_state.setdefault("result", None)
|
| 153 |
+
st.session_state.setdefault("elapsed_sec", None)
|
| 154 |
+
st.session_state.setdefault("prev_params", None)
|
| 155 |
+
st.session_state.setdefault("integrated_answer", None)
|
| 156 |
+
st.session_state.setdefault("integrated_alpha", None)
|
| 157 |
+
|
| 158 |
+
# Load static data (cached after first call)
|
| 159 |
+
try:
|
| 160 |
+
thresholds_df = load_thresholds()
|
| 161 |
+
samples = load_samples()
|
| 162 |
+
except FileNotFoundError as e:
|
| 163 |
+
st.error(
|
| 164 |
+
f"データファイルが見つかりません: {e}\n\n"
|
| 165 |
+
"`python -m demo.precompute` を実行してください。"
|
| 166 |
+
)
|
| 167 |
+
return
|
| 168 |
+
|
| 169 |
+
# ── Sidebar ──────────────────────────────────────────────────────────────
|
| 170 |
+
|
| 171 |
+
with st.sidebar:
|
| 172 |
+
st.header("設定")
|
| 173 |
+
|
| 174 |
+
# Dataset
|
| 175 |
+
dataset: str = st.selectbox(
|
| 176 |
+
"データセット",
|
| 177 |
+
options=list(DATASET_LABELS.keys()),
|
| 178 |
+
format_func=lambda k: DATASET_LABELS[k],
|
| 179 |
+
)
|
| 180 |
+
is_medlf = dataset == "medlf_qa"
|
| 181 |
+
|
| 182 |
+
# Mode (MedLFQA only)
|
| 183 |
+
mode_raw: str = st.radio(
|
| 184 |
+
"モード",
|
| 185 |
+
options=["marginal", "conditional"],
|
| 186 |
+
format_func=lambda m: (
|
| 187 |
+
"Marginal(全体)" if m == "marginal" else "Conditional(グループ別)"
|
| 188 |
+
),
|
| 189 |
+
disabled=not is_medlf,
|
| 190 |
+
)
|
| 191 |
+
mode: str = mode_raw if is_medlf else "marginal"
|
| 192 |
+
|
| 193 |
+
# Group (MedLFQA + Conditional only)
|
| 194 |
+
is_conditional = is_medlf and mode == "conditional"
|
| 195 |
+
if is_conditional:
|
| 196 |
+
group: str = st.selectbox("グループ", options=MEDLF_GROUPS)
|
| 197 |
+
else:
|
| 198 |
+
group = "default"
|
| 199 |
+
|
| 200 |
+
# Clear cached result when key params change
|
| 201 |
+
if _params_changed(dataset, mode, group):
|
| 202 |
+
st.session_state.query_processed = False
|
| 203 |
+
st.session_state.result = None
|
| 204 |
+
st.session_state.elapsed_sec = None
|
| 205 |
+
st.session_state.integrated_answer = None
|
| 206 |
+
st.session_state.integrated_alpha = None
|
| 207 |
+
_save_params(dataset, mode, group)
|
| 208 |
+
|
| 209 |
+
st.divider()
|
| 210 |
+
|
| 211 |
+
# Query input
|
| 212 |
+
dataset_samples = [
|
| 213 |
+
s for s in samples
|
| 214 |
+
if s["dataset"] == dataset
|
| 215 |
+
and (not is_conditional or s.get("group") == group)
|
| 216 |
+
]
|
| 217 |
+
sample_map: dict[str, dict] = {s["query"]: s for s in dataset_samples}
|
| 218 |
+
|
| 219 |
+
query_input: str = st.selectbox(
|
| 220 |
+
"サンプルクエリ", options=list(sample_map.keys())
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
st.divider()
|
| 224 |
+
|
| 225 |
+
run_btn = st.button("推論実行", use_container_width=True, type="primary")
|
| 226 |
+
|
| 227 |
+
# Factuality slider (active only after inference)
|
| 228 |
+
target_factuality_pct: int = st.select_slider(
|
| 229 |
+
"Target Factuality (1−α)",
|
| 230 |
+
options=FACTUALITY_OPTIONS,
|
| 231 |
+
value=90,
|
| 232 |
+
format_func=lambda v: f"{v}%",
|
| 233 |
+
disabled=not st.session_state.query_processed,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
alpha = round((100 - target_factuality_pct) / 100, 2)
|
| 237 |
+
|
| 238 |
+
# ── Heavy processing (button click only) ─────────────────────────────────
|
| 239 |
+
|
| 240 |
+
if run_btn and query_input:
|
| 241 |
+
st.session_state.integrated_answer = None
|
| 242 |
+
st.session_state.integrated_alpha = None
|
| 243 |
+
# Use precomputed result if available for this exact (mode, group)
|
| 244 |
+
precomputed = sample_map.get(query_input)
|
| 245 |
+
if (
|
| 246 |
+
precomputed is not None
|
| 247 |
+
and (
|
| 248 |
+
mode == "marginal"
|
| 249 |
+
or (precomputed["mode"] == mode and precomputed["group"] == group)
|
| 250 |
+
)
|
| 251 |
+
):
|
| 252 |
+
st.session_state.result = precomputed
|
| 253 |
+
st.session_state.elapsed_sec = None
|
| 254 |
+
st.session_state.query_processed = True
|
| 255 |
+
st.rerun()
|
| 256 |
+
elif SPACES_DEMO:
|
| 257 |
+
st.error("このデモではサンプルクエリのみ対応しています。")
|
| 258 |
+
else:
|
| 259 |
+
t0 = time.perf_counter()
|
| 260 |
+
with st.spinner("推論中… (30〜60 秒ほどかかります)"):
|
| 261 |
+
faiss_manager = get_faiss_manager(dataset)
|
| 262 |
+
scorer = get_scorer(dataset)
|
| 263 |
+
st.session_state.result = process_query(
|
| 264 |
+
query=query_input,
|
| 265 |
+
dataset=dataset,
|
| 266 |
+
mode=mode,
|
| 267 |
+
group=group,
|
| 268 |
+
faiss_manager=faiss_manager,
|
| 269 |
+
scorer=scorer,
|
| 270 |
+
)
|
| 271 |
+
st.session_state.elapsed_sec = time.perf_counter() - t0
|
| 272 |
+
st.session_state.query_processed = True
|
| 273 |
+
st.rerun()
|
| 274 |
+
|
| 275 |
+
# ── Results panel ─────────────────────────────────────────────────────────
|
| 276 |
+
|
| 277 |
+
if not st.session_state.query_processed or st.session_state.result is None:
|
| 278 |
+
st.info("サイドバーからクエリを選択して「推論実行」��押してください。")
|
| 279 |
+
return
|
| 280 |
+
|
| 281 |
+
result: SubclaimResult = st.session_state.result
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
q_hat = _lookup_q_hat(thresholds_df, dataset, mode, group, alpha)
|
| 285 |
+
except ValueError as e:
|
| 286 |
+
st.error(str(e))
|
| 287 |
+
return
|
| 288 |
+
|
| 289 |
+
filtered: FilteredResult = apply_threshold(result, q_hat)
|
| 290 |
+
|
| 291 |
+
# Query
|
| 292 |
+
st.subheader("クエリ")
|
| 293 |
+
st.markdown(result["query"])
|
| 294 |
+
|
| 295 |
+
st.divider()
|
| 296 |
+
|
| 297 |
+
# Original answer (left) | Re-integrated answer (right)
|
| 298 |
+
col_orig, col_integ = st.columns(2)
|
| 299 |
+
|
| 300 |
+
with col_orig:
|
| 301 |
+
st.subheader("元の RAG 回答")
|
| 302 |
+
st.markdown(result["rag_answer"])
|
| 303 |
+
|
| 304 |
+
with col_integ:
|
| 305 |
+
st.subheader("フィルタ後の回答")
|
| 306 |
+
|
| 307 |
+
if st.session_state.integrated_answer is not None:
|
| 308 |
+
if st.session_state.integrated_alpha != alpha:
|
| 309 |
+
st.warning(
|
| 310 |
+
f"α={st.session_state.integrated_alpha:.2f} 時点の結果です。"
|
| 311 |
+
"「回答を生成」を再度押して更新してください。"
|
| 312 |
+
)
|
| 313 |
+
st.markdown(st.session_state.integrated_answer)
|
| 314 |
+
|
| 315 |
+
if st.button("回答を生成", type="secondary"):
|
| 316 |
+
removed = [sc["text"].strip() for sc in result["subclaims"] if sc["score"] < q_hat]
|
| 317 |
+
with st.spinner("回答を生成中…"):
|
| 318 |
+
st.session_state.integrated_answer = reintegrate_subclaims(
|
| 319 |
+
rag_answer=result["rag_answer"],
|
| 320 |
+
removed_subclaims=removed,
|
| 321 |
+
)
|
| 322 |
+
st.session_state.integrated_alpha = alpha
|
| 323 |
+
st.rerun()
|
| 324 |
+
|
| 325 |
+
st.divider()
|
| 326 |
+
|
| 327 |
+
# Subclaims with keep/remove highlighting
|
| 328 |
+
st.subheader("Conformal-RAG による回答")
|
| 329 |
+
html = _render_subclaims(result, q_hat)
|
| 330 |
+
st.markdown(html, unsafe_allow_html=True)
|
| 331 |
+
|
| 332 |
+
st.divider()
|
| 333 |
+
|
| 334 |
+
# Retrieved documents
|
| 335 |
+
with st.expander(
|
| 336 |
+
f"検索されたナレッジ({len(result['retrieved_docs'])} 件)",
|
| 337 |
+
expanded=False,
|
| 338 |
+
):
|
| 339 |
+
idx = 1
|
| 340 |
+
for doc in result["retrieved_docs"]:
|
| 341 |
+
content = _extract_page_content(doc)
|
| 342 |
+
if len(content) < 15:
|
| 343 |
+
continue
|
| 344 |
+
st.markdown(f"**[{idx}]** {content}")
|
| 345 |
+
idx += 1
|
| 346 |
+
|
| 347 |
+
# Metrics
|
| 348 |
+
st.subheader("分析メトリクス")
|
| 349 |
+
cols = st.columns(4 if st.session_state.elapsed_sec is None else 5)
|
| 350 |
+
cols[0].metric("閾値 q̂", f"{q_hat:.4f}")
|
| 351 |
+
cols[1].metric("サブクレーム数", len(result["subclaims"]))
|
| 352 |
+
cols[2].metric("Keep", filtered["keep_count"])
|
| 353 |
+
cols[3].metric(
|
| 354 |
+
"削除率",
|
| 355 |
+
f"{filtered['remove_count'] / max(len(result['subclaims']), 1) * 100:.0f}%",
|
| 356 |
+
)
|
| 357 |
+
if st.session_state.elapsed_sec is not None:
|
| 358 |
+
cols[4].metric("推論時間", f"{st.session_state.elapsed_sec:.1f} 秒")
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
if __name__ == "__main__":
|
| 362 |
+
main()
|
demo/constants.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo/constants.py — shared constants; precompute.py and app.py must import from here.
|
| 2 |
+
|
| 3 |
+
# ── Scoring ──────────────────────────────────────────────────────────────────
|
| 4 |
+
|
| 5 |
+
# To switch scoring methods, change this line and re-run precompute.py.
|
| 6 |
+
DEFAULT_SCORING_METHOD: str = "relevance"
|
| 7 |
+
|
| 8 |
+
# Maps clean scoring method names to the keys used in subclaim["scores"]
|
| 9 |
+
# in data files produced by main.py.
|
| 10 |
+
# "relavance" is a typo in the original codebase — preserved for compatibility.
|
| 11 |
+
SCORING_METHOD_TO_DATA_KEY: dict[str, str] = {
|
| 12 |
+
"relevance": "relavance",
|
| 13 |
+
"cosine_similarity": "query_claim_cosine_similarity",
|
| 14 |
+
"min_log_prob": "min_log_prob",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
# ── Conformal prediction ─────────────────────────────────────────────────────
|
| 18 |
+
|
| 19 |
+
# alpha = 1 - target_factuality.
|
| 20 |
+
# precompute.py computes thresholds for exactly these values;
|
| 21 |
+
# app.py passes this list to st.select_slider.
|
| 22 |
+
# → Target Factuality: 95%, 90%, 85%, 80%, 75%, 70%, 65%, 60%
|
| 23 |
+
ALPHA_LEVELS: list[float] = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40]
|
demo/data/samples.json
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"generated_at": "2026-03-29T17:28:52.797800+00:00",
|
| 4 |
+
"model": "gpt-4.1-mini",
|
| 5 |
+
"scoring_method": "relevance"
|
| 6 |
+
},
|
| 7 |
+
"samples": [
|
| 8 |
+
{
|
| 9 |
+
"query": "Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?",
|
| 10 |
+
"dataset": "medlf_qa",
|
| 11 |
+
"mode": "conditional",
|
| 12 |
+
"group": "live_qa",
|
| 13 |
+
"scoring_method": "relevance",
|
| 14 |
+
"rag_answer": "Yes, you can stop using the menopausal hormone therapy patch after 4.5 months, but it is generally recommended that attempts to taper or discontinue the medication be made at intervals of 3 to 6 months. It is important to consider your symptoms and discuss with your healthcare provider the best approach for stopping the therapy safely.",
|
| 15 |
+
"retrieved_docs": [
|
| 16 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: For the treatment of vasomotor symptoms, treatment should be initiated with the 6' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=21 fileposition=21 score=0.8116",
|
| 17 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: Attempts to taper or discontinue the medication should be made at 3- to 6-month intervals' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=26 fileposition=26 score=0.7919",
|
| 18 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: Six (6' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=16 fileposition=16 score=0.7711",
|
| 19 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: The dose should be adjusted as necessary to control symptoms' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=24 fileposition=24 score=0.7696",
|
| 20 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: 5, 9' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=17 fileposition=17 score=0.7678",
|
| 21 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: This system is applied to the skin once weekly for the treatment of vasomotor symptoms related to menopause.' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=28 fileposition=28 score=0.7627",
|
| 22 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: 375, 12' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=18 fileposition=18 score=0.7588",
|
| 23 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: 5 cm2 (0' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=22 fileposition=22 score=0.7534",
|
| 24 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: 5, 15, 18' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=19 fileposition=19 score=0.7519",
|
| 25 |
+
"page_content='Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?: Clinical responses (relief of symptoms) at the lowest effective dose should be the guide for establishing administration of the Climara system, especially in women with an intact uterus' metadata={'source': 'Can I stop using the patch (menopausal hormone therapy) after only 4.5 months?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=25 fileposition=25 score=0.7484"
|
| 26 |
+
],
|
| 27 |
+
"subclaims": [
|
| 28 |
+
{
|
| 29 |
+
"id": 0,
|
| 30 |
+
"text": "You can stop using the menopausal hormone therapy patch after 4.5 months",
|
| 31 |
+
"score": 0.5595795678160917
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"id": 1,
|
| 35 |
+
"text": " It is generally recommended to attempt tapering or discontinuing the medication at intervals of 3 to 6 months",
|
| 36 |
+
"score": 0.31861125529509804
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"id": 2,
|
| 40 |
+
"text": " It is important to consider your symptoms when stopping menopausal hormone therapy",
|
| 41 |
+
"score": 0.3732023054183166
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"id": 3,
|
| 45 |
+
"text": " You should discuss with your healthcare provider the best approach for safely stopping the therapy",
|
| 46 |
+
"score": 0.2703604207878165
|
| 47 |
+
}
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"query": "Would you recommend something a bit stronger than meloxicam?",
|
| 52 |
+
"dataset": "medlf_qa",
|
| 53 |
+
"mode": "conditional",
|
| 54 |
+
"group": "kqa_golden",
|
| 55 |
+
"scoring_method": "relevance",
|
| 56 |
+
"rag_answer": "If you are looking for something stronger than meloxicam, options include combining drugs from different classes or considering opioid medications. Opioids such as tramadol (Conzip) and oxycodone (Oxyado, Oxycontin) are more potent pain relievers but come with significant risks, including addiction, toxicity, and adverse cognitive and respiratory effects. It's important to consult with a physician before changing your treatment regimen to carefully weigh the benefits and risks.",
|
| 57 |
+
"retrieved_docs": [
|
| 58 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: For a stronger pain relief effect, combining two drugs from different classes or medications from the opioid group is possible' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=41 fileposition=41 score=0.7034",
|
| 59 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: NSAIDs are generally more potent but carry a risk of harming the kidneys and other side effects' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=38 fileposition=38 score=0.6924",
|
| 60 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: In case of acute and severe pain, NSAIDs like Meloxicam are often recommended' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=40 fileposition=40 score=0.6871",
|
| 61 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: Opioids such as Tramadol (Conzip) and Oxycodone (Oxyado, Oxycontin) are the most potent but carry a significant risk of addiction, toxicity, and cognitive and respiratory adverse effects' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=39 fileposition=39 score=0.6863",
|
| 62 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: It is essential to consult with a physician before changing treatment regimens due to the risk of addiction and potential side effects' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=42 fileposition=42 score=0.6812",
|
| 63 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: There are generally three main types of drugs used for pain management: acetaminophen (paracetamol), NSAIDs (like Meloxicam), and opioids' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=36 fileposition=36 score=0.6810",
|
| 64 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: NSAIDs like Meloxicam carry a risk of harming the kidneys and other side effects' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=46 fileposition=46 score=0.6799",
|
| 65 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: Opioids such as Tramadol (Conzip) and Oxycodone (Oxyado, Oxycontin) are the most potent pain management drugs' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=47 fileposition=47 score=0.6730",
|
| 66 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: NSAIDs like Meloxicam are generally more potent than acetaminophen' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=45 fileposition=45 score=0.6688",
|
| 67 |
+
"page_content='Would you recommend something a bit stronger than meloxicam?: Acetaminophen (paracetamol) is the least potent but has the safest side effect profile' metadata={'source': 'Would you recommend something a bit stronger than meloxicam?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=37 fileposition=37 score=0.6409"
|
| 68 |
+
],
|
| 69 |
+
"subclaims": [
|
| 70 |
+
{
|
| 71 |
+
"id": 0,
|
| 72 |
+
"text": "If you are looking for something stronger than meloxicam, options include combining drugs from different classes",
|
| 73 |
+
"score": 0.38080484418115046
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"id": 1,
|
| 77 |
+
"text": " Opioids such as tramadol (Conzip) and oxycodone (Oxyado, Oxycontin) are more potent pain relievers",
|
| 78 |
+
"score": 0.2872582910028152
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"id": 2,
|
| 82 |
+
"text": " Opioids come with significant risks including addiction, toxicity, and adverse cognitive effects",
|
| 83 |
+
"score": 0.19291751764168302
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"id": 3,
|
| 87 |
+
"text": " Opioids also have adverse respiratory effects",
|
| 88 |
+
"score": 0.1868863743184725
|
| 89 |
+
},
|
| 90 |
+
{
|
| 91 |
+
"id": 4,
|
| 92 |
+
"text": " It is important to consult with a physician before changing your treatment regimen",
|
| 93 |
+
"score": 0.1655506771407161
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"id": 5,
|
| 97 |
+
"text": " Consulting a physician helps to carefully weigh the benefits and risks of the treatment",
|
| 98 |
+
"score": 0.17235453684696078
|
| 99 |
+
}
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"query": "how soon does losartan afffect blood pressure",
|
| 104 |
+
"dataset": "medlf_qa",
|
| 105 |
+
"mode": "conditional",
|
| 106 |
+
"group": "medication_qa",
|
| 107 |
+
"scoring_method": "relevance",
|
| 108 |
+
"rag_answer": "The clinical effect of losartan on blood pressure typically becomes apparent within two weeks. Peak effects are observed around 6 hours after taking the medication, with these peak effects being moderately larger than the trough (lowest) effects.",
|
| 109 |
+
"retrieved_docs": [
|
| 110 |
+
"page_content='how soon does losartan afffect blood pressure: Hypertension … Peak (6 hour) effects were uniformly, but moderately, larger than trough effects,' metadata={'source': 'how soon does losartan afffect blood pressure', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=51 fileposition=51 score=0.7362",
|
| 111 |
+
"page_content='how soon does losartan afffect blood pressure: The peak effects of losartan were uniformly, but moderately, larger than trough effects, with the trough-to-peak ratio for systolic and diastolic responses between 50-95%.' metadata={'source': 'how soon does losartan afffect blood pressure', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=52 fileposition=52 score=0.7016",
|
| 112 |
+
"page_content='how soon does losartan afffect blood pressure: Clinical effect of losartan typically becomes apparent within two weeks.' metadata={'source': 'how soon does losartan afffect blood pressure', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=53 fileposition=53 score=0.6849"
|
| 113 |
+
],
|
| 114 |
+
"subclaims": [
|
| 115 |
+
{
|
| 116 |
+
"id": 0,
|
| 117 |
+
"text": "The clinical effect of losartan on blood pressure typically becomes apparent within two weeks",
|
| 118 |
+
"score": 0.46106826938943024
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"id": 1,
|
| 122 |
+
"text": " Peak effects of losartan are observed around 6 hours after taking the medication",
|
| 123 |
+
"score": 0.42004030890343347
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"id": 2,
|
| 127 |
+
"text": " The peak effects of losartan are moderately larger than the trough (lowest) effects",
|
| 128 |
+
"score": 0.4264968324332668
|
| 129 |
+
}
|
| 130 |
+
]
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"query": "Are there any signs I should look out for if Lyme develops and I should follow up?",
|
| 134 |
+
"dataset": "medlf_qa",
|
| 135 |
+
"mode": "conditional",
|
| 136 |
+
"group": "kqa_silver_wogold",
|
| 137 |
+
"scoring_method": "relevance",
|
| 138 |
+
"rag_answer": "Yes, there are several signs to watch for if Lyme disease develops after a tick bite, and you should follow up with your healthcare provider if you notice any of these symptoms:\n\n- Early symptoms (within a week to a few weeks after the bite) may include:\n - Fever\n - Fatigue\n - Headache\n - Muscle and joint aches\n - Swollen lymph nodes\n - A characteristic skin rash called erythema migrans, which often looks like a bull’s-eye on the skin\n\n- If left untreated, symptoms can progress to more severe issues such as:\n - Severe fatigue\n - Stiff, aching neck\n - Severe headaches\n - Irregularities in heartbeat\n - Changes in vision\n - Inflammatory arthritis, often affecting the knee\n\nIf you experience any of these signs or symptoms after suspected or known exposure to ticks, it’s important to schedule an appointment with your healthcare provider for evaluation and possible treatment. Regular check-ins can be helpful to monitor for any developments.",
|
| 139 |
+
"retrieved_docs": [
|
| 140 |
+
"page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: Regular check-ins with a healthcare provider are beneficial for managing and observing the potential developments of Lyme disease following suspected or known tick exposure.' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=68 fileposition=68 score=0.7716",
|
| 141 |
+
"page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: If you exhibit any of these signs or symptoms after suspected or known exposure to ticks, schedule an appointment with your healthcare provider' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=66 fileposition=66 score=0.7560",
|
| 142 |
+
"page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: If such an infection develops, early signs of Lyme disease may occur within a week to a few weeks of a tick bite' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=62 fileposition=62 score=0.7273",
|
| 143 |
+
"page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: If left untreated, symptoms can progress to severe fatigue, a stiff, aching neck, and severe headaches' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=64 fileposition=64 score=0.7108",
|
| 144 |
+
"page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: One may experience irregularities in heartbeat, changes in vision, or an inflammatory arthritis often affecting the knee as well' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=65 fileposition=65 score=0.6972",
|
| 145 |
+
"page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: These early symptoms can include: fever, fatigue, headache, muscle and joint aches, swollen lymph nodes, and a characteristic skin rash called erythema migrans, which often looks like a bull’s-eye on the skin' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=63 fileposition=63 score=0.6907",
|
| 146 |
+
"page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: Lyme disease is a bacterial infection that can be transmitted to humans through the bite of infected black-legged ticks' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=61 fileposition=61 score=0.6902",
|
| 147 |
+
"page_content='Are there any signs I should look out for if Lyme develops and I should follow up?: Lyme disease is a bacterial infection that is transmitted to humans through the bite of infected black-legged ticks.' metadata={'source': 'Are there any signs I should look out for if Lyme develops and I should follow up?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=67 fileposition=67 score=0.6879"
|
| 148 |
+
],
|
| 149 |
+
"subclaims": [
|
| 150 |
+
{
|
| 151 |
+
"id": 0,
|
| 152 |
+
"text": "Lyme disease can develop after a tick bite",
|
| 153 |
+
"score": 0.36037872147634314
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"id": 1,
|
| 157 |
+
"text": " Early symptoms of Lyme disease may appear within a week to a few weeks after the tick bite",
|
| 158 |
+
"score": 0.3590103639552498
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"id": 2,
|
| 162 |
+
"text": " Early symptoms of Lyme disease include fever",
|
| 163 |
+
"score": 0.330641627741826
|
| 164 |
+
},
|
| 165 |
+
{
|
| 166 |
+
"id": 3,
|
| 167 |
+
"text": " Early symptoms of Lyme disease include fatigue",
|
| 168 |
+
"score": 0.3326410987371023
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"id": 4,
|
| 172 |
+
"text": " Early symptoms of Lyme disease include headache",
|
| 173 |
+
"score": 0.3409522381113656
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"id": 5,
|
| 177 |
+
"text": " Early symptoms of Lyme disease include muscle and joint aches",
|
| 178 |
+
"score": 0.31654988182829147
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"id": 6,
|
| 182 |
+
"text": " Early symptoms of Lyme disease include swollen lymph nodes",
|
| 183 |
+
"score": 0.3134207503110557
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"id": 7,
|
| 187 |
+
"text": " Early symptoms of Lyme disease include a characteristic skin rash called erythema migrans",
|
| 188 |
+
"score": 0.31260172708098244
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"id": 8,
|
| 192 |
+
"text": " Erythema migrans often looks like a bull’s-eye on the skin",
|
| 193 |
+
"score": 0.2383296001897222
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"id": 9,
|
| 197 |
+
"text": " If Lyme disease is left untreated, symptoms can progress to more severe issues",
|
| 198 |
+
"score": 0.3353740959350072
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"id": 10,
|
| 202 |
+
"text": " Severe issues from untreated Lyme disease include severe fatigue",
|
| 203 |
+
"score": 0.3008999281508631
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"id": 11,
|
| 207 |
+
"text": " Severe issues from untreated Lyme disease include stiff, aching neck",
|
| 208 |
+
"score": 0.32575460132098455
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"id": 12,
|
| 212 |
+
"text": " Severe issues from untreated Lyme disease include severe headaches",
|
| 213 |
+
"score": 0.31177354749866565
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"id": 13,
|
| 217 |
+
"text": " Severe issues from untreated Lyme disease include irregularities in heartbeat",
|
| 218 |
+
"score": 0.31990244889782615
|
| 219 |
+
},
|
| 220 |
+
{
|
| 221 |
+
"id": 14,
|
| 222 |
+
"text": " Severe issues from untreated Lyme disease include changes in vision",
|
| 223 |
+
"score": 0.30950684605158113
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"id": 15,
|
| 227 |
+
"text": " Severe issues from untreated Lyme disease include inflammatory arthritis",
|
| 228 |
+
"score": 0.3049160752333055
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"id": 16,
|
| 232 |
+
"text": " Inflammatory arthritis from Lyme disease often affects the knee",
|
| 233 |
+
"score": 0.2874386307287703
|
| 234 |
+
},
|
| 235 |
+
{
|
| 236 |
+
"id": 17,
|
| 237 |
+
"text": " If you experience signs or symptoms after tick exposure, you should schedule an appointment with your healthcare provider",
|
| 238 |
+
"score": 0.3763904855114536
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"id": 18,
|
| 242 |
+
"text": " Regular check-ins with a healthcare provider can help monitor for developments in Lyme disease",
|
| 243 |
+
"score": 0.3957411848526053
|
| 244 |
+
}
|
| 245 |
+
]
|
| 246 |
+
},
|
| 247 |
+
{
|
| 248 |
+
"query": "How do adults usually get botulism?",
|
| 249 |
+
"dataset": "medlf_qa",
|
| 250 |
+
"mode": "conditional",
|
| 251 |
+
"group": "healthsearch_qa",
|
| 252 |
+
"scoring_method": "relevance",
|
| 253 |
+
"rag_answer": "Adults usually get botulism in one of three main ways: \n\n1. **Foodborne botulism**: This is the most common form in adults and occurs when they eat foods contaminated with the botulinum toxin. This often happens with home-canned, home-preserved, or fermented foods prepared under conditions that encourage bacterial growth.\n\n2. **Wound botulism**: This occurs when the bacteria enter a wound, which is more common in people who inject drugs.\n\n3. **Inhalation botulism**: This is very rare and can happen if the pure toxin is released into the air, such as during laboratory accidents or if used as a biological weapon.\n\nOverall, foodborne botulism is the most frequent way adults contract the illness.",
|
| 254 |
+
"retrieved_docs": [
|
| 255 |
+
"page_content='How do adults usually get botulism?: In adults, botulism can be contracted in a few different ways: through food, wound botulism, and inhalation botulism' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=78 fileposition=78 score=0.6993",
|
| 256 |
+
"page_content='How do adults usually get botulism?: Foodborne botulism is the most common form in adults and occurs when adults eat foods that are contaminated with the botulism toxin' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=79 fileposition=79 score=0.6891",
|
| 257 |
+
"page_content='How do adults usually get botulism?: Wound botulism happens when the bacteria enter a wound, usually in people who inject drugs' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=81 fileposition=81 score=0.6884",
|
| 258 |
+
"page_content='How do adults usually get botulism?: Botulism is a rare but serious illness caused by a toxin that attacks the body's nerves' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=77 fileposition=77 score=0.6836",
|
| 259 |
+
"page_content='How do adults usually get botulism?: This typically happens with home-canned, home-preserved, or fermented foods that have been prepared in conditions that encourage bacterial growth' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=80 fileposition=80 score=0.6832",
|
| 260 |
+
"page_content='How do adults usually get botulism?: In very rare instances, inhalation botulism can occur when the pure toxin is released into the air, such as during laboratory accidents or if used as a biological weapon.' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=83 fileposition=83 score=0.6773",
|
| 261 |
+
"page_content='How do adults usually get botulism?: Inhalation botulism is very rare and can occur when the pure toxin is released into the air as it can in laboratory accidents or if used as a biological weapon' metadata={'source': 'How do adults usually get botulism?', 'file_path': '/home/ryoya.awano/ResponseQualityAssessment/data/processed/MedLFQA/sampled_10_medlf_qa_documents.txt'} indice=82 fileposition=82 score=0.6543"
|
| 262 |
+
],
|
| 263 |
+
"subclaims": [
|
| 264 |
+
{
|
| 265 |
+
"id": 0,
|
| 266 |
+
"text": "Adults usually get botulism in one of three main ways",
|
| 267 |
+
"score": 0.42031683919779217
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"id": 1,
|
| 271 |
+
"text": " Foodborne botulism is the most common form in adults",
|
| 272 |
+
"score": 0.39594970883087005
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"id": 2,
|
| 276 |
+
"text": " Foodborne botulism occurs when adults eat foods contaminated with the botulinum toxin",
|
| 277 |
+
"score": 0.3942247114653798
|
| 278 |
+
},
|
| 279 |
+
{
|
| 280 |
+
"id": 3,
|
| 281 |
+
"text": " Foodborne botulism often happens with home-canned, home-preserved, or fermented foods prepared under conditions that encourage bacterial growth",
|
| 282 |
+
"score": 0.3146562102397841
|
| 283 |
+
},
|
| 284 |
+
{
|
| 285 |
+
"id": 4,
|
| 286 |
+
"text": " Wound botulism occurs when the bacteria enter a wound",
|
| 287 |
+
"score": 0.2937099827086362
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
"id": 5,
|
| 291 |
+
"text": " Wound botulism is more common in people who inject drugs",
|
| 292 |
+
"score": 0.29189200145131067
|
| 293 |
+
},
|
| 294 |
+
{
|
| 295 |
+
"id": 6,
|
| 296 |
+
"text": " Inhalation botulism is very rare",
|
| 297 |
+
"score": 0.3344457211496001
|
| 298 |
+
},
|
| 299 |
+
{
|
| 300 |
+
"id": 7,
|
| 301 |
+
"text": " Inhalation botulism can happen if the pure toxin is released into the air",
|
| 302 |
+
"score": 0.3194226107069425
|
| 303 |
+
},
|
| 304 |
+
{
|
| 305 |
+
"id": 8,
|
| 306 |
+
"text": " Inhalation botulism can occur during laboratory accidents",
|
| 307 |
+
"score": 0.3357236089778982
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"id": 9,
|
| 311 |
+
"text": " Inhalation botulism can occur if used as a biological weapon",
|
| 312 |
+
"score": 0.31917884005382047
|
| 313 |
+
},
|
| 314 |
+
{
|
| 315 |
+
"id": 10,
|
| 316 |
+
"text": " Foodborne botulism is the most frequent way adults contract botulism",
|
| 317 |
+
"score": 0.3974304925095534
|
| 318 |
+
}
|
| 319 |
+
]
|
| 320 |
+
}
|
| 321 |
+
]
|
| 322 |
+
}
|
demo/data/thresholds.csv
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset,mode,scoring_method,group,alpha,q_hat
|
| 2 |
+
medlf_qa,marginal,relevance,default,0.05,0.26235458225908953
|
| 3 |
+
medlf_qa,marginal,relevance,default,0.1,0.26235458225908953
|
| 4 |
+
medlf_qa,marginal,relevance,default,0.15,0.26235458225908953
|
| 5 |
+
medlf_qa,marginal,relevance,default,0.2,0.26235458225908953
|
| 6 |
+
medlf_qa,marginal,relevance,default,0.25,0.26235458225908953
|
| 7 |
+
medlf_qa,marginal,relevance,default,0.3,0.26235458225908953
|
| 8 |
+
medlf_qa,marginal,relevance,default,0.35,-1
|
| 9 |
+
medlf_qa,marginal,relevance,default,0.4,-1
|
| 10 |
+
medlf_qa,conditional,relevance,live_qa,0.05,0.26235458225908953
|
| 11 |
+
medlf_qa,conditional,relevance,live_qa,0.1,0.26235458225908953
|
| 12 |
+
medlf_qa,conditional,relevance,live_qa,0.15,0.26235458225908953
|
| 13 |
+
medlf_qa,conditional,relevance,live_qa,0.2,0.26235458225908953
|
| 14 |
+
medlf_qa,conditional,relevance,live_qa,0.25,0.26235458225908953
|
| 15 |
+
medlf_qa,conditional,relevance,live_qa,0.3,0.26235458225908953
|
| 16 |
+
medlf_qa,conditional,relevance,live_qa,0.35,0.26235458225908953
|
| 17 |
+
medlf_qa,conditional,relevance,live_qa,0.4,0.26235458225908953
|
| 18 |
+
medlf_qa,conditional,relevance,kqa_golden,0.05,-1
|
| 19 |
+
medlf_qa,conditional,relevance,kqa_golden,0.1,-1
|
| 20 |
+
medlf_qa,conditional,relevance,kqa_golden,0.15,-1
|
| 21 |
+
medlf_qa,conditional,relevance,kqa_golden,0.2,-1
|
| 22 |
+
medlf_qa,conditional,relevance,kqa_golden,0.25,-1
|
| 23 |
+
medlf_qa,conditional,relevance,kqa_golden,0.3,-1
|
| 24 |
+
medlf_qa,conditional,relevance,kqa_golden,0.35,-1
|
| 25 |
+
medlf_qa,conditional,relevance,kqa_golden,0.4,-1
|
| 26 |
+
medlf_qa,conditional,relevance,medication_qa,0.05,-1
|
| 27 |
+
medlf_qa,conditional,relevance,medication_qa,0.1,-1
|
| 28 |
+
medlf_qa,conditional,relevance,medication_qa,0.15,-1
|
| 29 |
+
medlf_qa,conditional,relevance,medication_qa,0.2,-1
|
| 30 |
+
medlf_qa,conditional,relevance,medication_qa,0.25,-1
|
| 31 |
+
medlf_qa,conditional,relevance,medication_qa,0.3,-1
|
| 32 |
+
medlf_qa,conditional,relevance,medication_qa,0.35,-1
|
| 33 |
+
medlf_qa,conditional,relevance,medication_qa,0.4,-1
|
| 34 |
+
medlf_qa,conditional,relevance,kqa_silver_wogold,0.05,-1
|
| 35 |
+
medlf_qa,conditional,relevance,kqa_silver_wogold,0.1,-1
|
| 36 |
+
medlf_qa,conditional,relevance,kqa_silver_wogold,0.15,-1
|
| 37 |
+
medlf_qa,conditional,relevance,kqa_silver_wogold,0.2,-1
|
| 38 |
+
medlf_qa,conditional,relevance,kqa_silver_wogold,0.25,-1
|
| 39 |
+
medlf_qa,conditional,relevance,kqa_silver_wogold,0.3,-1
|
| 40 |
+
medlf_qa,conditional,relevance,kqa_silver_wogold,0.35,-1
|
| 41 |
+
medlf_qa,conditional,relevance,kqa_silver_wogold,0.4,-1
|
| 42 |
+
medlf_qa,conditional,relevance,healthsearch_qa,0.05,-1
|
| 43 |
+
medlf_qa,conditional,relevance,healthsearch_qa,0.1,-1
|
| 44 |
+
medlf_qa,conditional,relevance,healthsearch_qa,0.15,-1
|
| 45 |
+
medlf_qa,conditional,relevance,healthsearch_qa,0.2,-1
|
| 46 |
+
medlf_qa,conditional,relevance,healthsearch_qa,0.25,-1
|
| 47 |
+
medlf_qa,conditional,relevance,healthsearch_qa,0.3,-1
|
| 48 |
+
medlf_qa,conditional,relevance,healthsearch_qa,0.35,-1
|
| 49 |
+
medlf_qa,conditional,relevance,healthsearch_qa,0.4,-1
|
| 50 |
+
medlf_qa,marginal,cosine_similarity,default,0.05,0.3992308077305861
|
| 51 |
+
medlf_qa,marginal,cosine_similarity,default,0.1,0.3992308077305861
|
| 52 |
+
medlf_qa,marginal,cosine_similarity,default,0.15,0.3992308077305861
|
| 53 |
+
medlf_qa,marginal,cosine_similarity,default,0.2,0.3992308077305861
|
| 54 |
+
medlf_qa,marginal,cosine_similarity,default,0.25,0.3992308077305861
|
| 55 |
+
medlf_qa,marginal,cosine_similarity,default,0.3,0.3992308077305861
|
| 56 |
+
medlf_qa,marginal,cosine_similarity,default,0.35,-1
|
| 57 |
+
medlf_qa,marginal,cosine_similarity,default,0.4,-1
|
| 58 |
+
medlf_qa,conditional,cosine_similarity,live_qa,0.05,0.3992308077305861
|
| 59 |
+
medlf_qa,conditional,cosine_similarity,live_qa,0.1,0.3992308077305861
|
| 60 |
+
medlf_qa,conditional,cosine_similarity,live_qa,0.15,0.3992308077305861
|
| 61 |
+
medlf_qa,conditional,cosine_similarity,live_qa,0.2,0.3992308077305861
|
| 62 |
+
medlf_qa,conditional,cosine_similarity,live_qa,0.25,0.3992308077305861
|
| 63 |
+
medlf_qa,conditional,cosine_similarity,live_qa,0.3,0.3992308077305861
|
| 64 |
+
medlf_qa,conditional,cosine_similarity,live_qa,0.35,0.3992308077305861
|
| 65 |
+
medlf_qa,conditional,cosine_similarity,live_qa,0.4,0.3992308077305861
|
| 66 |
+
medlf_qa,conditional,cosine_similarity,kqa_golden,0.05,-1
|
| 67 |
+
medlf_qa,conditional,cosine_similarity,kqa_golden,0.1,-1
|
| 68 |
+
medlf_qa,conditional,cosine_similarity,kqa_golden,0.15,-1
|
| 69 |
+
medlf_qa,conditional,cosine_similarity,kqa_golden,0.2,-1
|
| 70 |
+
medlf_qa,conditional,cosine_similarity,kqa_golden,0.25,-1
|
| 71 |
+
medlf_qa,conditional,cosine_similarity,kqa_golden,0.3,-1
|
| 72 |
+
medlf_qa,conditional,cosine_similarity,kqa_golden,0.35,-1
|
| 73 |
+
medlf_qa,conditional,cosine_similarity,kqa_golden,0.4,-1
|
| 74 |
+
medlf_qa,conditional,cosine_similarity,medication_qa,0.05,-1
|
| 75 |
+
medlf_qa,conditional,cosine_similarity,medication_qa,0.1,-1
|
| 76 |
+
medlf_qa,conditional,cosine_similarity,medication_qa,0.15,-1
|
| 77 |
+
medlf_qa,conditional,cosine_similarity,medication_qa,0.2,-1
|
| 78 |
+
medlf_qa,conditional,cosine_similarity,medication_qa,0.25,-1
|
| 79 |
+
medlf_qa,conditional,cosine_similarity,medication_qa,0.3,-1
|
| 80 |
+
medlf_qa,conditional,cosine_similarity,medication_qa,0.35,-1
|
| 81 |
+
medlf_qa,conditional,cosine_similarity,medication_qa,0.4,-1
|
| 82 |
+
medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.05,-1
|
| 83 |
+
medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.1,-1
|
| 84 |
+
medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.15,-1
|
| 85 |
+
medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.2,-1
|
| 86 |
+
medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.25,-1
|
| 87 |
+
medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.3,-1
|
| 88 |
+
medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.35,-1
|
| 89 |
+
medlf_qa,conditional,cosine_similarity,kqa_silver_wogold,0.4,-1
|
| 90 |
+
medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.05,-1
|
| 91 |
+
medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.1,-1
|
| 92 |
+
medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.15,-1
|
| 93 |
+
medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.2,-1
|
| 94 |
+
medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.25,-1
|
| 95 |
+
medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.3,-1
|
| 96 |
+
medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.35,-1
|
| 97 |
+
medlf_qa,conditional,cosine_similarity,healthsearch_qa,0.4,-1
|
| 98 |
+
medlf_qa,marginal,min_log_prob,default,0.05,1.002567365930604
|
| 99 |
+
medlf_qa,marginal,min_log_prob,default,0.1,1.002567365930604
|
| 100 |
+
medlf_qa,marginal,min_log_prob,default,0.15,1.002567365930604
|
| 101 |
+
medlf_qa,marginal,min_log_prob,default,0.2,1.002567365930604
|
| 102 |
+
medlf_qa,marginal,min_log_prob,default,0.25,1.002567365930604
|
| 103 |
+
medlf_qa,marginal,min_log_prob,default,0.3,1.002567365930604
|
| 104 |
+
medlf_qa,marginal,min_log_prob,default,0.35,-1
|
| 105 |
+
medlf_qa,marginal,min_log_prob,default,0.4,-1
|
| 106 |
+
medlf_qa,conditional,min_log_prob,live_qa,0.05,1.002567365930604
|
| 107 |
+
medlf_qa,conditional,min_log_prob,live_qa,0.1,1.002567365930604
|
| 108 |
+
medlf_qa,conditional,min_log_prob,live_qa,0.15,1.002567365930604
|
| 109 |
+
medlf_qa,conditional,min_log_prob,live_qa,0.2,1.002567365930604
|
| 110 |
+
medlf_qa,conditional,min_log_prob,live_qa,0.25,1.002567365930604
|
| 111 |
+
medlf_qa,conditional,min_log_prob,live_qa,0.3,1.002567365930604
|
| 112 |
+
medlf_qa,conditional,min_log_prob,live_qa,0.35,1.002567365930604
|
| 113 |
+
medlf_qa,conditional,min_log_prob,live_qa,0.4,1.002567365930604
|
| 114 |
+
medlf_qa,conditional,min_log_prob,kqa_golden,0.05,-1
|
| 115 |
+
medlf_qa,conditional,min_log_prob,kqa_golden,0.1,-1
|
| 116 |
+
medlf_qa,conditional,min_log_prob,kqa_golden,0.15,-1
|
| 117 |
+
medlf_qa,conditional,min_log_prob,kqa_golden,0.2,-1
|
| 118 |
+
medlf_qa,conditional,min_log_prob,kqa_golden,0.25,-1
|
| 119 |
+
medlf_qa,conditional,min_log_prob,kqa_golden,0.3,-1
|
| 120 |
+
medlf_qa,conditional,min_log_prob,kqa_golden,0.35,-1
|
| 121 |
+
medlf_qa,conditional,min_log_prob,kqa_golden,0.4,-1
|
| 122 |
+
medlf_qa,conditional,min_log_prob,medication_qa,0.05,-1
|
| 123 |
+
medlf_qa,conditional,min_log_prob,medication_qa,0.1,-1
|
| 124 |
+
medlf_qa,conditional,min_log_prob,medication_qa,0.15,-1
|
| 125 |
+
medlf_qa,conditional,min_log_prob,medication_qa,0.2,-1
|
| 126 |
+
medlf_qa,conditional,min_log_prob,medication_qa,0.25,-1
|
| 127 |
+
medlf_qa,conditional,min_log_prob,medication_qa,0.3,-1
|
| 128 |
+
medlf_qa,conditional,min_log_prob,medication_qa,0.35,-1
|
| 129 |
+
medlf_qa,conditional,min_log_prob,medication_qa,0.4,-1
|
| 130 |
+
medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.05,-1
|
| 131 |
+
medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.1,-1
|
| 132 |
+
medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.15,-1
|
| 133 |
+
medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.2,-1
|
| 134 |
+
medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.25,-1
|
| 135 |
+
medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.3,-1
|
| 136 |
+
medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.35,-1
|
| 137 |
+
medlf_qa,conditional,min_log_prob,kqa_silver_wogold,0.4,-1
|
| 138 |
+
medlf_qa,conditional,min_log_prob,healthsearch_qa,0.05,-1
|
| 139 |
+
medlf_qa,conditional,min_log_prob,healthsearch_qa,0.1,-1
|
| 140 |
+
medlf_qa,conditional,min_log_prob,healthsearch_qa,0.15,-1
|
| 141 |
+
medlf_qa,conditional,min_log_prob,healthsearch_qa,0.2,-1
|
| 142 |
+
medlf_qa,conditional,min_log_prob,healthsearch_qa,0.25,-1
|
| 143 |
+
medlf_qa,conditional,min_log_prob,healthsearch_qa,0.3,-1
|
| 144 |
+
medlf_qa,conditional,min_log_prob,healthsearch_qa,0.35,-1
|
| 145 |
+
medlf_qa,conditional,min_log_prob,healthsearch_qa,0.4,-1
|
demo/inference_api.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo/inference_api.py
|
| 2 |
+
# Single-query inference wrapper around src/ core logic.
|
| 3 |
+
# Does NOT modify src/ — imports only.
|
| 4 |
+
#
|
| 5 |
+
# Heavy resources (FAISSIndexManager, SubclaimScorer) are constructed here
|
| 6 |
+
# via build_faiss_manager() / build_scorer() and cached in app.py with
|
| 7 |
+
# @st.cache_resource, so they survive Streamlit re-runs.
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import re
|
| 11 |
+
import glob
|
| 12 |
+
import yaml
|
| 13 |
+
from typing import TYPE_CHECKING, TypedDict
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
# FAISSIndexManager and SubclaimScorer transitively import torch/transformers
|
| 17 |
+
# via langchain_text_splitters (file_manager.py). To avoid loading these
|
| 18 |
+
# heavy packages at module import time (e.g. in HF Spaces where FAISS is
|
| 19 |
+
# unused), we guard them here and do the real import inside each build_*()
|
| 20 |
+
# function instead. String annotations ("FAISSIndexManager") ensure the
|
| 21 |
+
# type hints remain valid at runtime without triggering the import.
|
| 22 |
+
from src.common.faiss_manager import FAISSIndexManager
|
| 23 |
+
from src.subclaim_processor.scorer.subclaim_scorer import SubclaimScorer
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _expand_env_vars(obj):
|
| 27 |
+
"""Recursively expand environment variables in string values."""
|
| 28 |
+
if isinstance(obj, dict):
|
| 29 |
+
return {k: _expand_env_vars(v) for k, v in obj.items()}
|
| 30 |
+
if isinstance(obj, list):
|
| 31 |
+
return [_expand_env_vars(v) for v in obj]
|
| 32 |
+
if isinstance(obj, str):
|
| 33 |
+
return os.path.expandvars(obj)
|
| 34 |
+
return obj
|
| 35 |
+
|
| 36 |
+
from src.common.llm.openai_rag_agent import OpenAIRAGAgent
|
| 37 |
+
from src.common.llm.openai_atomicfact_generator import OpenAIAtomicFactGenerator
|
| 38 |
+
from src.common.llm.openai_llm_agent import OpenAILLMAgent
|
| 39 |
+
from demo.constants import DEFAULT_SCORING_METHOD
|
| 40 |
+
|
| 41 |
+
# ── Config loading (module-level cache) ──────────────────────────────────────
|
| 42 |
+
|
| 43 |
+
_main_config: dict | None = None
|
| 44 |
+
_dataset_config: dict | None = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _load_main_config() -> dict:
|
| 48 |
+
global _main_config
|
| 49 |
+
if _main_config is None:
|
| 50 |
+
with open("conf/config.yaml") as f:
|
| 51 |
+
_main_config = yaml.safe_load(f)
|
| 52 |
+
return _main_config
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _load_dataset_config() -> dict:
|
| 56 |
+
global _dataset_config
|
| 57 |
+
if _dataset_config is None:
|
| 58 |
+
with open("conf/dataset_config.yaml") as f:
|
| 59 |
+
_dataset_config = _expand_env_vars(yaml.safe_load(f))
|
| 60 |
+
return _dataset_config
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ── FAISS index path resolution ───────────────────────────────────────────────
|
| 64 |
+
|
| 65 |
+
def _get_index_paths(dataset: str) -> tuple[str, str]:
|
| 66 |
+
"""Resolve index_path and indice2fm_path for a dataset.
|
| 67 |
+
|
| 68 |
+
Scans the index_store directory for files created by main.py
|
| 69 |
+
(index_{query_size}.faiss) and picks the one with the largest query_size.
|
| 70 |
+
"""
|
| 71 |
+
dataset_cfg = _load_dataset_config()
|
| 72 |
+
if dataset not in dataset_cfg["datasets"]:
|
| 73 |
+
raise ValueError(f"Unknown dataset: '{dataset}'. "
|
| 74 |
+
f"Available: {list(dataset_cfg['datasets'].keys())}")
|
| 75 |
+
index_store = dataset_cfg["datasets"][dataset]["index_store"]
|
| 76 |
+
|
| 77 |
+
faiss_files = glob.glob(f"{index_store}/index_*.faiss")
|
| 78 |
+
if not faiss_files:
|
| 79 |
+
raise FileNotFoundError(
|
| 80 |
+
f"No FAISS index found in '{index_store}'. "
|
| 81 |
+
"Run main.py first to build the index."
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def _query_size(path: str) -> int:
|
| 85 |
+
m = re.search(r"index_(\d+)\.faiss$", path)
|
| 86 |
+
return int(m.group(1)) if m else 0
|
| 87 |
+
|
| 88 |
+
index_path = max(faiss_files, key=_query_size)
|
| 89 |
+
indice2fm_path = re.sub(r"index_(\d+)\.faiss$", r"indice2fm_\1.json", index_path)
|
| 90 |
+
return index_path, indice2fm_path
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ── Resource builders (call from app.py wrapped with @st.cache_resource) ─────
|
| 94 |
+
|
| 95 |
+
def build_faiss_manager(dataset: str) -> "FAISSIndexManager":
|
| 96 |
+
"""Build a FAISSIndexManager for the given dataset.
|
| 97 |
+
|
| 98 |
+
Intended to be wrapped with @st.cache_resource in app.py.
|
| 99 |
+
FAISSIndexManager is imported here (not at module level) to avoid pulling
|
| 100 |
+
torch/transformers via langchain_text_splitters at startup.
|
| 101 |
+
"""
|
| 102 |
+
from src.common.faiss_manager import FAISSIndexManager # noqa: PLC0415
|
| 103 |
+
config = _load_main_config()
|
| 104 |
+
index_path, indice2fm_path = _get_index_paths(dataset)
|
| 105 |
+
return FAISSIndexManager(
|
| 106 |
+
index_truncation_config=config["index"]["truncation_config"],
|
| 107 |
+
index_path=index_path,
|
| 108 |
+
indice2fm_path=indice2fm_path,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def build_scorer(dataset: str) -> "SubclaimScorer":
|
| 113 |
+
"""Build a SubclaimScorer for the given dataset.
|
| 114 |
+
|
| 115 |
+
SubclaimScorer loads its own FAISS index internally, so wrap this with
|
| 116 |
+
@st.cache_resource in app.py. Only needed for "relevance" and
|
| 117 |
+
"cosine_similarity" scoring methods.
|
| 118 |
+
SubclaimScorer is imported here (not at module level) for the same reason
|
| 119 |
+
as FAISSIndexManager — to avoid pulling torch/transformers at startup.
|
| 120 |
+
"""
|
| 121 |
+
from src.subclaim_processor.scorer.subclaim_scorer import SubclaimScorer # noqa: PLC0415
|
| 122 |
+
config = _load_main_config()
|
| 123 |
+
index_path, indice2fm_path = _get_index_paths(dataset)
|
| 124 |
+
return SubclaimScorer(
|
| 125 |
+
index_truncation_config=config["index"]["truncation_config"],
|
| 126 |
+
embedding_model=config["index"]["embedding_model"],
|
| 127 |
+
index_path=index_path,
|
| 128 |
+
indice2fm_path=indice2fm_path,
|
| 129 |
+
frequency_score_model=config["conformal_prediction"]["frequency_score_model"],
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ── TypedDict definitions ─────────────────────────────────────────────────────
|
| 134 |
+
|
| 135 |
+
class Subclaim(TypedDict):
|
| 136 |
+
id: int
|
| 137 |
+
text: str
|
| 138 |
+
score: float # semantics depend on SubclaimResult.scoring_method
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class SubclaimResult(TypedDict):
|
| 142 |
+
query: str
|
| 143 |
+
dataset: str
|
| 144 |
+
mode: str # "marginal" or "conditional"
|
| 145 |
+
group: str # group name, or "default" for marginal / non-grouped datasets
|
| 146 |
+
scoring_method: str # e.g. "relevance"; used as key for threshold lookup in thresholds.csv
|
| 147 |
+
rag_answer: str
|
| 148 |
+
retrieved_docs: list[str]
|
| 149 |
+
subclaims: list[Subclaim]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class FilteredResult(TypedDict):
|
| 153 |
+
subclaims: list[Subclaim]
|
| 154 |
+
q_hat: float
|
| 155 |
+
keep_count: int
|
| 156 |
+
remove_count: int
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ── Score computation ─────────────────────────────────────────────────────────
|
| 160 |
+
|
| 161 |
+
def _compute_score(
|
| 162 |
+
scoring_method: str,
|
| 163 |
+
subclaim_text: str,
|
| 164 |
+
subclaim_probs: list[tuple[str, float]], # (token, probability) — probability = exp(log_prob)
|
| 165 |
+
query: str,
|
| 166 |
+
retrieved_docs: list[str],
|
| 167 |
+
scorer: "SubclaimScorer | None",
|
| 168 |
+
) -> float:
|
| 169 |
+
if scoring_method == "relevance":
|
| 170 |
+
if scorer is None:
|
| 171 |
+
raise ValueError("scorer is required for scoring_method='relevance'")
|
| 172 |
+
config = _load_main_config()
|
| 173 |
+
cp_config = config["conformal_prediction"]
|
| 174 |
+
return float(scorer.score(
|
| 175 |
+
claim=subclaim_text,
|
| 176 |
+
retrieved_docs=retrieved_docs,
|
| 177 |
+
aggregation_strategy=cp_config["aggregation_strategy"],
|
| 178 |
+
scoring_strategy=cp_config["scoring_strategy"],
|
| 179 |
+
))
|
| 180 |
+
|
| 181 |
+
elif scoring_method == "cosine_similarity":
|
| 182 |
+
if scorer is None:
|
| 183 |
+
raise ValueError("scorer is required for scoring_method='cosine_similarity'")
|
| 184 |
+
return float(scorer.cosine_similarity(subclaim_text, query))
|
| 185 |
+
|
| 186 |
+
elif scoring_method == "min_log_prob":
|
| 187 |
+
# min(probability) = least confident token; higher value = more confident subclaim
|
| 188 |
+
probs = [p for _, p in subclaim_probs]
|
| 189 |
+
return min(probs) if probs else 0.0
|
| 190 |
+
|
| 191 |
+
else:
|
| 192 |
+
raise ValueError(
|
| 193 |
+
f"Unknown scoring method: '{scoring_method}'. "
|
| 194 |
+
"Supported: 'relevance', 'cosine_similarity', 'min_log_prob'"
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ── Main inference functions ──────────────────────────────────────────────────
|
| 199 |
+
|
| 200 |
+
def process_query(
|
| 201 |
+
query: str,
|
| 202 |
+
dataset: str,
|
| 203 |
+
mode: str,
|
| 204 |
+
group: str,
|
| 205 |
+
faiss_manager: "FAISSIndexManager",
|
| 206 |
+
scorer: "SubclaimScorer | None" = None,
|
| 207 |
+
scoring_method: str = DEFAULT_SCORING_METHOD,
|
| 208 |
+
) -> SubclaimResult:
|
| 209 |
+
"""Run the full inference pipeline for a single query.
|
| 210 |
+
|
| 211 |
+
Heavy: makes OpenAI API calls. Trigger only on button click and cache the
|
| 212 |
+
result in st.session_state.result.
|
| 213 |
+
|
| 214 |
+
scorer is required when scoring_method is "relevance" or "cosine_similarity".
|
| 215 |
+
Pass a @st.cache_resource-wrapped instance from app.py.
|
| 216 |
+
"""
|
| 217 |
+
config = _load_main_config()
|
| 218 |
+
rag_cfg = config["rag"]
|
| 219 |
+
|
| 220 |
+
# 1. FAISS search
|
| 221 |
+
retrieved_docs: list[str] = faiss_manager.search_faiss_index(
|
| 222 |
+
query,
|
| 223 |
+
top_k=rag_cfg["retrival_topk"],
|
| 224 |
+
threshold=rag_cfg["retrival_threshold"],
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# 2. RAG answer generation
|
| 228 |
+
rag_agent = OpenAIRAGAgent(faiss_manager, model=rag_cfg["response_model"])
|
| 229 |
+
response = rag_agent.answer(
|
| 230 |
+
query,
|
| 231 |
+
retrieved_docs,
|
| 232 |
+
temperature=rag_cfg["response_temperature"],
|
| 233 |
+
n_samples=1,
|
| 234 |
+
)
|
| 235 |
+
rag_answer: str = response.choices[0].message.content
|
| 236 |
+
|
| 237 |
+
# 3. Subclaim decomposition (returns zip iterator of (text, [(token, prob)]))
|
| 238 |
+
generator = OpenAIAtomicFactGenerator(model=rag_cfg["fact_generation_model"])
|
| 239 |
+
subclaims_with_probs = list(generator.get_facts_from_text(rag_answer))
|
| 240 |
+
|
| 241 |
+
# 4. Score each subclaim
|
| 242 |
+
subclaims: list[Subclaim] = []
|
| 243 |
+
for i, (text, token_probs) in enumerate(subclaims_with_probs):
|
| 244 |
+
score = _compute_score(
|
| 245 |
+
scoring_method=scoring_method,
|
| 246 |
+
subclaim_text=text,
|
| 247 |
+
subclaim_probs=token_probs,
|
| 248 |
+
query=query,
|
| 249 |
+
retrieved_docs=retrieved_docs,
|
| 250 |
+
scorer=scorer,
|
| 251 |
+
)
|
| 252 |
+
subclaims.append(Subclaim(id=i, text=text, score=score))
|
| 253 |
+
|
| 254 |
+
return SubclaimResult(
|
| 255 |
+
query=query,
|
| 256 |
+
dataset=dataset,
|
| 257 |
+
mode=mode,
|
| 258 |
+
group=group,
|
| 259 |
+
scoring_method=scoring_method,
|
| 260 |
+
rag_answer=rag_answer,
|
| 261 |
+
retrieved_docs=retrieved_docs,
|
| 262 |
+
subclaims=subclaims,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
_REINTEGRATE_INSTRUCTION = (
|
| 267 |
+
"You are given an original answer and a list of statements to remove from it. "
|
| 268 |
+
"Rewrite the answer by removing information corresponding to the listed statements "
|
| 269 |
+
"while keeping the remaining text fluent and natural. "
|
| 270 |
+
"Return only the revised answer without any explanation."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def reintegrate_subclaims(
|
| 275 |
+
rag_answer: str,
|
| 276 |
+
removed_subclaims: list[str],
|
| 277 |
+
model: str = "gpt-4o-mini",
|
| 278 |
+
) -> str:
|
| 279 |
+
"""Rewrite rag_answer by removing the specified subclaims via LLM."""
|
| 280 |
+
if not removed_subclaims:
|
| 281 |
+
return rag_answer
|
| 282 |
+
removed_text = "\n".join(f"- {t}" for t in removed_subclaims)
|
| 283 |
+
question = f"Original answer:\n{rag_answer}\n\nStatements to remove:\n{removed_text}"
|
| 284 |
+
agent = OpenAILLMAgent(instruction=_REINTEGRATE_INSTRUCTION, model=model)
|
| 285 |
+
return agent.answer(question)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def apply_threshold(result: SubclaimResult, q_hat: float) -> FilteredResult:
|
| 289 |
+
"""Apply a conformal threshold to a cached SubclaimResult.
|
| 290 |
+
|
| 291 |
+
Light: no API calls. Safe to run on every slider move.
|
| 292 |
+
Subclaims with score >= q_hat are kept.
|
| 293 |
+
"""
|
| 294 |
+
subclaims = result["subclaims"]
|
| 295 |
+
keep_count = sum(1 for sc in subclaims if sc["score"] >= q_hat)
|
| 296 |
+
remove_count = len(subclaims) - keep_count
|
| 297 |
+
return FilteredResult(
|
| 298 |
+
subclaims=subclaims,
|
| 299 |
+
q_hat=q_hat,
|
| 300 |
+
keep_count=keep_count,
|
| 301 |
+
remove_count=remove_count,
|
| 302 |
+
)
|
demo/precompute.py
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# demo/precompute.py
|
| 2 |
+
# Offline script — run once before launching the demo.
|
| 3 |
+
#
|
| 4 |
+
# Produces:
|
| 5 |
+
# demo/data/thresholds.csv — conformal thresholds for every
|
| 6 |
+
# (dataset, mode, scoring_method, group, alpha)
|
| 7 |
+
# demo/data/samples.json — pre-computed SubclaimResult objects for
|
| 8 |
+
# sample queries (no API calls at demo runtime)
|
| 9 |
+
#
|
| 10 |
+
# Usage:
|
| 11 |
+
# python -m demo.precompute
|
| 12 |
+
#
|
| 13 |
+
# Prerequisites:
|
| 14 |
+
# Run main.py for each dataset first to generate FAISS indices and the
|
| 15 |
+
# subclaim data files in data/out/.
|
| 16 |
+
|
| 17 |
+
import csv
|
| 18 |
+
import glob
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import yaml
|
| 22 |
+
from collections import defaultdict
|
| 23 |
+
from datetime import datetime, timezone
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _expand_env_vars(obj):
|
| 27 |
+
"""Recursively expand environment variables in string values."""
|
| 28 |
+
if isinstance(obj, dict):
|
| 29 |
+
return {k: _expand_env_vars(v) for k, v in obj.items()}
|
| 30 |
+
if isinstance(obj, list):
|
| 31 |
+
return [_expand_env_vars(v) for v in obj]
|
| 32 |
+
if isinstance(obj, str):
|
| 33 |
+
return os.path.expandvars(obj)
|
| 34 |
+
return obj
|
| 35 |
+
|
| 36 |
+
from src.calibration.utils import compute_threshold, load_subclaim_data
|
| 37 |
+
from demo.constants import (
|
| 38 |
+
ALPHA_LEVELS,
|
| 39 |
+
DEFAULT_SCORING_METHOD,
|
| 40 |
+
SCORING_METHOD_TO_DATA_KEY,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# ── Config ────────────────────────────────────────────────────────────────────
|
| 44 |
+
|
| 45 |
+
_main_cfg = yaml.safe_load(open("conf/config.yaml"))
|
| 46 |
+
_path_cfg = _expand_env_vars(yaml.safe_load(open("conf/path_config.yaml")))
|
| 47 |
+
_dataset_cfg = _expand_env_vars(yaml.safe_load(open("conf/dataset_config.yaml")))
|
| 48 |
+
|
| 49 |
+
A_VALUE: float = _main_cfg["conformal_prediction"]["a_value"]
|
| 50 |
+
RESPONSE_MODEL: str = _main_cfg["rag"]["response_model"]
|
| 51 |
+
|
| 52 |
+
OUTPUT_THRESHOLDS = "demo/data/thresholds.csv"
|
| 53 |
+
OUTPUT_SAMPLES = "demo/data/samples.json"
|
| 54 |
+
|
| 55 |
+
# Number of entries held out per group (or per dataset for non-grouped) that
|
| 56 |
+
# are reserved exclusively for samples.json and excluded from threshold
|
| 57 |
+
# calibration. Conformal guarantees require calibration and evaluation data
|
| 58 |
+
# to be disjoint, so these entries must never appear in compute_thresholds().
|
| 59 |
+
HOLDOUT_PER_GROUP: int = 1
|
| 60 |
+
|
| 61 |
+
# ── Helpers ───────────────────────────────────────────────────────────────────
|
| 62 |
+
|
| 63 |
+
def _split_calibration_holdout(
|
| 64 |
+
data: list[dict], is_grouped: bool
|
| 65 |
+
) -> tuple[list[dict], list[dict]]:
|
| 66 |
+
"""Split entries into (calibration, holdout) sets.
|
| 67 |
+
|
| 68 |
+
holdout entries are reserved for samples.json and excluded from threshold
|
| 69 |
+
calibration to prevent data leakage. When a group has only HOLDOUT_PER_GROUP
|
| 70 |
+
or fewer entries it is kept entirely in calibration and the holdout for that
|
| 71 |
+
group is empty (a warning is printed).
|
| 72 |
+
|
| 73 |
+
Returns (calib_data, holdout_data).
|
| 74 |
+
"""
|
| 75 |
+
if not is_grouped:
|
| 76 |
+
if len(data) > HOLDOUT_PER_GROUP:
|
| 77 |
+
return data[:-HOLDOUT_PER_GROUP], data[-HOLDOUT_PER_GROUP:]
|
| 78 |
+
print(f" WARNING: only {len(data)} entries — too few to split; all used for calibration, holdout is empty")
|
| 79 |
+
return data, []
|
| 80 |
+
|
| 81 |
+
grouped = _group_data(data)
|
| 82 |
+
calib: list[dict] = []
|
| 83 |
+
holdout: list[dict] = []
|
| 84 |
+
for grp, grp_entries in grouped.items():
|
| 85 |
+
if len(grp_entries) > HOLDOUT_PER_GROUP:
|
| 86 |
+
calib.extend(grp_entries[:-HOLDOUT_PER_GROUP])
|
| 87 |
+
holdout.extend(grp_entries[-HOLDOUT_PER_GROUP:])
|
| 88 |
+
else:
|
| 89 |
+
print(f" WARNING: group '{grp}' has only {len(grp_entries)} entries — too few to split; all used for calibration")
|
| 90 |
+
calib.extend(grp_entries)
|
| 91 |
+
return calib, holdout
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _find_calibration_file(dataset: str) -> str:
|
| 95 |
+
"""Return the path to the latest subclaim data file for a dataset.
|
| 96 |
+
|
| 97 |
+
Looks for files matching the naming convention used by main.py:
|
| 98 |
+
data/out/{DatasetName}/{dataset}_{query_size}_subclaims_with_scores_{model}.json
|
| 99 |
+
"""
|
| 100 |
+
full_name = _dataset_cfg["datasets"][dataset]["name"]
|
| 101 |
+
response_dir = os.path.join(_path_cfg["paths"]["response_dir"], full_name)
|
| 102 |
+
pattern = os.path.join(response_dir, f"{dataset}_*_subclaims_with_scores_*.json")
|
| 103 |
+
files = glob.glob(pattern)
|
| 104 |
+
if not files:
|
| 105 |
+
raise FileNotFoundError(
|
| 106 |
+
f"No subclaim data file found for dataset '{dataset}' "
|
| 107 |
+
f"(searched: {pattern}). Run main.py first."
|
| 108 |
+
)
|
| 109 |
+
return max(files, key=os.path.getmtime)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _group_data(data: list[dict]) -> dict[str, list[dict]]:
|
| 113 |
+
"""Group entries by their first group label."""
|
| 114 |
+
groups: dict[str, list[dict]] = defaultdict(list)
|
| 115 |
+
for entry in data:
|
| 116 |
+
grp = entry["groups"][0] if entry.get("groups") else "default"
|
| 117 |
+
groups[grp].append(entry)
|
| 118 |
+
return dict(groups)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _has_score(entry: dict, data_key: str) -> bool:
|
| 122 |
+
return all(
|
| 123 |
+
data_key in sc.get("scores", {})
|
| 124 |
+
for sc in entry.get("subclaims", [])
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ── A. Threshold computation ──────────────────────────────────────────────────
|
| 129 |
+
|
| 130 |
+
def compute_thresholds() -> None:
|
| 131 |
+
"""Compute conformal thresholds and write demo/data/thresholds.csv.
|
| 132 |
+
|
| 133 |
+
Iterates over all datasets, scoring methods, modes, groups, and alpha
|
| 134 |
+
values defined in constants.py. Each row corresponds to one lookup key
|
| 135 |
+
used by apply_threshold() at demo runtime.
|
| 136 |
+
|
| 137 |
+
CSV schema: dataset, mode, scoring_method, group, alpha, q_hat
|
| 138 |
+
"""
|
| 139 |
+
os.makedirs(os.path.dirname(OUTPUT_THRESHOLDS), exist_ok=True)
|
| 140 |
+
rows: list[dict] = []
|
| 141 |
+
|
| 142 |
+
for dataset, ds_info in _dataset_cfg["datasets"].items():
|
| 143 |
+
is_grouped = ds_info.get("is_grouped", False)
|
| 144 |
+
print(f"\n[{dataset}] Loading calibration data…")
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
calib_path = _find_calibration_file(dataset)
|
| 148 |
+
except FileNotFoundError as e:
|
| 149 |
+
print(f" SKIP: {e}")
|
| 150 |
+
continue
|
| 151 |
+
|
| 152 |
+
data = load_subclaim_data(calib_path)
|
| 153 |
+
calib_data, holdout_data = _split_calibration_holdout(data, is_grouped)
|
| 154 |
+
print(f" Loaded {len(data)} entries from {calib_path} (calib={len(calib_data)}, holdout={len(holdout_data)})")
|
| 155 |
+
|
| 156 |
+
for method_name, data_key in SCORING_METHOD_TO_DATA_KEY.items():
|
| 157 |
+
valid = [e for e in calib_data if _has_score(e, data_key)]
|
| 158 |
+
if not valid:
|
| 159 |
+
print(f" SKIP scoring_method='{method_name}': no entries with key '{data_key}'")
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
# ── Marginal mode ────────────────────────────────────────────────
|
| 163 |
+
for alpha in ALPHA_LEVELS:
|
| 164 |
+
q_hat = compute_threshold(alpha, valid, A_VALUE, data_key)
|
| 165 |
+
rows.append({
|
| 166 |
+
"dataset": dataset,
|
| 167 |
+
"mode": "marginal",
|
| 168 |
+
"scoring_method": method_name,
|
| 169 |
+
"group": "default",
|
| 170 |
+
"alpha": alpha,
|
| 171 |
+
"q_hat": q_hat,
|
| 172 |
+
})
|
| 173 |
+
|
| 174 |
+
print(f" marginal / {method_name}: {len(ALPHA_LEVELS)} thresholds computed")
|
| 175 |
+
|
| 176 |
+
# ── Conditional mode (grouped datasets only) ─────────────────────
|
| 177 |
+
if is_grouped:
|
| 178 |
+
grouped = _group_data(valid)
|
| 179 |
+
for grp, grp_data in grouped.items():
|
| 180 |
+
for alpha in ALPHA_LEVELS:
|
| 181 |
+
q_hat = compute_threshold(alpha, grp_data, A_VALUE, data_key)
|
| 182 |
+
rows.append({
|
| 183 |
+
"dataset": dataset,
|
| 184 |
+
"mode": "conditional",
|
| 185 |
+
"scoring_method": method_name,
|
| 186 |
+
"group": grp,
|
| 187 |
+
"alpha": alpha,
|
| 188 |
+
"q_hat": q_hat,
|
| 189 |
+
})
|
| 190 |
+
print(f" conditional / {method_name}: {len(grouped)} groups × {len(ALPHA_LEVELS)} thresholds computed")
|
| 191 |
+
|
| 192 |
+
# Write CSV
|
| 193 |
+
fieldnames = ["dataset", "mode", "scoring_method", "group", "alpha", "q_hat"]
|
| 194 |
+
with open(OUTPUT_THRESHOLDS, "w", newline="") as f:
|
| 195 |
+
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
| 196 |
+
writer.writeheader()
|
| 197 |
+
writer.writerows(rows)
|
| 198 |
+
|
| 199 |
+
print(f"\nThresholds saved → {OUTPUT_THRESHOLDS} ({len(rows)} rows)")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# ── B. Sample query precomputation ────────────────────────────────────────────
|
| 203 |
+
|
| 204 |
+
def _entry_to_subclaim_result(entry: dict, dataset: str, is_grouped: bool) -> dict:
|
| 205 |
+
"""Convert an existing subclaim data entry to SubclaimResult format.
|
| 206 |
+
|
| 207 |
+
Uses DEFAULT_SCORING_METHOD and the corresponding data key so that the
|
| 208 |
+
stored score is directly comparable with the thresholds in thresholds.csv.
|
| 209 |
+
"""
|
| 210 |
+
data_key = SCORING_METHOD_TO_DATA_KEY[DEFAULT_SCORING_METHOD]
|
| 211 |
+
group = entry["groups"][0] if is_grouped and entry.get("groups") else "default"
|
| 212 |
+
mode = "conditional" if is_grouped else "marginal"
|
| 213 |
+
|
| 214 |
+
subclaims = [
|
| 215 |
+
{
|
| 216 |
+
"id": i,
|
| 217 |
+
"text": sc["subclaim"],
|
| 218 |
+
# Add noise so the score is in the same space as q_hat.
|
| 219 |
+
# q_hat is calibrated using (score + noise) in calibration/utils.py.
|
| 220 |
+
"score": float(sc["scores"].get(data_key, 0.0))
|
| 221 |
+
+ float(sc["scores"].get("noise", 0.0)),
|
| 222 |
+
}
|
| 223 |
+
for i, sc in enumerate(entry.get("subclaims", []))
|
| 224 |
+
]
|
| 225 |
+
|
| 226 |
+
return {
|
| 227 |
+
"query": entry["query"],
|
| 228 |
+
"dataset": dataset,
|
| 229 |
+
"mode": mode,
|
| 230 |
+
"group": group,
|
| 231 |
+
"scoring_method": DEFAULT_SCORING_METHOD,
|
| 232 |
+
"rag_answer": entry["response"],
|
| 233 |
+
"retrieved_docs": entry["retrieved_docs"],
|
| 234 |
+
"subclaims": subclaims,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def _pick_entries(data: list[dict], is_grouped: bool) -> list[dict]:
|
| 239 |
+
"""Pick entries with non-empty subclaims from the holdout set.
|
| 240 |
+
|
| 241 |
+
data is expected to be the holdout portion returned by
|
| 242 |
+
_split_calibration_holdout(), so at most HOLDOUT_PER_GROUP entries per
|
| 243 |
+
group are available.
|
| 244 |
+
"""
|
| 245 |
+
has_subclaims = [e for e in data if e.get("subclaims")]
|
| 246 |
+
|
| 247 |
+
if not is_grouped:
|
| 248 |
+
return has_subclaims[:HOLDOUT_PER_GROUP]
|
| 249 |
+
|
| 250 |
+
grouped = _group_data(has_subclaims)
|
| 251 |
+
picked = []
|
| 252 |
+
for grp_entries in grouped.values():
|
| 253 |
+
picked.extend(grp_entries[:HOLDOUT_PER_GROUP])
|
| 254 |
+
return picked
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def compute_samples() -> None:
|
| 258 |
+
"""Precompute SubclaimResult objects for sample queries.
|
| 259 |
+
|
| 260 |
+
Converts existing subclaim data (produced by main.py) to SubclaimResult
|
| 261 |
+
format without making any API calls. Results are saved with metadata so
|
| 262 |
+
it is easy to detect when re-generation is needed (e.g. after a model
|
| 263 |
+
change).
|
| 264 |
+
"""
|
| 265 |
+
os.makedirs(os.path.dirname(OUTPUT_SAMPLES), exist_ok=True)
|
| 266 |
+
samples: list[dict] = []
|
| 267 |
+
|
| 268 |
+
for dataset, ds_info in _dataset_cfg["datasets"].items():
|
| 269 |
+
is_grouped = ds_info.get("is_grouped", False)
|
| 270 |
+
print(f"\n[{dataset}] Picking sample queries…")
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
calib_path = _find_calibration_file(dataset)
|
| 274 |
+
except FileNotFoundError as e:
|
| 275 |
+
print(f" SKIP: {e}")
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
data = load_subclaim_data(calib_path)
|
| 279 |
+
_, holdout_data = _split_calibration_holdout(data, is_grouped)
|
| 280 |
+
picked = _pick_entries(holdout_data, is_grouped)
|
| 281 |
+
|
| 282 |
+
for entry in picked:
|
| 283 |
+
result = _entry_to_subclaim_result(entry, dataset, is_grouped)
|
| 284 |
+
samples.append(result)
|
| 285 |
+
print(f" + query={result['query'][:60]!r} group={result['group']}")
|
| 286 |
+
|
| 287 |
+
output = {
|
| 288 |
+
"metadata": {
|
| 289 |
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
| 290 |
+
"model": RESPONSE_MODEL,
|
| 291 |
+
"scoring_method": DEFAULT_SCORING_METHOD,
|
| 292 |
+
},
|
| 293 |
+
"samples": samples,
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
with open(OUTPUT_SAMPLES, "w", encoding="utf-8") as f:
|
| 297 |
+
json.dump(output, f, indent=2, ensure_ascii=False)
|
| 298 |
+
|
| 299 |
+
print(f"\nSamples saved → {OUTPUT_SAMPLES} ({len(samples)} entries)")
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# ── Entry point ───────────────────────────────────────────────────────────────
|
| 303 |
+
|
| 304 |
+
if __name__ == "__main__":
|
| 305 |
+
print("=== Step A: Computing conformal thresholds ===")
|
| 306 |
+
compute_thresholds()
|
| 307 |
+
|
| 308 |
+
print("\n=== Step B: Precomputing sample queries ===")
|
| 309 |
+
compute_samples()
|
| 310 |
+
|
| 311 |
+
print("\nDone. Run `streamlit run demo/app.py` to launch the demo.")
|
docs/context/01_original_architecture.md
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ResponseQualityAssessment リポジトリ解析
|
| 2 |
+
|
| 3 |
+
> 対象コード: `ResponseQualityAssessment/`(conformal-rag-demo サブモジュール)
|
| 4 |
+
> 論文: "Response Quality Assessment for Retrieval-Augmented Generation via Conditional Conformal Factuality"(SIGIR 2025 想定)
|
| 5 |
+
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## 1. システム概要
|
| 9 |
+
|
| 10 |
+
RAG(Retrieval-Augmented Generation)システムの回答品質を評価するパイプラインの実装。**Conformal Prediction** を用いて、LLM の回答をサブクレームに分解し、各サブクレームの事実性スコアを計算・校正することで、統計的保証付きの品質評価を実現する。
|
| 11 |
+
|
| 12 |
+
### 主な研究貢献
|
| 13 |
+
|
| 14 |
+
- **Split Conformal Prediction**: キャリブレーションデータから閾値を算出し、テストデータへ適用
|
| 15 |
+
- **Group Conditional Conformal**: グループ別(例: 医療カテゴリ)の閾値を独立に算出
|
| 16 |
+
- 複数のスコアリング手法(類似度・頻度・対数確率など)の横断比較
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## 2. ディレクトリ構成
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
ResponseQualityAssessment/
|
| 24 |
+
├── conf/
|
| 25 |
+
│ ├── config.yaml # 実行全体の設定
|
| 26 |
+
│ ├── dataset_config.yaml # データセット別設定
|
| 27 |
+
│ └── path_config.yaml # ファイルパス設定
|
| 28 |
+
├── data/
|
| 29 |
+
│ ├── raw/ # 生データ(FactScore, HotpotQA, PopQA, MedLFQA, WikiDB)
|
| 30 |
+
│ ├── processed/ # 標準化済みクエリ・文書データ
|
| 31 |
+
│ ├── out/ # サブクレーム+スコア出力
|
| 32 |
+
│ └── result/ # 最終結果・可視化
|
| 33 |
+
├── index_store/ # FAISS インデックスとマッピング
|
| 34 |
+
├── logs/ # 実行ログ
|
| 35 |
+
├── src/
|
| 36 |
+
│ ├── calibration/ # 共形予測キャリブレーション
|
| 37 |
+
│ ├── common/ # 共通コンポーネント(設定・ファイル・LLM 管理)
|
| 38 |
+
│ ├── data_processor/ # データセット処理パイプライン
|
| 39 |
+
│ ├── dataloader/ # HuggingFace データローダー
|
| 40 |
+
│ ├── rag/ # RAG(SQLite ベース文書 DB)
|
| 41 |
+
│ ├── subclaim_processor/ # 回答品質評価の中核パイプライン
|
| 42 |
+
│ └── utils/ # 汎用ユーティリティ
|
| 43 |
+
├── main.py # エントリーポイント
|
| 44 |
+
└── requirements.txt # 依存ライブラリ
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## 3. 使用ライブラリ
|
| 50 |
+
|
| 51 |
+
| ライブラリ | 用途 |
|
| 52 |
+
|---|---|
|
| 53 |
+
| `openai` | LLM API(GPT-4o-mini による応答生成・サブクレーム抽出・アノテーション)、埋め込みモデル |
|
| 54 |
+
| `faiss-cpu` | ベクトル近傍探索(FAISS `IndexFlatIP`、L2 正規化後の内積=コサイン類似度) |
|
| 55 |
+
| `sentence-transformers` / `transformers` | 埋め込みモデル(補助的)|
|
| 56 |
+
| `datasets` | HuggingFace からの QA データセット取得 |
|
| 57 |
+
| `langchain` | LLM オーケストレーション |
|
| 58 |
+
| `numpy` | 数値計算(共形予測の分位点計算、スコア演算)|
|
| 59 |
+
| `torch` | Transformer モデルのバックエンド |
|
| 60 |
+
| `PyPDF2` | PDF 文書のテキスト抽出 |
|
| 61 |
+
| `matplotlib` | キャリブレーション結果の可視化 |
|
| 62 |
+
| `jsonschema` | 各フェーズの出力データのスキーマ検証 |
|
| 63 |
+
| `python-dotenv` | `.env` からの API キー読み込み |
|
| 64 |
+
| `flask` / `flask-cors` | Web API(コードに存在するが現状未使用)|
|
| 65 |
+
|
| 66 |
+
**Python バージョン**: 3.11
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## 4. アーキテクチャ設計
|
| 71 |
+
|
| 72 |
+
### 4.1. 設計パターン
|
| 73 |
+
|
| 74 |
+
| パターン | 適用箇所 |
|
| 75 |
+
|---|---|
|
| 76 |
+
| **Strategy パターン** | スコアリング戦略(`ProductScoreStrategy`)、集約戦略(`MeanAggregation` / `MaxAggregation`)、チャンキング戦略 |
|
| 77 |
+
| **Template Method パターン** | `RawDataProcessor`(抽象基底)← データセット別実装、`ICalibration` ← 共形予測実装 |
|
| 78 |
+
| **Factory / Dispatcher パターン** | `QueryProcessor` がデータセット名に基づき各プロセッサへ委譲 |
|
| 79 |
+
| **Manager パターン** | `ConfigManager`、`FAISSIndexManager`、`FileManager`、`OpenAIManager` による各リソースのライフサイクル管理 |
|
| 80 |
+
| **Pipeline アーキテクチャ** | `main.py` で各ステージを順次呼び出す直列パイプライン |
|
| 81 |
+
| **Dependency Injection** | コンストラクタ経由でコンポーネントを注入(`faiss_manager`、`scorer` など)|
|
| 82 |
+
|
| 83 |
+
### 4.2. コンポーネント相関図
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
main.py
|
| 87 |
+
├─ ConfigManager … YAML 読み込み・ログ設定
|
| 88 |
+
├─ DataLoader … HuggingFace / Wikipedia SQLite DB
|
| 89 |
+
├─ QueryProcessor … データセット別に標準化
|
| 90 |
+
├─ FAISSIndexManager … ベクトルインデックス作成・検索
|
| 91 |
+
│ └─ OpenAIManager … text-embedding-3-large で埋め込み生成
|
| 92 |
+
├─ process_subclaims() … サブクレーム処理オーケスト���ーター
|
| 93 |
+
│ └─ SubclaimProcessor
|
| 94 |
+
│ ├─ OpenAIRAGAgent … RAG 応答生成
|
| 95 |
+
│ ├─ OpenAIAtomicFactGenerator … サブクレーム抽出(logprobs 付き)
|
| 96 |
+
│ ├─ OpenAIClaimVerification … 事実性アノテーション(S/I/U/N)
|
| 97 |
+
│ └─ SubclaimScorer … 7 種スコア計算
|
| 98 |
+
└─ SplitConformalCalibration / GroupConditionalConformal
|
| 99 |
+
└─ calibration/utils … r_a スコア・分位点閾値計算
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
---
|
| 103 |
+
|
| 104 |
+
## 5. 処理パイプライン詳細
|
| 105 |
+
|
| 106 |
+
### Step 1: データ読み込み・標準化
|
| 107 |
+
|
| 108 |
+
| 処理 | 実装クラス |
|
| 109 |
+
|---|---|
|
| 110 |
+
| HuggingFace からデータセット取得 | `DataLoader.load_qa_data()` |
|
| 111 |
+
| Wikipedia SQLite DB 構築 | `DataLoader.create_wiki_db()` |
|
| 112 |
+
| データセット別正規化 | `FactScoreProcessor`, `HotpotQAProcessor`, `PopQAProcessor`, `MedLFQAProcessor` |
|
| 113 |
+
|
| 114 |
+
#### データセット別 外部コーパス対応表
|
| 115 |
+
|
| 116 |
+
FAISS インデックスの構築・検索に用いる外部コーパスはデータセットによって異なる。
|
| 117 |
+
|
| 118 |
+
| データセット | 外部コーパス | 取得元 / 形式 |
|
| 119 |
+
|---|---|---|
|
| 120 |
+
| **FActScore** | Wikipedia(2023-04-01 ダンプ) | bz2 ダンプ → SQLite DB(`enwiki-20230401.db`)|
|
| 121 |
+
| **HotpotQA** | Wikipedia(2023-04-01 ダンプ) | 同上(HuggingFace `kilt_tasks/hotpotqa` と組み合わせ)|
|
| 122 |
+
| **PopQA** | Wikipedia(2023-04-01 ダンプ) | 同上(HuggingFace `akariasai/PopQA` と組み合わせ)|
|
| 123 |
+
| **MedLFQA** | MedLFQAv2 の QA データセット自体 | GitHub `jjcherian/conformal-safety` の JSONL ファイル群(`healthsearch_qa`, `kqa_golden`, `kqa_silver_wogold`, `live_qa`, `medication_qa`)|
|
| 124 |
+
|
| 125 |
+
FActScore・HotpotQA・PopQA は共通の Wikipedia SQLite DB(`DocDB`)から文書を取得する。MedLFQA のみ Wikipedia を使わず、各 JSONL ファイルに格納された `retrieved_passages` フィールドの文書をそのままコーパスとして使用する。
|
| 126 |
+
|
| 127 |
+
出力形式(JSON):
|
| 128 |
+
```json
|
| 129 |
+
{"input": "クエリ文字列", "output": {"answer": "正解", "provenance": [...]}}
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
### Step 2: FAISS インデックス構築
|
| 133 |
+
|
| 134 |
+
#### 構築フロー
|
| 135 |
+
|
| 136 |
+
全データセット共通で、以下の 2 フェーズに分かれる。
|
| 137 |
+
|
| 138 |
+
**フェーズ A: 文書収集(データセット別)**
|
| 139 |
+
|
| 140 |
+
各プロセッサが「どこから文書を取ってくるか」が異なり、最終的に同一 JSON 形式で保存される。
|
| 141 |
+
|
| 142 |
+
| データセット | 文書取得元 | 1クエリあたり文書数 |
|
| 143 |
+
|---|---|---|
|
| 144 |
+
| **FActScore** | Wikipedia SQLite DB(provenance タイトル 1 件でタイトル引き) | 1 記事 |
|
| 145 |
+
| **PopQA** | Wikipedia SQLite DB(`s_wiki_title` 1 件、重複排除あり) | 1 記事 |
|
| 146 |
+
| **HotpotQA** | Wikipedia SQLite DB(provenance の複数タイトル、重複排除あり) | 2〜3 記事 |
|
| 147 |
+
| **MedLFQA** | JSONL の `Free_form_answer` を文単位に分割 + `Nice_to_have` リスト | 10〜20 文(短文) |
|
| 148 |
+
|
| 149 |
+
Wikipedia 3 データセットは Wikipedia 全文をそのまま FAISS に入れるのではなく、各クエリの provenance タイトルに対応する記事のみを SQLite から引き当てて保存する。MedLFQA は Wikipedia を使わず、JSONL ファイルに格納済みの文書を直接使用する。
|
| 150 |
+
|
| 151 |
+
**フェーズ B: 埋め込み・インデックス化(全データセット共通、[main.py:211-275](main.py#L211-L275))**
|
| 152 |
+
|
| 153 |
+
1. 文書を **Fixed-Length チャンカー**(デフォルト: 2000 語、25 語オーバーラップ)で分割
|
| 154 |
+
2. OpenAI `text-embedding-3-large`(次元数: 3072)で埋め込みベクトルを生成
|
| 155 |
+
3. L2 正規化 → `IndexFlatIP`(内積 ≒ コサイン類似度)に追加
|
| 156 |
+
4. インデックスバイナリと `indice2fm` マッピング(FAISS インデックス ID → ファイル位置)を保存
|
| 157 |
+
|
| 158 |
+
各データセット独立したインデックスとして `index_store/{Dataset}/` に保存される。インデックスは `.gitignore` によりリポジトリ管理外であり、**実行のたびに再構築が必要**。
|
| 159 |
+
|
| 160 |
+
#### text-embedding-3-large の構築コスト見積もり(`query_size: 500`、$0.13/1M トークン)
|
| 161 |
+
|
| 162 |
+
| データセット | ユニーク文書数の目安 | チャンク/文書 | 総トークン数 | 推定コスト |
|
| 163 |
+
|---|---|---|---|---|
|
| 164 |
+
| **FActScore** | ~500 記事 | ~2 | ~2.7M | ~$0.35 |
|
| 165 |
+
| **PopQA** | ~500 記事 | ~2 | ~2.7M | ~$0.35 |
|
| 166 |
+
| **HotpotQA** | ~1,000 記事(2〜3 provenance × 重複排除) | ~2 | ~5.4M | ~$0.70 |
|
| 167 |
+
| **MedLFQA** | ~500 クエリ × ~15 文 | チャンク分割なし(短文) | ~0.2M | ~$0.03 |
|
| 168 |
+
| **合計(4 データセット)** | | | ~11M | **~$1.4** |
|
| 169 |
+
|
| 170 |
+
これは**1 回限りのインデックス構築コスト**。推論時のクエリ埋め込み・サブクレーム埋め込みは数千トークン程度で無視できる。Wikipedia 記事長のばらつきにより±2 倍程度の誤差あり。
|
| 171 |
+
|
| 172 |
+
### Step 3: 応答生成
|
| 173 |
+
|
| 174 |
+
`SubclaimProcessor.generate_responses()`:
|
| 175 |
+
1. クエリに対し FAISS から上位 k 件(デフォルト: 10)の文書を検索(閾値: 0.3 以上)
|
| 176 |
+
2. 検索文書をコンテキストとして GPT-4o-mini に渡し、初期回答を生成
|
| 177 |
+
|
| 178 |
+
### Step 4: サブクレーム抽出
|
| 179 |
+
|
| 180 |
+
`OpenAIAtomicFactGenerator.get_facts_from_text()`:
|
| 181 |
+
- LLM に回答テキストを渡し、セミコロン区切りの原子的事実に分解
|
| 182 |
+
- `logprobs=True, top_logprobs=1` でトークンごとの対数確率を取得
|
| 183 |
+
- 各サブクレームの対数確率リストを保持
|
| 184 |
+
|
| 185 |
+
### Step 5: サブクレームスコアリング
|
| 186 |
+
|
| 187 |
+
`SubclaimProcessor.score_subclaim()` で 7 種のスコアを計算:
|
| 188 |
+
|
| 189 |
+
| スコア名 | 計算方法 | 意味 |
|
| 190 |
+
|---|---|---|
|
| 191 |
+
| `relavance`(relevance) | FAISS 検索スコア × サブクレーム-文書コサイン類似度 の積、文書間で集約 | 検索文書との総合関連度 |
|
| 192 |
+
| `query_claim_cosine_similarity` | クエリ埋め込み ↔ サブクレーム埋め込み のコサイン類似度 | クエリとの意味的整合性 |
|
| 193 |
+
| `doc_claim_cosine_similarity` | 全検索文書とサブクレームのコサイン類似度の最大値 | 文書との意味的整合性 |
|
| 194 |
+
| `frequency` | 温度 1.0 で 5 回サンプリングし、同内容のサブクレームが出現した割合 | LLM 自体の一貫性(自己信頼度)|
|
| 195 |
+
| `min_log_prob` | サブクレーム中のトークン対数確率の最小値 | LLM の生成確信度 |
|
| 196 |
+
| `random` | Uniform(0, 1) | ベースライン |
|
| 197 |
+
| `ordinal` | `i / サブクレーム数` | 応答内の出現順序(ベースライン)|
|
| 198 |
+
|
| 199 |
+
> 全スコアに N(0, 0.001) のガウスノイズを付加(安定化のため)
|
| 200 |
+
|
| 201 |
+
集約戦略(文書間集約):
|
| 202 |
+
- `MeanAggregation`: 文書スコアの平均
|
| 203 |
+
- `MaxAggregation`: 文書スコアの最大値
|
| 204 |
+
|
| 205 |
+
スコアリング戦略:
|
| 206 |
+
- `ProductScoreStrategy`: FAISS スコア × コサイン類似度 の積(現状唯一の実装)
|
| 207 |
+
|
| 208 |
+
### Step 6: アノテーション
|
| 209 |
+
|
| 210 |
+
`OpenAIClaimVerification.annotate()`:
|
| 211 |
+
- サブクレームを検索文書・正解と照合し、GPT-4o-mini が 4 段階ラベルを付与
|
| 212 |
+
|
| 213 |
+
| ラベル | 意味 |
|
| 214 |
+
|---|---|
|
| 215 |
+
| `S` (Supported) | 事実として支持される |
|
| 216 |
+
| `I` (Irrelevant) | 質問と無関係 |
|
| 217 |
+
| `U` (Unverifiable) | 検証不可能 |
|
| 218 |
+
| `N` (Nonfactual) | 事実に反する |
|
| 219 |
+
|
| 220 |
+
### Step 7: 共形キャリブレーション
|
| 221 |
+
|
| 222 |
+
#### Split Conformal Prediction
|
| 223 |
+
|
| 224 |
+
```
|
| 225 |
+
For each alpha in [0.05, 0.10, ..., 0.40]:
|
| 226 |
+
Repeat 1000 runs:
|
| 227 |
+
1. データをシャッフルし 50/50 でキャリブ/テストに分割
|
| 228 |
+
2. キャリブデータで各サブクレームの r_a スコアを計算:
|
| 229 |
+
r_a(x) = "サブクレームを低スコア順に除去したとき、
|
| 230 |
+
残存サブクレームの正確率が初めて a 以上になる最小閾値"
|
| 231 |
+
3. 閾値 q̂ = ceil((n+1)*(1-alpha))/n 分位点(共形予測の保証付き計算)
|
| 232 |
+
4. テストデータで閾値 q̂ を適用し、除去率・正確率を計算
|
| 233 |
+
```
|
| 234 |
+
|
| 235 |
+
**キャリブレーションのデータ量について**
|
| 236 |
+
|
| 237 |
+
キャリブレーションに使うデータは `query_size` で処理したクエリ全体を 50/50 分割したものであり、専用のキャリブレーションセットは別途用意しない。デフォルト(`query_size: 500`)の場合、1 ランあたりキャリブレーション 250 件・テスト 250 件となる。各データセットの総件数(HotpotQA ~5,600 件、PopQA ~14,267 件など)に対して 500 件はサンプリング上限であり、`query_size` を増やすほどキャリブレーション精度は上がるが LLM 呼び出しコストも線形に増加する。1000 回のランダムシャッフルを繰り返すことで、少ないサンプル数でも統計的ばらつきを吸収している。
|
| 238 |
+
|
| 239 |
+
#### Group Conditional Conformal
|
| 240 |
+
|
| 241 |
+
`GroupConditionalConformal`: グループ(MedLFQA のカテゴリなど)ごとに独立してキャリブレーションを実行し、グループ別閾値を算出。
|
| 242 |
+
|
| 243 |
+
### Step 8: 結果出力
|
| 244 |
+
|
| 245 |
+
- **CSV**: alpha × confidence_method ごとの除去率・正確率
|
| 246 |
+
- **PNG プロット**: 共形除去カーブ(除去率 vs. 事実性保証)、事実的正確率カーブ
|
| 247 |
+
- **JSON**: 全サブクレーム(スコア・アノテーション込み)
|
| 248 |
+
- **config/**: 再現性のため実行設定の YAML コピー
|
| 249 |
+
|
| 250 |
+
---
|
| 251 |
+
|
| 252 |
+
## 6. データフロー
|
| 253 |
+
|
| 254 |
+
```
|
| 255 |
+
生データ(HuggingFace / bz2 Wikipedia)
|
| 256 |
+
↓ DataLoader
|
| 257 |
+
SQLite Wikipedia DB + 標準化クエリ JSON
|
| 258 |
+
↓ FileManager + OpenAIManager
|
| 259 |
+
チャンク埋め込み行列(N × 3072)
|
| 260 |
+
↓ FAISSIndexManager
|
| 261 |
+
FAISS IndexFlatIP + indice2fm マッピング
|
| 262 |
+
↓ SubclaimProcessor.generate_responses()
|
| 263 |
+
{query, gld_ans, retrieved_docs, response, groups}
|
| 264 |
+
↓ OpenAIAtomicFactGenerator
|
| 265 |
+
{subclaim_text, log_prob_list}(サブクレーム + 対数確率)
|
| 266 |
+
↓ SubclaimScorer (7 種スコア)
|
| 267 |
+
{scores: {relavance, query_claim_cosine, ..., ordinal}}
|
| 268 |
+
↓ OpenAIClaimVerification
|
| 269 |
+
{annotations: {gpt: "S"|"I"|"U"|"N"}}
|
| 270 |
+
↓ SplitConformalCalibration / GroupConditionalConformal
|
| 271 |
+
除去率・正確率の統計(1000 runs × alpha × method)
|
| 272 |
+
↓
|
| 273 |
+
CSV / PNG / config YAML
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
---
|
| 277 |
+
|
| 278 |
+
## 7. 設定ファイル
|
| 279 |
+
|
| 280 |
+
### `conf/config.yaml`(主要パラメータ)
|
| 281 |
+
|
| 282 |
+
```yaml
|
| 283 |
+
dataset:
|
| 284 |
+
name: "pop_qa" # fact_score | hotpot_qa | pop_qa | medlf_qa
|
| 285 |
+
query_size: 500 # 処理クエリ数(-1 = 全件)
|
| 286 |
+
|
| 287 |
+
index:
|
| 288 |
+
embedding_model: "text-embedding-3-large"
|
| 289 |
+
truncation_config:
|
| 290 |
+
strategy: "fixed_length"
|
| 291 |
+
chunk_size: 2000 # チャンクあたり単語数
|
| 292 |
+
chunk_overlap: 25
|
| 293 |
+
|
| 294 |
+
rag:
|
| 295 |
+
retrival_topk: 10 # 検索上位件数
|
| 296 |
+
retrival_threshold: 0.3 # コサイン類似度の下限閾値
|
| 297 |
+
response_model: "gpt-4o-mini"
|
| 298 |
+
|
| 299 |
+
conformal_prediction:
|
| 300 |
+
aggregation_strategy: "mean" # mean | max
|
| 301 |
+
scoring_strategy: "product"
|
| 302 |
+
split_conformal: true
|
| 303 |
+
conformal_alphas:
|
| 304 |
+
start: 0.05
|
| 305 |
+
end: 0.45
|
| 306 |
+
step: 0.05
|
| 307 |
+
a_value: 1.0 # 目標正確率(0〜1)
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
### `conf/dataset_config.yaml`
|
| 311 |
+
|
| 312 |
+
```yaml
|
| 313 |
+
datasets:
|
| 314 |
+
fact_score:
|
| 315 |
+
is_grouped: false # グループ条件付き共形を無効
|
| 316 |
+
medlf_qa:
|
| 317 |
+
is_grouped: true # グループ条件付き共形を有効
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
---
|
| 321 |
+
|
| 322 |
+
## 8. 主要クラス・関数一覧
|
| 323 |
+
|
| 324 |
+
### 設定・管理系
|
| 325 |
+
|
| 326 |
+
| クラス / 関数 | 役割 |
|
| 327 |
+
|---|---|
|
| 328 |
+
| `ConfigManager` | YAML 設定の読み込み・保存・更新、ロギング設定 |
|
| 329 |
+
| `FileManager` | 文書(PDF/テキスト)処理、チャンキング、埋め込みキャッシュ |
|
| 330 |
+
| `FAISSIndexManager` | FAISS インデックスの作成・読み込み・検索、ファイル位置マッピング |
|
| 331 |
+
| `OpenAIManager` | OpenAI API ラッパー(埋め込み・アシスタント・スレッド)|
|
| 332 |
+
|
| 333 |
+
### データ処理系
|
| 334 |
+
|
| 335 |
+
| クラス / 関数 | 役割 |
|
| 336 |
+
|---|---|
|
| 337 |
+
| `DataLoader.load_qa_data()` | 4 データセットを HuggingFace から取得 |
|
| 338 |
+
| `DataLoader.create_wiki_db()` | Wikipedia bz2 ダンプから SQLite DB 構築 |
|
| 339 |
+
| `QueryProcessor.get_queries()` | 標準スキーマへの変換・サンプリング |
|
| 340 |
+
| `DocDB` | SQLite ベースの文書検索(FActScore 実装を参照)|
|
| 341 |
+
|
| 342 |
+
### RAG・スコアリング系
|
| 343 |
+
|
| 344 |
+
| クラス / 関数 | 役割 |
|
| 345 |
+
|---|---|
|
| 346 |
+
| `OpenAIRAGAgent.answer()` | クエリ + 検索文書 → LLM 回答生成 |
|
| 347 |
+
| `OpenAIAtomicFactGenerator.get_facts_from_text()` | テキスト → 原子的事実リスト(logprobs 付き)|
|
| 348 |
+
| `OpenAIClaimVerification.annotate()` | サブクレームの S/I/U/N ラベル付け |
|
| 349 |
+
| `SubclaimScorer.score()` | コサイン類似度ベースの関連度スコア計算 |
|
| 350 |
+
| `SubclaimScorer.frequency_score()` | LLM 複数サンプリングによる一貫性スコア |
|
| 351 |
+
|
| 352 |
+
### キャリブレーション系
|
| 353 |
+
|
| 354 |
+
| クラス / 関数 | 役割 |
|
| 355 |
+
|---|---|
|
| 356 |
+
| `get_r_score(entry, method, a)` | `r_a` スコア(閾値を下げたとき正確率が `a` を超える最小値)の計算 |
|
| 357 |
+
| `compute_threshold(alpha, data, a, method)` | 共形予測の分位点閾値 `q̂` の計算 |
|
| 358 |
+
| `SplitConformalCalibration.plot_conformal_removal()` | 除去率 vs. 事実性保証のカーブ生成 |
|
| 359 |
+
| `GroupConditionalConformal` | グループ別閾値による条件付き共形予測 |
|
| 360 |
+
|
| 361 |
+
---
|
| 362 |
+
|
| 363 |
+
## 9. テスト・品質保証
|
| 364 |
+
|
| 365 |
+
専用のテストスイートは存在しない。品質保証は以下の方法で実施:
|
| 366 |
+
|
| 367 |
+
- **JSON スキーマ検証**: `jsonschema` による各ステージ出力の構造検証(`base_schema.json`, `wiki_schema.json`, `subclaims_schema.json`)
|
| 368 |
+
- **インデックス整合性チェック**: `FAISSIndexManager.is_indice_align()` によるインデックスとマッピングの同期確認
|
| 369 |
+
- **バッチ処理の冪等性**: サブクレーム処理は既存の処理済みエントリをスキップ(再実行安全)
|
| 370 |
+
- **設定ログ**: 実行時の設定 YAML を `config/` ディレクトリにコピーして再現性を担保
|
| 371 |
+
|
| 372 |
+
---
|
| 373 |
+
|
| 374 |
+
## 10. 既知の制約・拡張ポイント
|
| 375 |
+
|
| 376 |
+
| 項目 | 現状 | 拡張の方向性 |
|
| 377 |
+
|---|---|---|
|
| 378 |
+
| スコアリング戦略 | `ProductScoreStrategy` のみ実装 | Strategy パターンにより追加容易 |
|
| 379 |
+
| テキストチャンキング | Fixed-Length のみ(Recursive は未実装) | `FixedLengthChunker` と同インターフェースで追加可 |
|
| 380 |
+
| 共形予測の種別 | Split / Group Conditional の 2 種 | オンライン共形予測などへの拡張余地あり |
|
| 381 |
+
| 並列処理 | なし(完全逐次処理) | バッチ API や非同期処理で高速化可 |
|
| 382 |
+
| Web API | `flask` が依存関係に含まれるが未使用 | デモ API サーバー化の際に活用可 |
|
| 383 |
+
| FAISS コーパス範囲 | 各クエリの provenance に対応する文書のみインデックス化(正解文書が事前判明している前提) | データセット外の任意クエリに対応するには Wikipedia 全文などコーパス全体をインデックス化する必要がある |
|
| 384 |
+
| 全コーパス検索時の精度 | `IndexFlatIP` は完全探索のため数学的な top-k 精度は落ちない。ただし埋め込み空間に無関係文書が増えるこ���で真に関連する文書が top-k から押し出される可能性がある。FAISS スコアはスコアリングにも使われるため最終的な品質評価精度にも直結する | ドメイン特化の fine-tuned 埋め込みモデルへの切り替えや BM25 との hybrid retrieval が対策として有効 |
|
| 385 |
+
|
| 386 |
+
---
|
| 387 |
+
|
| 388 |
+
## 11. 実行方法
|
| 389 |
+
|
| 390 |
+
```bash
|
| 391 |
+
# 基本実行
|
| 392 |
+
python main.py --config conf/config.yaml
|
| 393 |
+
|
| 394 |
+
# データセット・クエリ数を上書き指定
|
| 395 |
+
python main.py --dataset hotpot_qa --query_size 200
|
| 396 |
+
|
| 397 |
+
# カスタム実行 ID を付与
|
| 398 |
+
python main.py --run_id my_experiment_01
|
| 399 |
+
```
|
| 400 |
+
|
| 401 |
+
環境変数(`.env`):
|
| 402 |
+
```
|
| 403 |
+
OPENAI_API_KEY=sk-...
|
| 404 |
+
```
|
docs/context/02_design_patterns_analysis.md
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 設計パターン詳細分析
|
| 2 |
+
|
| 3 |
+
## 概要
|
| 4 |
+
|
| 5 |
+
本ドキュメントでは、現在の `src/` に適用されている設計パターンをコードベースに基づいて詳細に解説し、問題点と改善案を記録する。
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## 1. Strategy パターン
|
| 10 |
+
|
| 11 |
+
### 目的
|
| 12 |
+
アルゴリズムをオブジェクトとして差し替え可能にする。
|
| 13 |
+
|
| 14 |
+
### 適用箇所
|
| 15 |
+
```
|
| 16 |
+
src/subclaim_processor/strategies/
|
| 17 |
+
├── aggregation/
|
| 18 |
+
│ ├── base.py → AggregationStrategy (ABC)
|
| 19 |
+
│ ├── mean.py → MeanAggregation
|
| 20 |
+
│ └── max.py → MaxAggregation
|
| 21 |
+
└── scoring/
|
| 22 |
+
├── base.py → ScoringStrategy (ABC)
|
| 23 |
+
└── product.py → ProductScoreStrategy
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### 実際の動作(`subclaim_scorer.py:46-68`)
|
| 27 |
+
```python
|
| 28 |
+
AGGREGATION_STRATEGIES: Dict[str, Callable] = {
|
| 29 |
+
"max": MaxAggregation,
|
| 30 |
+
"mean": MeanAggregation,
|
| 31 |
+
}
|
| 32 |
+
SCORING_STRATEGIES: Dict[str, Callable] = {
|
| 33 |
+
"product": ProductScoreStrategy,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def score(self, claim, retrieved_docs, aggregation_strategy, scoring_strategy):
|
| 37 |
+
agg_func = AGGREGATION_STRATEGIES[aggregation_strategy]() # "mean" or "max"
|
| 38 |
+
scoring_func = SCORING_STRATEGIES[scoring_strategy]() # "product"
|
| 39 |
+
|
| 40 |
+
for doc in retrieved_docs:
|
| 41 |
+
score = scoring_func.compute_score(...) # 差し替え可能
|
| 42 |
+
|
| 43 |
+
return agg_func.aggregate(doc_scores) # 差し替え可能
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
### 問題点
|
| 47 |
+
Strategy の選択が文字列キーの辞書で行われており、型安全性がない。`"product"` を typo しても実行時まで気づけない。
|
| 48 |
+
|
| 49 |
+
### 改善案
|
| 50 |
+
```python
|
| 51 |
+
# Enum で型安全にする
|
| 52 |
+
from enum import Enum
|
| 53 |
+
|
| 54 |
+
class AggregationMethod(Enum):
|
| 55 |
+
MEAN = "mean"
|
| 56 |
+
MAX = "max"
|
| 57 |
+
|
| 58 |
+
# または Protocol を使う(Python 3.8+)
|
| 59 |
+
from typing import Protocol
|
| 60 |
+
|
| 61 |
+
class AggregationStrategy(Protocol):
|
| 62 |
+
def aggregate(self, scores: list[float]) -> float: ...
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
## 2. Template Method パターン
|
| 68 |
+
|
| 69 |
+
### 目的
|
| 70 |
+
処理の骨格を親クラスで定め、詳細を子クラスに委ねる。
|
| 71 |
+
|
| 72 |
+
### 適用箇所
|
| 73 |
+
```
|
| 74 |
+
src/data_processor/raw_data_processor.py → IRawDataProcessor, DatasetProcessor (ABC)
|
| 75 |
+
src/calibration/base_calibration.py → ICalibration (ABC)
|
| 76 |
+
src/common/llm/llm_agent.py → LLMAgent (ABC)
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
### 現在の実装(`raw_data_processor.py`)
|
| 80 |
+
```python
|
| 81 |
+
class IRawDataProcessor(ABC):
|
| 82 |
+
@abstractmethod
|
| 83 |
+
def get_queries(self, input_file, output_file): pass
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def get_documents(self, query_file, output_file): pass
|
| 87 |
+
|
| 88 |
+
class DatasetProcessor(ABC):
|
| 89 |
+
@abstractmethod
|
| 90 |
+
def process_queries(self, input_file, **kwargs) -> list: pass
|
| 91 |
+
|
| 92 |
+
@abstractmethod
|
| 93 |
+
def process_documents(self, query_file, db, **kwargs) -> dict: pass
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### 問題点
|
| 97 |
+
抽象クラスが2つある(`IRawDataProcessor` と `DatasetProcessor`)。
|
| 98 |
+
|
| 99 |
+
| クラス | 役割 |
|
| 100 |
+
|--------|------|
|
| 101 |
+
| `IRawDataProcessor` | ファイルI/O込みのインターフェース |
|
| 102 |
+
| `DatasetProcessor` | 純粋な処理ロジックのインターフェース |
|
| 103 |
+
|
| 104 |
+
`QueryProcessor` は `IRawDataProcessor` を継承しつつ、内部で `DatasetProcessor` を利用するというダブル構造になっている。
|
| 105 |
+
|
| 106 |
+
真の Template Method パターンなら、親クラスに「骨格メソッド」が実装されているはずだが、現在の実装では `QueryProcessor.get_queries()` が骨格を担っている。
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## 3. Factory / Dispatcher パターン
|
| 111 |
+
|
| 112 |
+
### 目的
|
| 113 |
+
入力に基づいて適切なオブジェクトを生成・委譲する。
|
| 114 |
+
|
| 115 |
+
### 実際のコード(`query_processor.py:26-31`)
|
| 116 |
+
```python
|
| 117 |
+
self.processors = {
|
| 118 |
+
"fact_score": FactScoreProcessor(),
|
| 119 |
+
"hotpot_qa": HotpotQAProcessor(),
|
| 120 |
+
"pop_qa": PopQAProcessor(),
|
| 121 |
+
"medlf_qa": MedLFQAProcessor(),
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
# 使用時
|
| 125 |
+
processor = self.processors.get(dataset)
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
### 問題点
|
| 129 |
+
これは「Factory パターン」ではなく「Registry パターン」に近い。
|
| 130 |
+
|
| 131 |
+
| パターン | 特徴 |
|
| 132 |
+
|---------|------|
|
| 133 |
+
| Factory(正) | 必要時にオブジェクトを生成する |
|
| 134 |
+
| Registry(現在) | 起動時に全オブジェクトを一括生成して辞書に保持 |
|
| 135 |
+
|
| 136 |
+
`QueryProcessor()` を作った時点で4データセット全プロセッサがインスタンス化される。実害は少ないが、設計の意図とずれている。
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## 4. Manager パターン
|
| 141 |
+
|
| 142 |
+
### 目的
|
| 143 |
+
リソースのライフサイクル(生成・使用・破棄)を一元管理する。
|
| 144 |
+
|
| 145 |
+
### 各 Manager の評価
|
| 146 |
+
|
| 147 |
+
| Manager | 管理対象 | 実際の責任数 | 評価 |
|
| 148 |
+
|---------|---------|------------|------|
|
| 149 |
+
| `ConfigManager` | YAMLファイル | 設定読み込み + ログセットアップ | △ 2責任 |
|
| 150 |
+
| `FAISSIndexManager` | FAISSインデックス | インデックス管理 + 検索 + 応答生成 | ✗ 3責任 |
|
| 151 |
+
| `FileManager` | ドキュメント | ファイル読み込み + チャンキング | △ 2責任 |
|
| 152 |
+
| `OpenAIManager` | Embeddings API | API呼び出しのみ | ✓ 適切 |
|
| 153 |
+
|
| 154 |
+
### 最大の問題:`FAISSIndexManager`(`faiss_manager.py:245-286`)
|
| 155 |
+
```python
|
| 156 |
+
class FAISSIndexManager:
|
| 157 |
+
def upsert_file_to_faiss(...) # インデックス管理 ✓
|
| 158 |
+
def search_faiss_index(...) # 検索 ✓
|
| 159 |
+
def generate_response_from_context(...) # ← LLM応答生成(無関係)
|
| 160 |
+
def parse_result(...) # ← 文字列パース(雑多)
|
| 161 |
+
```
|
| 162 |
+
|
| 163 |
+
`generate_response_from_context` はインデックス管理とは無関係なのに `FAISSIndexManager` に含まれている。この責任は `OpenAIRAGAgent` に移すべき。
|
| 164 |
+
|
| 165 |
+
---
|
| 166 |
+
|
| 167 |
+
## 5. Pipeline アーキテクチャ
|
| 168 |
+
|
| 169 |
+
### 目的
|
| 170 |
+
処理を段階的に分けて順次実行する。
|
| 171 |
+
|
| 172 |
+
### 現在のパイプライン
|
| 173 |
+
```
|
| 174 |
+
main.py
|
| 175 |
+
1. DataLoader → 生データ取得
|
| 176 |
+
2. QueryProcessor → データ正規化
|
| 177 |
+
3. FAISSIndexManager → インデックス構築
|
| 178 |
+
4. SubclaimProcessor → 応答生成 → サブクレーム抽出 → スコアリング → アノテーション
|
| 179 |
+
5. SplitConformalCalibration → 統計的キャリブレーション
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
### 問題点:ステージ間の文字列シリアライズ(`faiss_manager.py:196-216`)
|
| 183 |
+
```python
|
| 184 |
+
# 検索結果を文字列として返す
|
| 185 |
+
results.append(
|
| 186 |
+
f"{text} indice={idx} fileposition={relative_idx} score={dist:.4f}"
|
| 187 |
+
# TODO reformat this ← 本人コメントあり
|
| 188 |
+
)
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
その後 `parse_result()` で正規表現パースし直している:
|
| 192 |
+
```python
|
| 193 |
+
pattern = re.compile(
|
| 194 |
+
r"page_content='(.*?)'\smetadata=(\{.*?\})\sindice=(\d+)\sfileposition=(\d+)\sscore=([\d.]+)",
|
| 195 |
+
re.DOTALL,
|
| 196 |
+
)
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
### 改善案
|
| 200 |
+
dataclass でステージ間のデータ型を定義すれば `parse_result()` が不要になる:
|
| 201 |
+
```python
|
| 202 |
+
@dataclass
|
| 203 |
+
class SearchResult:
|
| 204 |
+
text: str
|
| 205 |
+
indice: int
|
| 206 |
+
fileposition: int
|
| 207 |
+
score: float
|
| 208 |
+
```
|
| 209 |
+
|
| 210 |
+
---
|
| 211 |
+
|
| 212 |
+
## 6. Dependency Injection(DI)
|
| 213 |
+
|
| 214 |
+
### 目的
|
| 215 |
+
依存オブジェクトを外から注入してテスト容易性を高める。
|
| 216 |
+
|
| 217 |
+
### 現在の実装(`subclaim_scorer.py:29-44`)
|
| 218 |
+
```python
|
| 219 |
+
class SubclaimScorer(IDocumentScorer):
|
| 220 |
+
def __init__(self, index_truncation_config, embedding_model, index_path, ...):
|
| 221 |
+
self.faiss_manager = FAISSIndexManager(...) # ← 内部で直接生成
|
| 222 |
+
self.gen = OpenAIAtomicFactGenerator() # ← 内部で直接生成
|
| 223 |
+
self.openai_client = OpenAI() # ← 内部で直接生成
|
| 224 |
+
```
|
| 225 |
+
|
| 226 |
+
### 問題点
|
| 227 |
+
これは DI ではない。本当の DI は依存オブジェクトを外部から受け取る:
|
| 228 |
+
```python
|
| 229 |
+
# DI(正)
|
| 230 |
+
class SubclaimScorer:
|
| 231 |
+
def __init__(
|
| 232 |
+
self,
|
| 233 |
+
faiss_manager: FAISSIndexManager,
|
| 234 |
+
fact_generator: OpenAIAtomicFactGenerator,
|
| 235 |
+
openai_client: OpenAI,
|
| 236 |
+
):
|
| 237 |
+
self.faiss_manager = faiss_manager
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
現在の設計では `SubclaimScorer` の単体テストに本物の OpenAI API と FAISS インデックスが必要になる。
|
| 241 |
+
|
| 242 |
+
---
|
| 243 |
+
|
| 244 |
+
## 設計全体の評価サマリー
|
| 245 |
+
|
| 246 |
+
| パターン | 意図の正確さ | 実装の品質 | 改善優先度 |
|
| 247 |
+
|---------|------------|----------|----------|
|
| 248 |
+
| Strategy | ✓ 正しい | △ 文字列キーで型安全でない | 低(機能はしている) |
|
| 249 |
+
| Template Method | △ 混乱あり | △ 2つの抽象クラスが混在 | 中 |
|
| 250 |
+
| Factory/Dispatcher | △ Registry に近い | △ 起動時に全数生成 | 低 |
|
| 251 |
+
| Manager | △ 責任過多 | ✗ FAISSがLLM応答生成を担当 | 高 |
|
| 252 |
+
| Pipeline | ✓ 正しい | ✗ 文字列シリアライズ+正規表現パース | 高 |
|
| 253 |
+
| Dependency Injection | ✗ 名前だけ | ✗ 内部生成で外部注入不可 | 中 |
|
| 254 |
+
|
| 255 |
+
### 優先度「高」の改善項目(構造変更不要)
|
| 256 |
+
|
| 257 |
+
1. **`FAISSIndexManager.generate_response_from_context` の移動**
|
| 258 |
+
- 移動先:`OpenAIRAGAgent`(すでに存在する)
|
| 259 |
+
- 効果:Manager の単一責任原則回復
|
| 260 |
+
|
| 261 |
+
2. **パイプライン間データ型の構造化**
|
| 262 |
+
- `SearchResult` dataclass の導入
|
| 263 |
+
- `parse_result()` の正規表現解析の廃止
|
| 264 |
+
- 効果:バグ発生箇所の削減、デバッグ容易性の向上
|
docs/context/03_demo_app_design.md
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# デモアプリケーション設計
|
| 2 |
+
|
| 3 |
+
## 1. 技術選定
|
| 4 |
+
|
| 5 |
+
既存のバックエンド資産(Python 3.11, OpenAI, FAISS)を活かしつつ、デモ用途に最適化する。
|
| 6 |
+
|
| 7 |
+
| 層 | 選定技術 | 理由 |
|
| 8 |
+
|----|---------|------|
|
| 9 |
+
| フロントエンド・UI | **Streamlit** | Python のみで完結し、既存推論コードを直接インポート可能。Flask/FastAPI で別途 API サーバーを立てる手間が不要。テキストのハイライト(Markdown/HTML 描画)やスライダーも標準サポート |
|
| 10 |
+
| バックエンド・推論 | 既存 `src/` モジュールを再利用 | - |
|
| 11 |
+
| LLM API | `gpt-4o-mini` | デモの応答速度とコストのバランスが良い |
|
| 12 |
+
| ベクトル検索 | FAISS + `text-embedding-3-large` | そのまま利用 |
|
| 13 |
+
| データストア | 事前計算済み閾値を CSV として静的保持 | デモ実行時にキャリブレーションループを回すのは非現実的なため。alpha を float 列として保持するため型変換問題が発生しない |
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## 2. UI / UX 設計
|
| 18 |
+
|
| 19 |
+
ユーザーが「Target Factuality ($1-\alpha$) を操作すると、生成された回答のどの部分が削られるか」を体感できる構成にする。
|
| 20 |
+
|
| 21 |
+
### 画面レイアウト
|
| 22 |
+
|
| 23 |
+
#### サイドバー(コントロールパネル)
|
| 24 |
+
|
| 25 |
+
| コントロール | 内容 | 備考 |
|
| 26 |
+
|------------|------|------|
|
| 27 |
+
| データセット選択 | FactScore / HotpotQA / PopQA / MedLFQA から選択 | - |
|
| 28 |
+
| モード選択 | Marginal(全体)か Conditional(グループ別)か | **MedLFQA 選択時のみ有効**。他のデータセットでは非活性(`disabled=True`)または非表示にする |
|
| 29 |
+
| グループ選択 | healthsearch_qa / medication_qa / kqa_silver_wogold / kqa_golden / live_qa から選択 | **MedLFQA + Conditional モード選択時のみ有効**。それ以外は非表示 |
|
| 30 |
+
| 質問の入力 | サンプルクエリのプルダウン選択のみ | FAISS インデックスは各クエリの provenance に対応する文書のみで構成されているため、任意のフリーテキスト入力は検索精度・スコアリング精度が保証されない。デモではサンプルクエリに限定する |
|
| 31 |
+
| 推論実行ボタン | `st.button("推論実行")` | クリックで重い処理を起動 |
|
| 32 |
+
| Factuality スライダー | Target Factuality ($1-\alpha$) を 60% 〜 95% の範囲で調整 | 推論結果がキャッシュされている場合のみ有効 |
|
| 33 |
+
|
| 34 |
+
#### メインパネル(結果表示)
|
| 35 |
+
|
| 36 |
+
**クエリ**
|
| 37 |
+
- 選択されたサンプルクエリの全文を表示
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
**元の RAG 回答 | フィルタ後の回答**(2カラム)
|
| 42 |
+
|
| 43 |
+
左カラム: 元の RAG 回答
|
| 44 |
+
- `OpenAIRAGAgent` が生成した回答をそのまま表示
|
| 45 |
+
|
| 46 |
+
右カラム: フィルタ後の回答
|
| 47 |
+
- 「回答を生成」ボタン(`st.button`)を押すと LLM が回答を再生成(`reintegrate_subclaims()` 呼び出し)
|
| 48 |
+
- 生成された回答をボタンの上に表示
|
| 49 |
+
- スライダー($\alpha$)を変更した後に古い結果が残っている場合は警告を表示
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
**Conformal-RAG による回答(インタラクティブ表示)**
|
| 54 |
+
- スライダーの値に対応する閾値($\hat{q}$)を事前計算済み CSV から取得し、キャッシュ済みスコアと比較して即時再描画
|
| 55 |
+
- スコア($R(c)$)が閾値($\hat{q}$)以上のサブクレーム → 装飾なし(そのまま表示)
|
| 56 |
+
- スコアが閾値未満のサブクレーム → **取り消し線・赤字**で表示
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
**検索されたナレッジ(Retrieved Documents)**
|
| 61 |
+
- `st.expander` でアコーディオン表示(デフォルト折りたたみ)
|
| 62 |
+
- FAISS 検索結果の `page_content` を抽出(`_extract_page_content()` でソースプレフィックスを除去)
|
| 63 |
+
|
| 64 |
+
**分析メトリクス**
|
| 65 |
+
- 算出された閾値($\hat{q}$)
|
| 66 |
+
- 元のサブクレーム数
|
| 67 |
+
- Keep 件数
|
| 68 |
+
- 削除率(Removal Rate)
|
| 69 |
+
- 推論にかかった時間(ライブ推論時のみ表示。サンプルクエリ使用時は非表示)
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## 3. バックエンド設計変更(デモ向け改修)
|
| 74 |
+
|
| 75 |
+
現在のリポジトリは「バッチ評価用(論文の実験用)」設計のため、インタラクティブなデモ向けに以下のリファクタリングが必要。
|
| 76 |
+
|
| 77 |
+
> **方針(CLAUDE.md より)**: デモ特有の UI ロジックやエンドポイントは `demo/` ディレクトリ内に隔離し、`src/` のコアロジックと密結合させない。`main.py` を使ったバッチ処理が壊れないよう後方互換性を維持する。
|
| 78 |
+
|
| 79 |
+
### ディレクトリ構成
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
demo/
|
| 83 |
+
├── app.py # Streamlit エントリーポイント
|
| 84 |
+
├── inference_api.py # 単一クエリ推論ラッパー(src/ を呼び出す)
|
| 85 |
+
├── precompute.py # キャリブレーション閾値・サンプルクエリの事前計算(オフライン実行用)
|
| 86 |
+
├── constants.py # ALPHA_LEVELS など demo/ 内の共有定数
|
| 87 |
+
└── data/
|
| 88 |
+
├── thresholds.csv # 事前計算済みキャリブレーション閾値マスタ
|
| 89 |
+
└── samples.json # 事前計算済みサンプルクエリ推論結果
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
`src/` のコアロジックは変更せず、`demo/` 側から `import` して利用する。
|
| 93 |
+
|
| 94 |
+
### A. 事前計算(オフライン)フェーズ — `precompute.py`
|
| 95 |
+
|
| 96 |
+
`precompute.py` は以下 **2つの責務** を持つ。どちらもオフラインで実行し、デモ起動時には読み込むだけにする。
|
| 97 |
+
|
| 98 |
+
#### A-1. キャリブレーション閾値の計算
|
| 99 |
+
|
| 100 |
+
各データセット・スコアリング手法・$\alpha$ 値・グループごとの閾値($\hat{q}$)を計算して静的ファイルに保存する。
|
| 101 |
+
|
| 102 |
+
> **データ分割(ホールドアウト)**: `main.py` の出力データはキャリブレーション用とサンプル用に分割して使用する(`precompute.py` の `HOLDOUT_PER_GROUP` 定数で制御)。Conformal Prediction の理論的 coverage 保証はキャリブレーションに使っていないデータに対してのみ成立するため、同一データを両方に使ってはならない。`compute_thresholds()` はキャリブレーション側のみ、`compute_samples()` はホールドアウト側のみを参照する。
|
| 103 |
+
|
| 104 |
+
- 対象データセット: `fact_score`(FactScore)、`hotpot_qa`(HotpotQA)、`pop_qa`(PopQA)、`medlf_qa`(MedLFQA)
|
| 105 |
+
- MedLFQA の Conditional モードのグループ: **HealthsearchQA / MedicationQA / K-QA_Silver / K-QA_Golden / LiveQA**(全グループを網羅すること)
|
| 106 |
+
- スコアリング手法: デフォルト `relevance`。手法名は定数として管理し、変更可能にする
|
| 107 |
+
- 保存フォーマットは **CSV(縦持ち)**、出力先: `demo/data/thresholds.csv`
|
| 108 |
+
- スキーマ:
|
| 109 |
+
|
| 110 |
+
| 列名 | 型 | 説明 | 取りうる値 |
|
| 111 |
+
|------|----|------|-----------|
|
| 112 |
+
| `dataset` | `str` | データセット識別子 | `"fact_score"`, `"hotpot_qa"`, `"pop_qa"`, `"medlf_qa"` |
|
| 113 |
+
| `mode` | `str` | キャリブレーションモード | `"marginal"`, `"conditional"` |
|
| 114 |
+
| `scoring_method` | `str` | スコアリング手法名 | `"relevance"`, `"cosine_similarity"`, `"min_log_prob"` など。手法を追加する際は `precompute.py` を再実行して行を追加するだけでよい(スキーマ変更不要) |
|
| 115 |
+
| `group` | `str` | グループ名。Marginal モードまたはグループなしデータセットは `"default"` | `"HealthsearchQA"`, `"MedicationQA"`, `"K-QA_Silver"`, `"K-QA_Golden"`, `"LiveQA"`, `"default"` |
|
| 116 |
+
| `alpha` | `float` | 誤り率($\alpha = 1 - \text{target factuality}$) | `ALPHA_LEVELS`(後述)と完全一致させること |
|
| 117 |
+
| `q_hat` | `float` | 共形予測の閾値 | — |
|
| 118 |
+
|
| 119 |
+
- `alpha` は浮動小数点演算の誤差(例: `1 - 0.9 = 0.09999...`)を避けるため、CSV 書き込み時・lookup 時の両方で `round(alpha, 2)` を適用すること
|
| 120 |
+
- lookup 例: `df.query("dataset==@d and mode==@m and scoring_method==@s and group==@g and alpha==@a")["q_hat"].iloc[0]`
|
| 121 |
+
|
| 122 |
+
#### A-2. サンプルクエリの推論結果の事前計算
|
| 123 |
+
|
| 124 |
+
UI のプルダウンで選択できるサンプルクエリについて、`inference_api.process_query()` を実行済みの `SubclaimResult` を静的ファイルとして保存しておく。これによりプルダウン選択時に API 呼び出しなしで即座に表示できる。
|
| 125 |
+
|
| 126 |
+
- サンプルクエリ結果の保存ファイルには**メタデータ(生成日時・使用モデル名・スコアリング手法)を付与**し、モデルやプロンプト変更後に再生成が必要かどうか判断できるようにする
|
| 127 |
+
- MedLFQA のサンプルクエリは全グループ(HealthsearchQA / MedicationQA / K-QA_Silver / K-QA_Golden / LiveQA)から少なくとも 1 件ずつ含めること
|
| 128 |
+
|
| 129 |
+
保存フォーマット(`demo/data/samples.json`):
|
| 130 |
+
|
| 131 |
+
```json
|
| 132 |
+
{
|
| 133 |
+
"metadata": {
|
| 134 |
+
"generated_at": "2026-03-21T12:00:00Z",
|
| 135 |
+
"model": "gpt-4o-mini",
|
| 136 |
+
"scoring_method": "relevance"
|
| 137 |
+
},
|
| 138 |
+
"samples": [
|
| 139 |
+
{
|
| 140 |
+
"query": "What is the first-line treatment for hypertension?",
|
| 141 |
+
"dataset": "medlf_qa",
|
| 142 |
+
"mode": "conditional",
|
| 143 |
+
"group": "MedicationQA",
|
| 144 |
+
"scoring_method": "relevance",
|
| 145 |
+
"rag_answer": "...",
|
| 146 |
+
"retrieved_docs": ["...", "..."],
|
| 147 |
+
"subclaims": [
|
| 148 |
+
{"id": 0, "text": "...", "score": 0.85},
|
| 149 |
+
{"id": 1, "text": "...", "score": 0.42}
|
| 150 |
+
]
|
| 151 |
+
}
|
| 152 |
+
]
|
| 153 |
+
}
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
- `samples` 配列の各要素は `SubclaimResult` と同一スキーマ
|
| 157 |
+
- プルダウン表示ラベルは `query` フィールドから生成する
|
| 158 |
+
|
| 159 |
+
### B. 状態管理(`st.session_state`)
|
| 160 |
+
|
| 161 |
+
Streamlit はウィジェット操作のたびにスクリプト全体を再実行するため、重い処理の結果を `st.session_state` に保持して再計算を回避する。
|
| 162 |
+
|
| 163 |
+
#### B-1. セッションキー一覧
|
| 164 |
+
|
| 165 |
+
| キー | 型 | 内容 |
|
| 166 |
+
|------|----|------|
|
| 167 |
+
| `query_processed` | `bool` | クエリが処理済みかどうか |
|
| 168 |
+
| `result` | `SubclaimResult \| None` | `process_query()` の戻り値(サブクレーム・スコア・検索結果・RAG 回答をまとめて保持) |
|
| 169 |
+
| `elapsed_sec` | `float \| None` | ライブ推論の所要時間。サンプルクエリ使用時は `None`(推論時間メトリクスの表示制御に使用) |
|
| 170 |
+
| `prev_params` | `tuple \| None` | 直前の `(dataset, mode, group)` の組。変更検出してセッションをリセットするために使用 |
|
| 171 |
+
| `integrated_answer` | `str \| None` | `reintegrate_subclaims()` が生成した再統合済み回答。パラメータ変更や新規推論時に `None` にリセット |
|
| 172 |
+
| `integrated_alpha` | `float \| None` | `integrated_answer` 生成時の $\alpha$ 値。現在のスライダー値と異なれば警告を表示するために保持 |
|
| 173 |
+
|
| 174 |
+
#### B-2. データ型定義
|
| 175 |
+
|
| 176 |
+
`demo/inference_api.py` で定義する型。`app.py` との境界を明確にするため `TypedDict` で宣言する。
|
| 177 |
+
|
| 178 |
+
```python
|
| 179 |
+
from typing import TypedDict
|
| 180 |
+
from demo.constants import DEFAULT_SCORING_METHOD # constants.py を単一の定義元とする
|
| 181 |
+
|
| 182 |
+
class Subclaim(TypedDict):
|
| 183 |
+
id: int
|
| 184 |
+
text: str # サブクレームのテキスト
|
| 185 |
+
score: float # 計算されたスコア値(手法名は SubclaimResult.scoring_method を参照)
|
| 186 |
+
|
| 187 |
+
class SubclaimResult(TypedDict):
|
| 188 |
+
query: str
|
| 189 |
+
dataset: str # 例: "medlf_qa"
|
| 190 |
+
mode: str # 例: "marginal" / "conditional"。thresholds.csv の mode 列に対応
|
| 191 |
+
group: str # 例: "MedicationQA"。Marginal モードまたは非グループデータセットは "default"
|
| 192 |
+
scoring_method: str # 例: "relevance"。apply_threshold() での閾値ルックアップキーに使用
|
| 193 |
+
rag_answer: str
|
| 194 |
+
retrieved_docs: list[str]
|
| 195 |
+
subclaims: list[Subclaim]
|
| 196 |
+
|
| 197 |
+
class FilteredResult(TypedDict):
|
| 198 |
+
subclaims: list[Subclaim]
|
| 199 |
+
q_hat: float
|
| 200 |
+
keep_count: int
|
| 201 |
+
remove_count: int
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
- `Subclaim.score` は手法を問わず汎用名とし、何のスコアかは `SubclaimResult.scoring_method` で示す
|
| 205 |
+
- `SubclaimResult.group` は CSV の `group` 列と対応。Marginal モードや非グループデータセットでは `"default"` を使い、`None` は使わない
|
| 206 |
+
- `status`(Keep/Remove)は UI 側でスコアと `q_hat` を比較して動的に判定するため、バックエンドからは返さない
|
| 207 |
+
|
| 208 |
+
#### 共有定数 — `demo/constants.py`
|
| 209 |
+
|
| 210 |
+
`precompute.py`(CSV 生成)と `app.py`(スライダー)の両方が参照する定数をここに一元管理する。スライダーの選択値と CSV の `alpha` 列が一致しないと lookup が失敗するため、両者は必ずこのファイルを `import` して使う。
|
| 211 |
+
|
| 212 |
+
```python
|
| 213 |
+
# demo/constants.py
|
| 214 |
+
|
| 215 |
+
# スコアリング手法のデフォルト値。変更時はここのみ修正し、precompute.py を再実行する
|
| 216 |
+
DEFAULT_SCORING_METHOD = "relevance"
|
| 217 |
+
|
| 218 |
+
# クリーンな手法名 → main.py の出力データ内のキー名 のマッピング
|
| 219 |
+
# "relavance" は元コードベースのタイポ。データファイルとの互換性のため保持
|
| 220 |
+
SCORING_METHOD_TO_DATA_KEY: dict[str, str] = {
|
| 221 |
+
"relevance": "relavance",
|
| 222 |
+
"cosine_similarity": "query_claim_cosine_similarity",
|
| 223 |
+
"min_log_prob": "min_log_prob",
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
# Target Factuality スライダーの刻み(alpha = 1 - target_factuality)
|
| 227 |
+
# precompute.py はこのリストの alpha 値のみ計算する
|
| 228 |
+
# app.py の st.select_slider はこのリストをそのまま使う
|
| 229 |
+
ALPHA_LEVELS: list[float] = [0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40]
|
| 230 |
+
# → Target Factuality: 95%, 90%, 85%, 80%, 75%, 70%, 65%, 60%
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
---
|
| 234 |
+
|
| 235 |
+
### C. 推論パイプライン — 重い処理と軽い処理の分離
|
| 236 |
+
|
| 237 |
+
**重い処理① — 「推論実行」ボタンのクリック時のみ起動**
|
| 238 |
+
|
| 239 |
+
```
|
| 240 |
+
1. クエリのベクトル化 & FAISS 検索 ← src/common/faiss_manager.py
|
| 241 |
+
2. RAG による初期回答(ŷ)の生成 ← src/common/llm/openai_rag_agent.py
|
| 242 |
+
3. OpenAI API によるサブクレームへの分解 ← src/common/llm/openai_atomicfact_generator.py
|
| 243 |
+
4. サブクレームごとのスコア計算 ← src/subclaim_processor/scorer/
|
| 244 |
+
5. 結果を st.session_state に格納 → st.rerun() でスライダーを即時有効化
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
- `st.spinner` でユーザーに待機(30〜60 秒)を明示する
|
| 248 |
+
- アノテーション(S/I/U/N ラベル付け)はデモの表示に不要なため**スキップ**する
|
| 249 |
+
- frequency スコア(`src/subclaim_processor/scorer/subclaim_scorer.py` の `frequency_score()`)は1サブクレームあたり LLM を 5 回呼び出すため、**デモでは除外する**。表示するスコアは relevance・cosine similarity・min_log_prob に限定する
|
| 250 |
+
|
| 251 |
+
**重い処理② — 「回答を生成」ボタンのクリック時のみ起動**
|
| 252 |
+
|
| 253 |
+
```
|
| 254 |
+
1. 閾値未満のサブクレームを removed_subclaims として収集
|
| 255 |
+
2. reintegrate_subclaims(rag_answer, removed_subclaims) を呼び出し
|
| 256 |
+
→ OpenAILLMAgent 経由で元の回答から除去対象サブク��ームを削除した自然文を生成
|
| 257 |
+
3. 結果を st.session_state.integrated_answer に格納 → st.rerun()
|
| 258 |
+
```
|
| 259 |
+
|
| 260 |
+
**軽い処理(リアルタイム実行)— スライダー操作のたびに実行**
|
| 261 |
+
|
| 262 |
+
```
|
| 263 |
+
1. スライダーの alpha 値に対応する q̂ を事前計算済み CSV から取得
|
| 264 |
+
2. st.session_state.result のスコアと q̂ を比較
|
| 265 |
+
3. Keep/Remove の HTML タグを生成して st.markdown で再描画
|
| 266 |
+
<span style="padding:2px 5px;">サブクレーム</span> ← Keep(装飾なし)
|
| 267 |
+
<span style="text-decoration:line-through; color:#dc3545;">サブクレーム</span> ← Remove
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
API 呼び出しは一切行わない。
|
| 271 |
+
|
| 272 |
+
### D. リソースキャッシュ(`@st.cache_resource`)
|
| 273 |
+
|
| 274 |
+
FAISS インデックスはデータセットごとに数百 MB あるため、データセット切り替えのたびにロードし直すと数十秒かかる。`@st.cache_resource` でデータセット名をキーにキャッシュし、初回のみロードする。
|
| 275 |
+
|
| 276 |
+
```python
|
| 277 |
+
@st.cache_resource
|
| 278 |
+
def get_faiss_manager(dataset: str) -> FAISSIndexManager:
|
| 279 |
+
return build_faiss_manager(dataset)
|
| 280 |
+
|
| 281 |
+
@st.cache_resource
|
| 282 |
+
def get_scorer(dataset: str) -> SubclaimScorer:
|
| 283 |
+
return build_scorer(dataset)
|
| 284 |
+
```
|
| 285 |
+
|
| 286 |
+
`thresholds.csv` と `samples.json` の静的ファイルは `@st.cache_data` でキャッシュする(`@st.cache_resource` はオブジェクト向け。DataFrameや辞書は `@st.cache_data`)。
|
| 287 |
+
|
| 288 |
+
> **実装上の注意(HF Spaces 対応)**: `FAISSIndexManager` / `SubclaimScorer` は
|
| 289 |
+
> `langchain_text_splitters` を経由して `torch` / `transformers` 等の重い依存を引き込む。
|
| 290 |
+
> HF Spaces など FAISS 不要な環境での起動コストを避けるため、これらのクラスは
|
| 291 |
+
> `inference_api.py` のモジュールレベルではなく `build_faiss_manager()` / `build_scorer()`
|
| 292 |
+
> 関数内で遅延 import している。型ヒントは `TYPE_CHECKING` ガードで維持。
|
| 293 |
+
|
| 294 |
+
---
|
| 295 |
+
|
| 296 |
+
## 4. 開発ステップ
|
| 297 |
+
|
| 298 |
+
### ① `demo/constants.py`(共有定数)
|
| 299 |
+
|
| 300 |
+
他モジュールが依存するため最初に作成する。`ALPHA_LEVELS` と `DEFAULT_SCORING_METHOD` のみ定義。
|
| 301 |
+
|
| 302 |
+
### ② `demo/inference_api.py`(推論ラッパー)
|
| 303 |
+
|
| 304 |
+
`precompute.py` がこれを呼び出すため、先に実装する。
|
| 305 |
+
|
| 306 |
+
- `src/` モジュールを `import` するのみで、コアロジックは変更しない
|
| 307 |
+
- インターフェース:
|
| 308 |
+
- `process_query(query: str, dataset: str, mode: str, group: str) -> SubclaimResult`(重い処理①)
|
| 309 |
+
- `reintegrate_subclaims(rag_answer: str, removed_subclaims: list[str], model: str) -> str`(重い処理②)
|
| 310 |
+
- `apply_threshold(result: SubclaimResult, q_hat: float) -> FilteredResult`(軽い処理)
|
| 311 |
+
- 使用するスコアリング手法は `DEFAULT_SCORING_METHOD` 定数で管理し、ハードコーディングしない
|
| 312 |
+
|
| 313 |
+
### ③ `demo/precompute.py` + データ生成(オフライン)
|
| 314 |
+
|
| 315 |
+
`inference_api` が完成してから実行する。
|
| 316 |
+
|
| 317 |
+
- 既存の `main.py` を実行し、デモ用データセットの FAISS インデックスを作成(未作成の場合)
|
| 318 |
+
- `demo/precompute.py` を実装・実行し、以下を生成:
|
| 319 |
+
1. $\alpha$ ごとのキャリブレーション閾値マスタ(CSV: `demo/data/thresholds.csv`)
|
| 320 |
+
2. サンプルクエリ分の `SubclaimResult` キャッシュ(`demo/data/samples.json`、生成日時・モデル名のメタデータ付き)
|
| 321 |
+
|
| 322 |
+
### ④ `demo/app.py`(Streamlit UI)
|
| 323 |
+
|
| 324 |
+
上記 3 つが揃ってから実装する。
|
| 325 |
+
|
| 326 |
+
- ボタン・スライダー・`st.session_state` を連動させる
|
| 327 |
+
- MedLFQA 以外のデータセット選択時は Conditional モードおよびグループ選択を非活性にする
|
docs/context/04_environment_setup.md
ADDED
|
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 環境構築・実行手順
|
| 2 |
+
|
| 3 |
+
実行環境は **Singularity コンテナ経由** と **ホスト直接実行** の2通りをサポートする。
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 1. 前提条件
|
| 8 |
+
|
| 9 |
+
| ツール | バージョン要件 | 用途 |
|
| 10 |
+
|--------|-------------|------|
|
| 11 |
+
| Git | 任意 | リポジトリのクローン |
|
| 12 |
+
| Singularity (SingularityCE) | 3.x 以上 | コンテナ実行・SIF ビルド(コンテナ経由の場合) |
|
| 13 |
+
| `fakeroot` | — | SIF ビルド時に必要(`singularity build --fakeroot`) |
|
| 14 |
+
| Python 3.11 + `requirements-dev.txt` | — | ホスト直接実行の場合のみ |
|
| 15 |
+
|
| 16 |
+
> **推奨**: HPC 環境など Python 環境の管理が難しい場合は Singularity コンテナ経由を使う。ローカル開発では直接実行でも動作する。
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## 2. ホスト側ディレクトリ構成
|
| 21 |
+
|
| 22 |
+
ホストの任意のディレクトリ(`DATA_ROOT` の親)に以下の構造を作成する。
|
| 23 |
+
重いファイル(FAISS インデックス、生データ、HuggingFace キャッシュ)はすべてここに集約し、Singularity の `--bind` でコンテナ内にマウントする。
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
/mnt/data/<username>/
|
| 27 |
+
├── sif/
|
| 28 |
+
│ └── response_quality.sif ← SIF_DIR
|
| 29 |
+
├── hf_cache/ ← HF_HOME
|
| 30 |
+
└── ResponseQualityAssessment/ ← DATA_ROOT
|
| 31 |
+
├── data/
|
| 32 |
+
│ ├── raw/
|
| 33 |
+
│ │ └── WikiDB/ ← Wikipedia SQLite DB
|
| 34 |
+
│ ├── processed/
|
| 35 |
+
│ ├── out/ ← main.py 出力(サブクレーム JSON)
|
| 36 |
+
│ └── result/
|
| 37 |
+
└── index_store/
|
| 38 |
+
├── FactScore/ ← FAISS インデックス
|
| 39 |
+
├── HotpotQA/
|
| 40 |
+
├── PopQA/
|
| 41 |
+
└── MedLFQA/
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
> `demo/data/`(thresholds.csv, samples.json)はリポジトリ内に含まれるため `DATA_ROOT` 側への配置・バインドは不要。
|
| 45 |
+
|
| 46 |
+
---
|
| 47 |
+
|
| 48 |
+
## 3. `.env` の設定
|
| 49 |
+
|
| 50 |
+
`.env.example` をコピーして `.env` を作成し、環境に合わせて値を設定する。
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
cp .env.example .env
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
`.env` の内容(`.gitignore` 対象):
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# Machine-specific absolute paths
|
| 60 |
+
DATA_ROOT=/mnt/data/<username>/ResponseQualityAssessment
|
| 61 |
+
SIF_DIR=/mnt/data/<username>/sif
|
| 62 |
+
HF_HOME=/mnt/data/<username>/hf_cache
|
| 63 |
+
|
| 64 |
+
# API keys
|
| 65 |
+
OPENAI_API_KEY=sk-...
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
`conf/path_config.yaml` と `conf/dataset_config.yaml` の各パスは `${DATA_ROOT}` を起点として定義されており、実行時にこの環境変数を展開して解釈する。
|
| 69 |
+
|
| 70 |
+
> - **ホスト直接実行時**: `.env` の `DATA_ROOT`(例: `/mnt/data/<username>/ResponseQualityAssessment`)をそのまま使用する。
|
| 71 |
+
> - **コンテナ経由実行時**: `run_*.sh` スクリプトがホストの `data/` と `index_store/` を `REPO_ROOT` 以下にバインドマウントするため、コンテナ内での `DATA_ROOT` は `REPO_ROOT` に上書きされる。スクリプト内の `--env DATA_ROOT="${REPO_ROOT}"` がこれを担っている。
|
| 72 |
+
|
| 73 |
+
---
|
| 74 |
+
|
| 75 |
+
## 4. Singularity コンテナのビルド(初回のみ)
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
bash scripts/build_sif.sh
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
- `singularity/response_quality.def`(`docker://python:3.11-slim` ベース)から SIF をビルドする
|
| 82 |
+
- 出力先: `${SIF_DIR}/response_quality.sif`
|
| 83 |
+
- ビルドには数分かかる。`fakeroot` が必要
|
| 84 |
+
- `torch` はコンテナ内では CPU 専用 wheel(`--index-url https://download.pytorch.org/whl/cpu`)でインストールされる。ローカル開発で `requirements-dev.txt` から直接インストールする場合は同オプションを手動で指定すること
|
| 85 |
+
|
| 86 |
+
**requirements ファイルの構成:**
|
| 87 |
+
|
| 88 |
+
| ファイル | 用途 |
|
| 89 |
+
|----------|------|
|
| 90 |
+
| `requirements.txt` | HF Spaces 向け最小セット(`torch` 等の重い依存を除く)。Spaces はこのファイルを自動で読む |
|
| 91 |
+
| `requirements-dev.txt` | ローカル・Singularity 向け全依存(`-r requirements.txt` で共通部分を継承) |
|
| 92 |
+
|
| 93 |
+
ホスト直接実行時のインストール:
|
| 94 |
+
```bash
|
| 95 |
+
pip install -r requirements-dev.txt
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
---
|
| 99 |
+
|
| 100 |
+
## 5. データ準備(初回のみ)
|
| 101 |
+
|
| 102 |
+
| 方法 | 用途 | API 消費 |
|
| 103 |
+
|--------|------|---------|
|
| 104 |
+
| **A. 著者提供データを使う** | 生データ・FAISS インデックスを再利用して API 呼び出しを削減 | Embeddings 不要(FAISS スキップ)、Chat は必要 |
|
| 105 |
+
| **B. 自前でフルパイプライン実行** | 新モデルや新データセットで結果を再現 | あり(Chat + Embeddings) |
|
| 106 |
+
|
| 107 |
+
---
|
| 108 |
+
|
| 109 |
+
### A. 著者提供データを使う(FAISS 構築スキップ)
|
| 110 |
+
|
| 111 |
+
著者提供の Google Drive フォルダには以下が含まれる(**注**: 事前計算済みサブクレーム JSON は含まれていない):
|
| 112 |
+
|
| 113 |
+
- `data.zip` — 生クエリデータ(`HotpotQA/raw_hotpot_qa.json`、`PopQA/raw_pop_qa.json`、`FactScore/factscore_names.txt`、`MedLFQA/*.jsonl`)
|
| 114 |
+
- `index_store.zip` — 事前構築済み FAISS インデックス(各データセット × 各 query_size 分の `.faiss` と `indice2fm_*.json`)
|
| 115 |
+
|
| 116 |
+
FAISS インデックスが揃っていれば Embeddings API(インデックス構築)はスキップされる。ただし、RAG 回答生成・サブクレーム抽出・スコアリング・アノテーションには Chat API が必要。
|
| 117 |
+
|
| 118 |
+
**1. Wikipedia DB をダウンロードして配置する**
|
| 119 |
+
|
| 120 |
+
全データセット共通で必要(QueryProcessor が常に参照する)。
|
| 121 |
+
|
| 122 |
+
> [enwiki-20230401.db(Google Drive)](https://drive.google.com/file/d/1mekls6OGOKLmt7gYtHs0WGf5oTamTNat/view?usp=drive_link)
|
| 123 |
+
|
| 124 |
+
```
|
| 125 |
+
DATA_ROOT/data/raw/WikiDB/enwiki-20230401.db
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
**2. `medlf_qa` を使う場合は生データを手動取得する**
|
| 129 |
+
|
| 130 |
+
`fact_score` / `hotpot_qa` / `pop_qa` は DataLoader が HuggingFace から自動取得するため不要。
|
| 131 |
+
|
| 132 |
+
```bash
|
| 133 |
+
source .env
|
| 134 |
+
# 以下どちらか一方
|
| 135 |
+
git clone https://github.com/dmis-lab/OLAPH.git /tmp/olaph
|
| 136 |
+
cp -r /tmp/olaph/MedLFQA ${DATA_ROOT}/data/raw/MedLFQA
|
| 137 |
+
# または
|
| 138 |
+
git clone https://github.com/jjcherian/conformal-safety.git /tmp/cs
|
| 139 |
+
cp -r /tmp/cs/data/MedLFQAv2 ${DATA_ROOT}/data/raw/MedLFQA
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
**3. 著者提供データをダウンロードして展開する**
|
| 143 |
+
|
| 144 |
+
> https://drive.google.com/drive/folders/1aLbHxS6V1ipMH8FpVCxKmr8oMYfqmRgb
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
source .env
|
| 148 |
+
# gdown でフォルダをダウンロード("conformal prediction with RAG/" という名前で保存される)
|
| 149 |
+
gdown --folder https://drive.google.com/drive/folders/1aLbHxS6V1ipMH8FpVCxKmr8oMYfqmRgb \
|
| 150 |
+
-O "${DATA_ROOT}/data/out/"
|
| 151 |
+
|
| 152 |
+
# 生データを data/raw/ に展開
|
| 153 |
+
unzip "${DATA_ROOT}/data/out/conformal prediction with RAG/data.zip" \
|
| 154 |
+
-d "${DATA_ROOT}/data/raw/"
|
| 155 |
+
|
| 156 |
+
# FAISS インデックスを DATA_ROOT 直下に展開(index_store/ が作られる)
|
| 157 |
+
unzip "${DATA_ROOT}/data/out/conformal prediction with RAG/index_store.zip" \
|
| 158 |
+
-d "${DATA_ROOT}/"
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
展開後の構造:
|
| 162 |
+
|
| 163 |
+
```
|
| 164 |
+
DATA_ROOT/
|
| 165 |
+
├── data/
|
| 166 |
+
│ └── raw/
|
| 167 |
+
│ ├── HotpotQA/raw_hotpot_qa.json
|
| 168 |
+
│ ├── PopQA/raw_pop_qa.json
|
| 169 |
+
│ ├── FactScore/factscore_names.txt
|
| 170 |
+
│ └── MedLFQA/*.jsonl
|
| 171 |
+
└── index_store/
|
| 172 |
+
├── FactScore/index_500.faiss, indice2fm_500.json, ...
|
| 173 |
+
├── HotpotQA/index_500.faiss, indice2fm_500.json
|
| 174 |
+
└── MedLFQA/index_10.faiss, index_500.faiss, indice2fm_*.json, ...
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
**4. 実行する**
|
| 178 |
+
|
| 179 |
+
FAISS インデックスが存在する場合、インデックス構築(Embeddings API)はスキップされる。RAG 推論・スコアリング・アノテーションには Chat API が必要。
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
# コンテナ経由
|
| 183 |
+
bash scripts/run_main.sh --dataset medlf_qa
|
| 184 |
+
|
| 185 |
+
# ホスト直接実行
|
| 186 |
+
source .env && python main.py --dataset medlf_qa
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
---
|
| 190 |
+
|
| 191 |
+
### B. 自前でフルパイプライン実行(API あり)
|
| 192 |
+
|
| 193 |
+
**1. Wikipedia DB をダウンロードして配置する**
|
| 194 |
+
|
| 195 |
+
上記 A の手順 1 と同じ。
|
| 196 |
+
|
| 197 |
+
**2. `medlf_qa` を使う場合は生データを手動取得する**
|
| 198 |
+
|
| 199 |
+
上記 A の手順 2 と同じ。`fact_score` / `hotpot_qa` / `pop_qa` は不要。
|
| 200 |
+
|
| 201 |
+
**3. 実行する**
|
| 202 |
+
|
| 203 |
+
データセットごとに1つずつ実行する(並列不可)。
|
| 204 |
+
|
| 205 |
+
```bash
|
| 206 |
+
# コンテナ経由
|
| 207 |
+
bash scripts/run_main.sh --dataset medlf_qa
|
| 208 |
+
# 利用可能: fact_score, hotpot_qa, pop_qa, medlf_qa
|
| 209 |
+
|
| 210 |
+
# ホスト直接実行
|
| 211 |
+
source .env && python main.py --dataset medlf_qa
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
---
|
| 215 |
+
|
| 216 |
+
実行後に生成されるファイル(`DATA_ROOT` 以下):
|
| 217 |
+
|
| 218 |
+
| パス | 内容 |
|
| 219 |
+
|------|------|
|
| 220 |
+
| `data/out/{DatasetName}/` | サブクレームスコア付き JSON |
|
| 221 |
+
| `index_store/{DatasetName}/` | FAISS インデックス (`index_*.faiss`, `indice2fm_*.json`) |
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
### CLI オプション一覧
|
| 226 |
+
|
| 227 |
+
| オプション | デフォルト | 説明 |
|
| 228 |
+
|---|---|---|
|
| 229 |
+
| `--dataset` | config 参照 | データセット名(`fact_score` / `hotpot_qa` / `pop_qa` / `medlf_qa`) |
|
| 230 |
+
| `--query_size` | 500 | 処理するクエリ件数 |
|
| 231 |
+
| `--run_id` | タイムスタンプ | ログ・結果ディレクトリの識別子 |
|
| 232 |
+
| `--lite` | false | `frequency`・`doc_claim_cosine_similarity` スコアと conformal グラフ生成をスキップ |
|
| 233 |
+
| `--log_level` | `INFO` | ログレベル(`DEBUG` / `INFO` / `WARNING` / `ERROR`) |
|
| 234 |
+
|
| 235 |
+
### API コールのログ確認
|
| 236 |
+
|
| 237 |
+
`--log_level DEBUG` を付けると、OpenAI API の個別コール(purpose・model・トークン数)が `[API:Chat]` / `[API:Embeddings]` プレフィックスで記録される。
|
| 238 |
+
|
| 239 |
+
```bash
|
| 240 |
+
# API ログを有効にして実行(コンテナ経由)
|
| 241 |
+
bash scripts/run_main.sh --dataset medlf_qa --query_size 10 --lite --log_level DEBUG
|
| 242 |
+
|
| 243 |
+
# ログから API コールのみ抽出
|
| 244 |
+
grep "\[API:" logs/<run_id>/run_<run_id>.log
|
| 245 |
+
|
| 246 |
+
# purpose 別の集計
|
| 247 |
+
grep "\[API:" logs/<run_id>/run_<run_id>.log \
|
| 248 |
+
| sed 's/.*\[API:\([^]]*\)\] purpose=\([^ ]*\).*/\1 \2/' \
|
| 249 |
+
| sort | uniq -c | sort -rn
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
出力例:
|
| 253 |
+
```
|
| 254 |
+
103 Embeddings cosine_similarity(claim)
|
| 255 |
+
103 Embeddings cosine_similarity(query)
|
| 256 |
+
103 Embeddings relevance_score
|
| 257 |
+
103 Chat claim_verification
|
| 258 |
+
10 Chat subclaim_extraction
|
| 259 |
+
10 Chat rag_response
|
| 260 |
+
```
|
| 261 |
+
|
| 262 |
+
---
|
| 263 |
+
|
| 264 |
+
## 6. 事前計算(データ更新・モデル変更時)
|
| 265 |
+
|
| 266 |
+
`demo/data/thresholds.csv` と `demo/data/samples.json` を(再)生成する。
|
| 267 |
+
モデル変更やデータセット追加後に実行する。
|
| 268 |
+
|
| 269 |
+
**コンテナ経由:**
|
| 270 |
+
```bash
|
| 271 |
+
bash scripts/run_precompute.sh
|
| 272 |
+
```
|
| 273 |
+
|
| 274 |
+
**ホスト直接実行:**
|
| 275 |
+
```bash
|
| 276 |
+
source .env
|
| 277 |
+
python -m demo.precompute
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
出力はリポジトリ内の `demo/data/` に書き込まれる。
|
| 281 |
+
生成後は `git add demo/data/ && git commit` でコミットする。
|
| 282 |
+
|
| 283 |
+
---
|
| 284 |
+
|
| 285 |
+
## 7. デモの起動
|
| 286 |
+
|
| 287 |
+
**コンテナ経由:**
|
| 288 |
+
```bash
|
| 289 |
+
bash scripts/run_demo.sh # デフォルト: ポート 8502
|
| 290 |
+
bash scripts/run_demo.sh --port 8503 # ポートを指定する場合
|
| 291 |
+
```
|
| 292 |
+
|
| 293 |
+
**ホスト直接実行:**
|
| 294 |
+
```bash
|
| 295 |
+
source .env
|
| 296 |
+
streamlit run demo/app.py --server.port 8502
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
- デフォルトポート: **8502**(8501 はサーバー上の別アプリが使用中のため)
|
| 300 |
+
- ポートが使用中の場合は `--port` オプションで別ポートを指定する
|
| 301 |
+
|
| 302 |
+
**Singularity バインド構成(スクリプト内部):**
|
| 303 |
+
|
| 304 |
+
```bash
|
| 305 |
+
singularity run \
|
| 306 |
+
--bind "${DATA_ROOT}/data:${REPO_ROOT}/data" \
|
| 307 |
+
--bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store" \
|
| 308 |
+
--bind "${HF_HOME}:${HF_HOME}" \
|
| 309 |
+
--env HF_HOME="${HF_HOME}" \
|
| 310 |
+
--env OPENAI_API_KEY="${OPENAI_API_KEY}" \
|
| 311 |
+
--env DATA_ROOT="${REPO_ROOT}" \
|
| 312 |
+
--env PYTHONPATH="${REPO_ROOT}" \
|
| 313 |
+
--pwd "${REPO_ROOT}" \
|
| 314 |
+
"${SIF_DIR}/response_quality.sif" \
|
| 315 |
+
streamlit run demo/app.py --server.port 8502
|
| 316 |
+
```
|
| 317 |
+
|
| 318 |
+
(`REPO_ROOT` はスクリプト自身の場所から自動解決される。コンテナ内では `DATA_ROOT` を `REPO_ROOT` に上書きすることで、バインドマウント先のパスと整合させている。)
|
| 319 |
+
|
| 320 |
+
`demo/data/` はリポジトリ内に存在するため、バインド不要。
|
| 321 |
+
|
| 322 |
+
---
|
| 323 |
+
|
| 324 |
+
## 8. トラブルシューティング
|
| 325 |
+
|
| 326 |
+
### `.env` が読み込まれない
|
| 327 |
+
|
| 328 |
+
各スクリプトはリポジトリルートの `.env` を自動的に探して読み込む(`REPO_ROOT` をスクリプトの場所から導出するため、どのディレクトリから実行しても動作する)。`.env` ファイル自体が存在するか確認する。
|
| 329 |
+
|
| 330 |
+
```bash
|
| 331 |
+
ls /path/to/ResponseQualityAssessment/.env
|
| 332 |
+
```
|
| 333 |
+
|
| 334 |
+
### `DATA_ROOT` 以下のディレクトリが存在しない
|
| 335 |
+
|
| 336 |
+
`data/out/` や `index_store/` は `main.py` 実行時に自動作成されるが、
|
| 337 |
+
`DATA_ROOT` 自体と `DATA_ROOT/data/`, `DATA_ROOT/index_store/` は事前に作成が必要。
|
| 338 |
+
|
| 339 |
+
```bash
|
| 340 |
+
source .env
|
| 341 |
+
mkdir -p ${DATA_ROOT}/data/raw/WikiDB \
|
| 342 |
+
${DATA_ROOT}/data/processed \
|
| 343 |
+
${DATA_ROOT}/data/out \
|
| 344 |
+
${DATA_ROOT}/data/result \
|
| 345 |
+
${DATA_ROOT}/index_store \
|
| 346 |
+
${HF_HOME} \
|
| 347 |
+
${SIF_DIR}
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
### FAISS インデックスが見つからない(デモ起動時にエラー)
|
| 351 |
+
|
| 352 |
+
`scripts/run_main.sh` が未実行、または対象データセットが処理されていない。
|
| 353 |
+
§5「データ準備」を実行する。
|
| 354 |
+
|
| 355 |
+
### `demo/data/thresholds.csv` が見つからない
|
| 356 |
+
|
| 357 |
+
`scripts/run_precompute.sh` が未実行。§6「事前計算」を実行する。
|
| 358 |
+
または、リポジトリの `demo/data/` にコミット済みファイルがあるか確認する(`git pull` で最新化)。
|
| 359 |
+
|
| 360 |
+
### `subclaims_schema.json` が見つからない(コンテナ経由実行時)
|
| 361 |
+
|
| 362 |
+
コンテナ実行時は `DATA_ROOT/data/` が `REPO_ROOT/data/` にバインドマウントされるため、リポジトリ内の `data/out/subclaims_schema.json` が隠れてしまう。初回実行前に手動でコピーする。
|
| 363 |
+
|
| 364 |
+
```bash
|
| 365 |
+
source .env
|
| 366 |
+
cp data/out/subclaims_schema.json ${DATA_ROOT}/data/out/subclaims_schema.json
|
| 367 |
+
```
|
| 368 |
+
|
| 369 |
+
### OpenAI API エラー
|
| 370 |
+
|
| 371 |
+
`.env` の `OPENAI_API_KEY` が正しく設定されているか確認する。
|
| 372 |
+
`#` で始まる行はコメントとして無視されるため、`OPENAI_API_KEY=sk-...` のように `#` なしで記述する。
|
| 373 |
+
推論実行ボタンを押した場合のみ API が呼ばれる。サンプルクエリはボタンを押さずにプルダウン選択するだけのため API 不要。
|
docs/context/05_medrag_corpus_integration.md
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MedRAGコーパス導入・システム拡張要件書
|
| 2 |
+
|
| 3 |
+
## 1. プロジェクト概要
|
| 4 |
+
|
| 5 |
+
### 1.1. 目的
|
| 6 |
+
|
| 7 |
+
現状のRAG回答品質評価システム(Conformal-RAG)は、MedLFQAデータセット評価において、データセット内にあらかじめ用意された参照文書(JSONL)から検索インデックスを構築している。本プロジェクトでは、この閉じた検証環境を脱却し、外部の大規模医療特化コーパス(MedRAG)を統合することで、実運用に近いノイズ環境下でのRAGファクチュアリティ(事実性)評価基盤を構築する。
|
| 8 |
+
|
| 9 |
+
### 1.2. 課題と解決方針
|
| 10 |
+
|
| 11 |
+
大規模コーパスを導入すると、検索時のディストラクター(意味的に類似するが無関係なノイズ文書)が激増し、コンフォーマル予測のスコアリングが破綻するリスクがある。これに対し、以下の3点をシステム要件として定義する。
|
| 12 |
+
|
| 13 |
+
- 医療特化の埋め込みモデルへの切り替え
|
| 14 |
+
- ハイブリッド検索・リランカーの導入
|
| 15 |
+
- チャンク戦略の最適化
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## 2. システム構成・データパイプライン要件
|
| 20 |
+
|
| 21 |
+
### 2.1. コーパスデータベースの構築
|
| 22 |
+
|
| 23 |
+
- **要件**: MedRAGが提供する生データ(PubMedアブストラクト、StatPearls等)をシステムに取り込む。
|
| 24 |
+
- **MedRAGコーパスのデータ形式**: MedRAGは以下のコーパスをHuggingFace(`MedRAG/pubmed` 等)でチャンク済みデータセットとして公開している。スキーマは全コーパス共通の4フィールド。
|
| 25 |
+
|
| 26 |
+
| コーパス | スニペット数 | 平均トークン数 | チャンク戦略 | 備考 |
|
| 27 |
+
|---------|------------|------------|------------|------|
|
| 28 |
+
| PubMed | 約2,390万件 | **296 tokens** | 論文1件=1スニペット(title + abstract) | HuggingFaceで配布済み |
|
| 29 |
+
| StatPearls | 約30万件 | 119 tokens | 段落単位 | Privacy Policyによりコンテンツ非公開。自前処理必要 |
|
| 30 |
+
| Textbooks | 約13万件 | 182 tokens | `RecursiveCharacterTextSplitter`(上限1,000文字) | HuggingFaceで配布済み |
|
| 31 |
+
|
| 32 |
+
```json
|
| 33 |
+
{"id": "pubmed23n0001_0", "title": "...", "content": "...", "contents": "title + content の結合(BM25用)"}
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
- **実装方針**: HuggingFaceから配布済みのParquet/JSONL形式データをロードし、`DocDB.build_db()`(`src/rag/retrieval.py`)で`title`・`content`フィールドをSQLiteに格納する。`DataLoader.create_wiki_db()` と同様の仕組みを活用できる。
|
| 37 |
+
|
| 38 |
+
### 2.2. チャンキング戦略の方針
|
| 39 |
+
|
| 40 |
+
- **現状の課題**: `FixedLengthChunker`(`src/common/chunker.py`)がデフォルトで2000語(オーバーラップ25語)と大きく設定されているため、情報の境界が曖昧になりやすい。設定は `conf/config.yaml` の `index.truncation_config` で管理されている。
|
| 41 |
+
- **コーパス別の対応方針**:
|
| 42 |
+
- **PubMed**: チャンク済みデータ(平均296トークン)をそのまま使用。MedCPT Article Encoderの最大入力長512トークンの範囲内に収まっており、**追加のチャンキングは不要**。MedCPT使用時は `[title, content]` をリスト形式で渡す必要がある(文字列連結不可)。
|
| 43 |
+
- **StatPearls / Textbooks**: 自前チャンキングが必要な場合は `RecursiveCharacterChunker` を `FileManager` に追加し、BPEトークンベースで256〜512トークン程度を目標とする。`conf/config.yaml` の `strategy` キーで選択可能にする。
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## 3. 検索アルゴリズム要件(Retrieval & Embedding)
|
| 48 |
+
|
| 49 |
+
### 3.1. 医療特化Embeddingモデルの導入
|
| 50 |
+
|
| 51 |
+
- **現状の課題**: `conf/config.yaml` の `index.embedding_model` に `text-embedding-3-large`(汎用モデル)が固定で指定されており、`OpenAIManager.create_openai_embeddings()` を通じて埋め込みを生成している。
|
| 52 |
+
- **要件**: 医療専門用語や略語を正確にベクトル空間へマッピングする。
|
| 53 |
+
- **実装方針**: `OpenAIManager` に依存しないローカルモデル用のEmbeddingクラスを新設し(例: `LocalEmbeddingManager`)、以下のモデルをロードして埋め込みを生成できるようにする。`conf/config.yaml` の `embedding_model` キーで切り替え可能な設計とする。
|
| 54 |
+
- **`MedCPT`**(`ncbi/MedCPT-Query-Encoder` / `ncbi/MedCPT-Article-Encoder`): NCBIが255MのPubMed検索ログでContrastive学習したbi-encoderであり、RAGの密ベクトル検索に直接適用できる。MedRAGベンチマークでも高い性能を示しており、第一候補。
|
| 55 |
+
- **`MedEmbed`**(`abhinand/MedEmbed`シリーズ): 汎用モデルを医療データでContrastive学習させたモデル。MTEBの医療タスクで競争力のある性能を示す代替選択肢。
|
| 56 |
+
- > **注意**: `PubMedBERT` や `BioLinkBERT` はContrastive学習を経ていないMLMであり、fine-tuneなしでのbi-encoderとしての利用は非推奨。
|
| 57 |
+
|
| 58 |
+
### 3.2. ��イブリッド検索(Hybrid Search)の実装
|
| 59 |
+
|
| 60 |
+
- **現状の課題**: `FAISSIndexManager`(`src/common/faiss_manager.py`)による密ベクトル(Dense)検索のみが実装されており、薬品名・疾患名などの固有表現の完全一致を取りこぼすケースがある。
|
| 61 |
+
- **要件**: 固有表現の完全一致を取りこぼさない検索機構の構築。
|
| 62 |
+
- **実装方針**: `FAISSIndexManager` に加えて疎ベクトル(Sparse)検索エンジンを導入し、両者のスコアをReciprocal Rank Fusion(RRF)で統合して上位文書を取得する仕組みを追加する。
|
| 63 |
+
- **Sparse検索の選択肢**: BM25は実装が容易で安定したベースラインとなる。ただし医療・科学文書では専門用語の語彙ミスマッチが起きやすいため、意味的な語彙展開が可能な **SPLADE** も代替として検討する。
|
| 64 |
+
- **スコア統合**: RRFは訓練データなしで使えるロバストな手法であり第一選択。ラベル付きデータが入手できる場合は重み付きRRFや凸結合(Convex Combination)でさらなる精度改善が見込める。
|
| 65 |
+
|
| 66 |
+
### 3.3. リランカー(Re-ranker)の導入
|
| 67 |
+
|
| 68 |
+
- **要件**: 大規模コーパス特有のノイズ文書(ディストラクター)を排除する。
|
| 69 |
+
- **実装方針**: `SubclaimProcessor.generate_responses()`(`src/subclaim_processor/subclaim_processor.py`)において、ハイブリッド検索で上位100件程度を粗く取得したのち、Cross-Encoderで精緻な関連度再計算を行い、最終的に上位10件に絞り込む処理を追加する。
|
| 70 |
+
- **推奨モデル(セルフホスト)**: `BAAI/bge-reranker-v2-m3`(多言語・汎用)または `BAAI/bge-reranker-v2.5-gemma2-lightweight`(2025年時点のSOTA、トークン圧縮機能付き)
|
| 71 |
+
- **推奨モデル(商用API)**: Cohere Rerank(現行v4)。精度は高いがセルフホスト不可のため、運用コストと可搬性を考慮して選択する。
|
| 72 |
+
|
| 73 |
+
---
|
| 74 |
+
|
| 75 |
+
## 4. 評価・スコアリング要件
|
| 76 |
+
|
| 77 |
+
### 4.1. スコアリング戦略(Strategy)の改修
|
| 78 |
+
|
| 79 |
+
- **現状の課題**: `ProductScoreStrategy`(`src/subclaim_processor/strategies/scoring/product.py`)がFAISSの検索スコアとサブクレームのコサイン類似度の積でスコアを算出している。ハイブリッド検索やリランカーを導入するとベーススコアのスケールが変化するため、この計算式が破綻する。
|
| 80 |
+
- **要件**: 検索手法に依存しない普遍的なサブクレーム関連度スコアの定義。
|
| 81 |
+
- **実装方針**: `ScoringStrategy`(`src/subclaim_processor/strategies/scoring/base.py`)を継承した新たな戦略クラスを追加し、Cross-Encoderが出力する「クエリ+サブクレーム」と「検索文書」の推論ロジットを直接スコアとして採用する。`conf/config.yaml` の `conformal_prediction.scoring_strategy` キーで切り替え可能にする。
|
| 82 |
+
|
| 83 |
+
---
|
| 84 |
+
|
| 85 |
+
## 5. 非機能要件(パフォーマンス・拡張性)
|
| 86 |
+
|
| 87 |
+
### 5.1. 非同期処理・バッチ処理の導入
|
| 88 |
+
|
| 89 |
+
- **現状の課題**: RAG応答生成からサブクレーム抽出・アノテーションに至るまで、パイプライン全体が逐次処理で実装されており、大量クエリ評価時のスループットが低い。OpenAI APIのレートリミット(429エラー)も問題になりうる。
|
| 90 |
+
- **要件**: 大規模コーパスおよび大量クエリ評価における実行時間の大幅な短縮とAPIレートリミット回避。
|
| 91 |
+
- **実装方針**: 以下のいずれか、または組み合わせを採用する。
|
| 92 |
+
- `asyncio` を用いた非同期パイプラインへの書き換え(リアルタイム性が求められる場合)
|
| 93 |
+
- OpenAI Batch APIを用いたアノテーションの一括処理(コスト重視の場合、レイテンシは最大24時間)
|
| 94 |
+
|
| 95 |
+
### 5.2. インデックス管理の堅牢化
|
| 96 |
+
|
| 97 |
+
- **現状の課題**: `FAISSIndexManager` は `IndexFlatIP`(内積)を使用しており、ファイルパスとFAISSインデックス番号のリストを紐づけるマッピング `indice2fm`(`{file_path: [indices]}`、`index_store/indice2fm.json` に保存)をオンメモリで管理している。数百万規模のドキュメントではメモリ不足が発生しうる。
|
| 98 |
+
- **要件**: 大規模インデックスにおけるメモリ効率の確保と、インデックス更新の容易性。
|
| 99 |
+
- **実装方針**: インデックス構築をバッチ処理化する。必要に応じてディスクベースのベクトルDB(Milvus、Qdrant等)への移行を見据え、`FAISSIndexManager` のインターフェースを抽象化しておく。
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## 6. ハイパーパラメータ調整・実験計画
|
| 104 |
+
|
| 105 |
+
### 6.1. 調整対象ハイパーパラメータの一覧
|
| 106 |
+
|
| 107 |
+
パイプラインを構成する4つのステージごとにハイパーパラメータを整理する。各パラメータの現行値は `conf/config.yaml` に基づく。
|
| 108 |
+
|
| 109 |
+
#### ステージA: チャンキング(StatPearls / Textbooks のみ)
|
| 110 |
+
|
| 111 |
+
| パラメータ | 現行値 | 探索範囲 | 備考 |
|
| 112 |
+
|-----------|------|--------|------|
|
| 113 |
+
| `chunk_size` | 2000語 | 128 / 256 / **512** / 1024 トークン | PubMedは296トークン固定のため対象外 |
|
| 114 |
+
| `chunk_overlap` | 25語 | 0 / 32 / **64** / 128 トークン | chunkサイズの10〜20%が目安 |
|
| 115 |
+
|
| 116 |
+
#### ステージB: 検索(Retrieval)
|
| 117 |
+
|
| 118 |
+
| パラメータ | 現行値 | 探索範囲 | 備考 |
|
| 119 |
+
|-----------|------|--------|------|
|
| 120 |
+
| `retrival_topk` | 10 | 10 / 20 / **50** / 100 | リランカー導入後は候補プール用に拡大 |
|
| 121 |
+
| `retrival_threshold` | 0.3 | 0.1 / **0.2** / 0.3 | Denseスコア下限。低すぎるとノイズ増加 |
|
| 122 |
+
| BM25 `k1` | —(未実装) | 0.5 / **1.2** / 2.0 | 語頻度の飽和速度。医療文書は1.2前後が標準 |
|
| 123 |
+
| BM25 `b` | —(未実装) | 0.25 / **0.75** | 文書長正規化。長いアブストラクトには0.75推奨 |
|
| 124 |
+
| RRF `k` | —(未実装) | 10 / **60** | 標準値60。訓練データなしで使える頑健な設定 |
|
| 125 |
+
| RRF Dense重み | —(未実装) | 0.3 / **0.5** / 0.7 | Dense寄りにするほど意味検索優先 |
|
| 126 |
+
|
| 127 |
+
#### ステージC: リランキング
|
| 128 |
+
|
| 129 |
+
| パラメータ | 現行値 | 探索範囲 | 備考 |
|
| 130 |
+
|-----------|------|--------|------|
|
| 131 |
+
| リランカー候補プール数 | —(未実装) | 50 / **100** / 200 | 大きいほど再現率↑、遅延↑ |
|
| 132 |
+
| リランカー後 `topk` | —(未実装) | 5 / **10** / 20 | 最終的にLLMに渡す文書数 |
|
| 133 |
+
|
| 134 |
+
#### ステージD: スコアリング・集約(Conformal Prediction)
|
| 135 |
+
|
| 136 |
+
| パラメータ | 現行値 | 探索範囲 | 備考 |
|
| 137 |
+
|-----------|------|--------|------|
|
| 138 |
+
| `scoring_strategy` | `product` | `product` / `cross_encoder` | リランカー導入後は `cross_encoder` に切り替えを検討 |
|
| 139 |
+
| `aggregation_strategy` | `mean` | `mean` / `max` | `max`は最もサポートされた証拠を採用 |
|
| 140 |
+
| `a_value` | 1.0 | 0.8 / 0.9 / **1.0** | 「ファクチュアル」と見なす正解保持率の閾値 |
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
### 6.2. 評価指標
|
| 145 |
+
|
| 146 |
+
実験には以下の指標を使用する。`SplitConformalCalibration`(`src/calibration/conformal.py`)が既に以下の計算を実装済み。
|
| 147 |
+
|
| 148 |
+
| 指標 | 定義 | 担当ステージ |
|
| 149 |
+
|-----|------|------------|
|
| 150 |
+
| **Recall@k** | 正解文書がtop-k内に含まれる割合 | B, C |
|
| 151 |
+
| **MRR@k** | Mean Reciprocal Rank(正解文書の順位逆数の平均) | B, C |
|
| 152 |
+
| **Coverage rate** | 実経験的ファクチュアリティ ≥ 1−α を満たすクエリの割合(≥ 1−α であれば保証成立) | D |
|
| 153 |
+
| **Efficiency** | 各αにおける平均サブクレーム除去率(高いほど効率的にノイズを除去できている) | D |
|
| 154 |
+
|
| 155 |
+
---
|
| 156 |
+
|
| 157 |
+
### 6.3. 実験フロー
|
| 158 |
+
|
| 159 |
+
各ステージで最良の設定を固定したうえで次ステージの探索を行う、逐次的な**グリーディサーチ**を採用する。ステージ間の依存関係があるため、並列探索は同一ステージ内のみで行う。
|
| 160 |
+
|
| 161 |
+
```
|
| 162 |
+
Phase 1: チャンキング探索(StatPearls / Textbooks のみ)
|
| 163 |
+
└─ 評価指標: Recall@10 on MedLFQA
|
| 164 |
+
└─ 固定: Embedding=MedCPT, Dense-only 検索
|
| 165 |
+
└─ 探索: chunk_size × chunk_overlap の格子探索(8条件)
|
| 166 |
+
|
| 167 |
+
Phase 2: ハイブリッド検索の探索
|
| 168 |
+
└─ 評価指標: Recall@10, MRR@10
|
| 169 |
+
└─ 固定: Phase 1 のベストチャンク設定
|
| 170 |
+
└─ 探索ステップ:
|
| 171 |
+
2a. Dense-only ベースライン(現行)
|
| 172 |
+
2b. BM25 k1 × b の格子探索(6条件)、Denseとの RRF 融合
|
| 173 |
+
2c. SPLADE vs BM25 の比較(GPU使用可能な場合のみ)
|
| 174 |
+
2d. RRF Dense重み の探索(0.3 / 0.5 / 0.7)
|
| 175 |
+
|
| 176 |
+
Phase 3: リランカーの探索
|
| 177 |
+
└─ 評価指標: Recall@10 after rerank, NDCG@10
|
| 178 |
+
└─ 固定: Phase 2 のベストハイブリッド設定
|
| 179 |
+
└─ 探索: 候補プール数 × 最終topk の格子探索(6条件)
|
| 180 |
+
└─ モデル比較: bge-reranker-v2-m3 vs bge-reranker-v2.5-gemma2-lightweight
|
| 181 |
+
|
| 182 |
+
Phase 4: Conformal Prediction の探索
|
| 183 |
+
└─ 評価指標: Coverage rate, Efficiency(fraction_removed)
|
| 184 |
+
└─ 固定: Phase 3 のベスト検索設定
|
| 185 |
+
└─ 探索: scoring_strategy × aggregation_strategy × a_value(12条件)
|
| 186 |
+
└─ alphaは 0.05〜0.45 を既定のステップ幅でスイープ(conformal.py の既存実装を使用)
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
---
|
| 190 |
+
|
| 191 |
+
### 6.4. 実験管理上の注意点
|
| 192 |
+
|
| 193 |
+
- **PubMedのチャンク設定は固定**: PubMedスニペットは平均296トークンで配布されており、MedCPTの512トークン上限に収まる。Phase 1 の探索対象はStatPearls / Textbooks のみとする。
|
| 194 |
+
- **MedCPTの入力形式**: `[title, content]` をリスト要素として渡す必要があり、他のモデルとの入力フォーマットが異なる。`LocalEmbeddingManager` 実装時に対応すること。
|
| 195 |
+
- **キャッシュの活用**: `conformal.py` はスコア計算結果を `.npy` ファイルにキャッシュする仕組みを持つ��チャンクや検索設定を変えた場合はキャッシュを削除してから再実行すること。
|
| 196 |
+
- **Phase間の依存**: 各フェーズのベスト設定を `conf/config.yaml` に反映してから次フェーズを実行する。バージョン管理のため、フェーズごとに `git tag` を付けることを推奨する。
|
docs/context/06_test_strategy.md
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# テスト計画と実行フロー
|
| 2 |
+
|
| 3 |
+
OpenAI API への依存度に基づいてステップを分離し、API キーなしでも中核ロジックを早期に検証できる構成にする。
|
| 4 |
+
|
| 5 |
+
## 全体像
|
| 6 |
+
|
| 7 |
+
| ステップ | 目的 | API 使用 | 使用データ | 対象コンポーネント |
|
| 8 |
+
| :--- | :--- | :---: | :--- | :--- |
|
| 9 |
+
| **1. ユニットテスト** | コアロジック・純粋関数の単体テスト | **不要** | 合成データ(手作り) | `src/calibration/`、`demo/inference_api.py` |
|
| 10 |
+
| **2. スモークテスト** | 実データの生成とパイプラインの疎通確認 | **必要** | 実データ(極小スケール) | `main.py`(データ取得〜スコア計算) |
|
| 11 |
+
| **3. 統合テスト** | 事前計算ロジックの動作確認 | **不要** | ステップ 2 の出力 JSON | `demo/precompute.py` |
|
| 12 |
+
| **4. UI テスト** | Streamlit UI と状態管理の動作確認 | **不要** | ステップ 3 の出力 CSV/JSON | `demo/app.py`、`st.session_state` |
|
| 13 |
+
|
| 14 |
+
---
|
| 15 |
+
|
| 16 |
+
## 1. ユニットテスト(API 不要)
|
| 17 |
+
|
| 18 |
+
論文の中核である「Conformal Guarantee(事実性の統計的保証)」を計算するロジックと、デモの表示切り替えロジックを最優先でテストする。`pytest` で即座に実行できる。
|
| 19 |
+
|
| 20 |
+
実装ファイル:
|
| 21 |
+
- `tests/test_calibration.py` — キャリブレーション計算関数
|
| 22 |
+
- `tests/test_precompute.py` — `demo/precompute.py` のデータ変換・サンプリング関数
|
| 23 |
+
|
| 24 |
+
**準備するデータ:** ダミーのサブクレームリスト(スコア `relevance: 0.8`、ラベル `S` などを持たせた数十件の Python 辞書)
|
| 25 |
+
|
| 26 |
+
**テスト対象と確認事項:**
|
| 27 |
+
|
| 28 |
+
1. **`src/calibration/utils.py` — `get_r_score`**
|
| 29 |
+
- ダミーのサブクレーム群を与え、目標正確率 $a$ を満たすための最小スコア $r_a$ が数学的に正しく算出されるか。
|
| 30 |
+
- 全サブクレームが正解の場合に `-1`(常に安全)が返るか。
|
| 31 |
+
|
| 32 |
+
2. **`src/calibration/utils.py` — `compute_threshold`**
|
| 33 |
+
- 特定の $\alpha$(例: `0.10`)を与えたとき、`ceil((n+1)*(1-alpha))` 分位点ベースの閾値 $\hat{q}$ が正しく計算されるか。
|
| 34 |
+
|
| 35 |
+
3. **`src/calibration/utils.py` — `split_group`**
|
| 36 |
+
- グループ別に calibration/test が指定比率(デフォルト 50/50)で分割されるか。
|
| 37 |
+
- グループをまたいだデータ汚染(calibration と test のオーバーラップ)がないか。
|
| 38 |
+
|
| 39 |
+
4. **`src/calibration/conformal.py` — `_evaluate_conformal_correctness`**
|
| 40 |
+
- 閾値を適用したとき、スコアが閾値以上のサブクレームのみ retain され、`correctly_retained` と `fraction_removed` が正しく計算されるか。
|
| 41 |
+
- サブクレームが全件 remove された場合(`retained_cnt == 0`)に `correctly_retained_percentage = 1` のフォールバックが効くか。
|
| 42 |
+
|
| 43 |
+
5. **`demo/inference_api.py` — `apply_threshold`**
|
| 44 |
+
- モックの `SubclaimResult` と任意の $\hat{q}$ を渡し、スコアが $\hat{q}$ 以上のサブクレームが `keep_count` として正しく集計され、意図した `FilteredResult` が返るか。
|
| 45 |
+
- > **ノイズについて**: `calibration/utils.py` の `get_r_score` は `score + noise` の空間でスコアを集計し、`q_hat` もその空間で計算される。`noise` は Conformal Prediction のタイブレーカーとして `N(0, 0.001)` で付与される微小乱数(`subclaim_processor.py:150`)。`precompute.py` はサンプルクエリのスコアを `score + noise` で保存するため、`apply_threshold` の `score >= q_hat` 比較は理論的に整合している。ライブ推論(`process_query`)では noise を生成しないため、同一保証は成立しない点に注意。
|
| 46 |
+
|
| 47 |
+
6. **`demo/precompute.py` — `_entry_to_subclaim_result`**(リグレッションテスト)
|
| 48 |
+
- `subclaims[i]["score"]` が `score + noise` の値になっているか(ノイズ加算の回帰防止)。
|
| 49 |
+
- `noise` キーが欠落した古いデータ形式でもクラッシュしないか(デフォルト `0.0`)。
|
| 50 |
+
- `is_grouped=False` のとき `mode="marginal"`、`group="default"` が設定されるか。
|
| 51 |
+
- `is_grouped=True` のとき `mode="conditional"`、`group` が `entry["groups"][0]` から取得されるか。
|
| 52 |
+
|
| 53 |
+
7. **`demo/precompute.py` — `_group_data`、`_has_score`、`_pick_entries`**
|
| 54 |
+
- エントリが正しくグループ別に分類されるか。
|
| 55 |
+
- サブクレームに指定スコアキーがない場合に `_has_score` が `False` を返すか。
|
| 56 |
+
- `_pick_entries` がサブクレーム空のエントリをスキップし、指定件数(`N_SAMPLES_PER_DATASET` / `N_SAMPLES_PER_GROUP`)を返すか。
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## 2. スモークテスト(API 必要)
|
| 61 |
+
|
| 62 |
+
以降のテストには本物のデータ構造が必要なため、既存のバッチ処理を最小規模で走らせる。既存の `src/` ロジックが壊れていないことも同時に確認する。
|
| 63 |
+
|
| 64 |
+
**実行手順:**
|
| 65 |
+
|
| 66 |
+
```bash
|
| 67 |
+
# フル��行(全スコア)
|
| 68 |
+
python main.py --dataset medlf_qa --query_size 10
|
| 69 |
+
|
| 70 |
+
# スモークテスト推奨(frequency・doc_claim_cosine_similarity をスキップ)
|
| 71 |
+
python main.py --dataset medlf_qa --query_size 10 --lite
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
- データセットは `medlf_qa` を使用する(グループ構造を持つ最も複雑なケースのため)。
|
| 75 |
+
- `--query_size 10` で引数上書きするため、`conf/config.yaml` の編集は不要。
|
| 76 |
+
- `2-3` だと 50/50 split(`runs=1000`)に対してデータが少なすぎ、キャリブレーションが意味をなさない。統合テストの動作確認には最低 10 件程度が必要。
|
| 77 |
+
|
| 78 |
+
**API 呼び出し数の目安(`--query_size 10`、平均サブクレーム ~5/クエリ → 計 ~50 サブクレーム):**
|
| 79 |
+
|
| 80 |
+
| ステップ | API | 呼び出し数 |
|
| 81 |
+
|---------|-----|---------|
|
| 82 |
+
| FAISS インデックス構築 | Embeddings | **0**(`index_10.faiss` 既存の場合スキップ) |
|
| 83 |
+
| RAG 回答生成 | Chat (gpt-4.1-mini) | **10** |
|
| 84 |
+
| サブクレーム抽出 | Chat (gpt-4.1-mini) | **10** |
|
| 85 |
+
| `relavance` スコア(claim 埋め込み) | Embeddings | **~50** |
|
| 86 |
+
| `query_claim_cosine_similarity` | Embeddings | **~100**(2 コール × 50) |
|
| 87 |
+
| `doc_claim_cosine_similarity` | Embeddings | **~1,000**(2 × top_k=10 × 50) |
|
| 88 |
+
| `frequency` 代替回答生成(n=5 バッチ) | Chat (gpt-4.1-mini) | **~50 バッチ** |
|
| 89 |
+
| `frequency` カウントプロンプト(× 5) | Chat (gpt-4.1-mini) | **~250** |
|
| 90 |
+
| アノテーション(claim verification) | Chat (gpt-4.1-mini) | **~50** |
|
| 91 |
+
| **合計 Chat** | | **~370 コール** |
|
| 92 |
+
| **合計 Embeddings** | | **~1,150 コール** |
|
| 93 |
+
|
| 94 |
+
`index_store.zip`(著者提供)に `MedLFQA/index_10.faiss` が含まれているため、それを配置すれば Embeddings API によるインデックス構築はスキップできる。
|
| 95 |
+
|
| 96 |
+
`--lite` フラグを付けると以下のスコアをスキップし、API コストを大幅に削減できる。
|
| 97 |
+
|
| 98 |
+
| スキップされるスコア | 理由 |
|
| 99 |
+
|---|---|
|
| 100 |
+
| `frequency` | Chat ~300 コール(総コストの ~80%)、デモ未使用 |
|
| 101 |
+
| `doc_claim_cosine_similarity` | Embeddings ~1,000 コール(~87%)、デモ未使用 |
|
| 102 |
+
|
| 103 |
+
またキャリブレーション・グラフ生成(split conformal / group conditional conformal)もスキップされる。
|
| 104 |
+
|
| 105 |
+
**確認事項:**
|
| 106 |
+
|
| 107 |
+
1. **E2E パイプラインの疎通**: データ取得・FAISS インデックス生成・LLM 推論・スコア計算がエラーなく完走するか。
|
| 108 |
+
2. **出力ファイルの検証**: `data/out/` に正しい JSON スキーマのファイルが生成されているか(`subclaims[].scores` に `relavance`・`noise` 等のキーが存在するか)。`--lite` 使用時は `frequency`・`doc_claim_cosine_similarity` キーは生成されず、`relavance`・`noise`・`query_claim_cosine_similarity`・`min_log_prob` は常に計算される。
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## 3. 統合テスト(API 不要)
|
| 113 |
+
|
| 114 |
+
ステップ 2 で得られた少量の本物データ(`data/out/` 内の JSON)を入力として、デモ用の事前計算スクリプトが正しく動作するかをテストする。
|
| 115 |
+
|
| 116 |
+
**実行手順:**
|
| 117 |
+
|
| 118 |
+
```bash
|
| 119 |
+
python demo/precompute.py
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
**確認事項:**
|
| 123 |
+
|
| 124 |
+
1. **`demo/data/thresholds.csv` の生成**
|
| 125 |
+
- `ALPHA_LEVELS`(例: `0.05, 0.10, ...`)に対応する行が全件生成されているか。
|
| 126 |
+
- 浮動小数点の丸め誤差が発生していないか(`alpha` 列が `0.09999...` ではなく `0.1` になっているか)。
|
| 127 |
+
- キャリブレーションに使われるのがホールドアウト分を除いたデータのみであるか(ログの `calib=N, holdout=M` で確認)。
|
| 128 |
+
- > **注意**: `query_size: 10` 相当の少量データを使っている場合、$\hat{q}$ の値自体は統計的に意味が薄い。ここではスキーマと生成フローの正常動作の確認が目的。本番品質の閾値は `query_size: 500` で再実行する。
|
| 129 |
+
|
| 130 |
+
2. **`demo/data/samples.json` の生成**
|
| 131 |
+
- メタデータ(生成日時・モデル名・スコアリング手法)が正しく付与されているか。
|
| 132 |
+
- MedLFQA の場合、全グループ(HealthsearchQA / MedicationQA / K-QA_Silver / K-QA_Golden / LiveQA)から少なくとも 1 件ずつ含まれているか。
|
| 133 |
+
- `subclaims[].score` フィールドが `FilteredResult` の lookup キーと一致するか(`constants.py` の `SCORING_METHOD_TO_DATA_KEY` のマッピングが正しいか)。
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## 4. UI テスト(API 不要)
|
| 138 |
+
|
| 139 |
+
生成された静的ファイル(CSV/JSON)を用いて、ブラウザ上での UI コンポーネントと状態遷移をテストする。Streamlit の `AppTest` フレームワークを使えば自動テストも可能。
|
| 140 |
+
|
| 141 |
+
**実装ファイル:** `tests/test_app_ui.py`(16 テスト、API 不要)
|
| 142 |
+
|
| 143 |
+
**実行手順:**
|
| 144 |
+
|
| 145 |
+
```bash
|
| 146 |
+
# コンテナ経由
|
| 147 |
+
bash scripts/run_tests.sh tests/test_app_ui.py -v
|
| 148 |
+
|
| 149 |
+
# ホスト直接実行
|
| 150 |
+
source .env && python -m pytest tests/test_app_ui.py -v
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
**前提条件:** `demo/data/thresholds.csv` と `demo/data/samples.json` が存在すること(ステップ 3 の実行済み状態)。
|
| 154 |
+
|
| 155 |
+
**確認事項:**
|
| 156 |
+
|
| 157 |
+
1. **条件付き UI のレンダリング**(`TestConditionalUIRendering`、4 テスト)
|
| 158 |
+
- 「MedLFQA」選択時のみ「Marginal / Conditional」ラジオボタンが活性化するか。
|
| 159 |
+
- 「Conditional」選択時のみグループ選択プルダウンが表示されるか。
|
| 160 |
+
- 検証方法: `at.radio[0].disabled`、`len(at.selectbox)` の件数(Marginal=2, Conditional=3)。
|
| 161 |
+
|
| 162 |
+
2. **セッションリセットの動作**(`TestSessionStateReset`、4 テスト)
|
| 163 |
+
- データセット・モード・グループを切り替えたとき、`prev_params` の変化検出により `st.session_state` がリセットされ、前回の推論結果が表示されたままにならないか。
|
| 164 |
+
- 変更なしの再実行では `query_processed` と `result` が保持されるか。
|
| 165 |
+
|
| 166 |
+
3. **キャッシュの動作(`@st.cache_resource`)**(`TestCacheBehavior`、2 テスト)
|
| 167 |
+
- `get_faiss_manager` / `get_scorer` が `dataset` 引数を持ち、データセットごとにキャッシュが分離される設計になっているかをシグネチャで静的検証する。
|
| 168 |
+
- > **注意**: AppTest 環境ではキャッシュ実体の動作検証が難しいため、Streamlit がキャッシュキーに使う関数シグネチャの存在を確認する形にとどめている。
|
| 169 |
+
|
| 170 |
+
4. **スライダーのインタラクティビティ(最重要)**(`TestSliderInteractivity`、6 テスト)
|
| 171 |
+
- Target Factuality スライダーを動かしたとき、「推論実行」の重い処理が再トリガーされず、`st.session_state` のスコアと CSV の $\hat{q}$ を比較するハイライト再描画(軽い処理)のみが即座に実行されるか。
|
| 172 |
+
- 推論前は disabled、推論後は enabled になるか。
|
| 173 |
+
- Keep / 削除率メトリクスの値が `apply_threshold` の計算結果と一致するか。
|
| 174 |
+
- 検証方法: `process_query` を `MagicMock` で差し替え、スライダー操作後に `assert_not_called()` で確認。
|
| 175 |
+
|
| 176 |
+
**テスト設計上の注意点:**
|
| 177 |
+
|
| 178 |
+
- `samples.json` には `medlf_qa`(conditional)のサンプルのみ含まれるため、結果パネルの描画テストはすべて `medlf_qa / marginal / default` で行い、セッション状態を手動注入する。
|
| 179 |
+
- `st.metric` の値は AppTest では文字列として返るため、数値との比較は `== "1"` のように文字列で行う。
|
| 180 |
+
- `get_faiss_manager` / `get_scorer` / `process_query` は `unittest.mock.patch` で差し替え、API・FAISS・重い推論をすべてモックアウトする。
|
docs/context/07_huggingface_spaces_deployment.md
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Hugging Face Spaces デプロイ設計
|
| 2 |
+
|
| 3 |
+
## 1. 目標とスコープ
|
| 4 |
+
|
| 5 |
+
デモアプリ(`demo/app.py`)を Hugging Face Spaces(公開)で動かす。
|
| 6 |
+
|
| 7 |
+
**スコープ: サンプルクエリ専用モード**
|
| 8 |
+
- `demo/data/samples.json` に事前計算済みのクエリのみ動作する
|
| 9 |
+
- 「推論実行」ボタンは FAISS・Wikipedia DB を使わずに事前計算済み結果を即時表示
|
| 10 |
+
- 「回答を生成」ボタン(LLM 再統合)は OpenAI API を呼ぶため、API キーが必要
|
| 11 |
+
- ライブ推論(任意クエリ入力)は行わない
|
| 12 |
+
|
| 13 |
+
**Spaces の制約**
|
| 14 |
+
- Singularity は使えない(Docker / Python 環境)
|
| 15 |
+
- リポジトリに大きなバイナリファイル(FAISS インデックス、SQLite DB)は置けない
|
| 16 |
+
- 環境変数(Secrets)で API キーを設定できる
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## 2. 現行コードの移植可否分析
|
| 21 |
+
|
| 22 |
+
| 処理 | Spaces での動作 | 対応方針 |
|
| 23 |
+
|------|---------------|---------|
|
| 24 |
+
| `samples.json` / `thresholds.csv` の読み込み | ○ そのまま動く(リポジトリ内) | 変更不要 |
|
| 25 |
+
| 「推論実行」→ 事前計算済み結果の表示 | ○ API 呼び出しなし | 変更不要 |
|
| 26 |
+
| 「回答を生成」(`reintegrate_subclaims`) | ○ OpenAI API のみ | Secrets に `OPENAI_API_KEY` を設定 |
|
| 27 |
+
| `build_faiss_manager()` / `build_scorer()` | △ ファイルが存在しないためエラー | ライブ推論パスに到達しなければ呼ばれない(後述) |
|
| 28 |
+
| `@st.cache_resource` でのリソース初期化 | △ 呼ばれた時点でエラー | 環境フラグで無効化 |
|
| 29 |
+
| `.env` の読み込み | △ Spaces では `.env` ファイルは使わない | Spaces Secrets → 環境変数として自動注入 |
|
| 30 |
+
|
| 31 |
+
### ライブ推論パスが呼ばれる条件
|
| 32 |
+
|
| 33 |
+
現在のコードでは、`build_faiss_manager()` / `build_scorer()` はライブ推論パス内でしか呼ばれない。
|
| 34 |
+
|
| 35 |
+
```python
|
| 36 |
+
# app.py 内の live inference ブランチ(sample_map にないクエリのみ)
|
| 37 |
+
faiss_manager = get_faiss_manager(dataset) ← ここでエラーになる
|
| 38 |
+
scorer = get_scorer(dataset)
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
サンプルクエリのプルダウンには `samples.json` 内のクエリしか表示されないため、
|
| 42 |
+
通常操作では **このパスには到達しない**。
|
| 43 |
+
|
| 44 |
+
ただし、予期しないエラー時のフォールバックや将来の機能追加を考慮し、
|
| 45 |
+
**環境変数フラグ `SPACES_DEMO=1`** でライブ推論パスを明示的に無効化する。
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## 3. 必要な変更
|
| 50 |
+
|
| 51 |
+
### 3-1. `app.py` の変更(最小限)
|
| 52 |
+
|
| 53 |
+
`SPACES_DEMO=1` が設定されている場合、ライブ推論パスをブロックする。
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
import os
|
| 57 |
+
|
| 58 |
+
SPACES_DEMO = os.getenv("SPACES_DEMO", "0") == "1"
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
「推論実行」ボタン処理内:
|
| 62 |
+
|
| 63 |
+
```python
|
| 64 |
+
if run_btn and query_input:
|
| 65 |
+
precomputed = sample_map.get(query_input)
|
| 66 |
+
if precomputed is not None and ...:
|
| 67 |
+
# 事前計算済み → そのまま表示(変更なし)
|
| 68 |
+
...
|
| 69 |
+
elif SPACES_DEMO:
|
| 70 |
+
st.error("このデモではサンプルクエリのみ対応しています。")
|
| 71 |
+
else:
|
| 72 |
+
# ライブ推論(ローカル環境のみ)
|
| 73 |
+
...
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### 3-2. Spaces 用設定ファイル
|
| 77 |
+
|
| 78 |
+
Spaces は **リポジトリの `README.md`(frontmatter)** で設定を宣言する。
|
| 79 |
+
|
| 80 |
+
```yaml
|
| 81 |
+
---
|
| 82 |
+
title: Response Quality Assessment Demo
|
| 83 |
+
emoji: 📊
|
| 84 |
+
colorFrom: blue
|
| 85 |
+
colorTo: green
|
| 86 |
+
sdk: streamlit
|
| 87 |
+
sdk_version: 1.43.2
|
| 88 |
+
app_file: demo/app.py
|
| 89 |
+
pinned: false
|
| 90 |
+
---
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
`app_file` に `demo/app.py` を指定することでリポジトリ構造を変えずに済む。
|
| 94 |
+
|
| 95 |
+
### 3-3. `requirements.txt` の整理(実装済み)
|
| 96 |
+
|
| 97 |
+
Spaces はリポジトリルートの `requirements.txt` を自動で `pip install` する。
|
| 98 |
+
リポジトリでは以下の構成を採用している:
|
| 99 |
+
|
| 100 |
+
| ファイル | 用途 |
|
| 101 |
+
|----------|------|
|
| 102 |
+
| `requirements.txt` | HF Spaces 向け最小セット。Spaces はこのファイルを自動で読む |
|
| 103 |
+
| `requirements-dev.txt` | ローカル・Singularity 向け全依存。`-r requirements.txt` で共通部分を継承 |
|
| 104 |
+
|
| 105 |
+
**`requirements.txt`(Spaces 向け最小セット、§4-1 調査結果に基づく):**
|
| 106 |
+
|
| 107 |
+
```
|
| 108 |
+
openai>=2.0
|
| 109 |
+
python-dotenv>=1.0
|
| 110 |
+
numpy>=1.24
|
| 111 |
+
pandas>=2.0
|
| 112 |
+
pyyaml>=6.0
|
| 113 |
+
streamlit>=1.43
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
`faiss-cpu` / `torch` / `transformers` / `sentence-transformers` / `langchain-core` /
|
| 117 |
+
`scikit-learn` はいずれも不要(遅延 import により Spaces モードでは読み込まれない)。
|
| 118 |
+
|
| 119 |
+
ローカル開発時は `pip install -r requirements-dev.txt` を使う。
|
| 120 |
+
|
| 121 |
+
### 3-4. Secrets の設定
|
| 122 |
+
|
| 123 |
+
Spaces の **Settings > Variables and secrets** に以下を追加する:
|
| 124 |
+
|
| 125 |
+
| キー | 値 | 用途 |
|
| 126 |
+
|------|----|------|
|
| 127 |
+
| `OPENAI_API_KEY` | `sk-...` | 「回答を生成」ボタン(`reintegrate_subclaims`) |
|
| 128 |
+
| `SPACES_DEMO` | `1` | ライブ推論パスを無効化 |
|
| 129 |
+
|
| 130 |
+
`.env` ファイルは Spaces 環境では使わない(`python-dotenv` の `load_dotenv()` は
|
| 131 |
+
環境変数が既にセットされていれば上書きしないため、ローカルとの互換性は保たれる)。
|
| 132 |
+
|
| 133 |
+
---
|
| 134 |
+
|
| 135 |
+
## 4. 事前調査結果
|
| 136 |
+
|
| 137 |
+
### 4-1. `src/` の import チェーンの影響調査(済)
|
| 138 |
+
|
| 139 |
+
`sys.modules` を import 前後で比較し、重い依存パッケージの混入を確認した。
|
| 140 |
+
|
| 141 |
+
**原因**: `src.common.faiss_manager` → `src.common.file_manager` → `langchain_text_splitters`
|
| 142 |
+
が `torch` / `transformers` / `sentence_transformers` 等を連鎖的に引き込む。
|
| 143 |
+
`src.subclaim_processor.scorer.subclaim_scorer` も同様。
|
| 144 |
+
|
| 145 |
+
**対処**: `FAISSIndexManager` / `SubclaimScorer` の import を `build_faiss_manager()` /
|
| 146 |
+
`build_scorer()` 関数内に移動(遅延 import)。型ヒントは `TYPE_CHECKING` ガードで維持。
|
| 147 |
+
これにより、サンプルクエリ専用モードでは import 時に重い依存が読み込まれない。
|
| 148 |
+
|
| 149 |
+
**Spaces 向け最小パッケージセット**(`build_*()` が呼ばれない前提):
|
| 150 |
+
|
| 151 |
+
```
|
| 152 |
+
openai
|
| 153 |
+
python-dotenv
|
| 154 |
+
numpy
|
| 155 |
+
pandas
|
| 156 |
+
pyyaml
|
| 157 |
+
langchain-core
|
| 158 |
+
streamlit>=1.43
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
`faiss-cpu` / `torch` / `transformers` / `sentence-transformers` / `scipy` /
|
| 162 |
+
`scikit-learn` は不要。
|
| 163 |
+
|
| 164 |
+
### 4-2. `DATA_ROOT` / config パスの解決(済)
|
| 165 |
+
|
| 166 |
+
`_load_main_config()` / `_load_dataset_config()` は関数呼び出し時に初めて実行される
|
| 167 |
+
(モジュールレベルでは実行されない)ため、`DATA_ROOT` 未設定でも import は成功する。
|
| 168 |
+
サンプルクエリ専用モードではこれらの関数は呼ばれないため問題なし。
|
| 169 |
+
|
| 170 |
+
---
|
| 171 |
+
|
| 172 |
+
## 5. デプロイ手順
|
| 173 |
+
|
| 174 |
+
1. **HF Space の作成**
|
| 175 |
+
- `huggingface.co/spaces/<username>/<space-name>` を新規作成(SDK: Streamlit)
|
| 176 |
+
|
| 177 |
+
2. **`spaces` remote を追加**
|
| 178 |
+
```bash
|
| 179 |
+
git remote add spaces https://huggingface.co/spaces/EQUES/Response-Quality-Assessment
|
| 180 |
+
```
|
| 181 |
+
認証はリモート URL にトークンを埋め込む方法を使う(`.git/config` は `.gitignore` 対象外だが
|
| 182 |
+
`git` の管理ファイルであり GitHub にはプッシュされない):
|
| 183 |
+
```bash
|
| 184 |
+
git remote set-url spaces https://<username>:<hf_token>@huggingface.co/spaces/EQUES/Response-Quality-Assessment
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
3. **Orphan ブランチで push**
|
| 188 |
+
|
| 189 |
+
HF Spaces はプッシュ時に **全コミット履歴** をスキャンし、10 MiB 超のファイルを拒否する。
|
| 190 |
+
`feature/hf-spaces` の祖先コミットに大きなデータファイル(`data/out/Medlfqav2/` 等)が含まれるため、
|
| 191 |
+
履歴なしの Orphan ブランチを一時作成してから push する:
|
| 192 |
+
|
| 193 |
+
```bash
|
| 194 |
+
git checkout --orphan spaces-deploy
|
| 195 |
+
git add -A
|
| 196 |
+
git rm --cached data/out/ data/raw/ -r --ignore-unmatch
|
| 197 |
+
git commit -m "deploy: initial push to HF Spaces"
|
| 198 |
+
git push --force spaces spaces-deploy:main
|
| 199 |
+
git checkout feature/hf-spaces
|
| 200 |
+
git branch -D spaces-deploy
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
> **なぜ Orphan か**: 通常の push では過去コミットごと送られる。Orphan ブランチは親コミットが
|
| 204 |
+
> 存在しない「起点」なので、HF Spaces には現在のスナップショット 1 コミットだけが届く。
|
| 205 |
+
> ローカルと GitHub の履歴には一切影響しない。
|
| 206 |
+
|
| 207 |
+
4. **Secrets の設定**
|
| 208 |
+
- Space の **Settings > Variables and secrets** に以下を追加:
|
| 209 |
+
|
| 210 |
+
| キー | 値 |
|
| 211 |
+
|------|----|
|
| 212 |
+
| `OPENAI_API_KEY` | `sk-...` |
|
| 213 |
+
| `SPACES_DEMO` | `1` |
|
| 214 |
+
|
| 215 |
+
5. **動作確認**
|
| 216 |
+
- サンプルクエリの表示・スライダー操作
|
| 217 |
+
- 「回答を生成」ボタン(OpenAI API 呼び出し)
|
| 218 |
+
|
| 219 |
+
### 再デプロイ(コード変更後)
|
| 220 |
+
|
| 221 |
+
コードを変更したら同じ Orphan 手順を繰り返す。`spaces` remote は設定済みのため手順 2 は不要。
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
## 6. 既知の問題・注意点
|
| 226 |
+
|
| 227 |
+
### MedLFQA Marginal モードのサンプル一致
|
| 228 |
+
|
| 229 |
+
`precompute.py` は grouped データセット(medlf_qa)のサンプルを全て `mode="conditional"` で
|
| 230 |
+
生成するため、`samples.json` に `mode="marginal"` エントリが存在しない。
|
| 231 |
+
|
| 232 |
+
`app.py` では Marginal モード選択時にモード不一致でサンプルが見つからない問題を回避するため、
|
| 233 |
+
Marginal モードではサンプルの `mode`/`group` フィールドを無視してマッチさせる:
|
| 234 |
+
|
| 235 |
+
```python
|
| 236 |
+
if (
|
| 237 |
+
precomputed is not None
|
| 238 |
+
and (
|
| 239 |
+
mode == "marginal"
|
| 240 |
+
or (precomputed["mode"] == mode and precomputed["group"] == group)
|
| 241 |
+
)
|
| 242 |
+
):
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
サブクレームスコア自体はモードに依存しないため、閾値(`_lookup_q_hat` が marginal 用 q_hat を
|
| 246 |
+
参照)との比較は正しく行われる。
|
| 247 |
+
|
| 248 |
+
---
|
| 249 |
+
|
| 250 |
+
## 7. 開発ステップ
|
| 251 |
+
|
| 252 |
+
| ステップ | 状態 | 担当ファイル |
|
| 253 |
+
|----------|------|------------|
|
| 254 |
+
| ① import 影響調査 | 完了 | — |
|
| 255 |
+
| ② requirements 整理 | 完了 | `requirements.txt`, `requirements-dev.txt` |
|
| 256 |
+
| ③ `app.py` にフラグ追加 | 完了 | `demo/app.py` |
|
| 257 |
+
| ④ `README.md` の作成 | 完了 | `README.md` |
|
| 258 |
+
| ⑤ Spaces へ push | 完了 | — |
|
| 259 |
+
| ⑥ Marginal モードのバグ修正 | 完了 | `demo/app.py` |
|
main.py
ADDED
|
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import numpy as np
|
| 4 |
+
import logging
|
| 5 |
+
import yaml
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from src.common.config_manager import ConfigManager
|
| 9 |
+
from src.dataloader.dataloader import DataLoader
|
| 10 |
+
from src.data_processor.query_processor import QueryProcessor
|
| 11 |
+
from src.common.file_manager import FileManager
|
| 12 |
+
from src.common.faiss_manager import FAISSIndexManager
|
| 13 |
+
from src.subclaim_processor.scorer.subclaim_scorer import SubclaimScorer
|
| 14 |
+
from src.subclaim_processor.subclaim_processor import process_subclaims
|
| 15 |
+
from src.calibration.conformal import SplitConformalCalibration
|
| 16 |
+
from src.calibration.conditional_conformal import GroupConditionalConformal
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_args(dataset_aliases):
|
| 20 |
+
"""Parse command line arguments"""
|
| 21 |
+
parser = argparse.ArgumentParser()
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--config",
|
| 24 |
+
type=str,
|
| 25 |
+
default="conf/config.yaml",
|
| 26 |
+
help="Path to configuration file",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--dataset",
|
| 30 |
+
type=str,
|
| 31 |
+
help="Override dataset name from config",
|
| 32 |
+
choices=dataset_aliases,
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--query_size", type=int, default=500, help="Override query size from config"
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument("--run_id", type=str, help="Custom run identifier")
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"--lite",
|
| 40 |
+
action="store_true",
|
| 41 |
+
help="Skip frequency and doc_claim_cosine_similarity scoring (demo-unused) and conformal prediction plots",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--log_level",
|
| 45 |
+
type=str,
|
| 46 |
+
default="INFO",
|
| 47 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
| 48 |
+
help="Logging level. Use DEBUG to record individual API calls (default: INFO)",
|
| 49 |
+
)
|
| 50 |
+
return parser.parse_args()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
avaliable_datasets = []
|
| 55 |
+
with open("conf/dataset_config.yaml", 'r') as f:
|
| 56 |
+
dataset_config = yaml.safe_load(f)
|
| 57 |
+
avaliable_datasets = list(dataset_config["datasets"].keys())
|
| 58 |
+
# Parse arguments
|
| 59 |
+
args = parse_args(avaliable_datasets)
|
| 60 |
+
|
| 61 |
+
# Initialize config manager
|
| 62 |
+
config_manager = ConfigManager(
|
| 63 |
+
config_path=args.config,
|
| 64 |
+
path_config_path="conf/path_config.yaml",
|
| 65 |
+
dataset_config_path="conf/dataset_config.yaml",
|
| 66 |
+
run_id=args.run_id,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
dataset_aliases = list(dataset_config["datasets"].keys())
|
| 70 |
+
|
| 71 |
+
# Setup logging
|
| 72 |
+
log_file, run_id = config_manager.setup_logging(
|
| 73 |
+
log_level=getattr(logging, args.log_level)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Update config with command line arguments if provided
|
| 77 |
+
if args.dataset or args.query_size:
|
| 78 |
+
updates = {"dataset": {}}
|
| 79 |
+
if args.dataset:
|
| 80 |
+
updates["dataset"]["name"] = args.dataset
|
| 81 |
+
if args.query_size:
|
| 82 |
+
updates["dataset"]["query_size"] = args.query_size
|
| 83 |
+
config_manager.update_config(updates)
|
| 84 |
+
|
| 85 |
+
lite = args.lite
|
| 86 |
+
|
| 87 |
+
# Save updated config
|
| 88 |
+
config_file = config_manager.save_config()
|
| 89 |
+
logging.info(f"Configuration saved to: {config_file}")
|
| 90 |
+
|
| 91 |
+
# Log important config values
|
| 92 |
+
config_manager.log_config()
|
| 93 |
+
|
| 94 |
+
# Get the config
|
| 95 |
+
config = config_manager.config
|
| 96 |
+
path_config = config_manager.path_config
|
| 97 |
+
dataset_config = config_manager.dataset_config
|
| 98 |
+
|
| 99 |
+
####################################### Data and Folder Set up ############################################
|
| 100 |
+
dataset_name = config["dataset"]["name"]
|
| 101 |
+
query_size = config["dataset"]["query_size"]
|
| 102 |
+
wiki_db_file = config["dataset"]["wiki_db_file"]
|
| 103 |
+
|
| 104 |
+
delete_existing_index = config["index"]["delete_existing"]
|
| 105 |
+
embedding_model = config["index"]["embedding_model"]
|
| 106 |
+
index_truncation_config = config["index"]["truncation_config"]
|
| 107 |
+
truncation_strategy = index_truncation_config["strategy"]
|
| 108 |
+
truncate_by = index_truncation_config["truncate_by"]
|
| 109 |
+
|
| 110 |
+
response_model = config["rag"]["response_model"]
|
| 111 |
+
frequency_score_model = config["conformal_prediction"]["frequency_score_model"]
|
| 112 |
+
|
| 113 |
+
alpha_config = config["conformal_prediction"]["conformal_alphas"]
|
| 114 |
+
conformal_alphas = np.arange(
|
| 115 |
+
alpha_config["start"], alpha_config["end"], alpha_config["step"]
|
| 116 |
+
)
|
| 117 |
+
a_value = config["conformal_prediction"]["a_value"]
|
| 118 |
+
|
| 119 |
+
dataset_custom_config = dataset_config["datasets"].get(dataset_name)
|
| 120 |
+
if not dataset_custom_config:
|
| 121 |
+
raise ValueError(f"Unknown dataset: {dataset_name}")
|
| 122 |
+
full_dataset_name = dataset_custom_config["name"]
|
| 123 |
+
index_store_dir = dataset_custom_config["index_store"]
|
| 124 |
+
group_conditional_conformal = dataset_custom_config.get("is_grouped", False)
|
| 125 |
+
|
| 126 |
+
raw_data_dir = os.path.join(path_config["paths"]["raw_data_dir"], full_dataset_name)
|
| 127 |
+
processed_data_dir = os.path.join(
|
| 128 |
+
path_config["paths"]["processed_data_dir"], full_dataset_name
|
| 129 |
+
)
|
| 130 |
+
response_dir = os.path.join(path_config["paths"]["response_dir"], full_dataset_name)
|
| 131 |
+
wiki_db_path = os.path.join(path_config["paths"]["wiki_db_dir"], wiki_db_file)
|
| 132 |
+
result_dir = os.path.join(
|
| 133 |
+
path_config["paths"]["result_dir"], full_dataset_name, run_id
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# set up directories
|
| 137 |
+
for dir_path in [raw_data_dir, processed_data_dir, response_dir, result_dir]:
|
| 138 |
+
os.makedirs(dir_path, exist_ok=True)
|
| 139 |
+
logging.info(f"Directory ensured: {dir_path}")
|
| 140 |
+
|
| 141 |
+
# Determine raw data file path
|
| 142 |
+
if dataset_name == "medlf_qa":
|
| 143 |
+
input_file = os.path.join(path_config["paths"]["raw_data_dir"], "MedLFQA")
|
| 144 |
+
raw_data_path = input_file
|
| 145 |
+
else:
|
| 146 |
+
raw_data_file = f"raw_{dataset_name}.json"
|
| 147 |
+
raw_data_path = os.path.join(raw_data_dir, raw_data_file)
|
| 148 |
+
|
| 149 |
+
logging.info(f"Raw data path: {raw_data_path}")
|
| 150 |
+
|
| 151 |
+
# Load data if needed
|
| 152 |
+
if not os.path.exists(raw_data_path):
|
| 153 |
+
logging.info(f"Raw data not found. Loading data for {dataset_name}")
|
| 154 |
+
data_loader = DataLoader(dataset_name)
|
| 155 |
+
data_loader.load_qa_data(output_path=raw_data_path)
|
| 156 |
+
logging.info(f"Data loaded and saved to {raw_data_path}")
|
| 157 |
+
|
| 158 |
+
# create wiki db if needed
|
| 159 |
+
if not os.path.exists(wiki_db_path) or not os.path.isfile(wiki_db_path):
|
| 160 |
+
wiki_source = os.path.join(
|
| 161 |
+
path_config["paths"]["wiki_db_dir"],
|
| 162 |
+
"enwiki-20171001-pages-meta-current-withlinks-abstracts",
|
| 163 |
+
)
|
| 164 |
+
if not os.path.exists(wiki_source):
|
| 165 |
+
raise FileNotFoundError(f"Wiki source data not found at {wiki_source}")
|
| 166 |
+
logging.info(f"Wiki DB not found. Creating from source {wiki_source}")
|
| 167 |
+
data_loader = DataLoader(dataset_name)
|
| 168 |
+
data_loader.create_wiki_db(source_path=wiki_source, output_path=wiki_db_path)
|
| 169 |
+
logging.info(f"Wiki DB created at {wiki_db_path}")
|
| 170 |
+
|
| 171 |
+
# Process queries and documents
|
| 172 |
+
input_file = raw_data_path
|
| 173 |
+
if dataset_name == "medlf_qa":
|
| 174 |
+
input_file = os.path.join(path_config["paths"]["raw_data_dir"], "MedLFQA")
|
| 175 |
+
|
| 176 |
+
query_output_file = f"{dataset_name}_queries.json"
|
| 177 |
+
document_output_file = f"{dataset_name}_documents.txt"
|
| 178 |
+
|
| 179 |
+
subclaims_path = os.path.join(
|
| 180 |
+
response_dir,
|
| 181 |
+
f"{dataset_name}_{query_size}_subclaims_with_scores_{response_model}.json",
|
| 182 |
+
)
|
| 183 |
+
CP_result_fig_path = os.path.join(
|
| 184 |
+
result_dir, f"{dataset_name}_{query_size}_a={a_value:.2f}_CP_removal.png"
|
| 185 |
+
)
|
| 186 |
+
GCP_result_fig_path = os.path.join(
|
| 187 |
+
result_dir, f"{dataset_name}_{query_size}_a={a_value:.2f}_GCP_removal.png"
|
| 188 |
+
)
|
| 189 |
+
factual_result_fig_path = os.path.join(
|
| 190 |
+
result_dir,
|
| 191 |
+
f"{dataset_name}_{query_size}_a={a_value:.2f}_factual_correctness.png",
|
| 192 |
+
)
|
| 193 |
+
group_factual_result_fig_path = os.path.join(
|
| 194 |
+
result_dir,
|
| 195 |
+
f"group_{dataset_name}_{query_size}_a={a_value:.2f}_factual_correctness.png",
|
| 196 |
+
)
|
| 197 |
+
result_path = os.path.join(
|
| 198 |
+
result_dir, f"{dataset_name}_{query_size}_a={a_value:.2f}.csv"
|
| 199 |
+
)
|
| 200 |
+
group_result_path = os.path.join(
|
| 201 |
+
result_dir, f"group_{dataset_name}_{query_size}_a={a_value:.2f}.csv"
|
| 202 |
+
)
|
| 203 |
+
####################################### End of Data and Folder Set up ######################################
|
| 204 |
+
|
| 205 |
+
# Create QueryProcessor
|
| 206 |
+
logging.info("Initializing QueryProcessor")
|
| 207 |
+
query_processor = QueryProcessor(db_path=wiki_db_path, query_size=query_size)
|
| 208 |
+
|
| 209 |
+
# Create queries data
|
| 210 |
+
logging.info("Processing queries")
|
| 211 |
+
queries, query_path = query_processor.get_queries(
|
| 212 |
+
dataset=dataset_name,
|
| 213 |
+
input_file=input_file,
|
| 214 |
+
output_dir=processed_data_dir,
|
| 215 |
+
output_file=query_output_file,
|
| 216 |
+
)
|
| 217 |
+
logging.info(f"Query size: {len(queries)}")
|
| 218 |
+
|
| 219 |
+
# Create documents data
|
| 220 |
+
logging.info("Processing documents")
|
| 221 |
+
document_path = query_processor.get_documents(
|
| 222 |
+
query_dir=query_path,
|
| 223 |
+
output_dir=processed_data_dir,
|
| 224 |
+
output_file=document_output_file,
|
| 225 |
+
)
|
| 226 |
+
logging.info(f"Documents saved to {document_path}")
|
| 227 |
+
|
| 228 |
+
# Index creation and retrieval
|
| 229 |
+
os.makedirs(index_store_dir, exist_ok=True)
|
| 230 |
+
index_file_path = os.path.join(index_store_dir, f"index_{query_size}.faiss")
|
| 231 |
+
indice2fm_path = os.path.join(index_store_dir, f"indice2fm_{query_size}.json")
|
| 232 |
+
|
| 233 |
+
logging.info(f"Setting up FAISS index manager")
|
| 234 |
+
faiss_manager = FAISSIndexManager(
|
| 235 |
+
index_truncation_config=index_truncation_config,
|
| 236 |
+
index_path=index_file_path,
|
| 237 |
+
indice2fm_path=indice2fm_path,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if delete_existing_index:
|
| 241 |
+
logging.info("Deleting existing index as requested")
|
| 242 |
+
faiss_manager.delete_index()
|
| 243 |
+
|
| 244 |
+
# Create index if it does not exist
|
| 245 |
+
document_file = FileManager(
|
| 246 |
+
document_path, index_truncation_config=index_truncation_config
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
logging.info(
|
| 250 |
+
f"Using truncation strategy: {truncation_strategy}, truncate_by: {truncate_by}"
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# If Index doesn't exist yet
|
| 254 |
+
if not os.path.exists(index_file_path):
|
| 255 |
+
try:
|
| 256 |
+
logging.info(f"Creating new index with document '{document_path}'")
|
| 257 |
+
faiss_manager.upsert_file_to_faiss(
|
| 258 |
+
document_file,
|
| 259 |
+
truncation_strategy=truncation_strategy,
|
| 260 |
+
truncate_by=truncate_by,
|
| 261 |
+
)
|
| 262 |
+
logging.info("Index created successfully")
|
| 263 |
+
except Exception as e:
|
| 264 |
+
error_msg = f"Failed to create new index: {str(e)}"
|
| 265 |
+
logging.error(error_msg)
|
| 266 |
+
raise RuntimeError(error_msg)
|
| 267 |
+
|
| 268 |
+
# If Index exists but current document isn't indexed
|
| 269 |
+
elif document_path not in faiss_manager.indice2fm:
|
| 270 |
+
# Verify index integrity
|
| 271 |
+
logging.info("Checking index integrity")
|
| 272 |
+
if not faiss_manager.is_indice_align():
|
| 273 |
+
error_msg = "Index corruption detected: index and indice2fm are not aligned"
|
| 274 |
+
logging.error(error_msg)
|
| 275 |
+
raise ValueError(error_msg)
|
| 276 |
+
|
| 277 |
+
try:
|
| 278 |
+
logging.info(f"Adding document '{document_path}' to existing index")
|
| 279 |
+
faiss_manager.upsert_file_to_faiss(
|
| 280 |
+
document_file,
|
| 281 |
+
truncation_strategy=truncation_strategy,
|
| 282 |
+
truncate_by=truncate_by,
|
| 283 |
+
)
|
| 284 |
+
logging.info("Document added to index successfully")
|
| 285 |
+
except Exception as e:
|
| 286 |
+
error_msg = f"Failed to add document to index: {str(e)}"
|
| 287 |
+
logging.error(error_msg)
|
| 288 |
+
raise RuntimeError(error_msg)
|
| 289 |
+
|
| 290 |
+
# Case 3: Document is already indexed
|
| 291 |
+
else:
|
| 292 |
+
logging.info(f"Document '{document_path}' is already indexed")
|
| 293 |
+
|
| 294 |
+
# generate subclaims with scores
|
| 295 |
+
logging.info(f"Initializing SubclaimScorer with embedding model {embedding_model}")
|
| 296 |
+
scorer = SubclaimScorer(
|
| 297 |
+
index_truncation_config=index_truncation_config,
|
| 298 |
+
embedding_model=embedding_model,
|
| 299 |
+
index_path=index_file_path,
|
| 300 |
+
indice2fm_path=indice2fm_path,
|
| 301 |
+
frequency_score_model=frequency_score_model,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
logging.info(f"Processing subclaims and generating scores")
|
| 305 |
+
subclaim_with_annotation_data = process_subclaims(
|
| 306 |
+
query_path=query_path,
|
| 307 |
+
subclaims_path=subclaims_path,
|
| 308 |
+
faiss_manager=faiss_manager,
|
| 309 |
+
scorer=scorer,
|
| 310 |
+
config=config,
|
| 311 |
+
lite=lite,
|
| 312 |
+
)
|
| 313 |
+
logging.info(f"Subclaims processed and saved to {subclaims_path}")
|
| 314 |
+
|
| 315 |
+
# calibration and conformal prediction results
|
| 316 |
+
if not lite and config["conformal_prediction"]["split_conformal"]:
|
| 317 |
+
logging.info("Running split conformal prediction")
|
| 318 |
+
conformal = SplitConformalCalibration(dataset_name=dataset_name)
|
| 319 |
+
logging.info(
|
| 320 |
+
f"Plotting conformal removal with alphas: {conformal_alphas}, a={a_value}"
|
| 321 |
+
)
|
| 322 |
+
conformal.plot_conformal_removal(
|
| 323 |
+
data=subclaim_with_annotation_data,
|
| 324 |
+
alphas=conformal_alphas,
|
| 325 |
+
a=a_value,
|
| 326 |
+
fig_filename=CP_result_fig_path,
|
| 327 |
+
csv_filename=result_path,
|
| 328 |
+
)
|
| 329 |
+
logging.info(f"CP removal plot saved to {CP_result_fig_path}")
|
| 330 |
+
|
| 331 |
+
logging.info("Plotting factual removal")
|
| 332 |
+
conformal.plot_factual_removal(
|
| 333 |
+
data=subclaim_with_annotation_data,
|
| 334 |
+
alphas=conformal_alphas,
|
| 335 |
+
a=a_value,
|
| 336 |
+
fig_filename=factual_result_fig_path,
|
| 337 |
+
csv_filename=result_path,
|
| 338 |
+
)
|
| 339 |
+
logging.info(f"Factual removal plot saved to {factual_result_fig_path}")
|
| 340 |
+
logging.info(f"Results saved to {result_path}")
|
| 341 |
+
|
| 342 |
+
if not lite and group_conditional_conformal:
|
| 343 |
+
logging.info("Running group conditional conformal prediction")
|
| 344 |
+
conformal = GroupConditionalConformal(dataset_name=dataset_name, result_dir=result_dir)
|
| 345 |
+
logging.info(
|
| 346 |
+
f"Plotting conformal removal with alphas: {conformal_alphas}, a={a_value}"
|
| 347 |
+
)
|
| 348 |
+
conformal.plot_conformal_removal(
|
| 349 |
+
data=subclaim_with_annotation_data,
|
| 350 |
+
alphas=conformal_alphas,
|
| 351 |
+
a=a_value,
|
| 352 |
+
fig_filename=GCP_result_fig_path,
|
| 353 |
+
csv_filename=group_result_path,
|
| 354 |
+
)
|
| 355 |
+
logging.info(f"CP removal plot saved to {GCP_result_fig_path}")
|
| 356 |
+
|
| 357 |
+
logging.info("Plotting factual removal")
|
| 358 |
+
conformal.plot_factual_removal(
|
| 359 |
+
data=subclaim_with_annotation_data,
|
| 360 |
+
alphas=conformal_alphas,
|
| 361 |
+
a=a_value,
|
| 362 |
+
fig_filename=group_factual_result_fig_path,
|
| 363 |
+
csv_filename=group_result_path,
|
| 364 |
+
)
|
| 365 |
+
logging.info(f"Factual removal plot saved to {factual_result_fig_path}")
|
| 366 |
+
logging.info(f"Results saved to {result_path}")
|
| 367 |
+
|
| 368 |
+
# Copy config and log files to result directory for reproducibility
|
| 369 |
+
result_run_dir = config_manager.copy_run_artifacts(result_dir)
|
| 370 |
+
logging.info(
|
| 371 |
+
f"Run completed successfully. Results and logs saved to {result_run_dir}"
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
if __name__ == "__main__":
|
| 376 |
+
main()
|
requirements-dev.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-r requirements.txt
|
| 2 |
+
faiss-cpu
|
| 3 |
+
torch
|
| 4 |
+
transformers
|
| 5 |
+
sentence-transformers
|
| 6 |
+
huggingface-hub
|
| 7 |
+
datasets
|
| 8 |
+
langchain
|
| 9 |
+
scikit-learn
|
| 10 |
+
PyPDF2
|
| 11 |
+
jsonschema
|
| 12 |
+
matplotlib
|
| 13 |
+
tqdm
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal requirements for HF Spaces (SPACES_DEMO=1, sample queries only).
|
| 2 |
+
# FAISS / torch / transformers are NOT needed because FAISSIndexManager and
|
| 3 |
+
# SubclaimScorer are lazily imported and never called in this mode.
|
| 4 |
+
#
|
| 5 |
+
# Usage: rename or symlink this file to requirements.txt in the Space repo,
|
| 6 |
+
# or specify it via the Spaces build configuration.
|
| 7 |
+
|
| 8 |
+
openai>=2.0
|
| 9 |
+
python-dotenv>=1.0
|
| 10 |
+
numpy>=1.24
|
| 11 |
+
pandas>=2.0
|
| 12 |
+
pyyaml>=6.0
|
| 13 |
+
streamlit>=1.43
|
scripts/build_sif.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Build the Singularity SIF image from singularity/response_quality.def.
|
| 3 |
+
# Usage: bash scripts/build_sif.sh
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
|
| 6 |
+
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 7 |
+
source "${REPO_ROOT}/.env"
|
| 8 |
+
|
| 9 |
+
mkdir -p "${SIF_DIR}"
|
| 10 |
+
|
| 11 |
+
singularity build --fakeroot --force \
|
| 12 |
+
"${SIF_DIR}/response_quality.sif" \
|
| 13 |
+
"${REPO_ROOT}/singularity/response_quality.def"
|
| 14 |
+
|
| 15 |
+
echo "Built: ${SIF_DIR}/response_quality.sif"
|
scripts/run_demo.sh
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Launch the Streamlit demo inside the Singularity container.
|
| 3 |
+
# Usage: bash scripts/run_demo.sh [--port PORT]
|
| 4 |
+
# Access: http://localhost:<PORT> (default: 8502)
|
| 5 |
+
set -euo pipefail
|
| 6 |
+
|
| 7 |
+
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 8 |
+
source "${REPO_ROOT}/.env"
|
| 9 |
+
|
| 10 |
+
PORT=8502
|
| 11 |
+
while [[ $# -gt 0 ]]; do
|
| 12 |
+
case "$1" in
|
| 13 |
+
--port) PORT="$2"; shift 2 ;;
|
| 14 |
+
*) echo "Unknown option: $1" >&2; exit 1 ;;
|
| 15 |
+
esac
|
| 16 |
+
done
|
| 17 |
+
|
| 18 |
+
singularity run \
|
| 19 |
+
--bind "${DATA_ROOT}/data:${REPO_ROOT}/data" \
|
| 20 |
+
--bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store" \
|
| 21 |
+
--bind "${HF_HOME}:${HF_HOME}" \
|
| 22 |
+
--env HF_HOME="${HF_HOME}" \
|
| 23 |
+
--env OPENAI_API_KEY="${OPENAI_API_KEY}" \
|
| 24 |
+
--env DATA_ROOT="${REPO_ROOT}" \
|
| 25 |
+
--env PYTHONPATH="${REPO_ROOT}" \
|
| 26 |
+
--pwd "${REPO_ROOT}" \
|
| 27 |
+
"${SIF_DIR}/response_quality.sif" \
|
| 28 |
+
streamlit run demo/app.py --server.port "${PORT}"
|
scripts/run_main.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Run main.py inside the Singularity container.
|
| 3 |
+
# All arguments are forwarded to main.py as-is.
|
| 4 |
+
#
|
| 5 |
+
# Usage:
|
| 6 |
+
# bash scripts/run_main.sh --dataset hotpot_qa
|
| 7 |
+
# bash scripts/run_main.sh --dataset hotpot_qa --query_size 200
|
| 8 |
+
set -euo pipefail
|
| 9 |
+
|
| 10 |
+
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 11 |
+
source "${REPO_ROOT}/.env"
|
| 12 |
+
|
| 13 |
+
singularity exec \
|
| 14 |
+
--bind "${DATA_ROOT}/data:${REPO_ROOT}/data" \
|
| 15 |
+
--bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store" \
|
| 16 |
+
--bind "${HF_HOME}:${HF_HOME}" \
|
| 17 |
+
--env HF_HOME="${HF_HOME}" \
|
| 18 |
+
--env OPENAI_API_KEY="${OPENAI_API_KEY}" \
|
| 19 |
+
--env DATA_ROOT="${REPO_ROOT}" \
|
| 20 |
+
--pwd "${REPO_ROOT}" \
|
| 21 |
+
"${SIF_DIR}/response_quality.sif" \
|
| 22 |
+
python main.py "$@"
|
scripts/run_precompute.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Run demo/precompute.py inside the Singularity container.
|
| 3 |
+
# Usage: bash scripts/run_precompute.sh
|
| 4 |
+
set -euo pipefail
|
| 5 |
+
|
| 6 |
+
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 7 |
+
source "${REPO_ROOT}/.env"
|
| 8 |
+
|
| 9 |
+
singularity exec \
|
| 10 |
+
--bind "${DATA_ROOT}/data:${REPO_ROOT}/data" \
|
| 11 |
+
--bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store" \
|
| 12 |
+
--bind "${HF_HOME}:${HF_HOME}" \
|
| 13 |
+
--env HF_HOME="${HF_HOME}" \
|
| 14 |
+
--env OPENAI_API_KEY="${OPENAI_API_KEY}" \
|
| 15 |
+
--env DATA_ROOT="${REPO_ROOT}" \
|
| 16 |
+
--pwd "${REPO_ROOT}" \
|
| 17 |
+
"${SIF_DIR}/response_quality.sif" \
|
| 18 |
+
python -m demo.precompute
|
| 19 |
+
|
| 20 |
+
echo ""
|
| 21 |
+
echo "Generated:"
|
| 22 |
+
echo " ${REPO_ROOT}/demo/data/thresholds.csv"
|
| 23 |
+
echo " ${REPO_ROOT}/demo/data/samples.json"
|
| 24 |
+
echo ""
|
| 25 |
+
echo "Commit these files: git add demo/data/ && git commit"
|
scripts/run_tests.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Run pytest inside the Singularity container.
|
| 3 |
+
# All arguments are forwarded to pytest as-is.
|
| 4 |
+
#
|
| 5 |
+
# Prerequisites: .env must define SIF_DIR (and optionally DATA_ROOT, HF_HOME).
|
| 6 |
+
# See docs/context/04_environment_setup.md and .env.example.
|
| 7 |
+
#
|
| 8 |
+
# Usage:
|
| 9 |
+
# bash scripts/run_tests.sh # run all tests
|
| 10 |
+
# bash scripts/run_tests.sh tests/test_calibration.py -v
|
| 11 |
+
set -euo pipefail
|
| 12 |
+
|
| 13 |
+
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 14 |
+
source "${REPO_ROOT}/.env"
|
| 15 |
+
|
| 16 |
+
if [[ -z "${SIF_DIR:-}" ]]; then
|
| 17 |
+
echo "ERROR: SIF_DIR is not set. Fill in .env (see .env.example)." >&2
|
| 18 |
+
exit 1
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
# Build the bind list only for paths that are set.
|
| 22 |
+
BINDS=()
|
| 23 |
+
if [[ -n "${DATA_ROOT:-}" ]]; then
|
| 24 |
+
BINDS+=(--bind "${DATA_ROOT}/data:${REPO_ROOT}/data")
|
| 25 |
+
BINDS+=(--bind "${DATA_ROOT}/index_store:${REPO_ROOT}/index_store")
|
| 26 |
+
fi
|
| 27 |
+
if [[ -n "${HF_HOME:-}" ]]; then
|
| 28 |
+
BINDS+=(--bind "${HF_HOME}:${HF_HOME}")
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
singularity exec \
|
| 32 |
+
"${BINDS[@]}" \
|
| 33 |
+
--env HF_HOME="${HF_HOME:-}" \
|
| 34 |
+
--env DATA_ROOT="${REPO_ROOT}" \
|
| 35 |
+
--pwd "${REPO_ROOT}" \
|
| 36 |
+
"${SIF_DIR}/response_quality.sif" \
|
| 37 |
+
python -m pytest "$@"
|
singularity/response_quality.def
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Bootstrap: docker
|
| 2 |
+
From: python:3.11-slim
|
| 3 |
+
|
| 4 |
+
%post
|
| 5 |
+
apt-get update && apt-get install -y --no-install-recommends \
|
| 6 |
+
build-essential \
|
| 7 |
+
git \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Install torch (CPU-only) explicitly before other packages to avoid
|
| 11 |
+
# pulling in the full CUDA build as a transitive dependency.
|
| 12 |
+
pip install --no-cache-dir \
|
| 13 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
| 14 |
+
|
| 15 |
+
pip install --no-cache-dir \
|
| 16 |
+
streamlit \
|
| 17 |
+
openai \
|
| 18 |
+
faiss-cpu \
|
| 19 |
+
numpy \
|
| 20 |
+
pandas \
|
| 21 |
+
pyyaml \
|
| 22 |
+
python-dotenv \
|
| 23 |
+
transformers \
|
| 24 |
+
sentence-transformers \
|
| 25 |
+
huggingface-hub \
|
| 26 |
+
datasets \
|
| 27 |
+
langchain \
|
| 28 |
+
scikit-learn \
|
| 29 |
+
PyPDF2 \
|
| 30 |
+
jsonschema \
|
| 31 |
+
matplotlib \
|
| 32 |
+
tqdm \
|
| 33 |
+
langchain-text-splitters \
|
| 34 |
+
pytest
|
| 35 |
+
|
| 36 |
+
%environment
|
| 37 |
+
export PYTHONUNBUFFERED=1
|
| 38 |
+
export PYTHONDONTWRITEBYTECODE=1
|
| 39 |
+
|
| 40 |
+
%runscript
|
| 41 |
+
exec "$@"
|
| 42 |
+
|
| 43 |
+
%labels
|
| 44 |
+
Author ryoya.awano
|
| 45 |
+
Description "Response Quality Assessment demo (Conformal RAG)"
|
src/calibration/__init__.py
ADDED
|
File without changes
|
src/calibration/base_calibration.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ICalibration(ABC):
|
| 5 |
+
"""
|
| 6 |
+
Interface for calibration methods.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def plot_conformal_removal(
|
| 11 |
+
self, data, alphas, a, fig_filename, csv_filename, plot_group_results=False
|
| 12 |
+
):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def plot_factual_removal(
|
| 17 |
+
self, data, alphas, a, fig_filename, csv_filename, plot_group_results=False
|
| 18 |
+
):
|
| 19 |
+
pass
|
src/calibration/conditional_conformal.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from src.calibration.conformal import SplitConformalCalibration
|
| 11 |
+
from src.calibration.utils import compute_threshold
|
| 12 |
+
from src.calibration.utils import append_result_to_csv
|
| 13 |
+
from src.calibration.utils import split_group
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
CORRECT_ANNOTATIONS = ["S"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class GroupConditionalConformal(SplitConformalCalibration):
|
| 20 |
+
def __init__(self, dataset_name: str, result_dir: str, runs: int = 1000):
|
| 21 |
+
super().__init__(dataset_name, runs)
|
| 22 |
+
self.result_dir = result_dir
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def compute_conformal_results(
|
| 26 |
+
self, data: list, alphas: np.ndarray, a: float, plot_group_results: bool = False
|
| 27 |
+
):
|
| 28 |
+
|
| 29 |
+
results = {}
|
| 30 |
+
for confidence_method in self.confidence_method:
|
| 31 |
+
results[confidence_method] = {}
|
| 32 |
+
for alpha in tqdm(
|
| 33 |
+
alphas, desc=f"Computing conformal results for {confidence_method}"
|
| 34 |
+
):
|
| 35 |
+
# TODO add grouping
|
| 36 |
+
groups = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
thresholds_result = []
|
| 40 |
+
correctness_list = []
|
| 41 |
+
fraction_removed_list = []
|
| 42 |
+
test_data = []
|
| 43 |
+
for _ in range(self.runs):
|
| 44 |
+
random.shuffle(data)
|
| 45 |
+
calibration_data, test_data = split_group(data)
|
| 46 |
+
groups = list(calibration_data.keys())
|
| 47 |
+
|
| 48 |
+
assert (
|
| 49 |
+
len(calibration_data) != 0
|
| 50 |
+
), "Calibration data should not be empty"
|
| 51 |
+
assert len(test_data) != 0, "Test data should not be empty"
|
| 52 |
+
thresholds = self._compute_threshold_by_group(
|
| 53 |
+
alpha, calibration_data, a, confidence_method, groups=groups
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
correctness, fraction_removed = (
|
| 57 |
+
self._evaluate_conformal_correctness(
|
| 58 |
+
test_data, thresholds, a, confidence_method
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
thresholds_result.append(thresholds)
|
| 62 |
+
correctness_list.append(correctness)
|
| 63 |
+
fraction_removed_list.append(fraction_removed)
|
| 64 |
+
|
| 65 |
+
results[confidence_method][alpha] = {
|
| 66 |
+
"threshold": thresholds_result,
|
| 67 |
+
"correctness": correctness_list,
|
| 68 |
+
"fraction_removed": fraction_removed_list,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
return results
|
| 72 |
+
|
| 73 |
+
def _compute_threshold_by_group(
|
| 74 |
+
self,
|
| 75 |
+
alpha: float,
|
| 76 |
+
calibration_data: list,
|
| 77 |
+
a: float,
|
| 78 |
+
confidence_method: str,
|
| 79 |
+
groups: list | None = None,
|
| 80 |
+
):
|
| 81 |
+
thresholds = {}
|
| 82 |
+
for group in groups:
|
| 83 |
+
group_data = calibration_data[group]
|
| 84 |
+
thresholds[group] = compute_threshold(
|
| 85 |
+
alpha, group_data, a, confidence_method
|
| 86 |
+
)
|
| 87 |
+
return thresholds
|
| 88 |
+
|
| 89 |
+
def _evaluate_conformal_correctness(
|
| 90 |
+
self, data: list, thresholds: dict, a: float, confidence_method: str
|
| 91 |
+
):
|
| 92 |
+
"""
|
| 93 |
+
Evaluates the performance of a conformal prediction model on test data.
|
| 94 |
+
Parameters:
|
| 95 |
+
data (list): A list of dictionaries, where each dictionary represents an entry with subclaims.
|
| 96 |
+
threshold (float): The similarity score threshold to determine if a subclaim is correctly retained.
|
| 97 |
+
a (float): The threshold for the correctly retained percentage to consider an entry as correctly retained.
|
| 98 |
+
Returns:
|
| 99 |
+
tuple: A tuple containing two lists:
|
| 100 |
+
- correctly_retained (float): Percentage of data that are correctly retained.
|
| 101 |
+
- fraction_removed (float): Percentage of subclaims removed for each entry.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
correctly_retained = []
|
| 105 |
+
|
| 106 |
+
fraction_removed = []
|
| 107 |
+
|
| 108 |
+
for entry in data:
|
| 109 |
+
removal_count = 0
|
| 110 |
+
retained_cnt = 0
|
| 111 |
+
correctly_retained_count = 0
|
| 112 |
+
threshold = thresholds[entry["groups"][0]] # Get threshold for the group
|
| 113 |
+
if threshold is None:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"Threshold for group {entry['groups'][0]} is None. Check your calibration data."
|
| 116 |
+
)
|
| 117 |
+
for subclaim in entry["subclaims"]:
|
| 118 |
+
# Find similarity score
|
| 119 |
+
score = subclaim["scores"][confidence_method]
|
| 120 |
+
noise = subclaim["scores"]["noise"]
|
| 121 |
+
if score + noise >= threshold:
|
| 122 |
+
retained_cnt += 1
|
| 123 |
+
if (
|
| 124 |
+
subclaim.get("annotations", {}).get("gpt", "")
|
| 125 |
+
in CORRECT_ANNOTATIONS
|
| 126 |
+
):
|
| 127 |
+
correctly_retained_count += 1
|
| 128 |
+
|
| 129 |
+
else:
|
| 130 |
+
removal_count += 1
|
| 131 |
+
|
| 132 |
+
total_subclaims = len(entry["subclaims"])
|
| 133 |
+
|
| 134 |
+
# Calculate fraction of removed subclaims
|
| 135 |
+
entry_removal_rate = (
|
| 136 |
+
0 if total_subclaims == 0 else removal_count / total_subclaims
|
| 137 |
+
)
|
| 138 |
+
fraction_removed.append(
|
| 139 |
+
entry_removal_rate
|
| 140 |
+
) # e.g. fraction_removed = [0.2, 0.5, 0.6, 0.2, 0.7] - one element per data entry
|
| 141 |
+
|
| 142 |
+
# Calculate correctly retained rate
|
| 143 |
+
correctly_retained_percentage = (
|
| 144 |
+
correctly_retained_count / retained_cnt if retained_cnt > 0 else 1
|
| 145 |
+
)
|
| 146 |
+
correctly_retained.append(correctly_retained_percentage >= a)
|
| 147 |
+
|
| 148 |
+
return np.mean(correctly_retained), np.mean(fraction_removed)
|
| 149 |
+
|
| 150 |
+
def compute_factual_results(self, data, alphas, a, calibrate_range=0.5):
|
| 151 |
+
overall_results = {}
|
| 152 |
+
per_group_results = defaultdict(lambda: defaultdict(dict))
|
| 153 |
+
|
| 154 |
+
for method in self.confidence_method:
|
| 155 |
+
overall_results[method] = {}
|
| 156 |
+
for alpha in tqdm(
|
| 157 |
+
alphas, desc=f"Computing factual results for {method}"
|
| 158 |
+
):
|
| 159 |
+
# trackers for this (method, alpha)
|
| 160 |
+
overall_correctness = []
|
| 161 |
+
thresholds_per_group = defaultdict(list)
|
| 162 |
+
correctness_per_group = defaultdict(list)
|
| 163 |
+
|
| 164 |
+
for _ in range(self.runs):
|
| 165 |
+
random.shuffle(data)
|
| 166 |
+
calibration_data, test_data = split_group(data, calibrate_range)
|
| 167 |
+
groups = list(calibration_data.keys())
|
| 168 |
+
|
| 169 |
+
# assert on nonempty
|
| 170 |
+
assert calibration_data, "No calibration groups"
|
| 171 |
+
assert len(test_data) != 0, "Test data should not be empty"
|
| 172 |
+
|
| 173 |
+
# compute all thresholds at once
|
| 174 |
+
thresholds = self._compute_threshold_by_group(
|
| 175 |
+
alpha, calibration_data, a, method, groups=groups
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
fraction_correct = self._evaluate_factual_correctness(
|
| 179 |
+
test_data, thresholds, a, method
|
| 180 |
+
)
|
| 181 |
+
for group in groups:
|
| 182 |
+
thresholds_per_group[group].append(thresholds[group])
|
| 183 |
+
correctness_per_group[group].append(
|
| 184 |
+
fraction_correct[group]
|
| 185 |
+
)
|
| 186 |
+
overall_correctness.append(fraction_correct["overall"])
|
| 187 |
+
|
| 188 |
+
# package overall
|
| 189 |
+
overall_results[method][alpha] = {
|
| 190 |
+
"threshold": thresholds_per_group,
|
| 191 |
+
"correctness": overall_correctness,
|
| 192 |
+
"factuality": 1 - alpha,
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
# package per‐group
|
| 196 |
+
for grp in thresholds_per_group:
|
| 197 |
+
per_group_results[grp][method][alpha] = {
|
| 198 |
+
"threshold": thresholds_per_group[grp],
|
| 199 |
+
"correctness": correctness_per_group[grp],
|
| 200 |
+
"factuality": 1 - alpha,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
# now write your CSVs
|
| 204 |
+
for grp, grp_results in per_group_results.items():
|
| 205 |
+
csv_name = os.path.join(self.result_dir, f"{self.dataset_name}_{grp}_factual_correctness.csv")
|
| 206 |
+
self._write_csv_header(csv_name, alphas)
|
| 207 |
+
for method, res in grp_results.items():
|
| 208 |
+
lvl, corr, err = self.process_factual_correctness_results(res)
|
| 209 |
+
append_result_to_csv(
|
| 210 |
+
csv_filename=csv_name,
|
| 211 |
+
label=f"{method}_factual_correctness",
|
| 212 |
+
y=corr,
|
| 213 |
+
yerr=err,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return overall_results
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def _evaluate_factual_correctness(
|
| 220 |
+
self,
|
| 221 |
+
data: list,
|
| 222 |
+
thresholds: dict,
|
| 223 |
+
a: float,
|
| 224 |
+
confidence_method: str,
|
| 225 |
+
):
|
| 226 |
+
"""
|
| 227 |
+
Evaluates the factual correctness of subclaims within the provided data,
|
| 228 |
+
using a per‑group threshold but computing overall accuracy over all entries.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
data (list): A list of dicts, each with "groups" and "subclaims".
|
| 232 |
+
thresholds (dict): Mapping from group_name -> threshold float.
|
| 233 |
+
a (float): The accuracy level to compare the correctly retained percentage against.
|
| 234 |
+
confidence_method (str): Which score key to use for similarity.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
dict: { "overall": float, "<group1>": float, "<group2>": float, … }
|
| 238 |
+
"""
|
| 239 |
+
per_group_percentages = defaultdict(list)
|
| 240 |
+
total_pass = 0
|
| 241 |
+
total_entries = 0
|
| 242 |
+
|
| 243 |
+
for entry in data:
|
| 244 |
+
group_name = entry["groups"][0]
|
| 245 |
+
retained_cnt = 0
|
| 246 |
+
correctly_retained_cnt = 0
|
| 247 |
+
|
| 248 |
+
# count retained & correctly retained
|
| 249 |
+
for sub in entry["subclaims"]:
|
| 250 |
+
score = sub["scores"][confidence_method]
|
| 251 |
+
noise = sub["scores"]["noise"]
|
| 252 |
+
if score + noise >= thresholds[group_name]:
|
| 253 |
+
retained_cnt += 1
|
| 254 |
+
if sub.get("annotations", {}).get("gpt", "") in CORRECT_ANNOTATIONS:
|
| 255 |
+
correctly_retained_cnt += 1
|
| 256 |
+
|
| 257 |
+
# pct for this entry (1.0 if nothing retained)
|
| 258 |
+
pct = (correctly_retained_cnt / retained_cnt) if retained_cnt > 0 else 1.0
|
| 259 |
+
per_group_percentages[group_name].append(pct)
|
| 260 |
+
|
| 261 |
+
# update global pass/fail
|
| 262 |
+
total_entries += 1
|
| 263 |
+
if pct >= a:
|
| 264 |
+
total_pass += 1
|
| 265 |
+
|
| 266 |
+
# compute per‑group correctness
|
| 267 |
+
per_group_correctness = {
|
| 268 |
+
grp: sum(1 for pct in pct_list if pct >= a) / len(pct_list)
|
| 269 |
+
for grp, pct_list in per_group_percentages.items()
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
# compute overall exactly as “fraction of all entries passing”
|
| 273 |
+
overall = total_pass / total_entries if total_entries > 0 else 0.0
|
| 274 |
+
|
| 275 |
+
# assemble result
|
| 276 |
+
result = {"overall": overall}
|
| 277 |
+
result.update(per_group_correctness)
|
| 278 |
+
return result
|
src/calibration/conformal.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import csv
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
from src.calibration.base_calibration import ICalibration
|
| 9 |
+
from src.calibration.utils import compute_threshold
|
| 10 |
+
from src.calibration.utils import append_result_to_csv
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
CORRECT_ANNOTATIONS = ["S"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SplitConformalCalibration(ICalibration):
|
| 17 |
+
"""
|
| 18 |
+
Implementation of standard conformal calibration.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, dataset_name: str, runs: int = 1000):
|
| 22 |
+
self.dataset_name = dataset_name
|
| 23 |
+
self.confidence_method = [
|
| 24 |
+
"relavance",
|
| 25 |
+
"frequency",
|
| 26 |
+
"query_claim_cosine_similarity",
|
| 27 |
+
"doc_claim_cosine_similarity",
|
| 28 |
+
"min_log_prob",
|
| 29 |
+
"random",
|
| 30 |
+
"ordinal",
|
| 31 |
+
]
|
| 32 |
+
self.runs = runs
|
| 33 |
+
|
| 34 |
+
def plot_conformal_removal(
|
| 35 |
+
self, data, alphas, a, fig_filename, csv_filename
|
| 36 |
+
):
|
| 37 |
+
|
| 38 |
+
# compute the correctness and fraction removed for each alpha
|
| 39 |
+
|
| 40 |
+
cache_filename = f"{os.path.splitext(os.path.abspath(csv_filename))[0]}_conformal_removal_cache.npy"
|
| 41 |
+
if not os.path.exists(cache_filename):
|
| 42 |
+
results = self.compute_conformal_results(
|
| 43 |
+
data, alphas, a
|
| 44 |
+
)
|
| 45 |
+
print(f"Caching results to {cache_filename}")
|
| 46 |
+
np.save(cache_filename, results)
|
| 47 |
+
|
| 48 |
+
else:
|
| 49 |
+
print(f"Loading cached results from {cache_filename}")
|
| 50 |
+
results = np.load(cache_filename, allow_pickle=True).item()
|
| 51 |
+
|
| 52 |
+
ax = None
|
| 53 |
+
for confidence_method, result in results.items():
|
| 54 |
+
correctness, fraction_removed, yerr = (
|
| 55 |
+
self.process_conformal_removal_results(result)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# write the results to csv file
|
| 59 |
+
self._write_csv_header(csv_filename, alphas)
|
| 60 |
+
append_result_to_csv(
|
| 61 |
+
csv_filename=csv_filename,
|
| 62 |
+
label=f"{confidence_method}_conformal_removal_rate",
|
| 63 |
+
y=fraction_removed,
|
| 64 |
+
yerr=yerr,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# plot the results
|
| 68 |
+
print(f"Producing conformal plot for {confidence_method}")
|
| 69 |
+
ax = self.plot_conformal_removal_rate_by_alpha(
|
| 70 |
+
correctness,
|
| 71 |
+
fraction_removed,
|
| 72 |
+
yerr,
|
| 73 |
+
a,
|
| 74 |
+
confidence_method,
|
| 75 |
+
fig_filename,
|
| 76 |
+
ax,
|
| 77 |
+
)
|
| 78 |
+
print(f"Conformal plot saved to {fig_filename}")
|
| 79 |
+
|
| 80 |
+
def compute_conformal_results(
|
| 81 |
+
self, data: list, alphas: np.ndarray, a: float
|
| 82 |
+
):
|
| 83 |
+
|
| 84 |
+
results = {}
|
| 85 |
+
for confidence_method in self.confidence_method:
|
| 86 |
+
results[confidence_method] = {}
|
| 87 |
+
for alpha in tqdm(
|
| 88 |
+
alphas, desc=f"Computing conformal results for {confidence_method}"
|
| 89 |
+
):
|
| 90 |
+
|
| 91 |
+
thresholds = []
|
| 92 |
+
correctness_list = []
|
| 93 |
+
fraction_removed_list = []
|
| 94 |
+
for _ in range(self.runs):
|
| 95 |
+
random.shuffle(data)
|
| 96 |
+
split_index = len(data) // 2
|
| 97 |
+
calibration_data = data[:split_index]
|
| 98 |
+
test_data = data[split_index:]
|
| 99 |
+
|
| 100 |
+
assert (
|
| 101 |
+
len(calibration_data) != 0
|
| 102 |
+
), "Calibration data should not be empty"
|
| 103 |
+
assert len(test_data) != 0, "Test data should not be empty"
|
| 104 |
+
|
| 105 |
+
threshold = compute_threshold(alpha, calibration_data, a, confidence_method)
|
| 106 |
+
|
| 107 |
+
correctness, fraction_removed = (
|
| 108 |
+
self._evaluate_conformal_correctness(
|
| 109 |
+
test_data, threshold, a, confidence_method
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
thresholds.append(threshold)
|
| 113 |
+
correctness_list.append(correctness)
|
| 114 |
+
fraction_removed_list.append(fraction_removed)
|
| 115 |
+
|
| 116 |
+
results[confidence_method][alpha] = {
|
| 117 |
+
"threshold": thresholds,
|
| 118 |
+
"correctness": correctness_list,
|
| 119 |
+
"fraction_removed": fraction_removed_list,
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
return results
|
| 123 |
+
|
| 124 |
+
def process_conformal_removal_results(self, results: dict):
|
| 125 |
+
"""
|
| 126 |
+
x: list of average correctness
|
| 127 |
+
y: list of average fraction removed
|
| 128 |
+
yerr: list of standard error of fraction removed
|
| 129 |
+
"""
|
| 130 |
+
x, y, yerr = [], [], []
|
| 131 |
+
for alpha, results_for_alpha in results.items():
|
| 132 |
+
x_per_alpha = np.mean(
|
| 133 |
+
results_for_alpha["correctness"]
|
| 134 |
+
) # correct retainment percentage at a specific alpha value, averaging over 1000 times of shuffled data
|
| 135 |
+
y_per_alpha = np.mean(
|
| 136 |
+
results_for_alpha["fraction_removed"]
|
| 137 |
+
) # removal percentage at a specific alpha value, averaging, averaging over 1000 times of shuffled data
|
| 138 |
+
x.append(x_per_alpha)
|
| 139 |
+
y.append(y_per_alpha)
|
| 140 |
+
yerr.append(
|
| 141 |
+
(
|
| 142 |
+
np.std(results_for_alpha["fraction_removed"])
|
| 143 |
+
* 1.96
|
| 144 |
+
/ np.sqrt(len(results_for_alpha["fraction_removed"]))
|
| 145 |
+
)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return x, y, yerr
|
| 149 |
+
|
| 150 |
+
def plot_conformal_removal_rate_by_alpha(
|
| 151 |
+
self, x, y, yerr, a, confidence_method, fig_filename, ax=None
|
| 152 |
+
):
|
| 153 |
+
if not ax:
|
| 154 |
+
fig, ax = plt.subplots(figsize=(8, 6), dpi=800)
|
| 155 |
+
ax.set_title(
|
| 156 |
+
f"Conformal Plots for {self.dataset_name} Datasets (a={a})", fontsize=20
|
| 157 |
+
)
|
| 158 |
+
x_label = (
|
| 159 |
+
f"Fraction achieving avg factuality >= {a}"
|
| 160 |
+
if a != 1
|
| 161 |
+
else "Fraction of factual outputs"
|
| 162 |
+
)
|
| 163 |
+
ax.set_xlabel(x_label, fontsize=16)
|
| 164 |
+
ax.set_ylabel("Average percent removed", fontsize=16)
|
| 165 |
+
else:
|
| 166 |
+
fig = ax.figure
|
| 167 |
+
|
| 168 |
+
# Plot the data
|
| 169 |
+
ax.errorbar(x, y, yerr=yerr, label=confidence_method, linewidth=2)
|
| 170 |
+
|
| 171 |
+
# set the legend
|
| 172 |
+
ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98), fontsize=10)
|
| 173 |
+
|
| 174 |
+
# Save the figure
|
| 175 |
+
fig.savefig(fig_filename, bbox_inches="tight")
|
| 176 |
+
|
| 177 |
+
return ax # Return the ax for further modifications if needed
|
| 178 |
+
|
| 179 |
+
def _write_csv_header(self, csv_filename, alphas):
|
| 180 |
+
target_factuality = [f"{(1-x):.2f}" for x in alphas][::-1]
|
| 181 |
+
header = ["target_factuality"] + target_factuality
|
| 182 |
+
|
| 183 |
+
# Ensure the directory exists
|
| 184 |
+
os.makedirs(os.path.dirname(csv_filename), exist_ok=True)
|
| 185 |
+
|
| 186 |
+
if not os.path.exists(csv_filename):
|
| 187 |
+
with open(csv_filename, mode="w", newline="") as file:
|
| 188 |
+
csv.writer(file).writerow(header)
|
| 189 |
+
|
| 190 |
+
def _evaluate_conformal_correctness(
|
| 191 |
+
self, data: list, threshold: float, a: float, confidence_method: str
|
| 192 |
+
):
|
| 193 |
+
"""
|
| 194 |
+
Evaluates the performance of a conformal prediction model on test data.
|
| 195 |
+
Parameters:
|
| 196 |
+
data (list): A list of dictionaries, where each dictionary represents an entry with subclaims.
|
| 197 |
+
threshold (float): The similarity score threshold to determine if a subclaim is correctly retained.
|
| 198 |
+
a (float): The threshold for the correctly retained percentage to consider an entry as correctly retained.
|
| 199 |
+
Returns:
|
| 200 |
+
tuple: A tuple containing two lists:
|
| 201 |
+
- correctly_retained (float): Percentage of data that are correctly retained.
|
| 202 |
+
- fraction_removed (float): Percentage of subclaims removed for each entry.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
correctly_retained = []
|
| 206 |
+
fraction_removed = []
|
| 207 |
+
|
| 208 |
+
for entry in data:
|
| 209 |
+
removal_count = 0
|
| 210 |
+
retained_cnt = 0
|
| 211 |
+
correctly_retained_count = 0
|
| 212 |
+
|
| 213 |
+
for subclaim in entry["subclaims"]:
|
| 214 |
+
# Find similarity score
|
| 215 |
+
score = subclaim["scores"][confidence_method]
|
| 216 |
+
noise = subclaim["scores"]["noise"]
|
| 217 |
+
if score + noise >= threshold:
|
| 218 |
+
retained_cnt += 1
|
| 219 |
+
if (
|
| 220 |
+
subclaim.get("annotations", {}).get("gpt", "")
|
| 221 |
+
in CORRECT_ANNOTATIONS
|
| 222 |
+
):
|
| 223 |
+
correctly_retained_count += 1
|
| 224 |
+
|
| 225 |
+
else:
|
| 226 |
+
removal_count += 1
|
| 227 |
+
|
| 228 |
+
total_subclaims = len(entry["subclaims"])
|
| 229 |
+
|
| 230 |
+
# Calculate fraction of removed subclaims
|
| 231 |
+
entry_removal_rate = (
|
| 232 |
+
0 if total_subclaims == 0 else removal_count / total_subclaims
|
| 233 |
+
)
|
| 234 |
+
fraction_removed.append(
|
| 235 |
+
entry_removal_rate
|
| 236 |
+
) # e.g. fraction_removed = [0.2, 0.5, 0.6, 0.2, 0.7] - one element per data entry
|
| 237 |
+
|
| 238 |
+
# Calculate correctly retained rate
|
| 239 |
+
correctly_retained_percentage = (
|
| 240 |
+
correctly_retained_count / retained_cnt if retained_cnt > 0 else 1
|
| 241 |
+
)
|
| 242 |
+
correctly_retained.append(correctly_retained_percentage >= a)
|
| 243 |
+
|
| 244 |
+
return np.mean(correctly_retained), np.mean(fraction_removed)
|
| 245 |
+
|
| 246 |
+
def plot_factual_removal(
|
| 247 |
+
self, data, alphas, a, fig_filename, csv_filename, plot_group_results=False
|
| 248 |
+
):
|
| 249 |
+
x_values = np.linspace(1 - alphas[-1] - 0.05, 1 - alphas[0] + 0.03, 100)
|
| 250 |
+
fig, ax = plt.subplots(figsize=(8, 6), dpi=800)
|
| 251 |
+
ax.plot(
|
| 252 |
+
x_values,
|
| 253 |
+
x_values,
|
| 254 |
+
"--",
|
| 255 |
+
color="gray",
|
| 256 |
+
linewidth=2,
|
| 257 |
+
label="Conformal guarantee lower bounds",
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
cache_filename = f"{os.path.splitext(os.path.abspath(csv_filename))[0]}_factual_correctness_cache.npy"
|
| 261 |
+
if not os.path.exists(cache_filename):
|
| 262 |
+
results = self.compute_factual_results(data, alphas, a)
|
| 263 |
+
print(f"Caching results to {cache_filename}")
|
| 264 |
+
np.save(cache_filename, results)
|
| 265 |
+
|
| 266 |
+
else:
|
| 267 |
+
print(f"Loading cached results from {cache_filename}")
|
| 268 |
+
results = np.load(cache_filename, allow_pickle=True).item()
|
| 269 |
+
|
| 270 |
+
for confidence_method, result in results.items():
|
| 271 |
+
conf_level, corretness, yerr = self.process_factual_correctness_results(
|
| 272 |
+
result
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
append_result_to_csv(
|
| 276 |
+
csv_filename=csv_filename,
|
| 277 |
+
label=f"{confidence_method}_factual_correctness",
|
| 278 |
+
y=corretness,
|
| 279 |
+
yerr=yerr,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
print(
|
| 283 |
+
f"Producing factual removal plot for {confidence_method}: {fig_filename}"
|
| 284 |
+
)
|
| 285 |
+
ax = self.plot_factual_removal_rate_by_alpha(
|
| 286 |
+
conf_level, corretness, a, confidence_method, fig_filename, ax
|
| 287 |
+
)
|
| 288 |
+
print(f"Conformal plot saved to {fig_filename}")
|
| 289 |
+
|
| 290 |
+
if plot_group_results:
|
| 291 |
+
# self.plot_factual_group_results(results, csv_filename, x)
|
| 292 |
+
raise NotImplementedError("Not implemented")
|
| 293 |
+
|
| 294 |
+
def compute_factual_results(self, data, alphas, a):
|
| 295 |
+
results = {}
|
| 296 |
+
for confidence_method in self.confidence_method:
|
| 297 |
+
results[confidence_method] = {}
|
| 298 |
+
for alpha in tqdm(
|
| 299 |
+
alphas, desc=f"Computing factual results for {confidence_method}"
|
| 300 |
+
):
|
| 301 |
+
thresholds = []
|
| 302 |
+
correctness = []
|
| 303 |
+
for _ in range(self.runs):
|
| 304 |
+
random.shuffle(data)
|
| 305 |
+
split_index = len(data) // 2
|
| 306 |
+
calibration_data = data[:split_index]
|
| 307 |
+
test_data = data[split_index:]
|
| 308 |
+
|
| 309 |
+
assert (
|
| 310 |
+
len(calibration_data) != 0
|
| 311 |
+
), "Calibration data should not be empty"
|
| 312 |
+
assert len(test_data) != 0, "Test data should not be empty"
|
| 313 |
+
|
| 314 |
+
threshold = compute_threshold(alpha, calibration_data, a, confidence_method)
|
| 315 |
+
fraction_correct = self._evaluate_factual_correctness(
|
| 316 |
+
test_data, threshold, a, confidence_method
|
| 317 |
+
)
|
| 318 |
+
thresholds.append(threshold)
|
| 319 |
+
correctness.append(fraction_correct)
|
| 320 |
+
|
| 321 |
+
results[confidence_method][alpha] = {
|
| 322 |
+
"threshold": thresholds,
|
| 323 |
+
"correctness": correctness,
|
| 324 |
+
"factuality": 1 - alpha,
|
| 325 |
+
}
|
| 326 |
+
return results
|
| 327 |
+
|
| 328 |
+
def process_factual_correctness_results(self, results: dict):
|
| 329 |
+
"""
|
| 330 |
+
x: confidence level
|
| 331 |
+
y: list of average factual correctness
|
| 332 |
+
yerr: list of standard error of factual correctness
|
| 333 |
+
"""
|
| 334 |
+
x, y, yerr = [], [], []
|
| 335 |
+
for alpha, results_for_alpha in results.items():
|
| 336 |
+
|
| 337 |
+
x.append(1 - alpha)
|
| 338 |
+
y.append(np.mean(results_for_alpha["correctness"]))
|
| 339 |
+
yerr.append(
|
| 340 |
+
(
|
| 341 |
+
np.std(results_for_alpha["correctness"])
|
| 342 |
+
* 1.96
|
| 343 |
+
/ np.sqrt(len(results_for_alpha["correctness"]))
|
| 344 |
+
)
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
return x, y, yerr
|
| 348 |
+
|
| 349 |
+
def plot_factual_removal_rate_by_alpha(
|
| 350 |
+
self, x, y, a, confidence_method, fig_filename, ax=None
|
| 351 |
+
):
|
| 352 |
+
if not ax:
|
| 353 |
+
fig, ax = plt.subplots(figsize=(8, 6), dpi=800)
|
| 354 |
+
else:
|
| 355 |
+
fig = ax.figure # Get the figure from the provided ax
|
| 356 |
+
|
| 357 |
+
ax.set_xlabel(f"Target factuality (1 - {chr(945)})", fontsize=16)
|
| 358 |
+
ax.set_ylabel("Empirical factuality", fontsize=16)
|
| 359 |
+
ax.set_title(
|
| 360 |
+
f"Factual correctness for {self.dataset_name} Datasets (a={a})", fontsize=20
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
# Plot the data
|
| 364 |
+
ax.plot(x, y, label=confidence_method, linewidth=2)
|
| 365 |
+
|
| 366 |
+
# Set legend
|
| 367 |
+
ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98), fontsize=10)
|
| 368 |
+
|
| 369 |
+
# Save the figure
|
| 370 |
+
fig.savefig(fig_filename, bbox_inches="tight", dpi=800)
|
| 371 |
+
|
| 372 |
+
return ax # Return the ax for further modifications if needed
|
| 373 |
+
|
| 374 |
+
def _evaluate_factual_correctness(
|
| 375 |
+
self, data: list, threshold: float, a: float, confidence_method: str
|
| 376 |
+
):
|
| 377 |
+
"""
|
| 378 |
+
Evaluates the factual correctness of subclaims within the provided data.
|
| 379 |
+
This function processes a list of data entries, each containing subclaims with similarity scores.
|
| 380 |
+
It calculates the percentage of correctly retained subclaims based on a given threshold and
|
| 381 |
+
compares it to a specified accuracy level `a`.
|
| 382 |
+
Args:
|
| 383 |
+
data (list): A list of dictionaries, where each dictionary represents an entry containing subclaims.
|
| 384 |
+
threshold (float): The similarity score threshold above which subclaims are considered retained.
|
| 385 |
+
a (float): The accuracy level to compare the correctly retained percentage against.
|
| 386 |
+
Returns:
|
| 387 |
+
float: The percentage of entries in the data that satisfy the correct level of accuracy `a`.
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
correctly_retained = []
|
| 391 |
+
# Process each item in the list
|
| 392 |
+
for entry in data:
|
| 393 |
+
# Extract subclaims from each item
|
| 394 |
+
retained_cnt = 0
|
| 395 |
+
correctly_retained_count = 0
|
| 396 |
+
for subclaim in entry["subclaims"]:
|
| 397 |
+
|
| 398 |
+
# Extract the score and noise
|
| 399 |
+
score = subclaim["scores"][confidence_method]
|
| 400 |
+
noise = subclaim["scores"]["noise"]
|
| 401 |
+
|
| 402 |
+
# Add the subclaim to the collection if similarity score is above threshold
|
| 403 |
+
if score + noise >= threshold:
|
| 404 |
+
retained_cnt += 1
|
| 405 |
+
if (
|
| 406 |
+
subclaim.get("annotations", {}).get("gpt", "")
|
| 407 |
+
in CORRECT_ANNOTATIONS
|
| 408 |
+
):
|
| 409 |
+
correctly_retained_count += 1
|
| 410 |
+
|
| 411 |
+
# Calculate correctly retained rate
|
| 412 |
+
correctly_retained_percentage = (
|
| 413 |
+
correctly_retained_count / retained_cnt if retained_cnt > 0 else 1
|
| 414 |
+
)
|
| 415 |
+
correctly_retained.append(correctly_retained_percentage)
|
| 416 |
+
|
| 417 |
+
correctness_list = [
|
| 418 |
+
correctly_retained_percentage >= a
|
| 419 |
+
for correctly_retained_percentage in correctly_retained
|
| 420 |
+
]
|
| 421 |
+
# percentage of test data satisfying correct level of a
|
| 422 |
+
return sum(correctness_list) / len(correctness_list)
|
src/calibration/utils.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import csv
|
| 3 |
+
import numpy as np
|
| 4 |
+
from math import ceil
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
CORRECT_ANNOTATIONS = ["Y", "S"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_subclaim_data(file_path):
|
| 11 |
+
"""Load calibration data from a JSON file"""
|
| 12 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 13 |
+
return json.load(file)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def append_result_to_csv(csv_filename, label, y, yerr):
|
| 17 |
+
"""Append calibration results to CSV file"""
|
| 18 |
+
formatted_results = [f"{y:.4f} ± {yerr:.4f}" for y, yerr in zip(y, yerr)]
|
| 19 |
+
formatted_results.reverse()
|
| 20 |
+
row = [label] + formatted_results
|
| 21 |
+
with open(csv_filename, mode="a", newline="") as file:
|
| 22 |
+
writer = csv.writer(file)
|
| 23 |
+
writer.writerow(row)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _get_accepted_subclaims(entry, threshold, confidence_method):
|
| 27 |
+
"""Helper function to get accepted subclaims based on threshold"""
|
| 28 |
+
return [
|
| 29 |
+
subclaim
|
| 30 |
+
for subclaim in entry["subclaims"]
|
| 31 |
+
if subclaim["scores"][confidence_method] + subclaim["scores"]["noise"]
|
| 32 |
+
>= threshold
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _calculate_entailed_fraction(subclaims):
|
| 37 |
+
"""Helper function to calculate fraction of entailed/correct subclaims"""
|
| 38 |
+
if not subclaims:
|
| 39 |
+
return 1.0
|
| 40 |
+
return np.mean(
|
| 41 |
+
[
|
| 42 |
+
subclaim["annotations"]["gpt"] in CORRECT_ANNOTATIONS
|
| 43 |
+
for subclaim in subclaims
|
| 44 |
+
]
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def get_r_score(entry: list, confidence_method: str, a: float):
|
| 49 |
+
"""
|
| 50 |
+
Compute the r_a score for each data entry when confidence_method is used as the sub-claim scoring function.
|
| 51 |
+
|
| 52 |
+
This function calculates the minimum threshold at which the fraction of correct subclaims
|
| 53 |
+
falls below the required threshold 'a'. The r_a score represents the confidence score
|
| 54 |
+
at which the model's reliability drops below the acceptable level.
|
| 55 |
+
|
| 56 |
+
The algorithm works by:
|
| 57 |
+
1. First checking if the score was already calculated and cached
|
| 58 |
+
2. Sorting all subclaim scores in descending order
|
| 59 |
+
3. Testing each score as a potential threshold
|
| 60 |
+
4. For each threshold, accepting only subclaims with scores >= threshold
|
| 61 |
+
5. Calculating the fraction of correct subclaims among the accepted ones
|
| 62 |
+
6. Returning the first threshold where this fraction falls below 'a'
|
| 63 |
+
7. Returning -1 if all possible thresholds maintain accuracy above 'a'
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
entry: Dictionary containing claims data
|
| 67 |
+
confidence_method: Method used for scoring subclaims
|
| 68 |
+
a: Required fraction correct threshold
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
float: r_a score for the entry
|
| 72 |
+
"""
|
| 73 |
+
r_score_key = f"r_score_{a}_{confidence_method}"
|
| 74 |
+
if r_score_key in entry:
|
| 75 |
+
return entry[r_score_key]
|
| 76 |
+
#add a cache in entry to remember it's r_score
|
| 77 |
+
|
| 78 |
+
scores = [
|
| 79 |
+
subclaim["scores"][confidence_method] + subclaim["scores"]["noise"]
|
| 80 |
+
for subclaim in entry["subclaims"]
|
| 81 |
+
]
|
| 82 |
+
threshold_set = sorted(scores, reverse=True)
|
| 83 |
+
|
| 84 |
+
for threshold in threshold_set:
|
| 85 |
+
accepted_subclaims = _get_accepted_subclaims(
|
| 86 |
+
entry, threshold, confidence_method
|
| 87 |
+
)
|
| 88 |
+
entailed_fraction = _calculate_entailed_fraction(accepted_subclaims)
|
| 89 |
+
|
| 90 |
+
if entailed_fraction < a:
|
| 91 |
+
entry[r_score_key] = threshold
|
| 92 |
+
return threshold
|
| 93 |
+
|
| 94 |
+
entry[r_score_key] = -1
|
| 95 |
+
return -1
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def compute_threshold(alpha, calibration_data, a, confidence_method):
|
| 99 |
+
"""
|
| 100 |
+
Computes the quantile/threshold from conformal prediction.
|
| 101 |
+
# alpha: float in (0, 1)
|
| 102 |
+
# calibration_data: calibration data
|
| 103 |
+
# a: as in paper, required fraction correct, section 4.1
|
| 104 |
+
# confidence_method: string
|
| 105 |
+
"""
|
| 106 |
+
# Compute r score for each example.
|
| 107 |
+
r_scores = [get_r_score(entry, confidence_method, a) for entry in calibration_data]
|
| 108 |
+
|
| 109 |
+
# Compute threshold for conformal prection. The quantile is ceil((n+1)*(1-alpha))/n, and
|
| 110 |
+
# We map this to the index by dropping the division by n and subtracting one (for zero-index).
|
| 111 |
+
quantile_target_index = min(ceil((len(r_scores) + 1) * (1 - alpha)), len(r_scores))
|
| 112 |
+
threshold = sorted(r_scores)[quantile_target_index - 1]
|
| 113 |
+
return threshold
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Make sure the split calibrate_range ratio are all same not just in overall level but in group level
|
| 117 |
+
# not return data in list but in a map with each group name as key
|
| 118 |
+
def split_group(data, calibrate_range=0.5):
|
| 119 |
+
group_data = defaultdict(list)
|
| 120 |
+
calibration_data = defaultdict(list)
|
| 121 |
+
test_data = []
|
| 122 |
+
|
| 123 |
+
for entry in data:
|
| 124 |
+
group = entry["groups"][0] # Use first group as default
|
| 125 |
+
group_data[group].append(entry)
|
| 126 |
+
|
| 127 |
+
for group, group_entries in group_data.items():
|
| 128 |
+
split_index = ceil(len(group_entries) * calibrate_range)
|
| 129 |
+
calibration_data[group].extend(group_entries[:split_index])
|
| 130 |
+
test_data.extend(group_entries[split_index:])
|
| 131 |
+
|
| 132 |
+
return calibration_data, test_data
|
| 133 |
+
|
| 134 |
+
# Analyze Functions #
|
| 135 |
+
|
| 136 |
+
def percentage_highest_not_S(data, key="relavance"):
|
| 137 |
+
count_total = 0
|
| 138 |
+
count_not_S = 0
|
| 139 |
+
|
| 140 |
+
for item in data:
|
| 141 |
+
subclaims = item.get("subclaims", [])
|
| 142 |
+
if not subclaims:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
# Sort subclaims by (score[key] + score[noise]), descending
|
| 146 |
+
subclaims_sorted = sorted(
|
| 147 |
+
subclaims,
|
| 148 |
+
key=lambda sc: sc["scores"].get(key, 0) + sc["scores"].get("noise", 0),
|
| 149 |
+
reverse=True
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
top_annotation = subclaims_sorted[0].get("annotations", {}).get("gpt", None)
|
| 153 |
+
|
| 154 |
+
count_total += 1
|
| 155 |
+
if top_annotation != "S":
|
| 156 |
+
count_not_S += 1
|
| 157 |
+
|
| 158 |
+
if count_total == 0:
|
| 159 |
+
return 0.0 # Avoid division by zero
|
| 160 |
+
|
| 161 |
+
return (count_not_S / count_total) * 100
|
src/common/__init__.py
ADDED
|
File without changes
|
src/common/chunker.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Any
|
| 3 |
+
|
| 4 |
+
class BaseChunker(ABC):
|
| 5 |
+
"""
|
| 6 |
+
Abstract base class for all chunking strategies.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, document, chunk_size: int, overlap_size: int = 0):
|
| 10 |
+
self.document = document
|
| 11 |
+
self.chunk_size = chunk_size
|
| 12 |
+
self.overlap_size = overlap_size
|
| 13 |
+
|
| 14 |
+
@abstractmethod
|
| 15 |
+
def create_chunks(self) -> list[dict[str, Any]]:
|
| 16 |
+
"""
|
| 17 |
+
Abstract method to be implemented by subclasses for chunking text.
|
| 18 |
+
"""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
class FixedLengthChunker(BaseChunker):
|
| 22 |
+
"""
|
| 23 |
+
Chunker that splits text into overlapping fixed-size chunks of words.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def create_chunks(self) -> list[str]:
|
| 27 |
+
|
| 28 |
+
chunks: list[str] = []
|
| 29 |
+
|
| 30 |
+
text = self.document
|
| 31 |
+
words = text.split()
|
| 32 |
+
start = 0
|
| 33 |
+
chunk_num = 0
|
| 34 |
+
|
| 35 |
+
while start < len(words):
|
| 36 |
+
end = start + self.chunk_size
|
| 37 |
+
chunks.append(" ".join(words[start:end]))
|
| 38 |
+
start += self.chunk_size - self.overlap_size
|
| 39 |
+
chunk_num += 1
|
| 40 |
+
|
| 41 |
+
return chunks, len(words)
|
src/common/config_manager.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import yaml
|
| 3 |
+
import logging
|
| 4 |
+
import datetime
|
| 5 |
+
import json
|
| 6 |
+
import shutil
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
class ConfigManager:
|
| 10 |
+
"""Utility class to manage configuration loading, saving and logging"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, config_path=None, path_config_path=None, dataset_config_path=None, run_id=None):
|
| 13 |
+
"""
|
| 14 |
+
Initialize the ConfigManager with a config file path
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
config_path (str): Path to the YAML config file
|
| 18 |
+
run_id (str): Optional identifier for the run
|
| 19 |
+
"""
|
| 20 |
+
self.config = {}
|
| 21 |
+
self.run_id = run_id or datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 22 |
+
self.log_dir = f"logs/{self.run_id}"
|
| 23 |
+
|
| 24 |
+
if config_path:
|
| 25 |
+
self.config = self.load_config(config_path)
|
| 26 |
+
if path_config_path:
|
| 27 |
+
self.path_config = self.load_config(path_config_path)
|
| 28 |
+
if dataset_config_path:
|
| 29 |
+
self.dataset_config = self.load_config(dataset_config_path)
|
| 30 |
+
|
| 31 |
+
def load_config(self, config_path):
|
| 32 |
+
"""
|
| 33 |
+
Load configuration from a YAML file
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
config_path (str): Path to the YAML config file
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
dict: The loaded configuration
|
| 40 |
+
"""
|
| 41 |
+
with open(config_path, 'r') as f:
|
| 42 |
+
return self._expand_env_vars(yaml.safe_load(f))
|
| 43 |
+
|
| 44 |
+
def _expand_env_vars(self, obj):
|
| 45 |
+
"""Recursively expand environment variables in string values."""
|
| 46 |
+
if isinstance(obj, dict):
|
| 47 |
+
return {k: self._expand_env_vars(v) for k, v in obj.items()}
|
| 48 |
+
if isinstance(obj, list):
|
| 49 |
+
return [self._expand_env_vars(v) for v in obj]
|
| 50 |
+
if isinstance(obj, str):
|
| 51 |
+
return os.path.expandvars(obj)
|
| 52 |
+
return obj
|
| 53 |
+
|
| 54 |
+
def save_config(self, output_path=None):
|
| 55 |
+
"""
|
| 56 |
+
Save the current configuration to a YAML file
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
output_path (str): Path to save the config file, defaults to log directory
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
str: Path to the saved config file
|
| 63 |
+
"""
|
| 64 |
+
if output_path is None:
|
| 65 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
| 66 |
+
output_path = os.path.join(self.log_dir, f"config_{self.run_id}.yaml")
|
| 67 |
+
|
| 68 |
+
# Ensure directory exists
|
| 69 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 70 |
+
|
| 71 |
+
with open(output_path, 'w') as f:
|
| 72 |
+
yaml.dump(self.config, f, default_flow_style=False)
|
| 73 |
+
|
| 74 |
+
return output_path
|
| 75 |
+
|
| 76 |
+
def setup_logging(self, log_level=logging.INFO):
|
| 77 |
+
"""
|
| 78 |
+
Setup logging configuration
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
log_level: Logging level
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
str: Path to the log file
|
| 85 |
+
"""
|
| 86 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
| 87 |
+
log_file = os.path.join(self.log_dir, f"run_{self.run_id}.log")
|
| 88 |
+
|
| 89 |
+
logging.basicConfig(
|
| 90 |
+
level=log_level,
|
| 91 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 92 |
+
handlers=[
|
| 93 |
+
logging.FileHandler(log_file),
|
| 94 |
+
logging.StreamHandler()
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Disable httpx logs
|
| 99 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 100 |
+
|
| 101 |
+
# Log some initial information
|
| 102 |
+
logging.info(f"Starting run with ID: {self.run_id}")
|
| 103 |
+
logging.info(f"Log file: {log_file}")
|
| 104 |
+
|
| 105 |
+
return log_file, self.run_id
|
| 106 |
+
|
| 107 |
+
def log_config(self):
|
| 108 |
+
"""Log the important parts of the configuration"""
|
| 109 |
+
if not self.config:
|
| 110 |
+
logging.warning("No configuration loaded to log")
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
logging.info("=== Run Configuration ===")
|
| 114 |
+
|
| 115 |
+
# Log dataset info
|
| 116 |
+
if 'dataset' in self.config:
|
| 117 |
+
logging.info(f"Dataset: {self.config['dataset']['name']}")
|
| 118 |
+
logging.info(f"Query size: {self.config['dataset']['query_size']}")
|
| 119 |
+
|
| 120 |
+
# Log index info
|
| 121 |
+
if 'index' in self.config:
|
| 122 |
+
logging.info(f"Embedding model: {self.config['index']['embedding_model']}")
|
| 123 |
+
logging.info(f"Delete existing index: {self.config['index']['delete_existing']}")
|
| 124 |
+
|
| 125 |
+
logging.info("========================")
|
| 126 |
+
|
| 127 |
+
def update_config(self, updates):
|
| 128 |
+
"""
|
| 129 |
+
Update the configuration with new values
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
updates (dict): Dictionary containing updates to apply
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
dict: The updated configuration
|
| 136 |
+
"""
|
| 137 |
+
# This is a simple implementation that only handles top-level keys
|
| 138 |
+
for key, value in updates.items():
|
| 139 |
+
if isinstance(value, dict) and key in self.config and isinstance(self.config[key], dict):
|
| 140 |
+
self.config[key].update(value)
|
| 141 |
+
else:
|
| 142 |
+
self.config[key] = value
|
| 143 |
+
|
| 144 |
+
return self.config
|
| 145 |
+
|
| 146 |
+
def copy_run_artifacts(self, result_dir):
|
| 147 |
+
"""
|
| 148 |
+
Copy config and logs to a results directory for reproducibility
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
result_dir (str): Path to the results directory
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
str: Path to the result run directory
|
| 155 |
+
"""
|
| 156 |
+
result_run_dir = os.path.join(result_dir, "config")
|
| 157 |
+
os.makedirs(result_run_dir, exist_ok=True)
|
| 158 |
+
|
| 159 |
+
# Get the latest config and log files
|
| 160 |
+
config_files = sorted(Path(self.log_dir).glob("config_*.yaml"))
|
| 161 |
+
# log_files = sorted(Path(self.log_dir).glob("run_*.log"))
|
| 162 |
+
|
| 163 |
+
if config_files:
|
| 164 |
+
latest_config = str(config_files[-1])
|
| 165 |
+
shutil.copy2(latest_config, os.path.join(result_run_dir, "config.yaml"))
|
| 166 |
+
|
| 167 |
+
# if log_files:
|
| 168 |
+
# latest_log = str(log_files[-1])
|
| 169 |
+
# shutil.copy2(latest_log, os.path.join(result_run_dir, "run.log"))
|
| 170 |
+
|
| 171 |
+
return result_run_dir
|
src/common/faiss_manager.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import ast
|
| 5 |
+
import faiss
|
| 6 |
+
from typing import Union, Optional
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
import numpy as np
|
| 9 |
+
from sklearn.preprocessing import normalize
|
| 10 |
+
from src.common.file_manager import FileManager
|
| 11 |
+
from src.common.llm.openai_manager import OpenAIManager
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class FAISSIndexManager:
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
index_truncation_config,
|
| 18 |
+
dimension=3072,
|
| 19 |
+
index_path="index_store/index.faiss",
|
| 20 |
+
indice2fm_path="index_store/indice2fm.json",
|
| 21 |
+
):
|
| 22 |
+
|
| 23 |
+
dotenv_path = os.path.join(os.getcwd(), ".env")
|
| 24 |
+
load_dotenv(dotenv_path)
|
| 25 |
+
self.openaiManager = OpenAIManager()
|
| 26 |
+
self.dimension = dimension
|
| 27 |
+
self.index = faiss.IndexFlatIP(dimension)
|
| 28 |
+
self.file_managers = []
|
| 29 |
+
self.indice2fm = (
|
| 30 |
+
{}
|
| 31 |
+
) # Mapping from file texts tracking from file_path to faiss index indices, guarantee indice in asc order
|
| 32 |
+
self.index_path = index_path
|
| 33 |
+
self.indice2fm_path = indice2fm_path
|
| 34 |
+
|
| 35 |
+
# initialize index and indice2fm from saved files
|
| 36 |
+
if os.path.exists(index_path):
|
| 37 |
+
self.index = faiss.read_index(index_path)
|
| 38 |
+
print(f"Loaded FAISS index from {index_path}")
|
| 39 |
+
|
| 40 |
+
if os.path.exists(indice2fm_path):
|
| 41 |
+
with open(indice2fm_path, "r") as file:
|
| 42 |
+
self.indice2fm = json.load(file)
|
| 43 |
+
for file_path, _ in self.indice2fm.items():
|
| 44 |
+
self.file_managers.append(
|
| 45 |
+
FileManager(
|
| 46 |
+
file_path=file_path,
|
| 47 |
+
index_truncation_config=index_truncation_config,
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
def is_indice_align(self):
|
| 52 |
+
last_index_id = self.index.ntotal - 1
|
| 53 |
+
return last_index_id == max(max(values) for values in self.indice2fm.values())
|
| 54 |
+
|
| 55 |
+
def save_index(self, index_path, indice2fm_path):
|
| 56 |
+
if self.index:
|
| 57 |
+
os.makedirs(os.path.dirname(index_path), exist_ok=True)
|
| 58 |
+
faiss.write_index(self.index, index_path)
|
| 59 |
+
# also save file_path to indice mapping, self.indice2fm should be updated before calling this function
|
| 60 |
+
with open(indice2fm_path, mode="w") as file:
|
| 61 |
+
json.dump(self.indice2fm, file, indent=4)
|
| 62 |
+
|
| 63 |
+
def delete_index(self):
|
| 64 |
+
self.index.reset()
|
| 65 |
+
self.indice2fm = {}
|
| 66 |
+
if os.path.exists(self.index_path):
|
| 67 |
+
os.remove(self.index_path)
|
| 68 |
+
if os.path.exists(self.indice2fm_path):
|
| 69 |
+
os.remove(self.indice2fm_path)
|
| 70 |
+
print("FAISS index deleted.")
|
| 71 |
+
|
| 72 |
+
def upsert_file_to_faiss(
|
| 73 |
+
self,
|
| 74 |
+
file_manager,
|
| 75 |
+
model="text-embedding-3-large",
|
| 76 |
+
truncation_strategy: Optional[Union[str, bool]] = "fixed_length",
|
| 77 |
+
truncate_by: Optional[str] = "\n",
|
| 78 |
+
):
|
| 79 |
+
if not file_manager.file_path in [
|
| 80 |
+
file_manager.file_path for file_manager in self.file_managers
|
| 81 |
+
]:
|
| 82 |
+
self.file_managers.append(file_manager)
|
| 83 |
+
else:
|
| 84 |
+
print(f"File '{file_manager.file_path}' already exists in the FAISS index.")
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
# Process the file if necessary
|
| 88 |
+
# TODO: check if file_manager.texts will in any case be empty, if not, remove the below block
|
| 89 |
+
if not file_manager.texts:
|
| 90 |
+
print("Processing documents...")
|
| 91 |
+
file_manager.process_document(
|
| 92 |
+
truncation_strategy=truncation_strategy, truncate_by=truncate_by
|
| 93 |
+
)
|
| 94 |
+
print("Documents processing done.")
|
| 95 |
+
|
| 96 |
+
# Generate embeddings and append to index if not already present
|
| 97 |
+
if not file_manager.file_path in self.indice2fm:
|
| 98 |
+
print("Creating embedding for the document...")
|
| 99 |
+
embeddings = self.openaiManager.create_openai_embeddings(
|
| 100 |
+
file_manager.texts, model=model
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# Normalize embeddings
|
| 104 |
+
embeddings_np = self.normalize_embeddings(embeddings)
|
| 105 |
+
start_index = self.index.ntotal
|
| 106 |
+
# Add embeddings to FAISS index
|
| 107 |
+
self.index.add(embeddings_np)
|
| 108 |
+
end_index = self.index.ntotal
|
| 109 |
+
added_indices = list(range(start_index, end_index))
|
| 110 |
+
|
| 111 |
+
# Update the self.indice2fm dictionary
|
| 112 |
+
self.indice2fm[file_manager.file_path] = added_indices
|
| 113 |
+
self.save_index(
|
| 114 |
+
index_path=self.index_path, indice2fm_path=self.indice2fm_path
|
| 115 |
+
)
|
| 116 |
+
print(
|
| 117 |
+
f"Embeddings from file '{file_manager.file_path}' added to FAISS index between indice {start_index} to {end_index}."
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
print(f"File '{file_manager.file_path}' already exists in the FAISS index.")
|
| 121 |
+
|
| 122 |
+
def normalize_embeddings(self, embeddings):
|
| 123 |
+
if np.isnan(embeddings).any() or np.isinf(embeddings).any():
|
| 124 |
+
raise ValueError("Embeddings contain NaNs or Infs.")
|
| 125 |
+
embeddings_np = np.array(embeddings).astype("float32")
|
| 126 |
+
#faiss normalize give error zsh: segmentation fault python faiss manager at some edge case in hotpotqa
|
| 127 |
+
#faiss.normalize_L2(embeddings_np)
|
| 128 |
+
embeddings_normalized = normalize(embeddings_np, norm='l2', axis=1)
|
| 129 |
+
return embeddings_normalized
|
| 130 |
+
|
| 131 |
+
def search_faiss_index(
|
| 132 |
+
self,
|
| 133 |
+
query,
|
| 134 |
+
top_k=10,
|
| 135 |
+
threshold=0.5,
|
| 136 |
+
truncation_strategy: Optional[Union[str, bool]] = "fixed_length",
|
| 137 |
+
truncate_by: Optional[str] = "\n",
|
| 138 |
+
):
|
| 139 |
+
if self.index.ntotal == 0:
|
| 140 |
+
return []
|
| 141 |
+
|
| 142 |
+
# Create a normalized embedding for the query
|
| 143 |
+
query_embedding = self.normalize_embeddings(
|
| 144 |
+
[
|
| 145 |
+
self.openaiManager.client.embeddings.create(
|
| 146 |
+
input=[query], model="text-embedding-3-large"
|
| 147 |
+
)
|
| 148 |
+
.data[0]
|
| 149 |
+
.embedding
|
| 150 |
+
]
|
| 151 |
+
)[0].reshape(1, -1)
|
| 152 |
+
|
| 153 |
+
# Perform the search
|
| 154 |
+
similarity, indices = self.index.search(query_embedding, top_k)
|
| 155 |
+
filtered_results = [
|
| 156 |
+
(idx, similar)
|
| 157 |
+
for idx, similar in zip(indices[0], similarity[0])
|
| 158 |
+
if similar >= threshold
|
| 159 |
+
]
|
| 160 |
+
results = []
|
| 161 |
+
|
| 162 |
+
# Reverse map indices to file paths and text
|
| 163 |
+
for idx, dist in filtered_results:
|
| 164 |
+
file_path_found = None
|
| 165 |
+
relative_idx = None
|
| 166 |
+
|
| 167 |
+
# Find the file_path and relative index using self.indice2fm
|
| 168 |
+
for file_path, indice_list in self.indice2fm.items():
|
| 169 |
+
if idx in indice_list:
|
| 170 |
+
file_path_found = file_path
|
| 171 |
+
relative_idx = indice_list.index(idx)
|
| 172 |
+
break
|
| 173 |
+
|
| 174 |
+
if file_path_found is not None and relative_idx is not None:
|
| 175 |
+
# Find the corresponding file_manager
|
| 176 |
+
file_manager = next(
|
| 177 |
+
(
|
| 178 |
+
fm
|
| 179 |
+
for fm in self.file_managers
|
| 180 |
+
if fm.file_path == file_path_found
|
| 181 |
+
),
|
| 182 |
+
None,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if file_manager:
|
| 186 |
+
# Process the file if necessary
|
| 187 |
+
file_manager.process_document(
|
| 188 |
+
truncation_strategy=truncation_strategy, truncate_by=truncate_by
|
| 189 |
+
)
|
| 190 |
+
try:
|
| 191 |
+
# Get the text from the file_manager
|
| 192 |
+
text = file_manager.texts[relative_idx][
|
| 193 |
+
1
|
| 194 |
+
] # Assuming (index, text) tuples in file_manager.texts
|
| 195 |
+
results.append(
|
| 196 |
+
f"{text} indice={idx} fileposition={relative_idx} score={dist:.4f}"
|
| 197 |
+
# TODO reformat this
|
| 198 |
+
# {
|
| 199 |
+
# "text": text,
|
| 200 |
+
# "indice": idx,
|
| 201 |
+
# "fileposition": relative_idx,
|
| 202 |
+
# "score": round(dist, 4),
|
| 203 |
+
# }
|
| 204 |
+
)
|
| 205 |
+
except:
|
| 206 |
+
print(
|
| 207 |
+
f"Error while retriving id={relative_idx} from file manager. Skipping over id={relative_idx}."
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
else:
|
| 211 |
+
results.append(
|
| 212 |
+
f"File manager not found for '{file_path_found}' score={dist:.4f}"
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
# TODO reformat this
|
| 216 |
+
results.append(f"Index not mapped, score={dist:.4f}")
|
| 217 |
+
|
| 218 |
+
return results
|
| 219 |
+
|
| 220 |
+
def parse_result(self, result):
|
| 221 |
+
"""
|
| 222 |
+
Parse the result from the search and return the page content, metadata, indice, and score.
|
| 223 |
+
"""
|
| 224 |
+
# Parse the input
|
| 225 |
+
parsed_item = None
|
| 226 |
+
pattern = re.compile(
|
| 227 |
+
r"page_content='(.*?)'\smetadata=(\{.*?\})\sindice=(\d+)\sfileposition=(\d+)\sscore=([\d.]+)",
|
| 228 |
+
re.DOTALL,
|
| 229 |
+
)
|
| 230 |
+
matches = pattern.findall(result)
|
| 231 |
+
# assume only 1 row with matched pattern will be feed in each time, only remain last item
|
| 232 |
+
for match in matches:
|
| 233 |
+
page_content, metadata, indice, fileposition, score = match
|
| 234 |
+
# Convert metadata string to a dictionary
|
| 235 |
+
metadata_dict = ast.literal_eval(metadata)
|
| 236 |
+
parsed_item = {
|
| 237 |
+
"page_content": page_content.strip(),
|
| 238 |
+
"metadata": metadata_dict,
|
| 239 |
+
"indice": int(indice),
|
| 240 |
+
"fileposition": int(fileposition),
|
| 241 |
+
"score": float(score),
|
| 242 |
+
}
|
| 243 |
+
return parsed_item
|
| 244 |
+
|
| 245 |
+
def generate_response_from_context(self, query, retrieved_docs, model="gpt-4o"):
|
| 246 |
+
if not retrieved_docs:
|
| 247 |
+
return "No relevant documents found in the FAISS index."
|
| 248 |
+
|
| 249 |
+
# Process retrieved documents into a clean context
|
| 250 |
+
formatted_docs = []
|
| 251 |
+
for doc in retrieved_docs:
|
| 252 |
+
try:
|
| 253 |
+
# Split the document string into page_content and metadata
|
| 254 |
+
doc_parts = doc.split("metadata=")
|
| 255 |
+
page_content = doc_parts[0].replace("page_content=", "").strip()
|
| 256 |
+
metadata = (
|
| 257 |
+
doc_parts[1].strip() if len(doc_parts) > 1 else "Unknown source"
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Format each document clearly
|
| 261 |
+
formatted_doc = f"Content: {page_content}\nSource: {metadata}"
|
| 262 |
+
formatted_docs.append(formatted_doc)
|
| 263 |
+
except Exception as e:
|
| 264 |
+
formatted_docs.append(f"Error processing document: {e}")
|
| 265 |
+
|
| 266 |
+
# Combine the formatted documents into a single context
|
| 267 |
+
context = "\n\n---\n\n".join(formatted_docs)
|
| 268 |
+
|
| 269 |
+
# Construct the prompt for the OpenAI API
|
| 270 |
+
messages = [
|
| 271 |
+
{
|
| 272 |
+
"role": "system",
|
| 273 |
+
"content": "You are a helpful assistant that answers questions based on provided context.",
|
| 274 |
+
},
|
| 275 |
+
{"role": "user", "content": query},
|
| 276 |
+
{
|
| 277 |
+
"role": "assistant",
|
| 278 |
+
"content": f"The following context was retrieved from the database:\n\n{context}",
|
| 279 |
+
},
|
| 280 |
+
]
|
| 281 |
+
|
| 282 |
+
# Generate response using OpenAI Chat API
|
| 283 |
+
response = self.openaiManager.client.chat.completions.create(
|
| 284 |
+
model=model, messages=messages, max_tokens=4096, temperature=0.7
|
| 285 |
+
)
|
| 286 |
+
return response.choices[0].message.content
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def main():
|
| 290 |
+
# Example Usage
|
| 291 |
+
file_path1 = os.path.join(os.getcwd(), "documents", "2024_Corrective_RAGv2.pdf")
|
| 292 |
+
file_manager1 = FileManager(file_path1)
|
| 293 |
+
manager = FAISSIndexManager(dimension=3072)
|
| 294 |
+
manager.upsert_file_to_faiss(file_manager1)
|
| 295 |
+
|
| 296 |
+
file_path2 = os.path.join(os.getcwd(), "documents", "2023_Iterative_RGen.pdf")
|
| 297 |
+
file_manager2 = FileManager(file_path2)
|
| 298 |
+
manager.upsert_file_to_faiss(file_manager2)
|
| 299 |
+
|
| 300 |
+
query = "tell me about corrective rag system."
|
| 301 |
+
retrieved_docs = manager.search_faiss_index(query, top_k=10, threshold=0.1)
|
| 302 |
+
print(retrieved_docs)
|
| 303 |
+
response = manager.generate_response_from_context(query, retrieved_docs)
|
| 304 |
+
print(response)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
print("Running faiss_manager.py")
|
| 309 |
+
main()
|
src/common/file_manager.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
from typing import Union, Optional
|
| 4 |
+
from PyPDF2 import PdfReader
|
| 5 |
+
from langchain_core.documents import Document
|
| 6 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 7 |
+
from src.common.chunker import FixedLengthChunker
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FileManager:
|
| 11 |
+
def __init__(self, file_path: str, index_truncation_config: dict):
|
| 12 |
+
self.file_path = file_path
|
| 13 |
+
self.chunk_size = index_truncation_config["chunk_size"]
|
| 14 |
+
self.chunk_overlap = index_truncation_config["chunk_overlap"]
|
| 15 |
+
self.texts = []
|
| 16 |
+
directory = os.path.dirname(file_path)
|
| 17 |
+
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 18 |
+
self.texts_file = os.path.join(directory, f"{base_name}_texts.json")
|
| 19 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 20 |
+
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
| 21 |
+
) # TODO
|
| 22 |
+
|
| 23 |
+
# Load texts from file if it exists
|
| 24 |
+
if os.path.exists(self.texts_file):
|
| 25 |
+
with open(self.texts_file, "r", encoding="utf-8-sig") as f:
|
| 26 |
+
self.texts = json.load(f)
|
| 27 |
+
print(f"Loaded texts from file: {self.texts_file}")
|
| 28 |
+
|
| 29 |
+
def load_pdf_document(self):
|
| 30 |
+
pdf_reader = PdfReader(self.file_path)
|
| 31 |
+
documents = []
|
| 32 |
+
|
| 33 |
+
for page_num, page in enumerate(pdf_reader.pages):
|
| 34 |
+
page_text = page.extract_text()
|
| 35 |
+
if page_text: # Ensure the page has text
|
| 36 |
+
document = Document(
|
| 37 |
+
metadata={"source": self.file_path, "page": page_num},
|
| 38 |
+
page_content=page_text,
|
| 39 |
+
)
|
| 40 |
+
documents.append(document)
|
| 41 |
+
|
| 42 |
+
return documents
|
| 43 |
+
|
| 44 |
+
def dump_documents(self, texts):
|
| 45 |
+
if texts and not os.path.exists(self.texts_file):
|
| 46 |
+
with open(self.texts_file, "w") as f:
|
| 47 |
+
json.dump(texts, f)
|
| 48 |
+
print(f"Associated texts saved to file: {self.texts_file}")
|
| 49 |
+
else:
|
| 50 |
+
raise FileExistsError(
|
| 51 |
+
f"File {self.texts_file} already exists. Please remove it before saving."
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def process_pdf(self):
|
| 55 |
+
data = self.load_pdf_document()
|
| 56 |
+
|
| 57 |
+
documents = self.text_splitter.split_documents(data)
|
| 58 |
+
self.texts = [(i, str(doc)) for i, doc in enumerate(documents)]
|
| 59 |
+
self.dump_documents(self.texts)
|
| 60 |
+
|
| 61 |
+
def process_document(
|
| 62 |
+
self,
|
| 63 |
+
truncation_strategy: Optional[Union[str, bool]] = "fixed_length",
|
| 64 |
+
chunk_size: int = 2000,
|
| 65 |
+
overlap_size: int = 25,
|
| 66 |
+
truncate_by: Optional[str] = "\n",
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Process document according to the specified strategy.
|
| 70 |
+
Either truncation_strategy or truncate_by must be provided, but not both.
|
| 71 |
+
"""
|
| 72 |
+
if truncation_strategy is None and truncate_by is None:
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"Either truncation_strategy or truncate_by must be provided"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if self.texts:
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
chunks = []
|
| 81 |
+
|
| 82 |
+
with open(self.file_path, "r", encoding="utf-8") as f:
|
| 83 |
+
data = json.load(f)
|
| 84 |
+
|
| 85 |
+
for title, texts in data.items():
|
| 86 |
+
if not truncation_strategy and not truncate_by:
|
| 87 |
+
chunks.append(self.create_document(title, texts, self.file_path))
|
| 88 |
+
print(f"{title} - No text splitting. Chunk size: {len(texts)}")
|
| 89 |
+
elif truncation_strategy == "fixed_length":
|
| 90 |
+
chunk_list = []
|
| 91 |
+
for text in texts:
|
| 92 |
+
fixed_length_chunks, texts_word_cnt = FixedLengthChunker(
|
| 93 |
+
text, chunk_size, overlap_size
|
| 94 |
+
).create_chunks()
|
| 95 |
+
chunk_list.extend(fixed_length_chunks)
|
| 96 |
+
print(
|
| 97 |
+
f"Document '{title}' is splitted into {len(chunk_list)} chunk(s) by length of {chunk_size} words. Initial text size: {texts_word_cnt}."
|
| 98 |
+
)
|
| 99 |
+
for text in chunk_list:
|
| 100 |
+
if text.strip():
|
| 101 |
+
chunks.append(self.create_document(title, text, self.file_path))
|
| 102 |
+
elif truncation_strategy == "recursive": # Fixed typo in strategy name
|
| 103 |
+
raise NotImplementedError(
|
| 104 |
+
"Recursive truncation is currently not supported"
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
# print("splitting by specific char")
|
| 108 |
+
if isinstance(texts, str):
|
| 109 |
+
if truncate_by in texts:
|
| 110 |
+
split_texts = texts.split(truncate_by)
|
| 111 |
+
else:
|
| 112 |
+
split_texts = [texts]
|
| 113 |
+
elif isinstance(texts, list):
|
| 114 |
+
split_texts = texts
|
| 115 |
+
|
| 116 |
+
for text in split_texts:
|
| 117 |
+
if text.strip():
|
| 118 |
+
chunks.append(self.create_document(title, text, self.file_path))
|
| 119 |
+
|
| 120 |
+
self.texts = [(i, str(doc)) for i, doc in enumerate(chunks)]
|
| 121 |
+
self.dump_documents(self.texts)
|
| 122 |
+
|
| 123 |
+
def create_document(self, title, text, file_path):
|
| 124 |
+
"""Create a document with the given title and text."""
|
| 125 |
+
return Document(
|
| 126 |
+
page_content=f"{title}: {text}",
|
| 127 |
+
metadata={"source": title, "file_path": file_path},
|
| 128 |
+
)
|
src/common/llm/__init__.py
ADDED
|
File without changes
|
src/common/llm/llm_agent.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
class LLMAgent(ABC):
|
| 4 |
+
@abstractmethod
|
| 5 |
+
def answer(self, question) -> str:
|
| 6 |
+
pass
|
| 7 |
+
|
| 8 |
+
@abstractmethod
|
| 9 |
+
def preProcess(self, query):
|
| 10 |
+
pass
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
def postProcess(self, response):
|
| 14 |
+
pass
|