Upload 34 files
Browse files- .gitattributes +6 -0
- README.md +150 -3
- assets/overview.png +3 -0
- data/2wikimqa.jsonl +0 -0
- data/2wikimqa_antifact.jsonl +0 -0
- data/hotpotqa.jsonl +3 -0
- data/hotpotqa_antifact.jsonl +3 -0
- data/hotpotqa_random.jsonl +3 -0
- data/multifieldqa_en.jsonl +0 -0
- data/multifieldqa_en_antifact.jsonl +0 -0
- data/musique.jsonl +3 -0
- data/musique_antifact.jsonl +3 -0
- detect/contextleakage.py +194 -0
- detect/contextleakage_api.py +87 -0
- detect/question_rephrase_answer_api.py +153 -0
- detect/question_rephrase_answer_qwen3.py +235 -0
- detect/question_rephrase_answer_vllm.py +226 -0
- eval/__init__.py +2 -0
- eval/eval_with_api.py +94 -0
- eval/evaluation.py +115 -0
- main_gpu.py +427 -0
- random_alternative_answer.py +418 -0
- requirements.txt +38 -0
- training_result/training_loss_antifact_llama.csv +41 -0
- training_result/training_loss_antifact_qwen38.csv +41 -0
- training_result/training_loss_llama.csv +41 -0
- training_result/training_loss_phi4.csv +41 -0
- training_result/training_loss_phi4_antifact.csv +41 -0
- training_result/training_loss_qwen38.csv +41 -0
- utils/__init__.py +2 -0
- utils/convert.py +52 -0
- utils/draw.py +82 -0
- utils/llmjudge.py +49 -0
- utils/metrics.py +152 -0
- utils/util.py +104 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/overview.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/hotpotqa_antifact.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
data/hotpotqa_random.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
data/hotpotqa.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
data/musique_antifact.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
data/musique.jsonl filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,150 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LastingBench: Defend Benchmarks Against Knowledge Leakage.
|
| 2 |
+
|
| 3 |
+
Welcome to the repository for the research paper: "LastingBench: Defend Benchmarks Against Knowledge Leakage." This project addresses the growing concern about large language models (LLMs) "cheating" on standard Question Answering (QA) benchmarks by memorizing task-specific data, which undermines the validity of benchmark evaluations as they no longer reflect genuine model capabilities but instead the effects of data leakage.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
LastingBench introduces a novel framework designed to continuously reinforce and safeguard existing benchmarks against knowledge leakage. The project aims to:
|
| 10 |
+
- **Detect knowledge leakage** through context and question perturbation techniques
|
| 11 |
+
- **Rewrite leaked content** to counterfactual alternatives that disrupt memorization while preserving the benchmark's original evaluative intent
|
| 12 |
+
- **Evaluate model responses** to contextual evidence and reasoning patterns
|
| 13 |
+
- **Provide practical solutions** to ensure benchmark robustness over time, promoting fairer and more interpretable evaluations of LLMs
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
## Installation
|
| 17 |
+
|
| 18 |
+
1. Clone the repository:
|
| 19 |
+
```bash
|
| 20 |
+
git clone https://github.com/Seriousss/lastingbench
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
2. Create and activate conda environment:
|
| 24 |
+
```bash
|
| 25 |
+
conda create -n lastingbench python=3.12
|
| 26 |
+
conda activate lastingbench
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
3. Install dependencies:
|
| 30 |
+
```bash
|
| 31 |
+
pip install -r requirements.txt
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
4. Set up environment variables:
|
| 35 |
+
```bash
|
| 36 |
+
export OPENAI_BASE_URL="your-api-base-url"
|
| 37 |
+
export OPENAI_API_KEY="your-api-key"
|
| 38 |
+
export CUDA_VISIBLE_DEVICES="0,1,2,3" # Adjust based on your GPU setup
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Usage
|
| 42 |
+
|
| 43 |
+
LastingBench provides three main functionalities: **Detection**, **Rewrite**, and **Training Comparision**.
|
| 44 |
+
|
| 45 |
+
### 🔍 Detection
|
| 46 |
+
|
| 47 |
+
Detect knowledge leakage through various perturbation techniques.
|
| 48 |
+
|
| 49 |
+
#### 1. Context Leakage Detection
|
| 50 |
+
Evaluate models using exact-match scoring on benchmark datasets:
|
| 51 |
+
```bash
|
| 52 |
+
# Using vLLM for most models
|
| 53 |
+
python -m detect.contextleakage --hf_model "Qwen/Qwen2.5-7B-Instruct" \
|
| 54 |
+
--dataset_subset "hotpotqa" --cuda_devices "0,1"
|
| 55 |
+
|
| 56 |
+
# Using Transformers for Qwen3 models
|
| 57 |
+
python -m detect.contextleakage --hf_model "Qwen/Qwen3-8B" \
|
| 58 |
+
--is_qwen3 --max_new_tokens 30
|
| 59 |
+
|
| 60 |
+
python -m detect.contextleakage_api --model "deepseek-r1" --dataset_subset "hotpotqa"
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
#### 2. Question Perturbation Detection
|
| 65 |
+
Rephrase questions to opposite meanings and test model consistency:
|
| 66 |
+
```bash
|
| 67 |
+
# Using OpenAI API
|
| 68 |
+
python -m detect.question_rephrase_answer_api \
|
| 69 |
+
--model_name "gpt-4o" --dataset_subset "2wikimqa" \
|
| 70 |
+
--rephrase_type "opposite" --sample_count 100
|
| 71 |
+
|
| 72 |
+
# Using local vLLM models
|
| 73 |
+
python -m detect.question_rephrase_answer_vllm \
|
| 74 |
+
--model_name "Qwen/Qwen2.5-7B-Instruct" --dataset_subset "hotpotqa" --rephrase_type "similar"
|
| 75 |
+
|
| 76 |
+
# Using Qwen3 with Transformers
|
| 77 |
+
python -m detect.question_rephrase_answer_qwen3 \
|
| 78 |
+
--model_name "Qwen/Qwen3-8B" --dataset_subset "2wikimqa"
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
### ✏️ Rewrite
|
| 83 |
+
|
| 84 |
+
Generate counterfactual answers and rewrite leaked evidence to create robust benchmarks.
|
| 85 |
+
`
|
| 86 |
+
|
| 87 |
+
#### 1. Evidence Finding and Counterfactual Rewriting Pipeline
|
| 88 |
+
Run the complete finding and rewriting pipeline:
|
| 89 |
+
```bash
|
| 90 |
+
|
| 91 |
+
# Specify custom output file and dataset
|
| 92 |
+
python main_gpu.py --output custom_output.jsonl \
|
| 93 |
+
--dataset_subset "hotpotqa" --start_idx 0 --max_samples 100
|
| 94 |
+
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
Convert and merge JSONL files with question-answer mappings:
|
| 98 |
+
```bash
|
| 99 |
+
# Merge single mapping file with original dataset
|
| 100 |
+
python utils/convert.py original.jsonl revised.jsonl custom_output.jsonl
|
| 101 |
+
|
| 102 |
+
```
|
| 103 |
+
The original and revised dataset can be found under the **data** folder.
|
| 104 |
+
|
| 105 |
+
#### 2. Random Answer Rewriting
|
| 106 |
+
Create random alternatives to disrupt memorization:
|
| 107 |
+
```bash
|
| 108 |
+
# Specify custom output file and dataset
|
| 109 |
+
python random_alternative_answer.py --output random_hotpot.jsonl \
|
| 110 |
+
--dataset_subset "hotpotqa" --start_idx 0 --max_samples 50
|
| 111 |
+
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
### 🚀Dataset evaluations on model inference and training
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
#### 1. Model Inference Evaluation
|
| 119 |
+
Comprehensive evaluation on original and revised benchmarks:
|
| 120 |
+
```bash
|
| 121 |
+
# Transformers-based evaluation
|
| 122 |
+
python -m eval.evaluation -i data/hotpotqa.jsonl -model "Qwen/Qwen3-8B" -k 40 -t 0.5
|
| 123 |
+
|
| 124 |
+
# API-based evaluation
|
| 125 |
+
python -m eval.eval_with_api.py --input data/hotpotqa_antifact.jsonl \
|
| 126 |
+
--model "deepseek-r1" --max_tokens 30 --temperature 0.5
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
#### 2. Model training Evaluation
|
| 130 |
+
Compare training dynamics between original and rewritten datasets:
|
| 131 |
+
|
| 132 |
+
The training loss data can be found under **training_result**.
|
| 133 |
+
|
| 134 |
+
To repoduce the picture in our paper:
|
| 135 |
+
```bash
|
| 136 |
+
python utils/draw.py training_result/training_loss_qwen38.csv training_result/training_loss_antifact_qwen38.csv \
|
| 137 |
+
--title "Original vs Rewritten Training Loss"
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
### 📊 Utility Functions
|
| 143 |
+
|
| 144 |
+
Additional tools for analysis and metrics:
|
| 145 |
+
|
| 146 |
+
- **Metrics Calculation**: F1 scores, EM scores, and custom evaluation metrics
|
| 147 |
+
- **Document Retrieval**: BM25-based retrieval for evidence analysis
|
| 148 |
+
|
| 149 |
+
All scripts support various parameters for customization. Use `--help` with any script to see available options.
|
| 150 |
+
|
assets/overview.png
ADDED
|
Git LFS Details
|
data/2wikimqa.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/2wikimqa_antifact.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/hotpotqa.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0005ab2a1bc2ac3a70352dccbf96cccc4e0aac6bb677f6a55180fa51b92ef6f
|
| 3 |
+
size 11483614
|
data/hotpotqa_antifact.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94d00ea5c5f14ce5c0638bc6c66658e93425d87f377bf1d3da4a004fd15fcb6f
|
| 3 |
+
size 11447185
|
data/hotpotqa_random.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d2072c7caac2fdc462c0777bbcb38b41292974383db46a2fedb4c3b86f3c825
|
| 3 |
+
size 11461405
|
data/multifieldqa_en.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/multifieldqa_en_antifact.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/musique.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ac69b91281c4ec6b21316cb7282e83fb6b4dda04fc68480bb8d8ed1e19ff7bd
|
| 3 |
+
size 14085077
|
data/musique_antifact.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e80355c61ec5d4f79e2d3610fa9c25156746079bff21268932ae9cc8d23acdfc
|
| 3 |
+
size 14057004
|
detect/contextleakage.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Evaluate models on a LongBench subset with Exact-Match (EM).
|
| 5 |
+
Supports both Qwen3 (Transformers) and other models (vLLM).
|
| 6 |
+
|
| 7 |
+
Requirements
|
| 8 |
+
------------
|
| 9 |
+
pip install vllm datasets tqdm transformers accelerate
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import argparse, logging, time, torch
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from utils.metrics import qa_em_score
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
# ---------------------------- CLI ------------------------------------
|
| 21 |
+
parser = argparse.ArgumentParser()
|
| 22 |
+
parser.add_argument("--hf_model",
|
| 23 |
+
default="Qwen/Qwen3-8B-Instruct",
|
| 24 |
+
help="Model name or local path")
|
| 25 |
+
parser.add_argument("--is_qwen3", action="store_true",
|
| 26 |
+
help="Set this flag if using Qwen3 model (uses Transformers). Otherwise uses vLLM.")
|
| 27 |
+
parser.add_argument("--max_new_tokens", type=int, default=20)
|
| 28 |
+
parser.add_argument("--max_tokens", type=int, default=20,
|
| 29 |
+
help="For vLLM models (ignored if --is_qwen3)")
|
| 30 |
+
parser.add_argument("--temperature", type=float, default=0.0)
|
| 31 |
+
parser.add_argument("--top_p", type=float, default=1.0)
|
| 32 |
+
parser.add_argument("--tensor_parallel_size", type=int, default=2,
|
| 33 |
+
help="GPU parallel size for vLLM (ignored if --is_qwen3)")
|
| 34 |
+
|
| 35 |
+
parser.add_argument("--dataset_repo", default="THUDM/LongBench")
|
| 36 |
+
parser.add_argument("--dataset_subset", default="hotpotqa")
|
| 37 |
+
parser.add_argument("--split", default="test")
|
| 38 |
+
parser.add_argument("--sleep", type=float, default=0.0)
|
| 39 |
+
parser.add_argument("--log", default="summary.log")
|
| 40 |
+
parser.add_argument("--cuda_devices", default="1,6",
|
| 41 |
+
help="CUDA visible devices")
|
| 42 |
+
args = parser.parse_args()
|
| 43 |
+
|
| 44 |
+
# Set CUDA devices
|
| 45 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices
|
| 46 |
+
|
| 47 |
+
# --------------------------- logging ---------------------------------
|
| 48 |
+
logging.basicConfig(
|
| 49 |
+
filename=args.log,
|
| 50 |
+
level=logging.INFO,
|
| 51 |
+
format="%(asctime)s - %(message)s",
|
| 52 |
+
filemode="a",
|
| 53 |
+
)
|
| 54 |
+
logging.getLogger().addHandler(logging.StreamHandler())
|
| 55 |
+
|
| 56 |
+
# ------------------------- dataset -----------------------------------
|
| 57 |
+
ds = load_dataset(args.dataset_repo, args.dataset_subset, split=args.split)
|
| 58 |
+
total = len(ds)
|
| 59 |
+
logging.info("Loaded %d samples from %s/%s[%s]",
|
| 60 |
+
total, args.dataset_repo, args.dataset_subset, args.split)
|
| 61 |
+
|
| 62 |
+
if args.is_qwen3:
|
| 63 |
+
# ---------------------- Qwen3 with Transformers ----------------------------
|
| 64 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 65 |
+
|
| 66 |
+
load_kwargs = dict(
|
| 67 |
+
trust_remote_code=True,
|
| 68 |
+
device_map="auto",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
tokenizer = AutoTokenizer.from_pretrained(args.hf_model,
|
| 72 |
+
trust_remote_code=True)
|
| 73 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 74 |
+
args.hf_model,
|
| 75 |
+
torch_dtype=torch.float16,
|
| 76 |
+
**load_kwargs
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
EOS_ID = tokenizer.eos_token_id
|
| 80 |
+
THINK_ENDID = 151668 # </think> token id
|
| 81 |
+
|
| 82 |
+
gen_kwargs = dict(
|
| 83 |
+
max_new_tokens=args.max_new_tokens,
|
| 84 |
+
temperature=args.temperature,
|
| 85 |
+
top_p=args.top_p,
|
| 86 |
+
do_sample=args.temperature > 0,
|
| 87 |
+
eos_token_id=EOS_ID,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# -------------------------- Qwen3 loop -------------------------------------
|
| 91 |
+
correct_em = 0
|
| 92 |
+
|
| 93 |
+
for ex in tqdm(ds, desc="Evaluating with Transformers (Qwen3)"):
|
| 94 |
+
q = ex["input"]
|
| 95 |
+
golds = ex["answers"]
|
| 96 |
+
|
| 97 |
+
msgs = [
|
| 98 |
+
{"role": "system", "content": "You are a QA assistant."},
|
| 99 |
+
{"role": "user",
|
| 100 |
+
"content": f"Question: {q}\n"
|
| 101 |
+
"Please reply with *only* the final answer—no extra words."}
|
| 102 |
+
]
|
| 103 |
+
prompt = tokenizer.apply_chat_template(
|
| 104 |
+
msgs,
|
| 105 |
+
tokenize=False,
|
| 106 |
+
add_generation_prompt=True,
|
| 107 |
+
enable_thinking=False # Qwen3 thinking mode
|
| 108 |
+
)
|
| 109 |
+
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
|
| 110 |
+
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
outs = model.generate(**inputs, **gen_kwargs)[0]
|
| 113 |
+
|
| 114 |
+
# Extract newly generated tokens
|
| 115 |
+
new_ids = outs[len(inputs.input_ids[0]):].tolist()
|
| 116 |
+
|
| 117 |
+
# Find </think> (if not exist idx=0)
|
| 118 |
+
try:
|
| 119 |
+
idx = len(new_ids) - new_ids[::-1].index(THINK_ENDID)
|
| 120 |
+
except ValueError:
|
| 121 |
+
idx = 0
|
| 122 |
+
|
| 123 |
+
content = tokenizer.decode(new_ids[idx:],
|
| 124 |
+
skip_special_tokens=True).strip("\n").strip()
|
| 125 |
+
|
| 126 |
+
# Only use content for EM comparison
|
| 127 |
+
if any(qa_em_score(content, g) for g in golds):
|
| 128 |
+
correct_em += 1
|
| 129 |
+
|
| 130 |
+
if args.sleep:
|
| 131 |
+
time.sleep(args.sleep)
|
| 132 |
+
|
| 133 |
+
else:
|
| 134 |
+
# ---------------------- Other models with vLLM ----------------------------
|
| 135 |
+
from vllm import LLM, SamplingParams
|
| 136 |
+
|
| 137 |
+
# Initialize vLLM
|
| 138 |
+
llm = LLM(
|
| 139 |
+
model=args.hf_model,
|
| 140 |
+
tensor_parallel_size=args.tensor_parallel_size,
|
| 141 |
+
)
|
| 142 |
+
sampler = SamplingParams(
|
| 143 |
+
temperature=args.temperature,
|
| 144 |
+
max_tokens=args.max_tokens,
|
| 145 |
+
top_p=args.top_p,
|
| 146 |
+
stop=["</assistant>", "</s>", "<|end_of_text|>"],
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# -------------------------- vLLM loop -------------------------------------
|
| 150 |
+
correct_em = 0
|
| 151 |
+
|
| 152 |
+
for ex in tqdm(ds, desc="Evaluating with vLLM"):
|
| 153 |
+
question = ex["input"]
|
| 154 |
+
golds = ex["answers"] # list[str]
|
| 155 |
+
|
| 156 |
+
chat_params = SamplingParams(
|
| 157 |
+
temperature=args.temperature,
|
| 158 |
+
max_tokens=args.max_tokens,
|
| 159 |
+
top_p=args.top_p,
|
| 160 |
+
stop=["</s>", "<|end_of_text|>"], # Safety stop tokens
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
messages = [
|
| 164 |
+
{"role": "system",
|
| 165 |
+
"content": "You are a QA assistant."},
|
| 166 |
+
{"role": "user",
|
| 167 |
+
"content": f"Question: {question}\n"
|
| 168 |
+
"Please first reply with *only* the final answer—no extra words.\n Answer:"}
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
result = llm.chat(messages, sampling_params=chat_params)
|
| 172 |
+
# vLLM returns list[RequestOutput]; take first output's first candidate
|
| 173 |
+
pred = result[0].outputs[0].text.strip()
|
| 174 |
+
print(f"A: {pred}\nG: {golds}\n")
|
| 175 |
+
|
| 176 |
+
if any(qa_em_score(pred, g) for g in golds):
|
| 177 |
+
correct_em += 1
|
| 178 |
+
|
| 179 |
+
if args.sleep:
|
| 180 |
+
time.sleep(args.sleep)
|
| 181 |
+
|
| 182 |
+
# -------------------------- result -----------------------------------
|
| 183 |
+
em = correct_em / total
|
| 184 |
+
model_type = "Qwen3 (Transformers)" if args.is_qwen3 else "vLLM"
|
| 185 |
+
logging.info("RESULT | model=%s | type=%s | subset=%s | EM=%.4f",
|
| 186 |
+
args.hf_model, model_type, args.dataset_subset, em)
|
| 187 |
+
print(
|
| 188 |
+
f"\n=== SUMMARY ===\n"
|
| 189 |
+
f"Model : {args.hf_model}\n"
|
| 190 |
+
f"Type : {model_type}\n"
|
| 191 |
+
f"Subset : {args.dataset_subset} ({args.split})\n"
|
| 192 |
+
f"EM : {em:.4f}\n"
|
| 193 |
+
f"(Log in {Path(args.log).resolve()})"
|
| 194 |
+
)
|
detect/contextleakage_api.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, time, argparse, logging
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from openai import OpenAI
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from utils.metrics import qa_em_score
|
| 6 |
+
|
| 7 |
+
# ----------------------------------------------------------------------
|
| 8 |
+
# CLI
|
| 9 |
+
# ----------------------------------------------------------------------
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument("--model", default="gpt-4o")
|
| 12 |
+
parser.add_argument("--dataset_repo", default="THUDM/LongBench")
|
| 13 |
+
parser.add_argument("--dataset_subset", default="hotpotqa")
|
| 14 |
+
parser.add_argument("--split", default="test")
|
| 15 |
+
parser.add_argument("--max_tokens", type=int, default=30)
|
| 16 |
+
parser.add_argument("--temperature", type=float, default=0.0)
|
| 17 |
+
parser.add_argument("--sleep", type=float, default=0.5,
|
| 18 |
+
help="seconds to wait between requests")
|
| 19 |
+
parser.add_argument("--log", default="summary.log",
|
| 20 |
+
help="append overall score here")
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
# ----------------------------------------------------------------------
|
| 24 |
+
# Logging (append mode)
|
| 25 |
+
# ----------------------------------------------------------------------
|
| 26 |
+
logging.basicConfig(
|
| 27 |
+
filename=args.log,
|
| 28 |
+
level=logging.INFO,
|
| 29 |
+
format="%(asctime)s - %(message)s",
|
| 30 |
+
filemode="a",
|
| 31 |
+
)
|
| 32 |
+
console = logging.StreamHandler()
|
| 33 |
+
console.setLevel(logging.INFO)
|
| 34 |
+
logging.getLogger().addHandler(console)
|
| 35 |
+
|
| 36 |
+
# ----------------------------------------------------------------------
|
| 37 |
+
# OpenAI client
|
| 38 |
+
# ----------------------------------------------------------------------
|
| 39 |
+
client = OpenAI(
|
| 40 |
+
api_key=os.environ.get("OPENAI_API_KEY"),
|
| 41 |
+
base_url=os.environ.get("OPENAI_BASE_URL")
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# ----------------------------------------------------------------------
|
| 45 |
+
# Load dataset
|
| 46 |
+
# ----------------------------------------------------------------------
|
| 47 |
+
ds = load_dataset(args.dataset_repo, args.dataset_subset, split=args.split)
|
| 48 |
+
total = len(ds)
|
| 49 |
+
logging.info("Loaded %d samples from %s/%s[%s]",
|
| 50 |
+
total, args.dataset_repo, args.dataset_subset, args.split)
|
| 51 |
+
|
| 52 |
+
# ----------------------------------------------------------------------
|
| 53 |
+
# Evaluation loop
|
| 54 |
+
# ----------------------------------------------------------------------
|
| 55 |
+
correct_em = 0
|
| 56 |
+
|
| 57 |
+
for ex in tqdm(ds, desc="Evaluating"):
|
| 58 |
+
question = ex["input"]
|
| 59 |
+
golds = ex["answers"]
|
| 60 |
+
|
| 61 |
+
resp = client.chat.completions.create(
|
| 62 |
+
model=args.model,
|
| 63 |
+
messages=[
|
| 64 |
+
{"role": "system", "content": "You are a QA assistant."},
|
| 65 |
+
{"role": "user",
|
| 66 |
+
"content": f"Question: {question}\n"
|
| 67 |
+
"Please first reply with *only* the final answer—no extra words.\n Answer:"}
|
| 68 |
+
],
|
| 69 |
+
temperature=args.temperature,
|
| 70 |
+
max_tokens=args.max_tokens,
|
| 71 |
+
)
|
| 72 |
+
pred = resp.choices[0].message.content.strip()
|
| 73 |
+
print(f"A: {pred}\n G: {golds}")
|
| 74 |
+
|
| 75 |
+
if any(qa_em_score(pred, g) for g in golds):
|
| 76 |
+
correct_em += 1
|
| 77 |
+
|
| 78 |
+
time.sleep(args.sleep)
|
| 79 |
+
|
| 80 |
+
em_score = correct_em / total
|
| 81 |
+
logging.info("RESULT | model=%s | subset=%s | EM=%.4f",
|
| 82 |
+
args.model, args.dataset_subset, em_score)
|
| 83 |
+
|
| 84 |
+
print(f"\n=== SUMMARY ===\nModel : {args.model}"
|
| 85 |
+
f"\nDataset : {args.dataset_subset} ({args.split})"
|
| 86 |
+
f"\nEM : {em_score:.4f}\n"
|
| 87 |
+
f"(Appended to {args.log})")
|
detect/question_rephrase_answer_api.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from utils.metrics import qa_f1_score, qa_em_score # Import evaluation functions
|
| 9 |
+
|
| 10 |
+
# Configure OpenAI API
|
| 11 |
+
client = OpenAI(
|
| 12 |
+
api_key=os.environ.get("OPENAI_API_KEY"),
|
| 13 |
+
base_url=os.environ.get("OPENAI_BASE_URL")
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
def get_openai_response(prompt, model="gpt-4o", retries=3, delay=2):
|
| 17 |
+
"""Call OpenAI API to get response with retry mechanism"""
|
| 18 |
+
for attempt in range(retries):
|
| 19 |
+
try:
|
| 20 |
+
completion = client.chat.completions.create(
|
| 21 |
+
model=model,
|
| 22 |
+
messages=[{'role': 'user', 'content': prompt}],
|
| 23 |
+
max_tokens=100
|
| 24 |
+
)
|
| 25 |
+
return completion.choices[0].message.content.strip()
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f"Attempt {attempt + 1} failed: {e}")
|
| 28 |
+
if attempt < retries - 1:
|
| 29 |
+
print(f"Retrying in {delay} seconds...")
|
| 30 |
+
time.sleep(delay)
|
| 31 |
+
else:
|
| 32 |
+
print("Max retries reached. Skipping this request.")
|
| 33 |
+
return "Failed to get response"
|
| 34 |
+
|
| 35 |
+
def rephrase_question_api(question, model_name, rephrase_type="opposite"):
|
| 36 |
+
"""Use OpenAI API to rephrase question (English prompt)"""
|
| 37 |
+
if rephrase_type == "opposite":
|
| 38 |
+
prompt = f"""Please rephrase the following question to have the exact opposite meaning.
|
| 39 |
+
Question: {question}
|
| 40 |
+
|
| 41 |
+
Return only the rephrased question with the opposite meaning, without any explanations or other content."""
|
| 42 |
+
elif rephrase_type == "similar":
|
| 43 |
+
prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording:
|
| 44 |
+
Question: {question}
|
| 45 |
+
|
| 46 |
+
Return only the rephrased question, without any explanations or other content."""
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.")
|
| 49 |
+
|
| 50 |
+
return get_openai_response(prompt, model=model_name)
|
| 51 |
+
|
| 52 |
+
def answer_question_with_context_api(question, context, model_name, max_tokens_for_answer=30):
|
| 53 |
+
"""Use OpenAI API to answer question based on context (English prompt)"""
|
| 54 |
+
prompt = f"""Please answer the question based on the following context:
|
| 55 |
+
|
| 56 |
+
Context:
|
| 57 |
+
{context}
|
| 58 |
+
|
| 59 |
+
Question: {question}
|
| 60 |
+
|
| 61 |
+
Only output the answer, no any other text. If the answer is not in the context, please say "I don't know".
|
| 62 |
+
|
| 63 |
+
Answer:"""
|
| 64 |
+
try:
|
| 65 |
+
completion = client.chat.completions.create(
|
| 66 |
+
model=model_name,
|
| 67 |
+
messages=[{'role': 'user', 'content': prompt}],
|
| 68 |
+
max_tokens=max_tokens_for_answer
|
| 69 |
+
)
|
| 70 |
+
return completion.choices[0].message.content.strip()
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Answer generation failed for model {model_name}: {e}")
|
| 73 |
+
return "Failed to get answer"
|
| 74 |
+
|
| 75 |
+
def main(args):
|
| 76 |
+
# Load dataset
|
| 77 |
+
print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...")
|
| 78 |
+
try:
|
| 79 |
+
dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"]
|
| 80 |
+
print(f"Successfully loaded dataset with {len(dataset)} samples.")
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Failed to load dataset: {e}")
|
| 83 |
+
return
|
| 84 |
+
|
| 85 |
+
em_match_count = 0 # Counter for EM matches
|
| 86 |
+
successfully_processed_samples = 0 # Counter for successfully processed samples
|
| 87 |
+
|
| 88 |
+
num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset))
|
| 89 |
+
|
| 90 |
+
print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with {args.model_name} (max 30 tokens for answer)...")
|
| 91 |
+
|
| 92 |
+
for i in tqdm(range(num_samples_to_process), desc="Processing samples"):
|
| 93 |
+
example = dataset[i]
|
| 94 |
+
original_question = example['input']
|
| 95 |
+
context = example['context']
|
| 96 |
+
ground_truth_answers = example['answers']
|
| 97 |
+
|
| 98 |
+
print(f"Original question: {original_question}")
|
| 99 |
+
|
| 100 |
+
# Use API to rephrase question, fixed using gpt-4o
|
| 101 |
+
rephrased_question = rephrase_question_api(original_question, "gpt-4o", args.rephrase_type)
|
| 102 |
+
print(f"Rephrased question (opposite): {rephrased_question}")
|
| 103 |
+
|
| 104 |
+
if rephrased_question == "Failed to get response" or rephrased_question == "Failed to rephrase question": # Broader check
|
| 105 |
+
print(f"Skipping sample {i+1} due to rephrasing failure.")
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
# Use rephrased question and context to get answer, using args.model_name, answer length limited to 30 tokens
|
| 109 |
+
rephrased_answer = answer_question_with_context_api(rephrased_question, context, args.model_name, max_tokens_for_answer=30)
|
| 110 |
+
# print(f"Answer to rephrased question: {rephrased_answer}")
|
| 111 |
+
|
| 112 |
+
if rephrased_answer == "Failed to get answer":
|
| 113 |
+
print(f"Skipping sample {i+1} due to answer generation failure.")
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
if not ground_truth_answers:
|
| 117 |
+
print(f"Skipping sample {i+1} due to missing ground truth answers.")
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
successfully_processed_samples += 1
|
| 121 |
+
sample_had_em_match = False
|
| 122 |
+
for gt_ans in ground_truth_answers:
|
| 123 |
+
em = qa_em_score(rephrased_answer, gt_ans)
|
| 124 |
+
if em > 0: # EM is 1.0 for a match
|
| 125 |
+
sample_had_em_match = True
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
if sample_had_em_match:
|
| 129 |
+
em_match_count += 1
|
| 130 |
+
# print(f"Sample EM with original GT: {1 if sample_had_em_match else 0}")
|
| 131 |
+
|
| 132 |
+
if successfully_processed_samples > 0:
|
| 133 |
+
print(f"\n--- Evaluation Summary ---")
|
| 134 |
+
print(f"Answering Model : {args.model_name}")
|
| 135 |
+
print(f"Dataset : {args.dataset_name} ({args.dataset_subset})")
|
| 136 |
+
print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}")
|
| 137 |
+
print(f"Max Answer Tokens: 30")
|
| 138 |
+
print(f"Count of EM with original ground truth (after rephrase): {em_match_count}")
|
| 139 |
+
else:
|
| 140 |
+
print("\nNo samples were processed adequately to provide an evaluation summary.")
|
| 141 |
+
|
| 142 |
+
print("Processing complete!")
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
parser = argparse.ArgumentParser(description="Rephrase questions to opposite meaning with GPT-4o, answer with specified OpenAI model, then count EM against original GT.")
|
| 146 |
+
parser.add_argument("--model_name", type=str, default="gpt-4o", help="Name of the OpenAI model to use for Answering.")
|
| 147 |
+
parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.")
|
| 148 |
+
parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.")
|
| 149 |
+
parser.add_argument("--sample_count", type=int, default=-1, help="Number of samples to process. -1 for all samples.")
|
| 150 |
+
parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.")
|
| 151 |
+
|
| 152 |
+
args = parser.parse_args()
|
| 153 |
+
main(args)
|
detect/question_rephrase_answer_qwen3.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
import torch
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
+
from openai import OpenAI # Added for GPT-4o rephrasing
|
| 9 |
+
from utils.metrics import qa_f1_score, qa_em_score
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
THINK_END_ID = 151668 # </think> token ID for Qwen models (like Qwen1.5/Qwen2)
|
| 13 |
+
|
| 14 |
+
# --- OpenAI Client for Rephrasing ---
|
| 15 |
+
openai_client = OpenAI(
|
| 16 |
+
api_key=os.environ.get("OPENAI_API_KEY"),
|
| 17 |
+
base_url=os.environ.get("OPENAI_BASE_URL")
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
def get_openai_rephrase_response(prompt, model="gpt-4o", retries=3, delay=2):
|
| 21 |
+
"""Call OpenAI API for rephrasing."""
|
| 22 |
+
for attempt in range(retries):
|
| 23 |
+
try:
|
| 24 |
+
completion = openai_client.chat.completions.create(
|
| 25 |
+
model=model,
|
| 26 |
+
messages=[{'role': 'user', 'content': prompt}],
|
| 27 |
+
max_tokens=100
|
| 28 |
+
)
|
| 29 |
+
return completion.choices[0].message.content.strip()
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"OpenAI Rephrase attempt {attempt + 1} failed: {e}")
|
| 32 |
+
if attempt < retries - 1:
|
| 33 |
+
print(f"Retrying OpenAI rephrase in {delay} seconds...")
|
| 34 |
+
time.sleep(delay)
|
| 35 |
+
else:
|
| 36 |
+
print("Max retries for OpenAI rephrase reached.")
|
| 37 |
+
return "Failed to rephrase question"
|
| 38 |
+
|
| 39 |
+
def rephrase_question_with_gpt4o(question, rephrase_type="opposite"):
|
| 40 |
+
if rephrase_type == "opposite":
|
| 41 |
+
prompt = f"""Please rephrase the following question to have the exact opposite meaning.
|
| 42 |
+
Question: {question}
|
| 43 |
+
|
| 44 |
+
Return only the rephrased question with the opposite meaning, without any explanations or other content."""
|
| 45 |
+
elif rephrase_type == "similar":
|
| 46 |
+
prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording:
|
| 47 |
+
Question: {question}
|
| 48 |
+
|
| 49 |
+
Return only the rephrased question, without any explanations or other content."""
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.")
|
| 52 |
+
|
| 53 |
+
return get_openai_rephrase_response(prompt)
|
| 54 |
+
|
| 55 |
+
# --- Qwen3-Specific Hugging Face Model Functions (for Answering) ---
|
| 56 |
+
def get_qwen3_hf_response(prompt_text, model, tokenizer, device, max_new_tokens=40, retries=2, delay=5):
|
| 57 |
+
"""Generate a response from a Qwen3-like HF model. max_new_tokens default to 30."""
|
| 58 |
+
for attempt in range(retries):
|
| 59 |
+
try:
|
| 60 |
+
messages = [{"role": "user", 'content': prompt_text}]
|
| 61 |
+
|
| 62 |
+
chat_template_args = {
|
| 63 |
+
"tokenize": False,
|
| 64 |
+
"add_generation_prompt": True
|
| 65 |
+
}
|
| 66 |
+
# Qwen models (like Qwen1.5, Qwen2) often use/support enable_thinking
|
| 67 |
+
# Check if tokenizer's apply_chat_template supports 'enable_thinking'
|
| 68 |
+
# This check is simplified; for robust production, inspect.signature might be better
|
| 69 |
+
# but for Qwen-specific, we assume it or it gracefully ignores.
|
| 70 |
+
try:
|
| 71 |
+
# Attempt to use enable_thinking=False for Qwen models
|
| 72 |
+
processed_prompt = tokenizer.apply_chat_template(
|
| 73 |
+
messages, **chat_template_args, enable_thinking=False
|
| 74 |
+
)
|
| 75 |
+
except TypeError:
|
| 76 |
+
# Fallback if enable_thinking is not a valid kwarg for the specific tokenizer version
|
| 77 |
+
print("Warning: Tokenizer does not support 'enable_thinking' in apply_chat_template. Proceeding without it.")
|
| 78 |
+
processed_prompt = tokenizer.apply_chat_template(messages, **chat_template_args)
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Warning: Error applying chat template: {e}. Using raw prompt.")
|
| 81 |
+
processed_prompt = prompt_text # Fallback to raw prompt
|
| 82 |
+
|
| 83 |
+
inputs = tokenizer(processed_prompt, return_tensors="pt", padding=True, truncation=True).to(device)
|
| 84 |
+
|
| 85 |
+
generated_ids_full = model.generate(
|
| 86 |
+
inputs.input_ids,
|
| 87 |
+
attention_mask=inputs.attention_mask,
|
| 88 |
+
max_new_tokens=max_new_tokens,
|
| 89 |
+
pad_token_id=tokenizer.eos_token_id
|
| 90 |
+
)
|
| 91 |
+
# Get only newly generated tokens
|
| 92 |
+
output_only_ids_list = generated_ids_full[0][inputs.input_ids.shape[1]:].tolist()
|
| 93 |
+
|
| 94 |
+
# Strip <think>...</think> tags specifically for Qwen
|
| 95 |
+
try:
|
| 96 |
+
# Find the last occurrence of THINK_END_ID and take tokens after it
|
| 97 |
+
cut_index = len(output_only_ids_list) - output_only_ids_list[::-1].index(THINK_END_ID)
|
| 98 |
+
final_ids_to_decode = output_only_ids_list[cut_index:]
|
| 99 |
+
except ValueError:
|
| 100 |
+
# THINK_END_ID not found, use all generated new tokens
|
| 101 |
+
final_ids_to_decode = output_only_ids_list
|
| 102 |
+
|
| 103 |
+
response = tokenizer.decode(final_ids_to_decode, skip_special_tokens=True).strip()
|
| 104 |
+
return response
|
| 105 |
+
except Exception as e:
|
| 106 |
+
print(f"Qwen HF Model generation attempt {attempt + 1} failed: {e}")
|
| 107 |
+
if attempt < retries - 1:
|
| 108 |
+
print(f"Retrying in {delay} seconds...")
|
| 109 |
+
time.sleep(delay)
|
| 110 |
+
else:
|
| 111 |
+
print("Max retries for Qwen HF model reached. Skipping this request.")
|
| 112 |
+
return "Failed to get Qwen HF response"
|
| 113 |
+
|
| 114 |
+
def answer_question_with_context_qwen3_hf(question, context, model, tokenizer, device):
|
| 115 |
+
"""Answer a question with context using a Qwen3-like HF model."""
|
| 116 |
+
prompt = f"""Please answer the question based on the following context:
|
| 117 |
+
|
| 118 |
+
Context:
|
| 119 |
+
{context}
|
| 120 |
+
|
| 121 |
+
Question: {question}
|
| 122 |
+
|
| 123 |
+
Only output the answer, no any other text. If the answer is not in the context, please say "I don't know".
|
| 124 |
+
|
| 125 |
+
Answer:"""
|
| 126 |
+
return get_qwen3_hf_response(prompt, model, tokenizer, device)
|
| 127 |
+
|
| 128 |
+
def main(args):
|
| 129 |
+
hf_device_setting = "auto"
|
| 130 |
+
print(f"Attempting to use device: {hf_device_setting} for Qwen HF model.")
|
| 131 |
+
|
| 132 |
+
print(f"Loading Qwen HF model for Answering: {args.model_name}...")
|
| 133 |
+
hf_model = None
|
| 134 |
+
hf_tokenizer = None
|
| 135 |
+
try:
|
| 136 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code_hf)
|
| 137 |
+
hf_model = AutoModelForCausalLM.from_pretrained(
|
| 138 |
+
args.model_name,
|
| 139 |
+
device_map=hf_device_setting,
|
| 140 |
+
trust_remote_code=args.trust_remote_code_hf,
|
| 141 |
+
torch_dtype="bfloat16"
|
| 142 |
+
)
|
| 143 |
+
hf_model.eval()
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
print(f"Successfully loaded Qwen HF model {args.model_name}.")
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"Failed to load Qwen HF model {args.model_name}: {e}")
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...")
|
| 153 |
+
try:
|
| 154 |
+
dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"]
|
| 155 |
+
print(f"Successfully loaded dataset with {len(dataset)} samples.")
|
| 156 |
+
except Exception as e:
|
| 157 |
+
print(f"Failed to load dataset: {e}")
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
em_match_count = 0 # Counter for EM matches
|
| 161 |
+
em_match_original_count = 0 # Counter for EM matches
|
| 162 |
+
successfully_processed_samples = 0 # Counter for successfully processed samples
|
| 163 |
+
|
| 164 |
+
num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset))
|
| 165 |
+
|
| 166 |
+
print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with Qwen HF model {args.model_name} (max 30 tokens)...")
|
| 167 |
+
|
| 168 |
+
for i in tqdm(range(num_samples_to_process), desc="Processing samples"):
|
| 169 |
+
example = dataset[i]
|
| 170 |
+
original_question = example['input']
|
| 171 |
+
context = example['context']
|
| 172 |
+
ground_truth_answers = example['answers']
|
| 173 |
+
print(original_question)
|
| 174 |
+
|
| 175 |
+
rephrased_question = rephrase_question_with_gpt4o(original_question, args.rephrase_type)
|
| 176 |
+
print(rephrased_question)
|
| 177 |
+
|
| 178 |
+
if rephrased_question == "Failed to rephrase question":
|
| 179 |
+
print(f"Skipping sample {i+1} due to rephrasing failure.")
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
rephrased_answer = answer_question_with_context_qwen3_hf(rephrased_question, context, hf_model, hf_tokenizer, hf_model.device)
|
| 183 |
+
print(rephrased_answer)
|
| 184 |
+
original_answer = answer_question_with_context_qwen3_hf(original_question, context, hf_model, hf_tokenizer, hf_model.device)
|
| 185 |
+
if not ground_truth_answers:
|
| 186 |
+
print(f"Skipping sample {i+1} due to missing ground truth answers.")
|
| 187 |
+
continue
|
| 188 |
+
print(original_answer)
|
| 189 |
+
successfully_processed_samples += 1
|
| 190 |
+
|
| 191 |
+
sample_had_em_match = False
|
| 192 |
+
for gt_ans in ground_truth_answers:
|
| 193 |
+
em = qa_em_score(rephrased_answer, gt_ans)
|
| 194 |
+
if em > 0: # Check for exact match (assuming qa_em_score returns 1.0 for EM)
|
| 195 |
+
sample_had_em_match = True
|
| 196 |
+
break
|
| 197 |
+
|
| 198 |
+
if sample_had_em_match:
|
| 199 |
+
em_match_count += 1
|
| 200 |
+
|
| 201 |
+
sample_had_em_match = False
|
| 202 |
+
for gt_ans in ground_truth_answers:
|
| 203 |
+
em = qa_em_score(original_answer, gt_ans)
|
| 204 |
+
if em > 0: # Check for exact match (assuming qa_em_score returns 1.0 for EM)
|
| 205 |
+
sample_had_em_match = True
|
| 206 |
+
break
|
| 207 |
+
if sample_had_em_match:
|
| 208 |
+
em_match_original_count += 1
|
| 209 |
+
|
| 210 |
+
if successfully_processed_samples > 0:
|
| 211 |
+
print(f"\n--- Evaluation Summary ---")
|
| 212 |
+
print(f"Answering Qwen HF Model: {args.model_name}")
|
| 213 |
+
print(f"Dataset: {args.dataset_name} ({args.dataset_subset})")
|
| 214 |
+
print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}")
|
| 215 |
+
print(f"Count of EM with original ground truth (after rephrase): {em_match_count}")
|
| 216 |
+
print(f"Count of EM with original ground truth (before rephrase): {em_match_original_count}")
|
| 217 |
+
else:
|
| 218 |
+
print("\nNo samples were processed adequately to provide an evaluation summary.")
|
| 219 |
+
|
| 220 |
+
print("Processing complete!")
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
parser = argparse.ArgumentParser(description="Rephrase with GPT-4o, Answer with local Qwen3-like HF Model, then Evaluate.")
|
| 224 |
+
parser.add_argument("--model_name", type=str, default="Qwen/Qwen1.5-7B-Chat", help="Name of the Qwen3-like Hugging Face model for Answering.")
|
| 225 |
+
parser.add_argument("--trust_remote_code_hf", action="store_true", default=True, help="Set to true if the Hugging Face model requires remote code (default: True for Qwen). Argument is present for explicitness but defaults to True.")
|
| 226 |
+
parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.")
|
| 227 |
+
parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.")
|
| 228 |
+
parser.add_argument("--sample_count", type=int, default=5, help="Number of samples to process. -1 for all. Default: 5.")
|
| 229 |
+
parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.")
|
| 230 |
+
|
| 231 |
+
args = parser.parse_args()
|
| 232 |
+
if openai_client.api_key == "your_api_key_here":
|
| 233 |
+
print("CRITICAL ERROR: Please replace 'your_api_key_here' with your actual OpenAI API key in the script.")
|
| 234 |
+
else:
|
| 235 |
+
main(args)
|
detect/question_rephrase_answer_vllm.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import os
|
| 3 |
+
import argparse
|
| 4 |
+
# import torch # torch might not be directly needed if vLLM handles all device aspects
|
| 5 |
+
from datasets import load_dataset
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from openai import OpenAI # For GPT-4o rephrasing
|
| 8 |
+
from vllm import LLM, SamplingParams # For vLLM inference
|
| 9 |
+
from transformers import AutoTokenizer # Import AutoTokenizer
|
| 10 |
+
from utils.metrics import qa_f1_score, qa_em_score
|
| 11 |
+
|
| 12 |
+
# This will be respected by vLLM if CUDA_VISIBLE_DEVICES is set before vLLM import
|
| 13 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" # User can set this outside the script
|
| 14 |
+
|
| 15 |
+
# --- OpenAI Client for Rephrasing ---
|
| 16 |
+
openai_client = OpenAI(
|
| 17 |
+
api_key=os.environ.get("OPENAI_API_KEY"),
|
| 18 |
+
base_url=os.environ.get("OPENAI_BASE_URL")
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def get_openai_rephrase_response(prompt, model="gpt-4o", retries=3, delay=2):
|
| 22 |
+
"""Call OpenAI API for rephrasing."""
|
| 23 |
+
for attempt in range(retries):
|
| 24 |
+
try:
|
| 25 |
+
completion = openai_client.chat.completions.create(
|
| 26 |
+
model=model,
|
| 27 |
+
messages=[{'role': 'user', 'content': prompt}],
|
| 28 |
+
max_tokens=100 # Max tokens for rephrased question
|
| 29 |
+
)
|
| 30 |
+
return completion.choices[0].message.content.strip()
|
| 31 |
+
except Exception as e:
|
| 32 |
+
print(f"OpenAI Rephrase attempt {attempt + 1} failed: {e}")
|
| 33 |
+
if attempt < retries - 1:
|
| 34 |
+
print(f"Retrying OpenAI rephrase in {delay} seconds...")
|
| 35 |
+
time.sleep(delay)
|
| 36 |
+
else:
|
| 37 |
+
print("Max retries for OpenAI rephrase reached.")
|
| 38 |
+
return "Failed to rephrase question"
|
| 39 |
+
|
| 40 |
+
def rephrase_question_with_gpt4o(question, rephrase_type="opposite"):
|
| 41 |
+
"""Rephrase a question using GPT-4o (English prompt)."""
|
| 42 |
+
if rephrase_type == "opposite":
|
| 43 |
+
prompt = f"""Please rephrase the following question to have the exact opposite meaning.
|
| 44 |
+
Question: {question}
|
| 45 |
+
|
| 46 |
+
Return only the rephrased question with the opposite meaning, without any explanations or other content."""
|
| 47 |
+
elif rephrase_type == "similar":
|
| 48 |
+
prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording:
|
| 49 |
+
Question: {question}
|
| 50 |
+
|
| 51 |
+
Return only the rephrased question, without any explanations or other content."""
|
| 52 |
+
else:
|
| 53 |
+
raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.")
|
| 54 |
+
|
| 55 |
+
return get_openai_rephrase_response(prompt)
|
| 56 |
+
|
| 57 |
+
# --- vLLM Model Functions (for Answering) ---
|
| 58 |
+
def get_vllm_response(prompt_text, llm_instance, sampling_params_instance, retries=2, delay=5):
|
| 59 |
+
"""Generate a response from a vLLM instance."""
|
| 60 |
+
for attempt in range(retries):
|
| 61 |
+
try:
|
| 62 |
+
# vLLM generate method expects a list of prompts
|
| 63 |
+
outputs = llm_instance.generate([prompt_text], sampling_params_instance)
|
| 64 |
+
# For a single prompt, the result is in the first element of the output list
|
| 65 |
+
# Each output object has a list of `outputs` (for n>1 in SamplingParams)
|
| 66 |
+
response = outputs[0].outputs[0].text.strip()
|
| 67 |
+
return response
|
| 68 |
+
except Exception as e:
|
| 69 |
+
print(f"vLLM generation attempt {attempt + 1} failed: {e}")
|
| 70 |
+
if attempt < retries - 1:
|
| 71 |
+
print(f"Retrying vLLM generation in {delay} seconds...")
|
| 72 |
+
time.sleep(delay)
|
| 73 |
+
else:
|
| 74 |
+
print("Max retries for vLLM generation reached.")
|
| 75 |
+
return "Failed to get vLLM response"
|
| 76 |
+
|
| 77 |
+
def answer_question_with_context_vllm(question, context, llm_instance, sampling_params_instance, tokenizer):
|
| 78 |
+
"""Answer a question with context using a vLLM model and chat template (English prompt)."""
|
| 79 |
+
# Construct prompt using chat template, similar to evaluation.py
|
| 80 |
+
prompt_content = (
|
| 81 |
+
f"Answer the question based on the given passages. "
|
| 82 |
+
"Only give me your answer and do not output any other words.\\n"
|
| 83 |
+
"The following are given passages:\\n"
|
| 84 |
+
f"{context}\\n"
|
| 85 |
+
"Please strictly follow the context. "
|
| 86 |
+
f"Question: {question}\\n"
|
| 87 |
+
"Answer:"
|
| 88 |
+
)
|
| 89 |
+
messages = [{"role": "user", "content": prompt_content}]
|
| 90 |
+
|
| 91 |
+
# Apply chat template
|
| 92 |
+
# Note: Some tokenizers might not have a chat template configured, or might have different ways to apply it.
|
| 93 |
+
# This is a common way for many models.
|
| 94 |
+
try:
|
| 95 |
+
final_prompt_text = tokenizer.apply_chat_template(
|
| 96 |
+
messages,
|
| 97 |
+
tokenize=False,
|
| 98 |
+
add_generation_prompt=True
|
| 99 |
+
)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
print(f"Failed to apply chat template: {e}. Falling back to basic prompt string.")
|
| 102 |
+
# Fallback to a simpler prompt if template application fails
|
| 103 |
+
final_prompt_text = f"Context:\\n{context}\\n\\nQuestion: {question}\\n\\nAnswer:"
|
| 104 |
+
|
| 105 |
+
return get_vllm_response(final_prompt_text, llm_instance, sampling_params_instance)
|
| 106 |
+
|
| 107 |
+
def main(args):
|
| 108 |
+
# Load Tokenizer for the vLLM model
|
| 109 |
+
print(f"Loading tokenizer for model: {args.model_name}...")
|
| 110 |
+
try:
|
| 111 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code)
|
| 112 |
+
print("Successfully loaded tokenizer.")
|
| 113 |
+
except Exception as e:
|
| 114 |
+
print(f"Failed to load tokenizer for {args.model_name}: {e}")
|
| 115 |
+
print("Please ensure the model name is correct and the tokenizer can be loaded.")
|
| 116 |
+
return
|
| 117 |
+
|
| 118 |
+
# Load vLLM Model (for Answering)
|
| 119 |
+
print(f"Loading vLLM model for Answering: {args.model_name}...")
|
| 120 |
+
print(f"(This may take a while depending on the model size and download speed if not cached).")
|
| 121 |
+
vllm_model = None
|
| 122 |
+
try:
|
| 123 |
+
# You can expose more vLLM LLM parameters as args if needed
|
| 124 |
+
# (e.g., tensor_parallel_size, dtype, gpu_memory_utilization)
|
| 125 |
+
vllm_model = LLM(
|
| 126 |
+
model=args.model_name,
|
| 127 |
+
trust_remote_code=args.trust_remote_code,
|
| 128 |
+
dtype="bfloat16", # Use dtype from command line arguments
|
| 129 |
+
# Add other vLLM LLM constructor arguments here if needed, e.g.:
|
| 130 |
+
tensor_parallel_size=2
|
| 131 |
+
)
|
| 132 |
+
print(f"Successfully loaded vLLM model {args.model_name} with dtype='{args.dtype}' and tensor_parallel_size={args.tensor_parallel_size}.")
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(f"Failed to load vLLM model {args.model_name}: {e}")
|
| 135 |
+
print("Please ensure vLLM is installed correctly and the model identifier is valid.")
|
| 136 |
+
return
|
| 137 |
+
|
| 138 |
+
# Define Sampling Parameters for vLLM
|
| 139 |
+
# max_tokens is equivalent to max_new_tokens in HF
|
| 140 |
+
# temperature=0.0 for greedy decoding, good for QA tasks for more deterministic output.
|
| 141 |
+
# Adjust temperature (e.g., 0.7) and top_p (e.g., 0.95) for more diverse outputs if needed.
|
| 142 |
+
sampling_params = SamplingParams(temperature=0.0, max_tokens=30) # Set temperature to 0.0 for deterministic QA
|
| 143 |
+
|
| 144 |
+
# Load dataset
|
| 145 |
+
print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...")
|
| 146 |
+
try:
|
| 147 |
+
dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"]
|
| 148 |
+
print(f"Successfully loaded dataset with {len(dataset)} samples.")
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f"Failed to load dataset: {e}")
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
em_match_count = 0 # Counter for EM matches
|
| 154 |
+
em_match_original_count = 0 # Counter for EM matches
|
| 155 |
+
successfully_processed_samples = 0 # Counter for successfully processed samples
|
| 156 |
+
|
| 157 |
+
num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset))
|
| 158 |
+
|
| 159 |
+
print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with vLLM model {args.model_name} (max 30 tokens)...")
|
| 160 |
+
|
| 161 |
+
for i in tqdm(range(num_samples_to_process), desc="Processing samples with vLLM"):
|
| 162 |
+
example = dataset[i]
|
| 163 |
+
original_question = example['input']
|
| 164 |
+
context = example['context']
|
| 165 |
+
ground_truth_answers = example['answers']
|
| 166 |
+
|
| 167 |
+
rephrased_question = rephrase_question_with_gpt4o(original_question, args.rephrase_type) # Use new rephrasing
|
| 168 |
+
|
| 169 |
+
if rephrased_question == "Failed to rephrase question":
|
| 170 |
+
print(f"Skipping sample {i+1} due to rephrasing failure.")
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
rephrased_answer = answer_question_with_context_vllm(rephrased_question, context, vllm_model, sampling_params, tokenizer)
|
| 174 |
+
# print(f"Rephrased question: {rephrased_question}") # Optional: for debugging
|
| 175 |
+
# print(f"Answer to rephrased: {rephrased_answer}") # Optional: for debugging
|
| 176 |
+
|
| 177 |
+
original_answer = answer_question_with_context_vllm(original_question, context, vllm_model, sampling_params, tokenizer)
|
| 178 |
+
# print(f"Original question: {original_question}") # Optional: for debugging
|
| 179 |
+
# print(f"Answer to original: {original_answer}") # Optional: for debugging
|
| 180 |
+
|
| 181 |
+
if not ground_truth_answers:
|
| 182 |
+
print(f"Skipping sample {i+1} due to missing ground truth answers.")
|
| 183 |
+
continue
|
| 184 |
+
print(original_answer)
|
| 185 |
+
successfully_processed_samples += 1
|
| 186 |
+
|
| 187 |
+
sample_had_em_match = False
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
em_match_count += qa_em_score(rephrased_answer, ground_truth_answers[0])
|
| 193 |
+
|
| 194 |
+
sample_had_em_match = False
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
print(original_answer)
|
| 198 |
+
print(ground_truth_answers[0])
|
| 199 |
+
|
| 200 |
+
em_match_original_count += qa_em_score(original_answer, ground_truth_answers[0])
|
| 201 |
+
|
| 202 |
+
if successfully_processed_samples > 0:
|
| 203 |
+
print(f"Answering vLLM Model: {args.model_name}")
|
| 204 |
+
print(f"Dataset : {args.dataset_name} ({args.dataset_subset})")
|
| 205 |
+
print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}")
|
| 206 |
+
print(f"Max Answer Tokens : 30") # Reflects SamplingParams
|
| 207 |
+
print(f"Count of EM with original ground truth (after rephrase): {em_match_count}")
|
| 208 |
+
print(f"Count of EM with original ground truth (before rephrase): {em_match_original_count}")
|
| 209 |
+
else:
|
| 210 |
+
print("\nNo samples were processed adequately to provide an evaluation summary.")
|
| 211 |
+
|
| 212 |
+
print("vLLM processing complete!")
|
| 213 |
+
|
| 214 |
+
if __name__ == "__main__":
|
| 215 |
+
parser = argparse.ArgumentParser(description="Rephrase with GPT-4o, Answer with local vLLM-hosted Model, then Evaluate.")
|
| 216 |
+
parser.add_argument("--model_name", type=str, default="facebook/opt-125m", help="Name/path of the Hugging Face model for Answering via vLLM (e.g., 'mistralai/Mistral-7B-Instruct-v0.1').")
|
| 217 |
+
parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.")
|
| 218 |
+
parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.")
|
| 219 |
+
parser.add_argument("--sample_count", type=int, default=3, help="Number of samples to process. -1 for all. Default: 3 for quick testing.")
|
| 220 |
+
parser.add_argument("--trust_remote_code", action="store_true", help="Set to true if the Hugging Face model for vLLM requires remote code.")
|
| 221 |
+
parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Tensor parallel size for vLLM.")
|
| 222 |
+
parser.add_argument("--dtype", type=str, default="auto", help="Data type for the model. Examples: 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'. Default is 'auto'.")
|
| 223 |
+
parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.")
|
| 224 |
+
|
| 225 |
+
args = parser.parse_args()
|
| 226 |
+
main(args)
|
eval/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from .evaluation import *
|
| 2 |
+
|
eval/eval_with_api.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
#!/usr/bin/env python3
|
| 5 |
+
# -*- coding: utf-8 -*-
|
| 6 |
+
"""
|
| 7 |
+
Single-phase evaluator (DeepSeek API) — Calculate EM / F1 only.
|
| 8 |
+
|
| 9 |
+
Usage Example
|
| 10 |
+
--------
|
| 11 |
+
python eval_single_phase.py --input data/2wikimqa.jsonl
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import argparse, time, jsonlines, os
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from openai import OpenAI
|
| 18 |
+
from utils.metrics import qa_em_score, qa_f1_score
|
| 19 |
+
from utils.llmjudge import judge_answer_with_api
|
| 20 |
+
|
| 21 |
+
# -------------------- CLI --------------------
|
| 22 |
+
p = argparse.ArgumentParser("Single-phase evaluator")
|
| 23 |
+
p.add_argument("--input", required=True, help="Path to the *.jsonl file to evaluate")
|
| 24 |
+
p.add_argument("--model", default="deepseek-r1")
|
| 25 |
+
p.add_argument("--temperature", type=float, default=0.5)
|
| 26 |
+
p.add_argument("--max_tokens", type=int, default=30)
|
| 27 |
+
p.add_argument("--sleep", type=float, default=0.0)
|
| 28 |
+
args = p.parse_args()
|
| 29 |
+
|
| 30 |
+
client = OpenAI(
|
| 31 |
+
base_url=os.environ.get("OPENAI_BASE_URL"),
|
| 32 |
+
api_key=os.environ.get("OPENAI_API_KEY")
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# -------------------- helper --------------------
|
| 36 |
+
def ask(context: str, question: str) -> str:
|
| 37 |
+
"""Call DeepSeek to get answer (return final answer only)"""
|
| 38 |
+
messages = [
|
| 39 |
+
{"role": "system",
|
| 40 |
+
"content": ("You are a QA assistant. "
|
| 41 |
+
"Answer strictly based on the passages; "
|
| 42 |
+
"output only the final answer.")},
|
| 43 |
+
{"role": "user",
|
| 44 |
+
"content": f"Answer the question and output only the final answer without extra words. Passages:\n{context}\n\nQuestion: {question}\nAnswer:"}
|
| 45 |
+
]
|
| 46 |
+
resp = client.chat.completions.create(
|
| 47 |
+
model=args.model,
|
| 48 |
+
messages=messages,
|
| 49 |
+
temperature=args.temperature,
|
| 50 |
+
max_tokens=args.max_tokens
|
| 51 |
+
)
|
| 52 |
+
if not resp.choices[0].message.content:
|
| 53 |
+
return "None"
|
| 54 |
+
|
| 55 |
+
return resp.choices[0].message.content.strip()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# -------------------- core eval --------------------
|
| 59 |
+
def evaluate_file(path: Path):
|
| 60 |
+
dataset = path.stem
|
| 61 |
+
data = {obj["input"]: obj for obj in jsonlines.open(path)}
|
| 62 |
+
|
| 63 |
+
total = len(data)
|
| 64 |
+
em_hits = 0
|
| 65 |
+
f1_sum = 0.0
|
| 66 |
+
|
| 67 |
+
for q, item in tqdm(data.items(), desc=f"{dataset}"):
|
| 68 |
+
ctx = item["context"]
|
| 69 |
+
golds = item["answers"] if isinstance(item["answers"], list) else [item["answers"]]
|
| 70 |
+
|
| 71 |
+
pred = ask(ctx, q).split('.', 1)[0] # Cut off extra explanations
|
| 72 |
+
if pred == "None":
|
| 73 |
+
continue
|
| 74 |
+
em = max(qa_em_score(pred, g) for g in golds)
|
| 75 |
+
f1 = max(qa_f1_score(pred, g) for g in golds)
|
| 76 |
+
|
| 77 |
+
em_hits += em
|
| 78 |
+
f1_sum += f1
|
| 79 |
+
if args.sleep:
|
| 80 |
+
time.sleep(args.sleep)
|
| 81 |
+
|
| 82 |
+
print(f"\n=== {dataset.upper()} SUMMARY ===")
|
| 83 |
+
print(f"Total samples : {total}")
|
| 84 |
+
print(f"Exact Match : {em_hits}/{total} ({em_hits/total:.2%})")
|
| 85 |
+
print(f"Average F1 : {f1_sum/total:.4f}")
|
| 86 |
+
print("-" * 40 + "\n")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# -------------------- run --------------------
|
| 90 |
+
input_path = Path(args.input)
|
| 91 |
+
if not input_path.exists():
|
| 92 |
+
raise SystemExit(f"File does not exist: {input_path}")
|
| 93 |
+
|
| 94 |
+
evaluate_file(input_path)
|
eval/evaluation.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import argparse, os, jsonlines, torch
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
| 4 |
+
from utils.metrics import qa_f1_score, qa_em_score
|
| 5 |
+
THINK_END_ID = 151668 # "</think>" token id for Qwen3
|
| 6 |
+
|
| 7 |
+
# --------------------------------------------------
|
| 8 |
+
def strip_think(token_ids):
|
| 9 |
+
try:
|
| 10 |
+
cut = len(token_ids) - token_ids[::-1].index(THINK_END_ID)
|
| 11 |
+
return token_ids[cut:]
|
| 12 |
+
except ValueError:
|
| 13 |
+
return token_ids
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
# ---------- CLI ----------
|
| 17 |
+
parser = argparse.ArgumentParser(
|
| 18 |
+
description="Evaluate HotpotQA JSONL with Transformers + Qwen3-8B"
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument("-i", "--input", required=True,
|
| 21 |
+
help="Path to input JSONL file")
|
| 22 |
+
parser.add_argument("--model", required=True,
|
| 23 |
+
help="HF model name, e.g. Qwen/Qwen3-8B")
|
| 24 |
+
parser.add_argument("-d", "--devices", default="0",
|
| 25 |
+
help="CUDA_VISIBLE_DEVICES (comma-separated)")
|
| 26 |
+
parser.add_argument("-t", "--temperature", type=float, default=0.5,
|
| 27 |
+
help="Sampling temperature")
|
| 28 |
+
parser.add_argument("-k", "--max_tokens", type=int, default=40,
|
| 29 |
+
help="max_new_tokens")
|
| 30 |
+
args = parser.parse_args()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
|
| 35 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
+
args.model,
|
| 37 |
+
torch_dtype="auto",
|
| 38 |
+
device_map="auto",
|
| 39 |
+
trust_remote_code=True
|
| 40 |
+
)
|
| 41 |
+
gen_cfg = GenerationConfig(
|
| 42 |
+
temperature=args.temperature,
|
| 43 |
+
max_new_tokens=args.max_tokens,
|
| 44 |
+
do_sample=args.temperature > 0
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
with jsonlines.open(args.input) as reader:
|
| 49 |
+
data = list(reader)
|
| 50 |
+
|
| 51 |
+
total_f1 = total_em = 0.0
|
| 52 |
+
|
| 53 |
+
for idx, item in enumerate(data):
|
| 54 |
+
question = item.get("input", "")
|
| 55 |
+
context = item.get("context", "")
|
| 56 |
+
answers = item.get("answers", [])
|
| 57 |
+
if not answers:
|
| 58 |
+
print(f"[{idx}] no gold answer, skip")
|
| 59 |
+
continue
|
| 60 |
+
gold = answers[0]
|
| 61 |
+
print(gold)
|
| 62 |
+
|
| 63 |
+
# ----- Prompt -----
|
| 64 |
+
prompt = (
|
| 65 |
+
"Answer the question based on the given passages. "
|
| 66 |
+
"Only give me your answer and do not output any other words.\n"
|
| 67 |
+
"Passages:\n"
|
| 68 |
+
f"{context}\n"
|
| 69 |
+
f"Question: {question}\n"
|
| 70 |
+
"Answer:"
|
| 71 |
+
)
|
| 72 |
+
messages = [{"role": "user", "content": prompt}]
|
| 73 |
+
chat_text = tokenizer.apply_chat_template(
|
| 74 |
+
messages,
|
| 75 |
+
tokenize=False,
|
| 76 |
+
add_generation_prompt=True,
|
| 77 |
+
enable_thinking=False
|
| 78 |
+
)
|
| 79 |
+
inputs = tokenizer([chat_text], return_tensors="pt").to(model.device)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ----- Generate -----
|
| 83 |
+
try:
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
outputs = model.generate(**inputs, max_new_tokens=args.max_tokens)
|
| 86 |
+
except ValueError as e:
|
| 87 |
+
if "position ids exceed" in str(e).lower() or "sequence length" in str(e).lower():
|
| 88 |
+
print(f"[{idx}] prompt too long – skipped")
|
| 89 |
+
continue
|
| 90 |
+
raise
|
| 91 |
+
print("im here")
|
| 92 |
+
new_ids = outputs[0][len(inputs.input_ids[0]):].tolist()
|
| 93 |
+
try:
|
| 94 |
+
index = len(new_ids) - new_ids[::-1].index(151668)
|
| 95 |
+
except ValueError:
|
| 96 |
+
index = 0
|
| 97 |
+
answer = tokenizer.decode(new_ids[index:], skip_special_tokens=True).strip("\n")
|
| 98 |
+
answer = answer.strip()
|
| 99 |
+
|
| 100 |
+
# ----- Score -----
|
| 101 |
+
f1 = qa_f1_score(answer, gold)
|
| 102 |
+
em = qa_em_score(answer, gold)
|
| 103 |
+
total_f1 += f1
|
| 104 |
+
total_em += em
|
| 105 |
+
|
| 106 |
+
print(f"[{idx}] Q: {question}")
|
| 107 |
+
print(f" Resp: {answer!r} | Gold: {gold!r}")
|
| 108 |
+
print(f" F1={f1:.2f}, EM={em:.2f}")
|
| 109 |
+
|
| 110 |
+
n = len(data)
|
| 111 |
+
print(f"\nOverall F1: {total_f1/n:.4f}")
|
| 112 |
+
print(f"Overall EM: {total_em/n:.4f}")
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
main()
|
main_gpu.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
import argparse
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from haystack import Pipeline, Document
|
| 8 |
+
from haystack.utils import Secret
|
| 9 |
+
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
| 10 |
+
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
| 11 |
+
from haystack import Document # haystack Document for storing context
|
| 12 |
+
from nltk.tokenize import sent_tokenize
|
| 13 |
+
from utils.util import retriveDoc,compute_best_sentence_f1
|
| 14 |
+
from openai import OpenAI
|
| 15 |
+
import asyncio, json, torch, math
|
| 16 |
+
from typing import List, Tuple
|
| 17 |
+
# Hugging Face transformers related
|
| 18 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 19 |
+
from utils.metrics import qa_f1_score
|
| 20 |
+
from utils.llmjudge import judge_answer_with_api
|
| 21 |
+
|
| 22 |
+
client = OpenAI(
|
| 23 |
+
base_url=os.environ.get("OPENAI_BASE_URL"),
|
| 24 |
+
api_key=os.environ.get("OPENAI_API_KEY")
|
| 25 |
+
)
|
| 26 |
+
# Load models using transformers
|
| 27 |
+
|
| 28 |
+
tokenizer1 = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
|
| 29 |
+
model1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True,device_map="cuda:0",torch_dtype=torch.bfloat16)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
tok_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True)
|
| 33 |
+
model_qwen = AutoModelForCausalLM.from_pretrained(
|
| 34 |
+
"Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True,
|
| 35 |
+
device_map="cuda:1",torch_dtype=torch.bfloat16
|
| 36 |
+
).eval()
|
| 37 |
+
|
| 38 |
+
def get_transformers_answer(prompt, tokenizer, model, max_new_tokens=100, temperature=0.7, top_p=0.9, retries=3, delay=5):
|
| 39 |
+
"""
|
| 40 |
+
Use transformers model.generate method for inference with retry mechanism,
|
| 41 |
+
strip the input prompt part through token-level slicing,
|
| 42 |
+
and return the newly generated text.
|
| 43 |
+
"""
|
| 44 |
+
import time
|
| 45 |
+
for attempt in range(retries):
|
| 46 |
+
try:
|
| 47 |
+
# Encode prompt as model input tensor
|
| 48 |
+
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 49 |
+
# Call generate, the generated id sequence contains both prompt and subsequent generated text
|
| 50 |
+
generated_ids = model.generate(
|
| 51 |
+
**model_inputs,
|
| 52 |
+
max_new_tokens=max_new_tokens,
|
| 53 |
+
temperature=temperature,
|
| 54 |
+
top_p=top_p
|
| 55 |
+
)
|
| 56 |
+
# Calculate the token count corresponding to the prompt
|
| 57 |
+
input_length = model_inputs.input_ids.shape[1]
|
| 58 |
+
# Strip the prompt part from the front of the output, keeping only the newly added part
|
| 59 |
+
output_ids = generated_ids[0][input_length:]
|
| 60 |
+
# Decode generated text
|
| 61 |
+
answer = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
|
| 62 |
+
return answer
|
| 63 |
+
except Exception as e:
|
| 64 |
+
print(f"Error on attempt {attempt + 1}: {e}")
|
| 65 |
+
if attempt < retries - 1:
|
| 66 |
+
print(f"Retrying in {delay} seconds...")
|
| 67 |
+
time.sleep(delay)
|
| 68 |
+
else:
|
| 69 |
+
print("Max retries reached, skipping this request.")
|
| 70 |
+
return None
|
| 71 |
+
|
| 72 |
+
def truncate_answer(answer):
|
| 73 |
+
"""Truncate answer, only take the part before the first period"""
|
| 74 |
+
return answer.split('.')[0].strip() if answer else "No answer"
|
| 75 |
+
|
| 76 |
+
def write_to_log(filename, data):
|
| 77 |
+
"""Write data to log file"""
|
| 78 |
+
with open(filename, 'a', encoding='utf-8') as file:
|
| 79 |
+
file.write(data + '\n')
|
| 80 |
+
|
| 81 |
+
def remove_think_tags(text: str) -> str:
|
| 82 |
+
"""Remove all <think> ... </think> blocks"""
|
| 83 |
+
return re.sub(r'<think>(.*?)</think>', '', text, flags=re.DOTALL).strip()
|
| 84 |
+
|
| 85 |
+
def build_prompt(context: str, question: str) -> str:
|
| 86 |
+
prompt = (
|
| 87 |
+
f"Answer the question based on the given passages. The following are the passages:\n"
|
| 88 |
+
f"{context}\n"
|
| 89 |
+
f"Answer the question based on the given passages.\n"
|
| 90 |
+
f"Question: {question}.\n"
|
| 91 |
+
f"Please first provide your answer in the format of Answer:[Your answer]. Then provide your reasoning process step-by-step.(Only include explicit clues) "
|
| 92 |
+
f"At the end of each reasoning step, include a new line that specifies the key information or reference content used in that step. "
|
| 93 |
+
f"Please ensure that the [reference content] you include is the complete original sentence or consecutive sentences from the text. Please do not change the punctuation. Do not use ellipses inside the sentence. "
|
| 94 |
+
f"Follow this format:\n"
|
| 95 |
+
f"Answer: [Your answer]\n"
|
| 96 |
+
f"Step-by-step Reasoning:\n"
|
| 97 |
+
f"1. [Reasoning step 1]\n"
|
| 98 |
+
f"[replaced by your reference content]\n"
|
| 99 |
+
f"2. [Reasoning step 2]\n"
|
| 100 |
+
f"[replaced by your reference content]\n"
|
| 101 |
+
)
|
| 102 |
+
return prompt
|
| 103 |
+
|
| 104 |
+
def extract_final_bullet_passage(answer_text: str):
|
| 105 |
+
reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
|
| 106 |
+
reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
|
| 107 |
+
if not reasoning_match:
|
| 108 |
+
return None, None
|
| 109 |
+
|
| 110 |
+
reasoning_text = reasoning_match.group(1).strip()
|
| 111 |
+
bullet_pattern = r"(?m)^(\d+\.\s.*?)(?=(?:\n\d+\.\s)|\Z)"
|
| 112 |
+
bullets = re.findall(bullet_pattern, reasoning_text, flags=re.DOTALL)
|
| 113 |
+
if not bullets:
|
| 114 |
+
print("No bullet blocks found.")
|
| 115 |
+
return None, None
|
| 116 |
+
|
| 117 |
+
passage_pattern = re.compile(
|
| 118 |
+
r'(?i)(?:\*\*)?passage\s+(\d+)(?:\*\*)?\s*:\s*("([^"]*)"|(.+?))(?=\Z|\n\s*\n|$)',
|
| 119 |
+
flags=re.DOTALL
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
for bullet in reversed(bullets):
|
| 123 |
+
matches = passage_pattern.findall(bullet)
|
| 124 |
+
if matches:
|
| 125 |
+
last_match = matches[-1]
|
| 126 |
+
passage_number = last_match[0]
|
| 127 |
+
quoted_snippet = last_match[2]
|
| 128 |
+
non_quoted_snippet = last_match[3]
|
| 129 |
+
snippet = non_quoted_snippet.strip() if non_quoted_snippet.strip() else quoted_snippet.strip()
|
| 130 |
+
return passage_number, snippet
|
| 131 |
+
|
| 132 |
+
return None, None
|
| 133 |
+
|
| 134 |
+
def extract_all_bullet_passages(answer_text: str):
|
| 135 |
+
reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
|
| 136 |
+
reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
|
| 137 |
+
if not reasoning_match:
|
| 138 |
+
return []
|
| 139 |
+
|
| 140 |
+
reasoning_text = reasoning_match.group(1).strip()
|
| 141 |
+
bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
|
| 142 |
+
bullets = bullet_pattern.findall(reasoning_text)
|
| 143 |
+
if not bullets:
|
| 144 |
+
return []
|
| 145 |
+
|
| 146 |
+
results = []
|
| 147 |
+
for bullet_index, bullet_text in enumerate(bullets, start=1):
|
| 148 |
+
results.append({
|
| 149 |
+
'bullet_index': bullet_index,
|
| 150 |
+
'snippet': bullet_text.strip()
|
| 151 |
+
})
|
| 152 |
+
print(results)
|
| 153 |
+
return results
|
| 154 |
+
|
| 155 |
+
def extract_evidence(answer_text: str):
|
| 156 |
+
reasoning_pattern = r"(?i)Evidence\s*(.*)"
|
| 157 |
+
reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
|
| 158 |
+
if not reasoning_match:
|
| 159 |
+
return []
|
| 160 |
+
|
| 161 |
+
reasoning_text = reasoning_match.group(1).strip()
|
| 162 |
+
|
| 163 |
+
# Extract all bullet segments
|
| 164 |
+
bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
|
| 165 |
+
bullets = bullet_pattern.findall(reasoning_text)
|
| 166 |
+
if not bullets:
|
| 167 |
+
return []
|
| 168 |
+
|
| 169 |
+
# Find the index of the first bullet starting with 1.
|
| 170 |
+
start_index = -1
|
| 171 |
+
for i, bullet in enumerate(bullets):
|
| 172 |
+
if bullet.strip().startswith("1."):
|
| 173 |
+
start_index = i
|
| 174 |
+
break
|
| 175 |
+
|
| 176 |
+
if start_index == -1:
|
| 177 |
+
return [] # No valid starting bullet
|
| 178 |
+
|
| 179 |
+
# Only keep the part starting from the first valid bullet
|
| 180 |
+
bullets = bullets[start_index:]
|
| 181 |
+
|
| 182 |
+
results = []
|
| 183 |
+
for bullet_index, bullet_text in enumerate(bullets, start=1):
|
| 184 |
+
results.append({
|
| 185 |
+
'bullet_index': bullet_index,
|
| 186 |
+
'snippet': bullet_text.strip()
|
| 187 |
+
})
|
| 188 |
+
return results
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_answer_with_retry(model, prompt, retries=3, delay=5):
|
| 192 |
+
"""Call the model to get the answer based on the prompt, with retry on failure."""
|
| 193 |
+
for attempt in range(retries):
|
| 194 |
+
try:
|
| 195 |
+
completion = client.chat.completions.create(
|
| 196 |
+
model=model,
|
| 197 |
+
messages=[{'role': 'user', 'content': prompt}]
|
| 198 |
+
)
|
| 199 |
+
return completion.choices[0].message.content.strip()
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print(f"Error on attempt {attempt + 1}: {e}")
|
| 202 |
+
if attempt < retries - 1:
|
| 203 |
+
print(f"Retrying in {delay} seconds...")
|
| 204 |
+
time.sleep(delay)
|
| 205 |
+
else:
|
| 206 |
+
print("Max retries reached, skipping this request.")
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
@torch.no_grad()
|
| 210 |
+
def qwen_answer_and_ppl(question: str, context: str) -> Tuple[str,float]:
|
| 211 |
+
prompt = f"{context}\n\nQuestion: {question}\nAnswer:"
|
| 212 |
+
inputs = tok_qwen(prompt, return_tensors="pt").to(model_qwen.device)
|
| 213 |
+
gen_ids = model_qwen.generate(**inputs, max_new_tokens=30, eos_token_id=tok_qwen.eos_token_id)
|
| 214 |
+
ans_ids = gen_ids[0][inputs.input_ids.shape[1]:]
|
| 215 |
+
answer = tok_qwen.decode(ans_ids, skip_special_tokens=True).strip()
|
| 216 |
+
# Calculate PPL
|
| 217 |
+
full_ids = torch.cat([inputs.input_ids[0], ans_ids])
|
| 218 |
+
logits = model_qwen(full_ids.unsqueeze(0)).logits[0,:-1]
|
| 219 |
+
tgt = full_ids[1:]
|
| 220 |
+
logp = torch.log_softmax(logits, dim=-1)
|
| 221 |
+
sel = logp[range(len(tgt)), tgt]
|
| 222 |
+
ppl = math.exp(-sel.mean().item())
|
| 223 |
+
return answer, ppl
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def extract_json_from_gpt_response(text: str) -> dict | None:
|
| 228 |
+
"""
|
| 229 |
+
Finds the first JSON block inside ```json ... ``` or ``` … ``` and returns it as a dict.
|
| 230 |
+
"""
|
| 231 |
+
# Try to find a ```json … ``` block first
|
| 232 |
+
m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
|
| 233 |
+
if not m:
|
| 234 |
+
# Fallback: any ``` … ``` block that looks like JSON
|
| 235 |
+
m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
|
| 236 |
+
if not m:
|
| 237 |
+
# Lastly, maybe the model just spit raw JSON without fences
|
| 238 |
+
m = re.search(r"(\{.*?\})", text, flags=re.DOTALL)
|
| 239 |
+
if not m:
|
| 240 |
+
return None
|
| 241 |
+
|
| 242 |
+
json_str = m.group(1)
|
| 243 |
+
try:
|
| 244 |
+
return json.loads(json_str)
|
| 245 |
+
except json.JSONDecodeError:
|
| 246 |
+
# clean up trailing commas, etc.
|
| 247 |
+
cleaned = re.sub(r",\s*([\]}])", r"\1", json_str)
|
| 248 |
+
try:
|
| 249 |
+
return json.loads(cleaned)
|
| 250 |
+
except json.JSONDecodeError:
|
| 251 |
+
return None
|
| 252 |
+
async def multi_proposal_pipeline(
|
| 253 |
+
question: str,
|
| 254 |
+
original_context: str, # directly pass in example['context']
|
| 255 |
+
unique_sents: List[str],
|
| 256 |
+
correct_answer: str,
|
| 257 |
+
rounds: int = 3
|
| 258 |
+
) -> dict:
|
| 259 |
+
best = {"ppl": float("0"), "context": None, "answer": None}
|
| 260 |
+
|
| 261 |
+
for i in range(rounds):
|
| 262 |
+
# Construct GPT-4o prompt
|
| 263 |
+
numbered = "\n\n".join(f"{j+1}. {s}" for j, s in enumerate(unique_sents))
|
| 264 |
+
prompt = (
|
| 265 |
+
"You are a creative contrarian. Given the question below, and the original answer, first propose a concise alternative answer—that is, a plausible but intentionally misleading answer. "
|
| 266 |
+
"Followed are some sentences supporting the original answer, please rewrite them. When rewriting each sentence, modify only the parts necessary to support the antifact answer. Parts unrelated to the answer must keep their original meaning. Be sure that the modified evidence sentences are sufficient to answer the original question. Output must be strictly in the specified JSON format, with no additional text.\n"
|
| 267 |
+
'{\n'
|
| 268 |
+
' "answer": "<your antifact answer here, just provide the answer phrase, no need for complete sentence>",\n'
|
| 269 |
+
' "revised": [\n'
|
| 270 |
+
' "<rewritten sentence 1>",\n'
|
| 271 |
+
' "<rewritten sentence 2>",\n'
|
| 272 |
+
' ...\n'
|
| 273 |
+
' ]\n'
|
| 274 |
+
'}\n\n'
|
| 275 |
+
f"Question:\n{question}\n\n"
|
| 276 |
+
f"Original answer:\n{correct_answer}\n\n"
|
| 277 |
+
f"Sentences to rewrite:\n{numbered}"
|
| 278 |
+
)
|
| 279 |
+
print(f"[Proposal {i+1}] Prompt: {prompt}")
|
| 280 |
+
rsp = client.chat.completions.create(
|
| 281 |
+
model="gpt-4o", temperature=0.7,
|
| 282 |
+
messages=[{"role":"user","content":prompt}]
|
| 283 |
+
)
|
| 284 |
+
js = extract_json_from_gpt_response(rsp.choices[0].message.content)
|
| 285 |
+
if not js:
|
| 286 |
+
print("[Proposal {i+1}] Failed to parse JSON")
|
| 287 |
+
continue
|
| 288 |
+
revised = js["revised"] # List[str]
|
| 289 |
+
proposed = js["answer"] # Answer given by GPT-4o (optional record)
|
| 290 |
+
new_ctx = original_context
|
| 291 |
+
for old, new in zip(unique_sents, revised):
|
| 292 |
+
new_ctx = new_ctx.replace(old, new)
|
| 293 |
+
|
| 294 |
+
# Use Qwen to calculate answer & PPL
|
| 295 |
+
ans_i, ppl_i = qwen_answer_and_ppl(question, new_ctx)
|
| 296 |
+
print(f"[Proposal {i+1}] PPL = {ppl_i:.2f}")
|
| 297 |
+
|
| 298 |
+
if ppl_i > best["ppl"]:
|
| 299 |
+
best.update({"ppl": ppl_i, "context": new_ctx, "answer": proposed})
|
| 300 |
+
|
| 301 |
+
return best
|
| 302 |
+
def main():
|
| 303 |
+
# Parse command line arguments
|
| 304 |
+
parser = argparse.ArgumentParser(description="LastingBench main pipeline for context rewriting")
|
| 305 |
+
parser.add_argument("--output", "-o", type=str, default="output.jsonl",
|
| 306 |
+
help="Output JSONL file path (default: output.jsonl)")
|
| 307 |
+
parser.add_argument("--dataset_repo", type=str, default="THUDM/LongBench",
|
| 308 |
+
help="Dataset repository name (default: THUDM/LongBench)")
|
| 309 |
+
parser.add_argument("--dataset_subset", type=str, default="multifieldqa_en",
|
| 310 |
+
help="Dataset subset name (default: multifieldqa_en)")
|
| 311 |
+
parser.add_argument("--split", type=str, default="test",
|
| 312 |
+
help="Dataset split (default: test)")
|
| 313 |
+
parser.add_argument("--start_idx", type=int, default=0,
|
| 314 |
+
help="Starting index for processing (default: 0)")
|
| 315 |
+
parser.add_argument("--max_samples", type=int, default=-1,
|
| 316 |
+
help="Maximum number of samples to process (-1 for all, default: -1)")
|
| 317 |
+
|
| 318 |
+
args = parser.parse_args()
|
| 319 |
+
|
| 320 |
+
out_file = args.output
|
| 321 |
+
# Load dataset
|
| 322 |
+
longbench = load_dataset(args.dataset_repo, args.dataset_subset)[args.split]
|
| 323 |
+
|
| 324 |
+
print(f"Output file: {out_file}")
|
| 325 |
+
print(f"Dataset: {args.dataset_repo}/{args.dataset_subset}[{args.split}]")
|
| 326 |
+
print(f"Total samples: {len(longbench)}")
|
| 327 |
+
f1_score_total = 0
|
| 328 |
+
llm_judge_score_total = 0
|
| 329 |
+
count = 0
|
| 330 |
+
|
| 331 |
+
# Determine processing range
|
| 332 |
+
start_idx = args.start_idx
|
| 333 |
+
end_idx = len(longbench) if args.max_samples == -1 else min(start_idx + args.max_samples, len(longbench))
|
| 334 |
+
|
| 335 |
+
print(f"Processing samples from index {start_idx} to {end_idx-1}")
|
| 336 |
+
|
| 337 |
+
for idx in range(start_idx, end_idx):
|
| 338 |
+
example = longbench[idx]
|
| 339 |
+
question = example['input']
|
| 340 |
+
print(f"Question: {question}")
|
| 341 |
+
context = example['context']
|
| 342 |
+
correct_answer = example['answers'][0]
|
| 343 |
+
|
| 344 |
+
print(f"Processing example {idx + 1}:")
|
| 345 |
+
print(f"Correct Answer: {correct_answer}")
|
| 346 |
+
|
| 347 |
+
# Build prompts
|
| 348 |
+
prompt_with_context = build_prompt(context, question)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# Get answers using transformers pipelines
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
answer_with_context = get_answer_with_retry('deepseek-r1', prompt_with_context)
|
| 355 |
+
# Extract content after "Answer:" from answer_with_context
|
| 356 |
+
answer_with_context_simple = (
|
| 357 |
+
answer_with_context
|
| 358 |
+
.split("Answer:", 1)[-1] # First keep the part after Answer:
|
| 359 |
+
.split("Step-by-step Reasoning", 1)[0] # Then cut before Step-by-step Reasoning
|
| 360 |
+
.strip()
|
| 361 |
+
)
|
| 362 |
+
print(f"Answer with context: {answer_with_context_simple}")
|
| 363 |
+
result = judge_answer_with_api(question, correct_answer, answer_with_context_simple)
|
| 364 |
+
print(f"Answer judge result: {result}")
|
| 365 |
+
if not result:
|
| 366 |
+
continue
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
answer_with_context = remove_think_tags(answer_with_context or "")
|
| 370 |
+
|
| 371 |
+
evidence = extract_all_bullet_passages(answer_with_context)
|
| 372 |
+
|
| 373 |
+
page_contents = []
|
| 374 |
+
if evidence:
|
| 375 |
+
count += 1
|
| 376 |
+
for ev in evidence:
|
| 377 |
+
snippet = ev['snippet']
|
| 378 |
+
result = retriveDoc(context, snippet)
|
| 379 |
+
# result["context"] is a set of Document objects
|
| 380 |
+
page_contents += [doc.page_content for doc in result]
|
| 381 |
+
|
| 382 |
+
unique_page_contents = list(dict.fromkeys(page_contents))
|
| 383 |
+
aggregated_content = "\n".join(unique_page_contents)
|
| 384 |
+
prompt_final = (
|
| 385 |
+
f"Please answer the question based on the context.\nContext: {aggregated_content}.\n Question: {question}.\n"
|
| 386 |
+
f"Please only provide your answer. "
|
| 387 |
+
f"Your Answer:"
|
| 388 |
+
)
|
| 389 |
+
final_answer = get_transformers_answer(prompt_final, tokenizer1, model1)
|
| 390 |
+
if judge_answer_with_api(question, correct_answer, final_answer):
|
| 391 |
+
print("correct")
|
| 392 |
+
else:
|
| 393 |
+
print("incorrect")
|
| 394 |
+
result_query = retriveDoc(context, question)
|
| 395 |
+
page_contents += [doc.page_content for doc in result_query]
|
| 396 |
+
unique_page_contents = list(dict.fromkeys(page_contents))
|
| 397 |
+
best = asyncio.run(
|
| 398 |
+
multi_proposal_pipeline(
|
| 399 |
+
question,
|
| 400 |
+
context,
|
| 401 |
+
unique_page_contents,
|
| 402 |
+
correct_answer
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
record = {
|
| 406 |
+
"question": question,
|
| 407 |
+
"answer": best["answer"],
|
| 408 |
+
"context": best["context"]
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
# Append one line of JSON each loop
|
| 412 |
+
with open(out_file, "a", encoding="utf-8") as fout:
|
| 413 |
+
fout.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
if __name__ == "__main__":
|
| 422 |
+
main()
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
|
random_alternative_answer.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import time
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
import argparse
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from nltk.tokenize import sent_tokenize
|
| 8 |
+
from utils.util import retriveDoc,compute_best_sentence_f1
|
| 9 |
+
from openai import OpenAI
|
| 10 |
+
import asyncio, json, torch, math
|
| 11 |
+
from typing import List, Tuple
|
| 12 |
+
# Hugging Face transformers related
|
| 13 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
| 14 |
+
from utils.metrics import qa_f1_score
|
| 15 |
+
from utils.llmjudge import judge_answer_with_api
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
client = OpenAI(
|
| 19 |
+
base_url=os.environ.get("OPENAI_BASE_URL"),
|
| 20 |
+
api_key=os.environ.get("OPENAI_API_KEY")
|
| 21 |
+
)
|
| 22 |
+
# Load models using transformers
|
| 23 |
+
|
| 24 |
+
tokenizer1 = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
|
| 25 |
+
model1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True,device_map="cuda:0",torch_dtype=torch.bfloat16)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
tok_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True)
|
| 29 |
+
model_qwen = AutoModelForCausalLM.from_pretrained(
|
| 30 |
+
"Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True,
|
| 31 |
+
device_map="cuda:1",torch_dtype=torch.bfloat16
|
| 32 |
+
).eval()
|
| 33 |
+
|
| 34 |
+
def get_transformers_answer(prompt, tokenizer, model, max_new_tokens=100, temperature=0.7, top_p=0.9, retries=3, delay=5):
|
| 35 |
+
"""
|
| 36 |
+
Use transformers model.generate method for inference with retry mechanism,
|
| 37 |
+
use chat template to format input, and strip the input prompt part through token-level slicing,
|
| 38 |
+
return the newly generated text.
|
| 39 |
+
"""
|
| 40 |
+
import time
|
| 41 |
+
for attempt in range(retries):
|
| 42 |
+
try:
|
| 43 |
+
# Convert original prompt to message format
|
| 44 |
+
messages = [{"role": "user", "content": prompt}]
|
| 45 |
+
|
| 46 |
+
# Try to use chat template to format input
|
| 47 |
+
try:
|
| 48 |
+
formatted_prompt = tokenizer.apply_chat_template(
|
| 49 |
+
messages,
|
| 50 |
+
tokenize=False,
|
| 51 |
+
add_generation_prompt=True
|
| 52 |
+
)
|
| 53 |
+
except Exception as e:
|
| 54 |
+
print(f"Unable to apply chat template: {e}, falling back to basic text input")
|
| 55 |
+
formatted_prompt = prompt # Fall back to original prompt as input
|
| 56 |
+
|
| 57 |
+
# Encode formatted prompt as model input tensor
|
| 58 |
+
model_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
|
| 59 |
+
|
| 60 |
+
# Call generate, the generated id sequence contains both prompt and subsequent generated text
|
| 61 |
+
generated_ids = model.generate(
|
| 62 |
+
**model_inputs,
|
| 63 |
+
max_new_tokens=max_new_tokens,
|
| 64 |
+
temperature=temperature,
|
| 65 |
+
top_p=top_p
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Calculate the token count corresponding to the prompt
|
| 69 |
+
input_length = model_inputs.input_ids.shape[1]
|
| 70 |
+
|
| 71 |
+
# Strip the prompt part from the front of the output, keeping only the newly added part
|
| 72 |
+
output_ids = generated_ids[0][input_length:]
|
| 73 |
+
|
| 74 |
+
# Decode generated text
|
| 75 |
+
answer = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
|
| 76 |
+
return answer
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"Error on attempt {attempt + 1}: {e}")
|
| 79 |
+
if attempt < retries - 1:
|
| 80 |
+
print(f"Retrying in {delay} seconds...")
|
| 81 |
+
time.sleep(delay)
|
| 82 |
+
else:
|
| 83 |
+
print("Max retries reached, skipping this request.")
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
def truncate_answer(answer):
|
| 87 |
+
"""Truncate answer, only take the part before the first period"""
|
| 88 |
+
return answer.split('.')[0].strip() if answer else "No answer"
|
| 89 |
+
|
| 90 |
+
def write_to_log(filename, data):
|
| 91 |
+
"""Write data to log file"""
|
| 92 |
+
with open(filename, 'a', encoding='utf-8') as file:
|
| 93 |
+
file.write(data + '\n')
|
| 94 |
+
|
| 95 |
+
def remove_think_tags(text: str) -> str:
|
| 96 |
+
"""Remove all <think> ... </think> blocks"""
|
| 97 |
+
return re.sub(r'<think>(.*?)</think>', '', text, flags=re.DOTALL).strip()
|
| 98 |
+
|
| 99 |
+
def build_prompt(context: str, question: str) -> str:
|
| 100 |
+
prompt = (
|
| 101 |
+
f"Answer the question based on the given passages. The following are the passages:\n"
|
| 102 |
+
f"{context}\n"
|
| 103 |
+
f"Answer the question based on the given passages.\n"
|
| 104 |
+
f"Question: {question}.\n"
|
| 105 |
+
f"Answer:\n"
|
| 106 |
+
f"Please first provide your answer in the format of Answer:[Your answer]. Then provide your reasoning process step-by-step.(Only include explicit clues) "
|
| 107 |
+
f"At the end of each reasoning step, include a new line that specifies the key information or reference content used in that step. "
|
| 108 |
+
f"Please ensure that the [reference content] you include is the complete original sentence or consecutive sentences from the text. Please do not change the punctuation. Do not use ellipses inside the sentence. "
|
| 109 |
+
f"Follow this format:\n"
|
| 110 |
+
f"Answer: [Your answer]\n"
|
| 111 |
+
f"Step-by-step Reasoning:\n"
|
| 112 |
+
f"1. [Reasoning step 1]\n"
|
| 113 |
+
f"[replaced by your reference content]\n"
|
| 114 |
+
f"2. [Reasoning step 2]\n"
|
| 115 |
+
f"[replaced by your reference content]\n"
|
| 116 |
+
)
|
| 117 |
+
return prompt
|
| 118 |
+
|
| 119 |
+
def extract_final_bullet_passage(answer_text: str):
|
| 120 |
+
reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
|
| 121 |
+
reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
|
| 122 |
+
if not reasoning_match:
|
| 123 |
+
return None, None
|
| 124 |
+
|
| 125 |
+
reasoning_text = reasoning_match.group(1).strip()
|
| 126 |
+
bullet_pattern = r"(?m)^(\d+\.\s.*?)(?=(?:\n\d+\.\s)|\Z)"
|
| 127 |
+
bullets = re.findall(bullet_pattern, reasoning_text, flags=re.DOTALL)
|
| 128 |
+
if not bullets:
|
| 129 |
+
print("No bullet blocks found.")
|
| 130 |
+
return None, None
|
| 131 |
+
|
| 132 |
+
passage_pattern = re.compile(
|
| 133 |
+
r'(?i)(?:\*\*)?passage\s+(\d+)(?:\*\*)?\s*:\s*("([^"]*)"|(.+?))(?=\Z|\n\s*\n|$)',
|
| 134 |
+
flags=re.DOTALL
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
for bullet in reversed(bullets):
|
| 138 |
+
matches = passage_pattern.findall(bullet)
|
| 139 |
+
if matches:
|
| 140 |
+
last_match = matches[-1]
|
| 141 |
+
passage_number = last_match[0]
|
| 142 |
+
quoted_snippet = last_match[2]
|
| 143 |
+
non_quoted_snippet = last_match[3]
|
| 144 |
+
snippet = non_quoted_snippet.strip() if non_quoted_snippet.strip() else quoted_snippet.strip()
|
| 145 |
+
return passage_number, snippet
|
| 146 |
+
|
| 147 |
+
return None, None
|
| 148 |
+
|
| 149 |
+
def extract_all_bullet_passages(answer_text: str):
|
| 150 |
+
reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
|
| 151 |
+
reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
|
| 152 |
+
if not reasoning_match:
|
| 153 |
+
return []
|
| 154 |
+
|
| 155 |
+
reasoning_text = reasoning_match.group(1).strip()
|
| 156 |
+
bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
|
| 157 |
+
bullets = bullet_pattern.findall(reasoning_text)
|
| 158 |
+
if not bullets:
|
| 159 |
+
return []
|
| 160 |
+
|
| 161 |
+
results = []
|
| 162 |
+
for bullet_index, bullet_text in enumerate(bullets, start=1):
|
| 163 |
+
results.append({
|
| 164 |
+
'bullet_index': bullet_index,
|
| 165 |
+
'snippet': bullet_text.strip()
|
| 166 |
+
})
|
| 167 |
+
print(results)
|
| 168 |
+
return results
|
| 169 |
+
|
| 170 |
+
def extract_evidence(answer_text: str):
|
| 171 |
+
reasoning_pattern = r"(?i)Evidence\s*(.*)"
|
| 172 |
+
reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
|
| 173 |
+
if not reasoning_match:
|
| 174 |
+
return []
|
| 175 |
+
|
| 176 |
+
reasoning_text = reasoning_match.group(1).strip()
|
| 177 |
+
|
| 178 |
+
# Extract all bullet segments
|
| 179 |
+
bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
|
| 180 |
+
bullets = bullet_pattern.findall(reasoning_text)
|
| 181 |
+
if not bullets:
|
| 182 |
+
return []
|
| 183 |
+
|
| 184 |
+
# Find the index of the first bullet starting with 1.
|
| 185 |
+
start_index = -1
|
| 186 |
+
for i, bullet in enumerate(bullets):
|
| 187 |
+
if bullet.strip().startswith("1."):
|
| 188 |
+
start_index = i
|
| 189 |
+
break
|
| 190 |
+
|
| 191 |
+
if start_index == -1:
|
| 192 |
+
return [] # No valid starting bullet
|
| 193 |
+
|
| 194 |
+
# Only keep the part starting from the first valid bullet
|
| 195 |
+
bullets = bullets[start_index:]
|
| 196 |
+
|
| 197 |
+
results = []
|
| 198 |
+
for bullet_index, bullet_text in enumerate(bullets, start=1):
|
| 199 |
+
results.append({
|
| 200 |
+
'bullet_index': bullet_index,
|
| 201 |
+
'snippet': bullet_text.strip()
|
| 202 |
+
})
|
| 203 |
+
return results
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_answer_with_retry(model, prompt, retries=3, delay=5):
|
| 207 |
+
"""Call the model to get the answer based on the prompt, with retry on failure."""
|
| 208 |
+
for attempt in range(retries):
|
| 209 |
+
try:
|
| 210 |
+
completion = client.chat.completions.create(
|
| 211 |
+
model=model,
|
| 212 |
+
messages=[{'role': 'user', 'content': prompt}]
|
| 213 |
+
)
|
| 214 |
+
return completion.choices[0].message.content.strip()
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f"Error on attempt {attempt + 1}: {e}")
|
| 217 |
+
if attempt < retries - 1:
|
| 218 |
+
print(f"Retrying in {delay} seconds...")
|
| 219 |
+
time.sleep(delay)
|
| 220 |
+
else:
|
| 221 |
+
print("Max retries reached, skipping this request.")
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
def extract_json_from_gpt_response(text: str) -> dict | None:
|
| 225 |
+
"""
|
| 226 |
+
Finds the first JSON block inside ```json ... ``` or ``` … ``` and returns it as a dict.
|
| 227 |
+
"""
|
| 228 |
+
# Try to find a ```json … ``` block first
|
| 229 |
+
m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
|
| 230 |
+
if not m:
|
| 231 |
+
# Fallback: any ``` … ``` block that looks like JSON
|
| 232 |
+
m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
|
| 233 |
+
if not m:
|
| 234 |
+
# Lastly, maybe the model just spit raw JSON without fences
|
| 235 |
+
m = re.search(r"(\{.*?\})", text, flags=re.DOTALL)
|
| 236 |
+
if not m:
|
| 237 |
+
return None
|
| 238 |
+
|
| 239 |
+
json_str = m.group(1)
|
| 240 |
+
try:
|
| 241 |
+
return json.loads(json_str)
|
| 242 |
+
except json.JSONDecodeError:
|
| 243 |
+
# clean up trailing commas, etc.
|
| 244 |
+
cleaned = re.sub(r",\s*([\]}])", r"\1", json_str)
|
| 245 |
+
try:
|
| 246 |
+
return json.loads(cleaned)
|
| 247 |
+
except json.JSONDecodeError:
|
| 248 |
+
return None
|
| 249 |
+
|
| 250 |
+
async def random_alternative_answer(
|
| 251 |
+
question: str,
|
| 252 |
+
original_context: str,
|
| 253 |
+
unique_sents: List[str],
|
| 254 |
+
correct_answer: str
|
| 255 |
+
) -> dict:
|
| 256 |
+
"""Generate random alternative answer and modified evidence"""
|
| 257 |
+
|
| 258 |
+
# Construct GPT-4o prompt
|
| 259 |
+
numbered = "\n\n".join(f"{j+1}. {s}" for j, s in enumerate(unique_sents))
|
| 260 |
+
prompt = (
|
| 261 |
+
"You are a creative assistant. Given the question below and the original answer, propose a plausible alternative answer that is **different** from the original but still reasonable. "
|
| 262 |
+
"Then rewrite the provided sentences to support your alternative answer. When rewriting each sentence, modify only the parts necessary to support the alternative answer. "
|
| 263 |
+
"Parts unrelated to the answer must keep their original meaning. Be sure that the modified evidence sentences are sufficient to answer the original question. "
|
| 264 |
+
"Output must be strictly in the specified JSON format, with no additional text.\n"
|
| 265 |
+
'{\n'
|
| 266 |
+
' "answer": "<your alternative answer here, just provide the answer phrase, no need for complete sentence>",\n'
|
| 267 |
+
' "revised": [\n'
|
| 268 |
+
' "<rewritten sentence 1>",\n'
|
| 269 |
+
' "<rewritten sentence 2>",\n'
|
| 270 |
+
' ...\n'
|
| 271 |
+
' ]\n'
|
| 272 |
+
'}\n\n'
|
| 273 |
+
f"Question:\n{question}\n\n"
|
| 274 |
+
f"Original answer:\n{correct_answer}\n\n"
|
| 275 |
+
f"Sentences to rewrite:\n{numbered}"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
print(f"[Alternative Answer] Generating prompt: {prompt}")
|
| 279 |
+
|
| 280 |
+
rsp = client.chat.completions.create(
|
| 281 |
+
model="gpt-4o", temperature=0.7,
|
| 282 |
+
messages=[{"role":"user","content":prompt}]
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
js = extract_json_from_gpt_response(rsp.choices[0].message.content)
|
| 286 |
+
if not js:
|
| 287 |
+
print("[Alternative Answer] Failed to parse JSON")
|
| 288 |
+
return {"context": original_context, "answer": "Failed to generate alternative"}
|
| 289 |
+
|
| 290 |
+
revised = js["revised"] # List[str]
|
| 291 |
+
alternative = js["answer"] # Alternative answer
|
| 292 |
+
|
| 293 |
+
# Create new context
|
| 294 |
+
new_ctx = original_context
|
| 295 |
+
for old, new in zip(unique_sents, revised):
|
| 296 |
+
new_ctx = new_ctx.replace(old, new)
|
| 297 |
+
|
| 298 |
+
return {"context": new_ctx, "answer": alternative}
|
| 299 |
+
|
| 300 |
+
def main():
|
| 301 |
+
# Parse command line arguments
|
| 302 |
+
parser = argparse.ArgumentParser(description="LastingBench random alternative answer generation")
|
| 303 |
+
parser.add_argument("--output", "-o", type=str, default="output_random.jsonl",
|
| 304 |
+
help="Output JSONL file path (default: output_random.jsonl)")
|
| 305 |
+
parser.add_argument("--dataset_repo", type=str, default="THUDM/LongBench",
|
| 306 |
+
help="Dataset repository name (default: THUDM/LongBench)")
|
| 307 |
+
parser.add_argument("--dataset_subset", type=str, default="hotpotqa",
|
| 308 |
+
help="Dataset subset name (default: hotpotqa)")
|
| 309 |
+
parser.add_argument("--split", type=str, default="test",
|
| 310 |
+
help="Dataset split (default: test)")
|
| 311 |
+
parser.add_argument("--start_idx", type=int, default=0,
|
| 312 |
+
help="Starting index for processing (default: 0)")
|
| 313 |
+
parser.add_argument("--max_samples", type=int, default=-1,
|
| 314 |
+
help="Maximum number of samples to process (-1 for all, default: -1)")
|
| 315 |
+
|
| 316 |
+
args = parser.parse_args()
|
| 317 |
+
|
| 318 |
+
out_file = args.output
|
| 319 |
+
# Load dataset
|
| 320 |
+
longbench = load_dataset(args.dataset_repo, args.dataset_subset)[args.split]
|
| 321 |
+
|
| 322 |
+
print(f"Output file: {out_file}")
|
| 323 |
+
print(f"Dataset: {args.dataset_repo}/{args.dataset_subset}[{args.split}]")
|
| 324 |
+
print(f"Total samples: {len(longbench)}")
|
| 325 |
+
|
| 326 |
+
count = 0
|
| 327 |
+
|
| 328 |
+
# Determine processing range
|
| 329 |
+
start_idx = args.start_idx
|
| 330 |
+
end_idx = len(longbench) if args.max_samples == -1 else min(start_idx + args.max_samples, len(longbench))
|
| 331 |
+
|
| 332 |
+
print(f"Processing samples from index {start_idx} to {end_idx-1}")
|
| 333 |
+
|
| 334 |
+
for idx in range(start_idx, end_idx):
|
| 335 |
+
example = longbench[idx]
|
| 336 |
+
question = example['input']
|
| 337 |
+
print(f"Question: {question}")
|
| 338 |
+
context = example['context']
|
| 339 |
+
correct_answer = example['answers'][0]
|
| 340 |
+
|
| 341 |
+
print(f"Processing example {idx + 1}:")
|
| 342 |
+
print(f"Correct Answer: {correct_answer}")
|
| 343 |
+
|
| 344 |
+
# Build prompts
|
| 345 |
+
prompt_with_context = build_prompt(context, question)
|
| 346 |
+
|
| 347 |
+
# Get answers using transformers pipelines
|
| 348 |
+
answer_with_context = get_answer_with_retry('deepseek-r1', prompt_with_context)
|
| 349 |
+
|
| 350 |
+
# Extract content after "Answer:" from answer_with_context
|
| 351 |
+
answer_with_context_simple = (
|
| 352 |
+
answer_with_context
|
| 353 |
+
.split("Answer:", 1)[-1] # First keep the part after Answer:
|
| 354 |
+
.split("Step-by-step Reasoning", 1)[0] # Then cut before Step-by-step Reasoning
|
| 355 |
+
.strip()
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
print(f"Answer with context: {answer_with_context_simple}")
|
| 359 |
+
result = judge_answer_with_api(question, correct_answer, answer_with_context_simple)
|
| 360 |
+
print(f"Answer judge result: {result}")
|
| 361 |
+
|
| 362 |
+
if not result:
|
| 363 |
+
continue
|
| 364 |
+
|
| 365 |
+
answer_with_context = remove_think_tags(answer_with_context or "")
|
| 366 |
+
evidence = extract_all_bullet_passages(answer_with_context)
|
| 367 |
+
|
| 368 |
+
page_contents = []
|
| 369 |
+
if evidence:
|
| 370 |
+
count += 1
|
| 371 |
+
for ev in evidence:
|
| 372 |
+
snippet = ev['snippet']
|
| 373 |
+
result = retriveDoc(context, snippet)
|
| 374 |
+
# result["context"] is a set of Document objects
|
| 375 |
+
page_contents += [doc.page_content for doc in result]
|
| 376 |
+
|
| 377 |
+
unique_page_contents = list(dict.fromkeys(page_contents))
|
| 378 |
+
aggregated_content = "\n".join(unique_page_contents)
|
| 379 |
+
|
| 380 |
+
prompt_final = (
|
| 381 |
+
f"Please answer the question based on the context.\nContext: {aggregated_content}.\n Question: {question}.\n"
|
| 382 |
+
f"Please only provide your answer. "
|
| 383 |
+
f"Your Answer:"
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
final_answer = get_transformers_answer(prompt_final, tokenizer1, model1)
|
| 387 |
+
|
| 388 |
+
if judge_answer_with_api(question, correct_answer, final_answer):
|
| 389 |
+
print("correct")
|
| 390 |
+
else:
|
| 391 |
+
print("incorrect")
|
| 392 |
+
result_query = retriveDoc(context, question)
|
| 393 |
+
page_contents += [doc.page_content for doc in result_query]
|
| 394 |
+
|
| 395 |
+
unique_page_contents = list(dict.fromkeys(page_contents))
|
| 396 |
+
|
| 397 |
+
# Generate random alternative answer instead of selecting the highest ppl answer
|
| 398 |
+
alternative = asyncio.run(
|
| 399 |
+
random_alternative_answer(
|
| 400 |
+
question,
|
| 401 |
+
context,
|
| 402 |
+
unique_page_contents,
|
| 403 |
+
correct_answer
|
| 404 |
+
)
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
record = {
|
| 408 |
+
"question": question,
|
| 409 |
+
"answer": alternative["answer"],
|
| 410 |
+
"context": alternative["context"]
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
# Append one line of JSON each loop
|
| 414 |
+
with open(out_file, "a", encoding="utf-8") as fout:
|
| 415 |
+
fout.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 416 |
+
|
| 417 |
+
if __name__ == "__main__":
|
| 418 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core ML and NLP libraries
|
| 2 |
+
torch>=2.6.0
|
| 3 |
+
transformers>=4.51.0
|
| 4 |
+
datasets>=2.10.0
|
| 5 |
+
tokenizers>=0.13.0
|
| 6 |
+
|
| 7 |
+
# Haystack for document processing and retrieval
|
| 8 |
+
haystack-ai>=2.0.0
|
| 9 |
+
|
| 10 |
+
# OpenAI API client
|
| 11 |
+
openai == 1.84.0
|
| 12 |
+
|
| 13 |
+
# Data processing and analysis
|
| 14 |
+
pandas>=1.5.0
|
| 15 |
+
numpy>=1.24.0
|
| 16 |
+
|
| 17 |
+
# Natural language processing
|
| 18 |
+
nltk>=3.8.0
|
| 19 |
+
jieba
|
| 20 |
+
fuzzywuzzy
|
| 21 |
+
rouge
|
| 22 |
+
rank_bm25
|
| 23 |
+
langchain_text_splitters
|
| 24 |
+
langchain_community
|
| 25 |
+
langchain_openai
|
| 26 |
+
|
| 27 |
+
# Visualization
|
| 28 |
+
matplotlib>=3.6.0
|
| 29 |
+
|
| 30 |
+
# Async processing
|
| 31 |
+
asyncio-throttle>=1.0.0
|
| 32 |
+
|
| 33 |
+
# Optional: Additional ML utilities
|
| 34 |
+
scikit-learn>=1.2.0
|
| 35 |
+
tqdm>=4.64.0
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
training_result/training_loss_antifact_llama.csv
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Step,Loss,Epoch
|
| 2 |
+
1,1.8891,1
|
| 3 |
+
2,1.9053,1
|
| 4 |
+
3,1.8574,1
|
| 5 |
+
4,1.9114,1
|
| 6 |
+
5,1.8351,1
|
| 7 |
+
6,1.8981,2
|
| 8 |
+
7,1.9048,2
|
| 9 |
+
8,1.8649,2
|
| 10 |
+
9,1.877,2
|
| 11 |
+
10,1.8449,2
|
| 12 |
+
11,1.8986,3
|
| 13 |
+
12,1.8799,3
|
| 14 |
+
13,1.8378,3
|
| 15 |
+
14,1.9004,3
|
| 16 |
+
15,1.8675,3
|
| 17 |
+
16,1.8812,4
|
| 18 |
+
17,1.8635,4
|
| 19 |
+
18,1.8977,4
|
| 20 |
+
19,1.8393,4
|
| 21 |
+
20,1.9017,4
|
| 22 |
+
21,1.8482,5
|
| 23 |
+
22,1.8353,5
|
| 24 |
+
23,1.8514,5
|
| 25 |
+
24,1.9189,5
|
| 26 |
+
25,1.8596,5
|
| 27 |
+
26,1.8672,6
|
| 28 |
+
27,1.8421,6
|
| 29 |
+
28,1.848,6
|
| 30 |
+
29,1.8762,6
|
| 31 |
+
30,1.8964,6
|
| 32 |
+
31,1.8663,7
|
| 33 |
+
32,1.8491,7
|
| 34 |
+
33,1.8637,7
|
| 35 |
+
34,1.8403,7
|
| 36 |
+
35,1.8842,7
|
| 37 |
+
36,1.827,8
|
| 38 |
+
37,1.8486,8
|
| 39 |
+
38,1.8671,8
|
| 40 |
+
39,1.8921,8
|
| 41 |
+
40,1.7564,8
|
training_result/training_loss_antifact_qwen38.csv
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Step,Loss,Epoch
|
| 2 |
+
1,2.3492,1
|
| 3 |
+
2,2.3786,1
|
| 4 |
+
3,2.314,1
|
| 5 |
+
4,2.3676,1
|
| 6 |
+
5,2.2953,1
|
| 7 |
+
6,2.3573,2
|
| 8 |
+
7,2.3483,2
|
| 9 |
+
8,2.3144,2
|
| 10 |
+
9,2.321,2
|
| 11 |
+
10,2.301,2
|
| 12 |
+
11,2.326,3
|
| 13 |
+
12,2.2982,3
|
| 14 |
+
13,2.2573,3
|
| 15 |
+
14,2.2941,3
|
| 16 |
+
15,2.2627,3
|
| 17 |
+
16,2.2579,4
|
| 18 |
+
17,2.2519,4
|
| 19 |
+
18,2.2641,4
|
| 20 |
+
19,2.2128,4
|
| 21 |
+
20,2.2421,4
|
| 22 |
+
21,2.1929,5
|
| 23 |
+
22,2.1757,5
|
| 24 |
+
23,2.1914,5
|
| 25 |
+
24,2.2761,5
|
| 26 |
+
25,2.2079,5
|
| 27 |
+
26,2.1893,6
|
| 28 |
+
27,2.1754,6
|
| 29 |
+
28,2.174,6
|
| 30 |
+
29,2.2038,6
|
| 31 |
+
30,2.2008,6
|
| 32 |
+
31,2.1921,7
|
| 33 |
+
32,2.1533,7
|
| 34 |
+
33,2.1783,7
|
| 35 |
+
34,2.1534,7
|
| 36 |
+
35,2.2158,7
|
| 37 |
+
36,2.1322,8
|
| 38 |
+
37,2.1752,8
|
| 39 |
+
38,2.18,8
|
| 40 |
+
39,2.1953,8
|
| 41 |
+
40,2.0122,8
|
training_result/training_loss_llama.csv
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Step,Loss,Epoch
|
| 2 |
+
1,1.858,1
|
| 3 |
+
2,1.8676,1
|
| 4 |
+
3,1.8243,1
|
| 5 |
+
4,1.8815,1
|
| 6 |
+
5,1.7409,1
|
| 7 |
+
6,1.8664,2
|
| 8 |
+
7,1.8736,2
|
| 9 |
+
8,1.8304,2
|
| 10 |
+
9,1.8358,2
|
| 11 |
+
10,1.8119,2
|
| 12 |
+
11,1.8637,3
|
| 13 |
+
12,1.8456,3
|
| 14 |
+
13,1.7986,3
|
| 15 |
+
14,1.8643,3
|
| 16 |
+
15,1.8602,3
|
| 17 |
+
16,1.8451,4
|
| 18 |
+
17,1.8329,4
|
| 19 |
+
18,1.8589,4
|
| 20 |
+
19,1.8064,4
|
| 21 |
+
20,1.8571,4
|
| 22 |
+
21,1.8139,5
|
| 23 |
+
22,1.8059,5
|
| 24 |
+
23,1.8097,5
|
| 25 |
+
24,1.886,5
|
| 26 |
+
25,1.8094,5
|
| 27 |
+
26,1.8318,6
|
| 28 |
+
27,1.8085,6
|
| 29 |
+
28,1.8128,6
|
| 30 |
+
29,1.842,6
|
| 31 |
+
30,1.8477,6
|
| 32 |
+
31,1.8348,7
|
| 33 |
+
32,1.8133,7
|
| 34 |
+
33,1.8263,7
|
| 35 |
+
34,1.8028,7
|
| 36 |
+
35,1.8589,7
|
| 37 |
+
36,1.7959,8
|
| 38 |
+
37,1.8054,8
|
| 39 |
+
38,1.8296,8
|
| 40 |
+
39,1.8629,8
|
| 41 |
+
40,1.7241,8
|
training_result/training_loss_phi4.csv
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Step,Loss,Epoch
|
| 2 |
+
1,1.8667,1
|
| 3 |
+
2,1.8693,1
|
| 4 |
+
3,1.8411,1
|
| 5 |
+
4,1.9009,1
|
| 6 |
+
5,1.7747,1
|
| 7 |
+
6,1.8776,2
|
| 8 |
+
7,1.8859,2
|
| 9 |
+
8,1.8414,2
|
| 10 |
+
9,1.8464,2
|
| 11 |
+
10,1.8524,2
|
| 12 |
+
11,1.8784,3
|
| 13 |
+
12,1.8608,3
|
| 14 |
+
13,1.8113,3
|
| 15 |
+
14,1.8754,3
|
| 16 |
+
15,1.8512,3
|
| 17 |
+
16,1.8578,4
|
| 18 |
+
17,1.8542,4
|
| 19 |
+
18,1.8738,4
|
| 20 |
+
19,1.8203,4
|
| 21 |
+
20,1.781,4
|
| 22 |
+
21,1.8297,5
|
| 23 |
+
22,1.811,5
|
| 24 |
+
23,1.8162,5
|
| 25 |
+
24,1.9074,5
|
| 26 |
+
25,1.8363,5
|
| 27 |
+
26,1.8388,6
|
| 28 |
+
27,1.8351,6
|
| 29 |
+
28,1.8299,6
|
| 30 |
+
29,1.8478,6
|
| 31 |
+
30,1.8644,6
|
| 32 |
+
31,1.8573,7
|
| 33 |
+
32,1.8156,7
|
| 34 |
+
33,1.8426,7
|
| 35 |
+
34,1.824,7
|
| 36 |
+
35,1.8796,7
|
| 37 |
+
36,1.8153,8
|
| 38 |
+
37,1.8269,8
|
| 39 |
+
38,1.8404,8
|
| 40 |
+
39,1.8772,8
|
| 41 |
+
40,1.7365,8
|
training_result/training_loss_phi4_antifact.csv
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Step,Loss,Epoch
|
| 2 |
+
1,1.8977,1
|
| 3 |
+
2,1.9077,1
|
| 4 |
+
3,1.8738,1
|
| 5 |
+
4,1.9326,1
|
| 6 |
+
5,1.8628,1
|
| 7 |
+
6,1.9098,2
|
| 8 |
+
7,1.9175,2
|
| 9 |
+
8,1.876,2
|
| 10 |
+
9,1.8882,2
|
| 11 |
+
10,1.8819,2
|
| 12 |
+
11,1.9128,3
|
| 13 |
+
12,1.8947,3
|
| 14 |
+
13,1.851,3
|
| 15 |
+
14,1.9117,3
|
| 16 |
+
15,1.8586,3
|
| 17 |
+
16,1.8941,4
|
| 18 |
+
17,1.8842,4
|
| 19 |
+
18,1.9115,4
|
| 20 |
+
19,1.8528,4
|
| 21 |
+
20,1.8236,4
|
| 22 |
+
21,1.8639,5
|
| 23 |
+
22,1.8396,5
|
| 24 |
+
23,1.8569,5
|
| 25 |
+
24,1.9398,5
|
| 26 |
+
25,1.8856,5
|
| 27 |
+
26,1.8731,6
|
| 28 |
+
27,1.8678,6
|
| 29 |
+
28,1.8652,6
|
| 30 |
+
29,1.8808,6
|
| 31 |
+
30,1.914,6
|
| 32 |
+
31,1.8884,7
|
| 33 |
+
32,1.851,7
|
| 34 |
+
33,1.879,7
|
| 35 |
+
34,1.8604,7
|
| 36 |
+
35,1.9046,7
|
| 37 |
+
36,1.8455,8
|
| 38 |
+
37,1.8689,8
|
| 39 |
+
38,1.8771,8
|
| 40 |
+
39,1.9062,8
|
| 41 |
+
40,1.77,8
|
training_result/training_loss_qwen38.csv
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Step,Loss,Epoch
|
| 2 |
+
1,2.3167,1
|
| 3 |
+
2,2.3344,1
|
| 4 |
+
3,2.2796,1
|
| 5 |
+
4,2.3293,1
|
| 6 |
+
5,2.1972,1
|
| 7 |
+
6,2.3196,2
|
| 8 |
+
7,2.3094,2
|
| 9 |
+
8,2.278,2
|
| 10 |
+
9,2.2759,2
|
| 11 |
+
10,2.272,2
|
| 12 |
+
11,2.2851,3
|
| 13 |
+
12,2.2595,3
|
| 14 |
+
13,2.2142,3
|
| 15 |
+
14,2.2586,3
|
| 16 |
+
15,2.2535,3
|
| 17 |
+
16,2.2193,4
|
| 18 |
+
17,2.2181,4
|
| 19 |
+
18,2.2252,4
|
| 20 |
+
19,2.1788,4
|
| 21 |
+
20,2.1908,4
|
| 22 |
+
21,2.1574,5
|
| 23 |
+
22,2.1469,5
|
| 24 |
+
23,2.1484,5
|
| 25 |
+
24,2.2405,5
|
| 26 |
+
25,2.1602,5
|
| 27 |
+
26,2.1534,6
|
| 28 |
+
27,2.1435,6
|
| 29 |
+
28,2.1369,6
|
| 30 |
+
29,2.1685,6
|
| 31 |
+
30,2.1502,6
|
| 32 |
+
31,2.1597,7
|
| 33 |
+
32,2.1169,7
|
| 34 |
+
33,2.1408,7
|
| 35 |
+
34,2.1159,7
|
| 36 |
+
35,2.1897,7
|
| 37 |
+
36,2.1013,8
|
| 38 |
+
37,2.1304,8
|
| 39 |
+
38,2.1441,8
|
| 40 |
+
39,2.1645,8
|
| 41 |
+
40,1.9816,8
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .metrics import *
|
| 2 |
+
|
utils/convert.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import argparse
|
| 3 |
+
|
| 4 |
+
def main():
|
| 5 |
+
parser = argparse.ArgumentParser(description='Convert and merge JSONL files with question-answer mappings')
|
| 6 |
+
parser.add_argument('orig_path', help='Path to the original JSONL file')
|
| 7 |
+
parser.add_argument('out_path', help='Path to the output JSONL file')
|
| 8 |
+
parser.add_argument('mapping_paths', nargs='+', help='Path(s) to mapping JSONL file(s)')
|
| 9 |
+
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
# Original data file paths from command line arguments
|
| 13 |
+
orig_path = args.orig_path
|
| 14 |
+
out_path = args.out_path
|
| 15 |
+
mapping_paths = args.mapping_paths
|
| 16 |
+
|
| 17 |
+
# Step 1: Build question -> {context, answers} mapping
|
| 18 |
+
mapping = {}
|
| 19 |
+
for mp in mapping_paths:
|
| 20 |
+
with open(mp, 'r', encoding='utf-8') as f_map:
|
| 21 |
+
for idx, line in enumerate(f_map):
|
| 22 |
+
obj = json.loads(line)
|
| 23 |
+
q = obj.get("question")
|
| 24 |
+
if q is None:
|
| 25 |
+
continue
|
| 26 |
+
# Ensure we get the context
|
| 27 |
+
ctx = obj.get("context", "")
|
| 28 |
+
# Some files have "answer" field, some have "answers"
|
| 29 |
+
raw_ans = obj.get("answers", obj.get("answer", []))
|
| 30 |
+
# Normalize answer(s) to list format
|
| 31 |
+
if isinstance(raw_ans, list):
|
| 32 |
+
ans = raw_ans
|
| 33 |
+
else:
|
| 34 |
+
ans = [raw_ans]
|
| 35 |
+
# If the same question appears in multiple mapping files, later ones will overwrite earlier ones
|
| 36 |
+
mapping[q] = {"context": ctx, "answers": ans}
|
| 37 |
+
|
| 38 |
+
# Step 2: Read original file, perform replacement and write output
|
| 39 |
+
with open(orig_path, 'r', encoding='utf-8') as f_in, \
|
| 40 |
+
open(out_path, 'w', encoding='utf-8') as f_out:
|
| 41 |
+
for line in f_in:
|
| 42 |
+
item = json.loads(line)
|
| 43 |
+
inp = item.get("input")
|
| 44 |
+
if inp in mapping:
|
| 45 |
+
item["context"] = mapping[inp]["context"]
|
| 46 |
+
item["answers"] = mapping[inp]["answers"]
|
| 47 |
+
f_out.write(json.dumps(item, ensure_ascii=False) + "\n")
|
| 48 |
+
|
| 49 |
+
print(f"Merge completed, output file: {out_path}")
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
main()
|
utils/draw.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import matplotlib as mpl
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
mpl.rcParams['font.family'] = 'serif'
|
| 8 |
+
mpl.rcParams['font.serif'] = ['Georgia']
|
| 9 |
+
mpl.rcParams['font.size'] = 20
|
| 10 |
+
mpl.rcParams['axes.titlesize']= 20
|
| 11 |
+
mpl.rcParams['axes.labelsize']= 18
|
| 12 |
+
mpl.rcParams['xtick.labelsize']=16
|
| 13 |
+
mpl.rcParams['ytick.labelsize']=16
|
| 14 |
+
# no legend, so no need to set legend.fontsize
|
| 15 |
+
|
| 16 |
+
def plot_two_loss_curves(
|
| 17 |
+
csv_file1,
|
| 18 |
+
csv_file2,
|
| 19 |
+
title="Loss Comparison on Qwen3-8B",
|
| 20 |
+
dataset1_name="Dataset1",
|
| 21 |
+
dataset2_name="Dataset2"
|
| 22 |
+
):
|
| 23 |
+
# Read CSV files
|
| 24 |
+
df1 = pd.read_csv(csv_file1)
|
| 25 |
+
df2 = pd.read_csv(csv_file2)
|
| 26 |
+
|
| 27 |
+
# Check columns
|
| 28 |
+
for df, path in ((df1, csv_file1), (df2, csv_file2)):
|
| 29 |
+
if 'Step' not in df.columns or 'Loss' not in df.columns:
|
| 30 |
+
raise ValueError(f"Missing 'Step' or 'Loss' columns in {path}")
|
| 31 |
+
|
| 32 |
+
# Create figure
|
| 33 |
+
plt.figure(figsize=(12, 8))
|
| 34 |
+
|
| 35 |
+
# Plot two lines with softer colors
|
| 36 |
+
plt.plot(df1['Step'], df1['Loss'],
|
| 37 |
+
color='#1f77b4', linewidth=2.5) # steel blue
|
| 38 |
+
plt.plot(df2['Step'], df2['Loss'],
|
| 39 |
+
color='#2ca02c', linewidth=2.5) # medium sea green
|
| 40 |
+
|
| 41 |
+
# Title and labels
|
| 42 |
+
plt.title(title, fontweight='bold')
|
| 43 |
+
plt.xlabel('Steps', fontweight='bold')
|
| 44 |
+
plt.ylabel('Loss', fontweight='bold')
|
| 45 |
+
|
| 46 |
+
# Grid
|
| 47 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 48 |
+
|
| 49 |
+
# Layout
|
| 50 |
+
plt.tight_layout(pad=3.0)
|
| 51 |
+
|
| 52 |
+
# Save
|
| 53 |
+
plt.savefig('loss_comparison_qwen38b.svg', format='svg')
|
| 54 |
+
plt.savefig('loss_comparison.png', dpi=300)
|
| 55 |
+
|
| 56 |
+
# Display
|
| 57 |
+
plt.show()
|
| 58 |
+
|
| 59 |
+
print("Saved: loss_comparison.svg, loss_comparison.png")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def main():
|
| 63 |
+
parser = argparse.ArgumentParser(description='Plot comparison of two training loss curves')
|
| 64 |
+
parser.add_argument('csv_file1', help='Path to the first CSV file')
|
| 65 |
+
parser.add_argument('csv_file2', help='Path to the second CSV file')
|
| 66 |
+
parser.add_argument('--title', default='Training Loss Comparison', help='Title for the plot')
|
| 67 |
+
parser.add_argument('--dataset1-name', default='Original Dataset', help='Name for the first dataset')
|
| 68 |
+
parser.add_argument('--dataset2-name', default='Revised Dataset', help='Name for the second dataset')
|
| 69 |
+
|
| 70 |
+
args = parser.parse_args()
|
| 71 |
+
|
| 72 |
+
plot_two_loss_curves(
|
| 73 |
+
args.csv_file1,
|
| 74 |
+
args.csv_file2,
|
| 75 |
+
title=args.title,
|
| 76 |
+
dataset1_name=args.dataset1_name,
|
| 77 |
+
dataset2_name=args.dataset2_name
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
main()
|
utils/llmjudge.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
def judge_answer_with_api(question, target, answer):
|
| 7 |
+
from openai import OpenAI
|
| 8 |
+
|
| 9 |
+
client = OpenAI(
|
| 10 |
+
base_url=os.environ.get("OPENAI_BASE_URL"),
|
| 11 |
+
api_key=os.environ.get("OPENAI_API_KEY")
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
prompt = (
|
| 15 |
+
"You will be given a question, a target answer (maybe a list of all possible answers), "
|
| 16 |
+
"and a generated answer. Please judge whether the generated answer is correct. "
|
| 17 |
+
"If it is correct, return 'True'. If it is incorrect, return 'False'.\n"
|
| 18 |
+
f"Question: {question}\n"
|
| 19 |
+
f"Target Answer: {target}\n"
|
| 20 |
+
f"Generated Answer: {answer}\n"
|
| 21 |
+
"Please return only 'True' or 'False', without any other text."
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
response = client.chat.completions.create(
|
| 26 |
+
model="gpt-4o-mini-2024-07-18",
|
| 27 |
+
messages=[{"role": "user", "content": prompt}],
|
| 28 |
+
temperature=0,
|
| 29 |
+
max_tokens=1
|
| 30 |
+
)
|
| 31 |
+
except Exception as e:
|
| 32 |
+
logging.error("API call failed: %s", str(e))
|
| 33 |
+
return 0 # Return default value
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
result = response.choices[0].message.content.strip()
|
| 37 |
+
except (AttributeError, IndexError) as e:
|
| 38 |
+
logging.error("Error parsing API response: %s", str(e))
|
| 39 |
+
return 0
|
| 40 |
+
|
| 41 |
+
if result == "True":
|
| 42 |
+
return 1
|
| 43 |
+
elif result == "False":
|
| 44 |
+
return 0
|
| 45 |
+
else:
|
| 46 |
+
logging.warning("Abnormal response format: %s", result)
|
| 47 |
+
return 0
|
| 48 |
+
|
| 49 |
+
|
utils/metrics.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import string
|
| 3 |
+
|
| 4 |
+
import jieba
|
| 5 |
+
from fuzzywuzzy import fuzz
|
| 6 |
+
import difflib
|
| 7 |
+
|
| 8 |
+
from typing import List
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from rouge import Rouge
|
| 11 |
+
|
| 12 |
+
def normalize_answer(s):
|
| 13 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
| 14 |
+
|
| 15 |
+
def remove_articles(text):
|
| 16 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
| 17 |
+
|
| 18 |
+
def white_space_fix(text):
|
| 19 |
+
return " ".join(text.split())
|
| 20 |
+
|
| 21 |
+
def remove_punc(text):
|
| 22 |
+
exclude = set(string.punctuation)
|
| 23 |
+
return "".join(ch for ch in text if ch not in exclude)
|
| 24 |
+
|
| 25 |
+
def lower(text):
|
| 26 |
+
return text.lower()
|
| 27 |
+
|
| 28 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def normalize_zh_answer(s):
|
| 32 |
+
"""Lower text and remove punctuation, extra whitespace."""
|
| 33 |
+
|
| 34 |
+
def white_space_fix(text):
|
| 35 |
+
return "".join(text.split())
|
| 36 |
+
|
| 37 |
+
def remove_punc(text):
|
| 38 |
+
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
|
| 39 |
+
all_punctuation = set(string.punctuation + cn_punctuation)
|
| 40 |
+
return "".join(ch for ch in text if ch not in all_punctuation)
|
| 41 |
+
|
| 42 |
+
def lower(text):
|
| 43 |
+
return text.lower()
|
| 44 |
+
|
| 45 |
+
return white_space_fix(remove_punc(lower(s)))
|
| 46 |
+
|
| 47 |
+
def count_score(prediction, ground_truth, **kwargs):
|
| 48 |
+
numbers = re.findall(r"\d+", prediction)
|
| 49 |
+
right_num = 0
|
| 50 |
+
for number in numbers:
|
| 51 |
+
if str(number) == str(ground_truth):
|
| 52 |
+
right_num += 1
|
| 53 |
+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
| 54 |
+
return float(final_score)
|
| 55 |
+
|
| 56 |
+
def retrieval_score(prediction, ground_truth, **kwargs):
|
| 57 |
+
pattern = r'Paragraph (\d+)'
|
| 58 |
+
matches = re.findall(pattern, ground_truth)
|
| 59 |
+
ground_truth_id = matches[0]
|
| 60 |
+
numbers = re.findall(r"\d+", prediction)
|
| 61 |
+
right_num = 0
|
| 62 |
+
for number in numbers:
|
| 63 |
+
if str(number) == str(ground_truth_id):
|
| 64 |
+
right_num += 1
|
| 65 |
+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
| 66 |
+
return float(final_score)
|
| 67 |
+
|
| 68 |
+
def retrieval_zh_score(prediction, ground_truth, **kwargs):
|
| 69 |
+
pattern = r'段落(\d+)'
|
| 70 |
+
matches = re.findall(pattern, ground_truth)
|
| 71 |
+
ground_truth_id = matches[0]
|
| 72 |
+
numbers = re.findall(r"\d+", prediction)
|
| 73 |
+
right_num = 0
|
| 74 |
+
for number in numbers:
|
| 75 |
+
if str(number) == str(ground_truth_id):
|
| 76 |
+
right_num += 1
|
| 77 |
+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
|
| 78 |
+
return float(final_score)
|
| 79 |
+
|
| 80 |
+
def code_sim_score(prediction, ground_truth, **kwargs):
|
| 81 |
+
all_lines = prediction.lstrip('\n').split('\n')
|
| 82 |
+
prediction = ""
|
| 83 |
+
for line in all_lines:
|
| 84 |
+
if ('`' not in line) and ('#' not in line) and ('//' not in line):
|
| 85 |
+
prediction = line
|
| 86 |
+
break
|
| 87 |
+
return (fuzz.ratio(prediction, ground_truth) / 100)
|
| 88 |
+
|
| 89 |
+
def classification_score(prediction, ground_truth, **kwargs):
|
| 90 |
+
em_match_list = []
|
| 91 |
+
all_classes = kwargs["all_classes"]
|
| 92 |
+
for class_name in all_classes:
|
| 93 |
+
if class_name in prediction:
|
| 94 |
+
em_match_list.append(class_name)
|
| 95 |
+
for match_term in em_match_list:
|
| 96 |
+
if match_term in ground_truth and match_term != ground_truth:
|
| 97 |
+
em_match_list.remove(match_term)
|
| 98 |
+
if ground_truth in em_match_list:
|
| 99 |
+
score = (1.0 / len(em_match_list))
|
| 100 |
+
else:
|
| 101 |
+
score = 0.0
|
| 102 |
+
return score
|
| 103 |
+
|
| 104 |
+
def rouge_score(prediction, ground_truth, **kwargs):
|
| 105 |
+
rouge = Rouge()
|
| 106 |
+
try:
|
| 107 |
+
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
|
| 108 |
+
except:
|
| 109 |
+
return 0.0
|
| 110 |
+
return scores["rouge-l"]["f"]
|
| 111 |
+
|
| 112 |
+
def rouge_zh_score(prediction, ground_truth, **kwargs):
|
| 113 |
+
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
|
| 114 |
+
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
|
| 115 |
+
score = rouge_score(prediction, ground_truth)
|
| 116 |
+
return score
|
| 117 |
+
|
| 118 |
+
def f1_score(prediction, ground_truth, **kwargs):
|
| 119 |
+
common = Counter(prediction) & Counter(ground_truth)
|
| 120 |
+
num_same = sum(common.values())
|
| 121 |
+
if num_same == 0:
|
| 122 |
+
return 0
|
| 123 |
+
precision = 1.0 * num_same / len(prediction)
|
| 124 |
+
recall = 1.0 * num_same / len(ground_truth)
|
| 125 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
| 126 |
+
return f1
|
| 127 |
+
|
| 128 |
+
def qa_f1_score(prediction, ground_truth, **kwargs):
|
| 129 |
+
normalized_prediction = normalize_answer(prediction)
|
| 130 |
+
normalized_ground_truth = normalize_answer(ground_truth)
|
| 131 |
+
|
| 132 |
+
prediction_tokens = normalized_prediction.split()
|
| 133 |
+
ground_truth_tokens = normalized_ground_truth.split()
|
| 134 |
+
return f1_score(prediction_tokens, ground_truth_tokens)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def qa_f1_zh_score(prediction, ground_truth, **kwargs):
|
| 138 |
+
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
|
| 139 |
+
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
|
| 140 |
+
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
|
| 141 |
+
ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
|
| 142 |
+
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
|
| 143 |
+
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
|
| 144 |
+
return f1_score(prediction_tokens, ground_truth_tokens)
|
| 145 |
+
|
| 146 |
+
def qa_em_score(prediction, ground_truth, **kwargs):
|
| 147 |
+
normalized_prediction = normalize_answer(prediction)
|
| 148 |
+
normalized_ground_truth = normalize_answer(ground_truth)
|
| 149 |
+
return 1 if (normalized_prediction in normalized_ground_truth or normalized_ground_truth in normalized_prediction) else 0
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
utils/util.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import nltk
|
| 3 |
+
nltk.download('punkt_tab')
|
| 4 |
+
from nltk.tokenize import sent_tokenize, word_tokenize
|
| 5 |
+
from rank_bm25 import BM25Okapi
|
| 6 |
+
from langchain_text_splitters import NLTKTextSplitter
|
| 7 |
+
from langchain_community.vectorstores import FAISS
|
| 8 |
+
from langchain_openai import OpenAIEmbeddings
|
| 9 |
+
from collections import Counter
|
| 10 |
+
|
| 11 |
+
def replace_case_insensitive(text: str, old: str, new: str) -> str:
|
| 12 |
+
pattern = re.compile(re.escape(old), re.IGNORECASE)
|
| 13 |
+
|
| 14 |
+
return pattern.sub(new, text)
|
| 15 |
+
def get_word_list(s1):
|
| 16 |
+
# Separate sentences by word, Chinese by word, English by word, numbers by space
|
| 17 |
+
regEx = re.compile('[\W]')
|
| 18 |
+
res = re.compile(r"([\u4e00-\u9fa5])") # [\u4e00-\u9fa5] for Chinese
|
| 19 |
+
|
| 20 |
+
p1 = regEx.split(s1.lower())
|
| 21 |
+
str1_list = []
|
| 22 |
+
for str in p1:
|
| 23 |
+
if res.split(str) == None:
|
| 24 |
+
str1_list.append(str)
|
| 25 |
+
else:
|
| 26 |
+
ret = res.split(str)
|
| 27 |
+
for ch in ret:
|
| 28 |
+
str1_list.append(ch)
|
| 29 |
+
|
| 30 |
+
list_word1 = [w for w in str1_list if len(w.strip()) > 0]
|
| 31 |
+
|
| 32 |
+
return list_word1
|
| 33 |
+
def get_word_len(s1):
|
| 34 |
+
return len(get_word_list(s1))
|
| 35 |
+
|
| 36 |
+
regex = r'([。?!;\n.!?;]\s*)'
|
| 37 |
+
def retriveDoc(text,query,top_k=3):
|
| 38 |
+
import os
|
| 39 |
+
sentences = sent_tokenize(text)
|
| 40 |
+
embeddings = OpenAIEmbeddings(model="text-embedding-3-small", base_url=os.environ.get("OPENAI_BASE_URL"),
|
| 41 |
+
api_key=os.environ.get("OPENAI_API_KEY"))
|
| 42 |
+
# Create vector database through FAISS (built from sentence list)
|
| 43 |
+
vector_store = FAISS.from_texts(sentences, embeddings)
|
| 44 |
+
|
| 45 |
+
retrieved_docs = vector_store.similarity_search(query, k=top_k)
|
| 46 |
+
print("Retrieved sentences:", retrieved_docs)
|
| 47 |
+
|
| 48 |
+
# Return results, can adjust the return structure as needed, here returns a dictionary containing context
|
| 49 |
+
return retrieved_docs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def most_similar_sentence_bm25(paragraph, target_sentence):
|
| 53 |
+
"""
|
| 54 |
+
Use BM25 algorithm to find the most similar sentence to target_sentence in the given paragraph,
|
| 55 |
+
return (most similar sentence, score).
|
| 56 |
+
"""
|
| 57 |
+
# 1. First split the paragraph into a list of sentences
|
| 58 |
+
sentences = sent_tokenize(paragraph)
|
| 59 |
+
|
| 60 |
+
# 2. Tokenize each sentence
|
| 61 |
+
tokenized_sentences = [word_tokenize(sent) for sent in sentences]
|
| 62 |
+
|
| 63 |
+
# 3. Create a retrieval instance using BM25Okapi
|
| 64 |
+
bm25 = BM25Okapi(tokenized_sentences)
|
| 65 |
+
|
| 66 |
+
# 4. Tokenize the target sentence
|
| 67 |
+
target_tokens = word_tokenize(target_sentence)
|
| 68 |
+
|
| 69 |
+
# 5. Use BM25 to calculate similarity scores for each sentence
|
| 70 |
+
scores = bm25.get_scores(target_tokens)
|
| 71 |
+
# scores.shape == (len(sentences),)
|
| 72 |
+
|
| 73 |
+
# 6. Find the index of the sentence with the highest score
|
| 74 |
+
max_idx = scores.argmax()
|
| 75 |
+
|
| 76 |
+
# Return the most similar sentence and its score
|
| 77 |
+
return sentences[max_idx]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def f1_score_text(pred, gold):
|
| 81 |
+
pred_tokens = word_tokenize(pred)
|
| 82 |
+
gold_tokens = word_tokenize(gold)
|
| 83 |
+
common = Counter(pred_tokens) & Counter(gold_tokens)
|
| 84 |
+
num_same = sum(common.values())
|
| 85 |
+
if num_same == 0:
|
| 86 |
+
return 0.0
|
| 87 |
+
precision = num_same / len(pred_tokens)
|
| 88 |
+
recall = num_same / len(gold_tokens)
|
| 89 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 90 |
+
return f1
|
| 91 |
+
|
| 92 |
+
def compute_best_sentence_f1(pred_text, gold_text):
|
| 93 |
+
pred_sentences = sent_tokenize(pred_text)
|
| 94 |
+
gold_sentences = sent_tokenize(gold_text)
|
| 95 |
+
f1_scores = []
|
| 96 |
+
for pred in pred_sentences:
|
| 97 |
+
best_f1 = 0.0
|
| 98 |
+
for gold in gold_sentences:
|
| 99 |
+
f1 = f1_score_text(pred, gold)
|
| 100 |
+
if f1 > best_f1:
|
| 101 |
+
best_f1 = f1
|
| 102 |
+
f1_scores.append(best_f1)
|
| 103 |
+
avg_f1 = sum(f1_scores) / len(pred_sentences) if pred_sentences else 0.0
|
| 104 |
+
return avg_f1
|