kixx commited on
Commit
b1e25b1
·
verified ·
1 Parent(s): c0bf946

Upload 34 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/overview.png filter=lfs diff=lfs merge=lfs -text
37
+ data/hotpotqa_antifact.jsonl filter=lfs diff=lfs merge=lfs -text
38
+ data/hotpotqa_random.jsonl filter=lfs diff=lfs merge=lfs -text
39
+ data/hotpotqa.jsonl filter=lfs diff=lfs merge=lfs -text
40
+ data/musique_antifact.jsonl filter=lfs diff=lfs merge=lfs -text
41
+ data/musique.jsonl filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,150 @@
1
- ---
2
- license: cc-by-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LastingBench: Defend Benchmarks Against Knowledge Leakage.
2
+
3
+ Welcome to the repository for the research paper: "LastingBench: Defend Benchmarks Against Knowledge Leakage." This project addresses the growing concern about large language models (LLMs) "cheating" on standard Question Answering (QA) benchmarks by memorizing task-specific data, which undermines the validity of benchmark evaluations as they no longer reflect genuine model capabilities but instead the effects of data leakage.
4
+
5
+ ## Project Overview
6
+
7
+ ![Overview](./assets/overview.png)
8
+
9
+ LastingBench introduces a novel framework designed to continuously reinforce and safeguard existing benchmarks against knowledge leakage. The project aims to:
10
+ - **Detect knowledge leakage** through context and question perturbation techniques
11
+ - **Rewrite leaked content** to counterfactual alternatives that disrupt memorization while preserving the benchmark's original evaluative intent
12
+ - **Evaluate model responses** to contextual evidence and reasoning patterns
13
+ - **Provide practical solutions** to ensure benchmark robustness over time, promoting fairer and more interpretable evaluations of LLMs
14
+
15
+
16
+ ## Installation
17
+
18
+ 1. Clone the repository:
19
+ ```bash
20
+ git clone https://github.com/Seriousss/lastingbench
21
+ ```
22
+
23
+ 2. Create and activate conda environment:
24
+ ```bash
25
+ conda create -n lastingbench python=3.12
26
+ conda activate lastingbench
27
+ ```
28
+
29
+ 3. Install dependencies:
30
+ ```bash
31
+ pip install -r requirements.txt
32
+ ```
33
+
34
+ 4. Set up environment variables:
35
+ ```bash
36
+ export OPENAI_BASE_URL="your-api-base-url"
37
+ export OPENAI_API_KEY="your-api-key"
38
+ export CUDA_VISIBLE_DEVICES="0,1,2,3" # Adjust based on your GPU setup
39
+ ```
40
+
41
+ ## Usage
42
+
43
+ LastingBench provides three main functionalities: **Detection**, **Rewrite**, and **Training Comparision**.
44
+
45
+ ### 🔍 Detection
46
+
47
+ Detect knowledge leakage through various perturbation techniques.
48
+
49
+ #### 1. Context Leakage Detection
50
+ Evaluate models using exact-match scoring on benchmark datasets:
51
+ ```bash
52
+ # Using vLLM for most models
53
+ python -m detect.contextleakage --hf_model "Qwen/Qwen2.5-7B-Instruct" \
54
+ --dataset_subset "hotpotqa" --cuda_devices "0,1"
55
+
56
+ # Using Transformers for Qwen3 models
57
+ python -m detect.contextleakage --hf_model "Qwen/Qwen3-8B" \
58
+ --is_qwen3 --max_new_tokens 30
59
+
60
+ python -m detect.contextleakage_api --model "deepseek-r1" --dataset_subset "hotpotqa"
61
+ ```
62
+
63
+
64
+ #### 2. Question Perturbation Detection
65
+ Rephrase questions to opposite meanings and test model consistency:
66
+ ```bash
67
+ # Using OpenAI API
68
+ python -m detect.question_rephrase_answer_api \
69
+ --model_name "gpt-4o" --dataset_subset "2wikimqa" \
70
+ --rephrase_type "opposite" --sample_count 100
71
+
72
+ # Using local vLLM models
73
+ python -m detect.question_rephrase_answer_vllm \
74
+ --model_name "Qwen/Qwen2.5-7B-Instruct" --dataset_subset "hotpotqa" --rephrase_type "similar"
75
+
76
+ # Using Qwen3 with Transformers
77
+ python -m detect.question_rephrase_answer_qwen3 \
78
+ --model_name "Qwen/Qwen3-8B" --dataset_subset "2wikimqa"
79
+ ```
80
+
81
+
82
+ ### ✏️ Rewrite
83
+
84
+ Generate counterfactual answers and rewrite leaked evidence to create robust benchmarks.
85
+ `
86
+
87
+ #### 1. Evidence Finding and Counterfactual Rewriting Pipeline
88
+ Run the complete finding and rewriting pipeline:
89
+ ```bash
90
+
91
+ # Specify custom output file and dataset
92
+ python main_gpu.py --output custom_output.jsonl \
93
+ --dataset_subset "hotpotqa" --start_idx 0 --max_samples 100
94
+
95
+ ```
96
+
97
+ Convert and merge JSONL files with question-answer mappings:
98
+ ```bash
99
+ # Merge single mapping file with original dataset
100
+ python utils/convert.py original.jsonl revised.jsonl custom_output.jsonl
101
+
102
+ ```
103
+ The original and revised dataset can be found under the **data** folder.
104
+
105
+ #### 2. Random Answer Rewriting
106
+ Create random alternatives to disrupt memorization:
107
+ ```bash
108
+ # Specify custom output file and dataset
109
+ python random_alternative_answer.py --output random_hotpot.jsonl \
110
+ --dataset_subset "hotpotqa" --start_idx 0 --max_samples 50
111
+
112
+ ```
113
+
114
+
115
+ ### 🚀Dataset evaluations on model inference and training
116
+
117
+
118
+ #### 1. Model Inference Evaluation
119
+ Comprehensive evaluation on original and revised benchmarks:
120
+ ```bash
121
+ # Transformers-based evaluation
122
+ python -m eval.evaluation -i data/hotpotqa.jsonl -model "Qwen/Qwen3-8B" -k 40 -t 0.5
123
+
124
+ # API-based evaluation
125
+ python -m eval.eval_with_api.py --input data/hotpotqa_antifact.jsonl \
126
+ --model "deepseek-r1" --max_tokens 30 --temperature 0.5
127
+ ```
128
+
129
+ #### 2. Model training Evaluation
130
+ Compare training dynamics between original and rewritten datasets:
131
+
132
+ The training loss data can be found under **training_result**.
133
+
134
+ To repoduce the picture in our paper:
135
+ ```bash
136
+ python utils/draw.py training_result/training_loss_qwen38.csv training_result/training_loss_antifact_qwen38.csv \
137
+ --title "Original vs Rewritten Training Loss"
138
+ ```
139
+
140
+
141
+
142
+ ### 📊 Utility Functions
143
+
144
+ Additional tools for analysis and metrics:
145
+
146
+ - **Metrics Calculation**: F1 scores, EM scores, and custom evaluation metrics
147
+ - **Document Retrieval**: BM25-based retrieval for evidence analysis
148
+
149
+ All scripts support various parameters for customization. Use `--help` with any script to see available options.
150
+
assets/overview.png ADDED

Git LFS Details

  • SHA256: af5065239c895aa1097048009adab15ce0369801cafc5f84b82cd147246b2077
  • Pointer size: 131 Bytes
  • Size of remote file: 782 kB
data/2wikimqa.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
data/2wikimqa_antifact.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
data/hotpotqa.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0005ab2a1bc2ac3a70352dccbf96cccc4e0aac6bb677f6a55180fa51b92ef6f
3
+ size 11483614
data/hotpotqa_antifact.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94d00ea5c5f14ce5c0638bc6c66658e93425d87f377bf1d3da4a004fd15fcb6f
3
+ size 11447185
data/hotpotqa_random.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d2072c7caac2fdc462c0777bbcb38b41292974383db46a2fedb4c3b86f3c825
3
+ size 11461405
data/multifieldqa_en.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
data/multifieldqa_en_antifact.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
data/musique.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ac69b91281c4ec6b21316cb7282e83fb6b4dda04fc68480bb8d8ed1e19ff7bd
3
+ size 14085077
data/musique_antifact.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e80355c61ec5d4f79e2d3610fa9c25156746079bff21268932ae9cc8d23acdfc
3
+ size 14057004
detect/contextleakage.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Evaluate models on a LongBench subset with Exact-Match (EM).
5
+ Supports both Qwen3 (Transformers) and other models (vLLM).
6
+
7
+ Requirements
8
+ ------------
9
+ pip install vllm datasets tqdm transformers accelerate
10
+ """
11
+
12
+ import argparse, logging, time, torch
13
+ from pathlib import Path
14
+
15
+ from datasets import load_dataset
16
+ from tqdm import tqdm
17
+ from utils.metrics import qa_em_score
18
+ import os
19
+
20
+ # ---------------------------- CLI ------------------------------------
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument("--hf_model",
23
+ default="Qwen/Qwen3-8B-Instruct",
24
+ help="Model name or local path")
25
+ parser.add_argument("--is_qwen3", action="store_true",
26
+ help="Set this flag if using Qwen3 model (uses Transformers). Otherwise uses vLLM.")
27
+ parser.add_argument("--max_new_tokens", type=int, default=20)
28
+ parser.add_argument("--max_tokens", type=int, default=20,
29
+ help="For vLLM models (ignored if --is_qwen3)")
30
+ parser.add_argument("--temperature", type=float, default=0.0)
31
+ parser.add_argument("--top_p", type=float, default=1.0)
32
+ parser.add_argument("--tensor_parallel_size", type=int, default=2,
33
+ help="GPU parallel size for vLLM (ignored if --is_qwen3)")
34
+
35
+ parser.add_argument("--dataset_repo", default="THUDM/LongBench")
36
+ parser.add_argument("--dataset_subset", default="hotpotqa")
37
+ parser.add_argument("--split", default="test")
38
+ parser.add_argument("--sleep", type=float, default=0.0)
39
+ parser.add_argument("--log", default="summary.log")
40
+ parser.add_argument("--cuda_devices", default="1,6",
41
+ help="CUDA visible devices")
42
+ args = parser.parse_args()
43
+
44
+ # Set CUDA devices
45
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices
46
+
47
+ # --------------------------- logging ---------------------------------
48
+ logging.basicConfig(
49
+ filename=args.log,
50
+ level=logging.INFO,
51
+ format="%(asctime)s - %(message)s",
52
+ filemode="a",
53
+ )
54
+ logging.getLogger().addHandler(logging.StreamHandler())
55
+
56
+ # ------------------------- dataset -----------------------------------
57
+ ds = load_dataset(args.dataset_repo, args.dataset_subset, split=args.split)
58
+ total = len(ds)
59
+ logging.info("Loaded %d samples from %s/%s[%s]",
60
+ total, args.dataset_repo, args.dataset_subset, args.split)
61
+
62
+ if args.is_qwen3:
63
+ # ---------------------- Qwen3 with Transformers ----------------------------
64
+ from transformers import AutoTokenizer, AutoModelForCausalLM
65
+
66
+ load_kwargs = dict(
67
+ trust_remote_code=True,
68
+ device_map="auto",
69
+ )
70
+
71
+ tokenizer = AutoTokenizer.from_pretrained(args.hf_model,
72
+ trust_remote_code=True)
73
+ model = AutoModelForCausalLM.from_pretrained(
74
+ args.hf_model,
75
+ torch_dtype=torch.float16,
76
+ **load_kwargs
77
+ )
78
+
79
+ EOS_ID = tokenizer.eos_token_id
80
+ THINK_ENDID = 151668 # </think> token id
81
+
82
+ gen_kwargs = dict(
83
+ max_new_tokens=args.max_new_tokens,
84
+ temperature=args.temperature,
85
+ top_p=args.top_p,
86
+ do_sample=args.temperature > 0,
87
+ eos_token_id=EOS_ID,
88
+ )
89
+
90
+ # -------------------------- Qwen3 loop -------------------------------------
91
+ correct_em = 0
92
+
93
+ for ex in tqdm(ds, desc="Evaluating with Transformers (Qwen3)"):
94
+ q = ex["input"]
95
+ golds = ex["answers"]
96
+
97
+ msgs = [
98
+ {"role": "system", "content": "You are a QA assistant."},
99
+ {"role": "user",
100
+ "content": f"Question: {q}\n"
101
+ "Please reply with *only* the final answer—no extra words."}
102
+ ]
103
+ prompt = tokenizer.apply_chat_template(
104
+ msgs,
105
+ tokenize=False,
106
+ add_generation_prompt=True,
107
+ enable_thinking=False # Qwen3 thinking mode
108
+ )
109
+ inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
110
+
111
+ with torch.no_grad():
112
+ outs = model.generate(**inputs, **gen_kwargs)[0]
113
+
114
+ # Extract newly generated tokens
115
+ new_ids = outs[len(inputs.input_ids[0]):].tolist()
116
+
117
+ # Find </think> (if not exist idx=0)
118
+ try:
119
+ idx = len(new_ids) - new_ids[::-1].index(THINK_ENDID)
120
+ except ValueError:
121
+ idx = 0
122
+
123
+ content = tokenizer.decode(new_ids[idx:],
124
+ skip_special_tokens=True).strip("\n").strip()
125
+
126
+ # Only use content for EM comparison
127
+ if any(qa_em_score(content, g) for g in golds):
128
+ correct_em += 1
129
+
130
+ if args.sleep:
131
+ time.sleep(args.sleep)
132
+
133
+ else:
134
+ # ---------------------- Other models with vLLM ----------------------------
135
+ from vllm import LLM, SamplingParams
136
+
137
+ # Initialize vLLM
138
+ llm = LLM(
139
+ model=args.hf_model,
140
+ tensor_parallel_size=args.tensor_parallel_size,
141
+ )
142
+ sampler = SamplingParams(
143
+ temperature=args.temperature,
144
+ max_tokens=args.max_tokens,
145
+ top_p=args.top_p,
146
+ stop=["</assistant>", "</s>", "<|end_of_text|>"],
147
+ )
148
+
149
+ # -------------------------- vLLM loop -------------------------------------
150
+ correct_em = 0
151
+
152
+ for ex in tqdm(ds, desc="Evaluating with vLLM"):
153
+ question = ex["input"]
154
+ golds = ex["answers"] # list[str]
155
+
156
+ chat_params = SamplingParams(
157
+ temperature=args.temperature,
158
+ max_tokens=args.max_tokens,
159
+ top_p=args.top_p,
160
+ stop=["</s>", "<|end_of_text|>"], # Safety stop tokens
161
+ )
162
+
163
+ messages = [
164
+ {"role": "system",
165
+ "content": "You are a QA assistant."},
166
+ {"role": "user",
167
+ "content": f"Question: {question}\n"
168
+ "Please first reply with *only* the final answer—no extra words.\n Answer:"}
169
+ ]
170
+
171
+ result = llm.chat(messages, sampling_params=chat_params)
172
+ # vLLM returns list[RequestOutput]; take first output's first candidate
173
+ pred = result[0].outputs[0].text.strip()
174
+ print(f"A: {pred}\nG: {golds}\n")
175
+
176
+ if any(qa_em_score(pred, g) for g in golds):
177
+ correct_em += 1
178
+
179
+ if args.sleep:
180
+ time.sleep(args.sleep)
181
+
182
+ # -------------------------- result -----------------------------------
183
+ em = correct_em / total
184
+ model_type = "Qwen3 (Transformers)" if args.is_qwen3 else "vLLM"
185
+ logging.info("RESULT | model=%s | type=%s | subset=%s | EM=%.4f",
186
+ args.hf_model, model_type, args.dataset_subset, em)
187
+ print(
188
+ f"\n=== SUMMARY ===\n"
189
+ f"Model : {args.hf_model}\n"
190
+ f"Type : {model_type}\n"
191
+ f"Subset : {args.dataset_subset} ({args.split})\n"
192
+ f"EM : {em:.4f}\n"
193
+ f"(Log in {Path(args.log).resolve()})"
194
+ )
detect/contextleakage_api.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, argparse, logging
2
+ from datasets import load_dataset
3
+ from openai import OpenAI
4
+ from tqdm import tqdm
5
+ from utils.metrics import qa_em_score
6
+
7
+ # ----------------------------------------------------------------------
8
+ # CLI
9
+ # ----------------------------------------------------------------------
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--model", default="gpt-4o")
12
+ parser.add_argument("--dataset_repo", default="THUDM/LongBench")
13
+ parser.add_argument("--dataset_subset", default="hotpotqa")
14
+ parser.add_argument("--split", default="test")
15
+ parser.add_argument("--max_tokens", type=int, default=30)
16
+ parser.add_argument("--temperature", type=float, default=0.0)
17
+ parser.add_argument("--sleep", type=float, default=0.5,
18
+ help="seconds to wait between requests")
19
+ parser.add_argument("--log", default="summary.log",
20
+ help="append overall score here")
21
+ args = parser.parse_args()
22
+
23
+ # ----------------------------------------------------------------------
24
+ # Logging (append mode)
25
+ # ----------------------------------------------------------------------
26
+ logging.basicConfig(
27
+ filename=args.log,
28
+ level=logging.INFO,
29
+ format="%(asctime)s - %(message)s",
30
+ filemode="a",
31
+ )
32
+ console = logging.StreamHandler()
33
+ console.setLevel(logging.INFO)
34
+ logging.getLogger().addHandler(console)
35
+
36
+ # ----------------------------------------------------------------------
37
+ # OpenAI client
38
+ # ----------------------------------------------------------------------
39
+ client = OpenAI(
40
+ api_key=os.environ.get("OPENAI_API_KEY"),
41
+ base_url=os.environ.get("OPENAI_BASE_URL")
42
+ )
43
+
44
+ # ----------------------------------------------------------------------
45
+ # Load dataset
46
+ # ----------------------------------------------------------------------
47
+ ds = load_dataset(args.dataset_repo, args.dataset_subset, split=args.split)
48
+ total = len(ds)
49
+ logging.info("Loaded %d samples from %s/%s[%s]",
50
+ total, args.dataset_repo, args.dataset_subset, args.split)
51
+
52
+ # ----------------------------------------------------------------------
53
+ # Evaluation loop
54
+ # ----------------------------------------------------------------------
55
+ correct_em = 0
56
+
57
+ for ex in tqdm(ds, desc="Evaluating"):
58
+ question = ex["input"]
59
+ golds = ex["answers"]
60
+
61
+ resp = client.chat.completions.create(
62
+ model=args.model,
63
+ messages=[
64
+ {"role": "system", "content": "You are a QA assistant."},
65
+ {"role": "user",
66
+ "content": f"Question: {question}\n"
67
+ "Please first reply with *only* the final answer—no extra words.\n Answer:"}
68
+ ],
69
+ temperature=args.temperature,
70
+ max_tokens=args.max_tokens,
71
+ )
72
+ pred = resp.choices[0].message.content.strip()
73
+ print(f"A: {pred}\n G: {golds}")
74
+
75
+ if any(qa_em_score(pred, g) for g in golds):
76
+ correct_em += 1
77
+
78
+ time.sleep(args.sleep)
79
+
80
+ em_score = correct_em / total
81
+ logging.info("RESULT | model=%s | subset=%s | EM=%.4f",
82
+ args.model, args.dataset_subset, em_score)
83
+
84
+ print(f"\n=== SUMMARY ===\nModel : {args.model}"
85
+ f"\nDataset : {args.dataset_subset} ({args.split})"
86
+ f"\nEM : {em_score:.4f}\n"
87
+ f"(Appended to {args.log})")
detect/question_rephrase_answer_api.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import os
4
+ import argparse
5
+ from datasets import load_dataset
6
+ from openai import OpenAI
7
+ from tqdm import tqdm
8
+ from utils.metrics import qa_f1_score, qa_em_score # Import evaluation functions
9
+
10
+ # Configure OpenAI API
11
+ client = OpenAI(
12
+ api_key=os.environ.get("OPENAI_API_KEY"),
13
+ base_url=os.environ.get("OPENAI_BASE_URL")
14
+ )
15
+
16
+ def get_openai_response(prompt, model="gpt-4o", retries=3, delay=2):
17
+ """Call OpenAI API to get response with retry mechanism"""
18
+ for attempt in range(retries):
19
+ try:
20
+ completion = client.chat.completions.create(
21
+ model=model,
22
+ messages=[{'role': 'user', 'content': prompt}],
23
+ max_tokens=100
24
+ )
25
+ return completion.choices[0].message.content.strip()
26
+ except Exception as e:
27
+ print(f"Attempt {attempt + 1} failed: {e}")
28
+ if attempt < retries - 1:
29
+ print(f"Retrying in {delay} seconds...")
30
+ time.sleep(delay)
31
+ else:
32
+ print("Max retries reached. Skipping this request.")
33
+ return "Failed to get response"
34
+
35
+ def rephrase_question_api(question, model_name, rephrase_type="opposite"):
36
+ """Use OpenAI API to rephrase question (English prompt)"""
37
+ if rephrase_type == "opposite":
38
+ prompt = f"""Please rephrase the following question to have the exact opposite meaning.
39
+ Question: {question}
40
+
41
+ Return only the rephrased question with the opposite meaning, without any explanations or other content."""
42
+ elif rephrase_type == "similar":
43
+ prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording:
44
+ Question: {question}
45
+
46
+ Return only the rephrased question, without any explanations or other content."""
47
+ else:
48
+ raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.")
49
+
50
+ return get_openai_response(prompt, model=model_name)
51
+
52
+ def answer_question_with_context_api(question, context, model_name, max_tokens_for_answer=30):
53
+ """Use OpenAI API to answer question based on context (English prompt)"""
54
+ prompt = f"""Please answer the question based on the following context:
55
+
56
+ Context:
57
+ {context}
58
+
59
+ Question: {question}
60
+
61
+ Only output the answer, no any other text. If the answer is not in the context, please say "I don't know".
62
+
63
+ Answer:"""
64
+ try:
65
+ completion = client.chat.completions.create(
66
+ model=model_name,
67
+ messages=[{'role': 'user', 'content': prompt}],
68
+ max_tokens=max_tokens_for_answer
69
+ )
70
+ return completion.choices[0].message.content.strip()
71
+ except Exception as e:
72
+ print(f"Answer generation failed for model {model_name}: {e}")
73
+ return "Failed to get answer"
74
+
75
+ def main(args):
76
+ # Load dataset
77
+ print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...")
78
+ try:
79
+ dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"]
80
+ print(f"Successfully loaded dataset with {len(dataset)} samples.")
81
+ except Exception as e:
82
+ print(f"Failed to load dataset: {e}")
83
+ return
84
+
85
+ em_match_count = 0 # Counter for EM matches
86
+ successfully_processed_samples = 0 # Counter for successfully processed samples
87
+
88
+ num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset))
89
+
90
+ print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with {args.model_name} (max 30 tokens for answer)...")
91
+
92
+ for i in tqdm(range(num_samples_to_process), desc="Processing samples"):
93
+ example = dataset[i]
94
+ original_question = example['input']
95
+ context = example['context']
96
+ ground_truth_answers = example['answers']
97
+
98
+ print(f"Original question: {original_question}")
99
+
100
+ # Use API to rephrase question, fixed using gpt-4o
101
+ rephrased_question = rephrase_question_api(original_question, "gpt-4o", args.rephrase_type)
102
+ print(f"Rephrased question (opposite): {rephrased_question}")
103
+
104
+ if rephrased_question == "Failed to get response" or rephrased_question == "Failed to rephrase question": # Broader check
105
+ print(f"Skipping sample {i+1} due to rephrasing failure.")
106
+ continue
107
+
108
+ # Use rephrased question and context to get answer, using args.model_name, answer length limited to 30 tokens
109
+ rephrased_answer = answer_question_with_context_api(rephrased_question, context, args.model_name, max_tokens_for_answer=30)
110
+ # print(f"Answer to rephrased question: {rephrased_answer}")
111
+
112
+ if rephrased_answer == "Failed to get answer":
113
+ print(f"Skipping sample {i+1} due to answer generation failure.")
114
+ continue
115
+
116
+ if not ground_truth_answers:
117
+ print(f"Skipping sample {i+1} due to missing ground truth answers.")
118
+ continue
119
+
120
+ successfully_processed_samples += 1
121
+ sample_had_em_match = False
122
+ for gt_ans in ground_truth_answers:
123
+ em = qa_em_score(rephrased_answer, gt_ans)
124
+ if em > 0: # EM is 1.0 for a match
125
+ sample_had_em_match = True
126
+ break
127
+
128
+ if sample_had_em_match:
129
+ em_match_count += 1
130
+ # print(f"Sample EM with original GT: {1 if sample_had_em_match else 0}")
131
+
132
+ if successfully_processed_samples > 0:
133
+ print(f"\n--- Evaluation Summary ---")
134
+ print(f"Answering Model : {args.model_name}")
135
+ print(f"Dataset : {args.dataset_name} ({args.dataset_subset})")
136
+ print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}")
137
+ print(f"Max Answer Tokens: 30")
138
+ print(f"Count of EM with original ground truth (after rephrase): {em_match_count}")
139
+ else:
140
+ print("\nNo samples were processed adequately to provide an evaluation summary.")
141
+
142
+ print("Processing complete!")
143
+
144
+ if __name__ == "__main__":
145
+ parser = argparse.ArgumentParser(description="Rephrase questions to opposite meaning with GPT-4o, answer with specified OpenAI model, then count EM against original GT.")
146
+ parser.add_argument("--model_name", type=str, default="gpt-4o", help="Name of the OpenAI model to use for Answering.")
147
+ parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.")
148
+ parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.")
149
+ parser.add_argument("--sample_count", type=int, default=-1, help="Number of samples to process. -1 for all samples.")
150
+ parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.")
151
+
152
+ args = parser.parse_args()
153
+ main(args)
detect/question_rephrase_answer_qwen3.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import argparse
4
+ import torch
5
+ from datasets import load_dataset
6
+ from tqdm import tqdm
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from openai import OpenAI # Added for GPT-4o rephrasing
9
+ from utils.metrics import qa_f1_score, qa_em_score
10
+
11
+
12
+ THINK_END_ID = 151668 # </think> token ID for Qwen models (like Qwen1.5/Qwen2)
13
+
14
+ # --- OpenAI Client for Rephrasing ---
15
+ openai_client = OpenAI(
16
+ api_key=os.environ.get("OPENAI_API_KEY"),
17
+ base_url=os.environ.get("OPENAI_BASE_URL")
18
+ )
19
+
20
+ def get_openai_rephrase_response(prompt, model="gpt-4o", retries=3, delay=2):
21
+ """Call OpenAI API for rephrasing."""
22
+ for attempt in range(retries):
23
+ try:
24
+ completion = openai_client.chat.completions.create(
25
+ model=model,
26
+ messages=[{'role': 'user', 'content': prompt}],
27
+ max_tokens=100
28
+ )
29
+ return completion.choices[0].message.content.strip()
30
+ except Exception as e:
31
+ print(f"OpenAI Rephrase attempt {attempt + 1} failed: {e}")
32
+ if attempt < retries - 1:
33
+ print(f"Retrying OpenAI rephrase in {delay} seconds...")
34
+ time.sleep(delay)
35
+ else:
36
+ print("Max retries for OpenAI rephrase reached.")
37
+ return "Failed to rephrase question"
38
+
39
+ def rephrase_question_with_gpt4o(question, rephrase_type="opposite"):
40
+ if rephrase_type == "opposite":
41
+ prompt = f"""Please rephrase the following question to have the exact opposite meaning.
42
+ Question: {question}
43
+
44
+ Return only the rephrased question with the opposite meaning, without any explanations or other content."""
45
+ elif rephrase_type == "similar":
46
+ prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording:
47
+ Question: {question}
48
+
49
+ Return only the rephrased question, without any explanations or other content."""
50
+ else:
51
+ raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.")
52
+
53
+ return get_openai_rephrase_response(prompt)
54
+
55
+ # --- Qwen3-Specific Hugging Face Model Functions (for Answering) ---
56
+ def get_qwen3_hf_response(prompt_text, model, tokenizer, device, max_new_tokens=40, retries=2, delay=5):
57
+ """Generate a response from a Qwen3-like HF model. max_new_tokens default to 30."""
58
+ for attempt in range(retries):
59
+ try:
60
+ messages = [{"role": "user", 'content': prompt_text}]
61
+
62
+ chat_template_args = {
63
+ "tokenize": False,
64
+ "add_generation_prompt": True
65
+ }
66
+ # Qwen models (like Qwen1.5, Qwen2) often use/support enable_thinking
67
+ # Check if tokenizer's apply_chat_template supports 'enable_thinking'
68
+ # This check is simplified; for robust production, inspect.signature might be better
69
+ # but for Qwen-specific, we assume it or it gracefully ignores.
70
+ try:
71
+ # Attempt to use enable_thinking=False for Qwen models
72
+ processed_prompt = tokenizer.apply_chat_template(
73
+ messages, **chat_template_args, enable_thinking=False
74
+ )
75
+ except TypeError:
76
+ # Fallback if enable_thinking is not a valid kwarg for the specific tokenizer version
77
+ print("Warning: Tokenizer does not support 'enable_thinking' in apply_chat_template. Proceeding without it.")
78
+ processed_prompt = tokenizer.apply_chat_template(messages, **chat_template_args)
79
+ except Exception as e:
80
+ print(f"Warning: Error applying chat template: {e}. Using raw prompt.")
81
+ processed_prompt = prompt_text # Fallback to raw prompt
82
+
83
+ inputs = tokenizer(processed_prompt, return_tensors="pt", padding=True, truncation=True).to(device)
84
+
85
+ generated_ids_full = model.generate(
86
+ inputs.input_ids,
87
+ attention_mask=inputs.attention_mask,
88
+ max_new_tokens=max_new_tokens,
89
+ pad_token_id=tokenizer.eos_token_id
90
+ )
91
+ # Get only newly generated tokens
92
+ output_only_ids_list = generated_ids_full[0][inputs.input_ids.shape[1]:].tolist()
93
+
94
+ # Strip <think>...</think> tags specifically for Qwen
95
+ try:
96
+ # Find the last occurrence of THINK_END_ID and take tokens after it
97
+ cut_index = len(output_only_ids_list) - output_only_ids_list[::-1].index(THINK_END_ID)
98
+ final_ids_to_decode = output_only_ids_list[cut_index:]
99
+ except ValueError:
100
+ # THINK_END_ID not found, use all generated new tokens
101
+ final_ids_to_decode = output_only_ids_list
102
+
103
+ response = tokenizer.decode(final_ids_to_decode, skip_special_tokens=True).strip()
104
+ return response
105
+ except Exception as e:
106
+ print(f"Qwen HF Model generation attempt {attempt + 1} failed: {e}")
107
+ if attempt < retries - 1:
108
+ print(f"Retrying in {delay} seconds...")
109
+ time.sleep(delay)
110
+ else:
111
+ print("Max retries for Qwen HF model reached. Skipping this request.")
112
+ return "Failed to get Qwen HF response"
113
+
114
+ def answer_question_with_context_qwen3_hf(question, context, model, tokenizer, device):
115
+ """Answer a question with context using a Qwen3-like HF model."""
116
+ prompt = f"""Please answer the question based on the following context:
117
+
118
+ Context:
119
+ {context}
120
+
121
+ Question: {question}
122
+
123
+ Only output the answer, no any other text. If the answer is not in the context, please say "I don't know".
124
+
125
+ Answer:"""
126
+ return get_qwen3_hf_response(prompt, model, tokenizer, device)
127
+
128
+ def main(args):
129
+ hf_device_setting = "auto"
130
+ print(f"Attempting to use device: {hf_device_setting} for Qwen HF model.")
131
+
132
+ print(f"Loading Qwen HF model for Answering: {args.model_name}...")
133
+ hf_model = None
134
+ hf_tokenizer = None
135
+ try:
136
+ hf_tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code_hf)
137
+ hf_model = AutoModelForCausalLM.from_pretrained(
138
+ args.model_name,
139
+ device_map=hf_device_setting,
140
+ trust_remote_code=args.trust_remote_code_hf,
141
+ torch_dtype="bfloat16"
142
+ )
143
+ hf_model.eval()
144
+
145
+
146
+
147
+ print(f"Successfully loaded Qwen HF model {args.model_name}.")
148
+ except Exception as e:
149
+ print(f"Failed to load Qwen HF model {args.model_name}: {e}")
150
+ return
151
+
152
+ print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...")
153
+ try:
154
+ dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"]
155
+ print(f"Successfully loaded dataset with {len(dataset)} samples.")
156
+ except Exception as e:
157
+ print(f"Failed to load dataset: {e}")
158
+ return
159
+
160
+ em_match_count = 0 # Counter for EM matches
161
+ em_match_original_count = 0 # Counter for EM matches
162
+ successfully_processed_samples = 0 # Counter for successfully processed samples
163
+
164
+ num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset))
165
+
166
+ print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with Qwen HF model {args.model_name} (max 30 tokens)...")
167
+
168
+ for i in tqdm(range(num_samples_to_process), desc="Processing samples"):
169
+ example = dataset[i]
170
+ original_question = example['input']
171
+ context = example['context']
172
+ ground_truth_answers = example['answers']
173
+ print(original_question)
174
+
175
+ rephrased_question = rephrase_question_with_gpt4o(original_question, args.rephrase_type)
176
+ print(rephrased_question)
177
+
178
+ if rephrased_question == "Failed to rephrase question":
179
+ print(f"Skipping sample {i+1} due to rephrasing failure.")
180
+ continue
181
+
182
+ rephrased_answer = answer_question_with_context_qwen3_hf(rephrased_question, context, hf_model, hf_tokenizer, hf_model.device)
183
+ print(rephrased_answer)
184
+ original_answer = answer_question_with_context_qwen3_hf(original_question, context, hf_model, hf_tokenizer, hf_model.device)
185
+ if not ground_truth_answers:
186
+ print(f"Skipping sample {i+1} due to missing ground truth answers.")
187
+ continue
188
+ print(original_answer)
189
+ successfully_processed_samples += 1
190
+
191
+ sample_had_em_match = False
192
+ for gt_ans in ground_truth_answers:
193
+ em = qa_em_score(rephrased_answer, gt_ans)
194
+ if em > 0: # Check for exact match (assuming qa_em_score returns 1.0 for EM)
195
+ sample_had_em_match = True
196
+ break
197
+
198
+ if sample_had_em_match:
199
+ em_match_count += 1
200
+
201
+ sample_had_em_match = False
202
+ for gt_ans in ground_truth_answers:
203
+ em = qa_em_score(original_answer, gt_ans)
204
+ if em > 0: # Check for exact match (assuming qa_em_score returns 1.0 for EM)
205
+ sample_had_em_match = True
206
+ break
207
+ if sample_had_em_match:
208
+ em_match_original_count += 1
209
+
210
+ if successfully_processed_samples > 0:
211
+ print(f"\n--- Evaluation Summary ---")
212
+ print(f"Answering Qwen HF Model: {args.model_name}")
213
+ print(f"Dataset: {args.dataset_name} ({args.dataset_subset})")
214
+ print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}")
215
+ print(f"Count of EM with original ground truth (after rephrase): {em_match_count}")
216
+ print(f"Count of EM with original ground truth (before rephrase): {em_match_original_count}")
217
+ else:
218
+ print("\nNo samples were processed adequately to provide an evaluation summary.")
219
+
220
+ print("Processing complete!")
221
+
222
+ if __name__ == "__main__":
223
+ parser = argparse.ArgumentParser(description="Rephrase with GPT-4o, Answer with local Qwen3-like HF Model, then Evaluate.")
224
+ parser.add_argument("--model_name", type=str, default="Qwen/Qwen1.5-7B-Chat", help="Name of the Qwen3-like Hugging Face model for Answering.")
225
+ parser.add_argument("--trust_remote_code_hf", action="store_true", default=True, help="Set to true if the Hugging Face model requires remote code (default: True for Qwen). Argument is present for explicitness but defaults to True.")
226
+ parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.")
227
+ parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.")
228
+ parser.add_argument("--sample_count", type=int, default=5, help="Number of samples to process. -1 for all. Default: 5.")
229
+ parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.")
230
+
231
+ args = parser.parse_args()
232
+ if openai_client.api_key == "your_api_key_here":
233
+ print("CRITICAL ERROR: Please replace 'your_api_key_here' with your actual OpenAI API key in the script.")
234
+ else:
235
+ main(args)
detect/question_rephrase_answer_vllm.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import os
3
+ import argparse
4
+ # import torch # torch might not be directly needed if vLLM handles all device aspects
5
+ from datasets import load_dataset
6
+ from tqdm import tqdm
7
+ from openai import OpenAI # For GPT-4o rephrasing
8
+ from vllm import LLM, SamplingParams # For vLLM inference
9
+ from transformers import AutoTokenizer # Import AutoTokenizer
10
+ from utils.metrics import qa_f1_score, qa_em_score
11
+
12
+ # This will be respected by vLLM if CUDA_VISIBLE_DEVICES is set before vLLM import
13
+ # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2" # User can set this outside the script
14
+
15
+ # --- OpenAI Client for Rephrasing ---
16
+ openai_client = OpenAI(
17
+ api_key=os.environ.get("OPENAI_API_KEY"),
18
+ base_url=os.environ.get("OPENAI_BASE_URL")
19
+ )
20
+
21
+ def get_openai_rephrase_response(prompt, model="gpt-4o", retries=3, delay=2):
22
+ """Call OpenAI API for rephrasing."""
23
+ for attempt in range(retries):
24
+ try:
25
+ completion = openai_client.chat.completions.create(
26
+ model=model,
27
+ messages=[{'role': 'user', 'content': prompt}],
28
+ max_tokens=100 # Max tokens for rephrased question
29
+ )
30
+ return completion.choices[0].message.content.strip()
31
+ except Exception as e:
32
+ print(f"OpenAI Rephrase attempt {attempt + 1} failed: {e}")
33
+ if attempt < retries - 1:
34
+ print(f"Retrying OpenAI rephrase in {delay} seconds...")
35
+ time.sleep(delay)
36
+ else:
37
+ print("Max retries for OpenAI rephrase reached.")
38
+ return "Failed to rephrase question"
39
+
40
+ def rephrase_question_with_gpt4o(question, rephrase_type="opposite"):
41
+ """Rephrase a question using GPT-4o (English prompt)."""
42
+ if rephrase_type == "opposite":
43
+ prompt = f"""Please rephrase the following question to have the exact opposite meaning.
44
+ Question: {question}
45
+
46
+ Return only the rephrased question with the opposite meaning, without any explanations or other content."""
47
+ elif rephrase_type == "similar":
48
+ prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording:
49
+ Question: {question}
50
+
51
+ Return only the rephrased question, without any explanations or other content."""
52
+ else:
53
+ raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.")
54
+
55
+ return get_openai_rephrase_response(prompt)
56
+
57
+ # --- vLLM Model Functions (for Answering) ---
58
+ def get_vllm_response(prompt_text, llm_instance, sampling_params_instance, retries=2, delay=5):
59
+ """Generate a response from a vLLM instance."""
60
+ for attempt in range(retries):
61
+ try:
62
+ # vLLM generate method expects a list of prompts
63
+ outputs = llm_instance.generate([prompt_text], sampling_params_instance)
64
+ # For a single prompt, the result is in the first element of the output list
65
+ # Each output object has a list of `outputs` (for n>1 in SamplingParams)
66
+ response = outputs[0].outputs[0].text.strip()
67
+ return response
68
+ except Exception as e:
69
+ print(f"vLLM generation attempt {attempt + 1} failed: {e}")
70
+ if attempt < retries - 1:
71
+ print(f"Retrying vLLM generation in {delay} seconds...")
72
+ time.sleep(delay)
73
+ else:
74
+ print("Max retries for vLLM generation reached.")
75
+ return "Failed to get vLLM response"
76
+
77
+ def answer_question_with_context_vllm(question, context, llm_instance, sampling_params_instance, tokenizer):
78
+ """Answer a question with context using a vLLM model and chat template (English prompt)."""
79
+ # Construct prompt using chat template, similar to evaluation.py
80
+ prompt_content = (
81
+ f"Answer the question based on the given passages. "
82
+ "Only give me your answer and do not output any other words.\\n"
83
+ "The following are given passages:\\n"
84
+ f"{context}\\n"
85
+ "Please strictly follow the context. "
86
+ f"Question: {question}\\n"
87
+ "Answer:"
88
+ )
89
+ messages = [{"role": "user", "content": prompt_content}]
90
+
91
+ # Apply chat template
92
+ # Note: Some tokenizers might not have a chat template configured, or might have different ways to apply it.
93
+ # This is a common way for many models.
94
+ try:
95
+ final_prompt_text = tokenizer.apply_chat_template(
96
+ messages,
97
+ tokenize=False,
98
+ add_generation_prompt=True
99
+ )
100
+ except Exception as e:
101
+ print(f"Failed to apply chat template: {e}. Falling back to basic prompt string.")
102
+ # Fallback to a simpler prompt if template application fails
103
+ final_prompt_text = f"Context:\\n{context}\\n\\nQuestion: {question}\\n\\nAnswer:"
104
+
105
+ return get_vllm_response(final_prompt_text, llm_instance, sampling_params_instance)
106
+
107
+ def main(args):
108
+ # Load Tokenizer for the vLLM model
109
+ print(f"Loading tokenizer for model: {args.model_name}...")
110
+ try:
111
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=args.trust_remote_code)
112
+ print("Successfully loaded tokenizer.")
113
+ except Exception as e:
114
+ print(f"Failed to load tokenizer for {args.model_name}: {e}")
115
+ print("Please ensure the model name is correct and the tokenizer can be loaded.")
116
+ return
117
+
118
+ # Load vLLM Model (for Answering)
119
+ print(f"Loading vLLM model for Answering: {args.model_name}...")
120
+ print(f"(This may take a while depending on the model size and download speed if not cached).")
121
+ vllm_model = None
122
+ try:
123
+ # You can expose more vLLM LLM parameters as args if needed
124
+ # (e.g., tensor_parallel_size, dtype, gpu_memory_utilization)
125
+ vllm_model = LLM(
126
+ model=args.model_name,
127
+ trust_remote_code=args.trust_remote_code,
128
+ dtype="bfloat16", # Use dtype from command line arguments
129
+ # Add other vLLM LLM constructor arguments here if needed, e.g.:
130
+ tensor_parallel_size=2
131
+ )
132
+ print(f"Successfully loaded vLLM model {args.model_name} with dtype='{args.dtype}' and tensor_parallel_size={args.tensor_parallel_size}.")
133
+ except Exception as e:
134
+ print(f"Failed to load vLLM model {args.model_name}: {e}")
135
+ print("Please ensure vLLM is installed correctly and the model identifier is valid.")
136
+ return
137
+
138
+ # Define Sampling Parameters for vLLM
139
+ # max_tokens is equivalent to max_new_tokens in HF
140
+ # temperature=0.0 for greedy decoding, good for QA tasks for more deterministic output.
141
+ # Adjust temperature (e.g., 0.7) and top_p (e.g., 0.95) for more diverse outputs if needed.
142
+ sampling_params = SamplingParams(temperature=0.0, max_tokens=30) # Set temperature to 0.0 for deterministic QA
143
+
144
+ # Load dataset
145
+ print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...")
146
+ try:
147
+ dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"]
148
+ print(f"Successfully loaded dataset with {len(dataset)} samples.")
149
+ except Exception as e:
150
+ print(f"Failed to load dataset: {e}")
151
+ return
152
+
153
+ em_match_count = 0 # Counter for EM matches
154
+ em_match_original_count = 0 # Counter for EM matches
155
+ successfully_processed_samples = 0 # Counter for successfully processed samples
156
+
157
+ num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset))
158
+
159
+ print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with vLLM model {args.model_name} (max 30 tokens)...")
160
+
161
+ for i in tqdm(range(num_samples_to_process), desc="Processing samples with vLLM"):
162
+ example = dataset[i]
163
+ original_question = example['input']
164
+ context = example['context']
165
+ ground_truth_answers = example['answers']
166
+
167
+ rephrased_question = rephrase_question_with_gpt4o(original_question, args.rephrase_type) # Use new rephrasing
168
+
169
+ if rephrased_question == "Failed to rephrase question":
170
+ print(f"Skipping sample {i+1} due to rephrasing failure.")
171
+ continue
172
+
173
+ rephrased_answer = answer_question_with_context_vllm(rephrased_question, context, vllm_model, sampling_params, tokenizer)
174
+ # print(f"Rephrased question: {rephrased_question}") # Optional: for debugging
175
+ # print(f"Answer to rephrased: {rephrased_answer}") # Optional: for debugging
176
+
177
+ original_answer = answer_question_with_context_vllm(original_question, context, vllm_model, sampling_params, tokenizer)
178
+ # print(f"Original question: {original_question}") # Optional: for debugging
179
+ # print(f"Answer to original: {original_answer}") # Optional: for debugging
180
+
181
+ if not ground_truth_answers:
182
+ print(f"Skipping sample {i+1} due to missing ground truth answers.")
183
+ continue
184
+ print(original_answer)
185
+ successfully_processed_samples += 1
186
+
187
+ sample_had_em_match = False
188
+
189
+
190
+
191
+
192
+ em_match_count += qa_em_score(rephrased_answer, ground_truth_answers[0])
193
+
194
+ sample_had_em_match = False
195
+
196
+
197
+ print(original_answer)
198
+ print(ground_truth_answers[0])
199
+
200
+ em_match_original_count += qa_em_score(original_answer, ground_truth_answers[0])
201
+
202
+ if successfully_processed_samples > 0:
203
+ print(f"Answering vLLM Model: {args.model_name}")
204
+ print(f"Dataset : {args.dataset_name} ({args.dataset_subset})")
205
+ print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}")
206
+ print(f"Max Answer Tokens : 30") # Reflects SamplingParams
207
+ print(f"Count of EM with original ground truth (after rephrase): {em_match_count}")
208
+ print(f"Count of EM with original ground truth (before rephrase): {em_match_original_count}")
209
+ else:
210
+ print("\nNo samples were processed adequately to provide an evaluation summary.")
211
+
212
+ print("vLLM processing complete!")
213
+
214
+ if __name__ == "__main__":
215
+ parser = argparse.ArgumentParser(description="Rephrase with GPT-4o, Answer with local vLLM-hosted Model, then Evaluate.")
216
+ parser.add_argument("--model_name", type=str, default="facebook/opt-125m", help="Name/path of the Hugging Face model for Answering via vLLM (e.g., 'mistralai/Mistral-7B-Instruct-v0.1').")
217
+ parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.")
218
+ parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.")
219
+ parser.add_argument("--sample_count", type=int, default=3, help="Number of samples to process. -1 for all. Default: 3 for quick testing.")
220
+ parser.add_argument("--trust_remote_code", action="store_true", help="Set to true if the Hugging Face model for vLLM requires remote code.")
221
+ parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Tensor parallel size for vLLM.")
222
+ parser.add_argument("--dtype", type=str, default="auto", help="Data type for the model. Examples: 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'. Default is 'auto'.")
223
+ parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.")
224
+
225
+ args = parser.parse_args()
226
+ main(args)
eval/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # from .evaluation import *
2
+
eval/eval_with_api.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ #!/usr/bin/env python3
5
+ # -*- coding: utf-8 -*-
6
+ """
7
+ Single-phase evaluator (DeepSeek API) — Calculate EM / F1 only.
8
+
9
+ Usage Example
10
+ --------
11
+ python eval_single_phase.py --input data/2wikimqa.jsonl
12
+ """
13
+
14
+ import argparse, time, jsonlines, os
15
+ from pathlib import Path
16
+ from tqdm import tqdm
17
+ from openai import OpenAI
18
+ from utils.metrics import qa_em_score, qa_f1_score
19
+ from utils.llmjudge import judge_answer_with_api
20
+
21
+ # -------------------- CLI --------------------
22
+ p = argparse.ArgumentParser("Single-phase evaluator")
23
+ p.add_argument("--input", required=True, help="Path to the *.jsonl file to evaluate")
24
+ p.add_argument("--model", default="deepseek-r1")
25
+ p.add_argument("--temperature", type=float, default=0.5)
26
+ p.add_argument("--max_tokens", type=int, default=30)
27
+ p.add_argument("--sleep", type=float, default=0.0)
28
+ args = p.parse_args()
29
+
30
+ client = OpenAI(
31
+ base_url=os.environ.get("OPENAI_BASE_URL"),
32
+ api_key=os.environ.get("OPENAI_API_KEY")
33
+ )
34
+
35
+ # -------------------- helper --------------------
36
+ def ask(context: str, question: str) -> str:
37
+ """Call DeepSeek to get answer (return final answer only)"""
38
+ messages = [
39
+ {"role": "system",
40
+ "content": ("You are a QA assistant. "
41
+ "Answer strictly based on the passages; "
42
+ "output only the final answer.")},
43
+ {"role": "user",
44
+ "content": f"Answer the question and output only the final answer without extra words. Passages:\n{context}\n\nQuestion: {question}\nAnswer:"}
45
+ ]
46
+ resp = client.chat.completions.create(
47
+ model=args.model,
48
+ messages=messages,
49
+ temperature=args.temperature,
50
+ max_tokens=args.max_tokens
51
+ )
52
+ if not resp.choices[0].message.content:
53
+ return "None"
54
+
55
+ return resp.choices[0].message.content.strip()
56
+
57
+
58
+ # -------------------- core eval --------------------
59
+ def evaluate_file(path: Path):
60
+ dataset = path.stem
61
+ data = {obj["input"]: obj for obj in jsonlines.open(path)}
62
+
63
+ total = len(data)
64
+ em_hits = 0
65
+ f1_sum = 0.0
66
+
67
+ for q, item in tqdm(data.items(), desc=f"{dataset}"):
68
+ ctx = item["context"]
69
+ golds = item["answers"] if isinstance(item["answers"], list) else [item["answers"]]
70
+
71
+ pred = ask(ctx, q).split('.', 1)[0] # Cut off extra explanations
72
+ if pred == "None":
73
+ continue
74
+ em = max(qa_em_score(pred, g) for g in golds)
75
+ f1 = max(qa_f1_score(pred, g) for g in golds)
76
+
77
+ em_hits += em
78
+ f1_sum += f1
79
+ if args.sleep:
80
+ time.sleep(args.sleep)
81
+
82
+ print(f"\n=== {dataset.upper()} SUMMARY ===")
83
+ print(f"Total samples : {total}")
84
+ print(f"Exact Match : {em_hits}/{total} ({em_hits/total:.2%})")
85
+ print(f"Average F1 : {f1_sum/total:.4f}")
86
+ print("-" * 40 + "\n")
87
+
88
+
89
+ # -------------------- run --------------------
90
+ input_path = Path(args.input)
91
+ if not input_path.exists():
92
+ raise SystemExit(f"File does not exist: {input_path}")
93
+
94
+ evaluate_file(input_path)
eval/evaluation.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse, os, jsonlines, torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
4
+ from utils.metrics import qa_f1_score, qa_em_score
5
+ THINK_END_ID = 151668 # "</think>" token id for Qwen3
6
+
7
+ # --------------------------------------------------
8
+ def strip_think(token_ids):
9
+ try:
10
+ cut = len(token_ids) - token_ids[::-1].index(THINK_END_ID)
11
+ return token_ids[cut:]
12
+ except ValueError:
13
+ return token_ids
14
+
15
+ def main():
16
+ # ---------- CLI ----------
17
+ parser = argparse.ArgumentParser(
18
+ description="Evaluate HotpotQA JSONL with Transformers + Qwen3-8B"
19
+ )
20
+ parser.add_argument("-i", "--input", required=True,
21
+ help="Path to input JSONL file")
22
+ parser.add_argument("--model", required=True,
23
+ help="HF model name, e.g. Qwen/Qwen3-8B")
24
+ parser.add_argument("-d", "--devices", default="0",
25
+ help="CUDA_VISIBLE_DEVICES (comma-separated)")
26
+ parser.add_argument("-t", "--temperature", type=float, default=0.5,
27
+ help="Sampling temperature")
28
+ parser.add_argument("-k", "--max_tokens", type=int, default=40,
29
+ help="max_new_tokens")
30
+ args = parser.parse_args()
31
+
32
+
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ args.model,
37
+ torch_dtype="auto",
38
+ device_map="auto",
39
+ trust_remote_code=True
40
+ )
41
+ gen_cfg = GenerationConfig(
42
+ temperature=args.temperature,
43
+ max_new_tokens=args.max_tokens,
44
+ do_sample=args.temperature > 0
45
+ )
46
+
47
+
48
+ with jsonlines.open(args.input) as reader:
49
+ data = list(reader)
50
+
51
+ total_f1 = total_em = 0.0
52
+
53
+ for idx, item in enumerate(data):
54
+ question = item.get("input", "")
55
+ context = item.get("context", "")
56
+ answers = item.get("answers", [])
57
+ if not answers:
58
+ print(f"[{idx}] no gold answer, skip")
59
+ continue
60
+ gold = answers[0]
61
+ print(gold)
62
+
63
+ # ----- Prompt -----
64
+ prompt = (
65
+ "Answer the question based on the given passages. "
66
+ "Only give me your answer and do not output any other words.\n"
67
+ "Passages:\n"
68
+ f"{context}\n"
69
+ f"Question: {question}\n"
70
+ "Answer:"
71
+ )
72
+ messages = [{"role": "user", "content": prompt}]
73
+ chat_text = tokenizer.apply_chat_template(
74
+ messages,
75
+ tokenize=False,
76
+ add_generation_prompt=True,
77
+ enable_thinking=False
78
+ )
79
+ inputs = tokenizer([chat_text], return_tensors="pt").to(model.device)
80
+
81
+
82
+ # ----- Generate -----
83
+ try:
84
+ with torch.no_grad():
85
+ outputs = model.generate(**inputs, max_new_tokens=args.max_tokens)
86
+ except ValueError as e:
87
+ if "position ids exceed" in str(e).lower() or "sequence length" in str(e).lower():
88
+ print(f"[{idx}] prompt too long – skipped")
89
+ continue
90
+ raise
91
+ print("im here")
92
+ new_ids = outputs[0][len(inputs.input_ids[0]):].tolist()
93
+ try:
94
+ index = len(new_ids) - new_ids[::-1].index(151668)
95
+ except ValueError:
96
+ index = 0
97
+ answer = tokenizer.decode(new_ids[index:], skip_special_tokens=True).strip("\n")
98
+ answer = answer.strip()
99
+
100
+ # ----- Score -----
101
+ f1 = qa_f1_score(answer, gold)
102
+ em = qa_em_score(answer, gold)
103
+ total_f1 += f1
104
+ total_em += em
105
+
106
+ print(f"[{idx}] Q: {question}")
107
+ print(f" Resp: {answer!r} | Gold: {gold!r}")
108
+ print(f" F1={f1:.2f}, EM={em:.2f}")
109
+
110
+ n = len(data)
111
+ print(f"\nOverall F1: {total_f1/n:.4f}")
112
+ print(f"Overall EM: {total_em/n:.4f}")
113
+
114
+ if __name__ == "__main__":
115
+ main()
main_gpu.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import re
4
+ import os
5
+ import argparse
6
+ from datasets import load_dataset
7
+ from haystack import Pipeline, Document
8
+ from haystack.utils import Secret
9
+ from haystack.document_stores.in_memory import InMemoryDocumentStore
10
+ from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
11
+ from haystack import Document # haystack Document for storing context
12
+ from nltk.tokenize import sent_tokenize
13
+ from utils.util import retriveDoc,compute_best_sentence_f1
14
+ from openai import OpenAI
15
+ import asyncio, json, torch, math
16
+ from typing import List, Tuple
17
+ # Hugging Face transformers related
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
19
+ from utils.metrics import qa_f1_score
20
+ from utils.llmjudge import judge_answer_with_api
21
+
22
+ client = OpenAI(
23
+ base_url=os.environ.get("OPENAI_BASE_URL"),
24
+ api_key=os.environ.get("OPENAI_API_KEY")
25
+ )
26
+ # Load models using transformers
27
+
28
+ tokenizer1 = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
29
+ model1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True,device_map="cuda:0",torch_dtype=torch.bfloat16)
30
+
31
+
32
+ tok_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True)
33
+ model_qwen = AutoModelForCausalLM.from_pretrained(
34
+ "Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True,
35
+ device_map="cuda:1",torch_dtype=torch.bfloat16
36
+ ).eval()
37
+
38
+ def get_transformers_answer(prompt, tokenizer, model, max_new_tokens=100, temperature=0.7, top_p=0.9, retries=3, delay=5):
39
+ """
40
+ Use transformers model.generate method for inference with retry mechanism,
41
+ strip the input prompt part through token-level slicing,
42
+ and return the newly generated text.
43
+ """
44
+ import time
45
+ for attempt in range(retries):
46
+ try:
47
+ # Encode prompt as model input tensor
48
+ model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
49
+ # Call generate, the generated id sequence contains both prompt and subsequent generated text
50
+ generated_ids = model.generate(
51
+ **model_inputs,
52
+ max_new_tokens=max_new_tokens,
53
+ temperature=temperature,
54
+ top_p=top_p
55
+ )
56
+ # Calculate the token count corresponding to the prompt
57
+ input_length = model_inputs.input_ids.shape[1]
58
+ # Strip the prompt part from the front of the output, keeping only the newly added part
59
+ output_ids = generated_ids[0][input_length:]
60
+ # Decode generated text
61
+ answer = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
62
+ return answer
63
+ except Exception as e:
64
+ print(f"Error on attempt {attempt + 1}: {e}")
65
+ if attempt < retries - 1:
66
+ print(f"Retrying in {delay} seconds...")
67
+ time.sleep(delay)
68
+ else:
69
+ print("Max retries reached, skipping this request.")
70
+ return None
71
+
72
+ def truncate_answer(answer):
73
+ """Truncate answer, only take the part before the first period"""
74
+ return answer.split('.')[0].strip() if answer else "No answer"
75
+
76
+ def write_to_log(filename, data):
77
+ """Write data to log file"""
78
+ with open(filename, 'a', encoding='utf-8') as file:
79
+ file.write(data + '\n')
80
+
81
+ def remove_think_tags(text: str) -> str:
82
+ """Remove all <think> ... </think> blocks"""
83
+ return re.sub(r'<think>(.*?)</think>', '', text, flags=re.DOTALL).strip()
84
+
85
+ def build_prompt(context: str, question: str) -> str:
86
+ prompt = (
87
+ f"Answer the question based on the given passages. The following are the passages:\n"
88
+ f"{context}\n"
89
+ f"Answer the question based on the given passages.\n"
90
+ f"Question: {question}.\n"
91
+ f"Please first provide your answer in the format of Answer:[Your answer]. Then provide your reasoning process step-by-step.(Only include explicit clues) "
92
+ f"At the end of each reasoning step, include a new line that specifies the key information or reference content used in that step. "
93
+ f"Please ensure that the [reference content] you include is the complete original sentence or consecutive sentences from the text. Please do not change the punctuation. Do not use ellipses inside the sentence. "
94
+ f"Follow this format:\n"
95
+ f"Answer: [Your answer]\n"
96
+ f"Step-by-step Reasoning:\n"
97
+ f"1. [Reasoning step 1]\n"
98
+ f"[replaced by your reference content]\n"
99
+ f"2. [Reasoning step 2]\n"
100
+ f"[replaced by your reference content]\n"
101
+ )
102
+ return prompt
103
+
104
+ def extract_final_bullet_passage(answer_text: str):
105
+ reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
106
+ reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
107
+ if not reasoning_match:
108
+ return None, None
109
+
110
+ reasoning_text = reasoning_match.group(1).strip()
111
+ bullet_pattern = r"(?m)^(\d+\.\s.*?)(?=(?:\n\d+\.\s)|\Z)"
112
+ bullets = re.findall(bullet_pattern, reasoning_text, flags=re.DOTALL)
113
+ if not bullets:
114
+ print("No bullet blocks found.")
115
+ return None, None
116
+
117
+ passage_pattern = re.compile(
118
+ r'(?i)(?:\*\*)?passage\s+(\d+)(?:\*\*)?\s*:\s*("([^"]*)"|(.+?))(?=\Z|\n\s*\n|$)',
119
+ flags=re.DOTALL
120
+ )
121
+
122
+ for bullet in reversed(bullets):
123
+ matches = passage_pattern.findall(bullet)
124
+ if matches:
125
+ last_match = matches[-1]
126
+ passage_number = last_match[0]
127
+ quoted_snippet = last_match[2]
128
+ non_quoted_snippet = last_match[3]
129
+ snippet = non_quoted_snippet.strip() if non_quoted_snippet.strip() else quoted_snippet.strip()
130
+ return passage_number, snippet
131
+
132
+ return None, None
133
+
134
+ def extract_all_bullet_passages(answer_text: str):
135
+ reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
136
+ reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
137
+ if not reasoning_match:
138
+ return []
139
+
140
+ reasoning_text = reasoning_match.group(1).strip()
141
+ bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
142
+ bullets = bullet_pattern.findall(reasoning_text)
143
+ if not bullets:
144
+ return []
145
+
146
+ results = []
147
+ for bullet_index, bullet_text in enumerate(bullets, start=1):
148
+ results.append({
149
+ 'bullet_index': bullet_index,
150
+ 'snippet': bullet_text.strip()
151
+ })
152
+ print(results)
153
+ return results
154
+
155
+ def extract_evidence(answer_text: str):
156
+ reasoning_pattern = r"(?i)Evidence\s*(.*)"
157
+ reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
158
+ if not reasoning_match:
159
+ return []
160
+
161
+ reasoning_text = reasoning_match.group(1).strip()
162
+
163
+ # Extract all bullet segments
164
+ bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
165
+ bullets = bullet_pattern.findall(reasoning_text)
166
+ if not bullets:
167
+ return []
168
+
169
+ # Find the index of the first bullet starting with 1.
170
+ start_index = -1
171
+ for i, bullet in enumerate(bullets):
172
+ if bullet.strip().startswith("1."):
173
+ start_index = i
174
+ break
175
+
176
+ if start_index == -1:
177
+ return [] # No valid starting bullet
178
+
179
+ # Only keep the part starting from the first valid bullet
180
+ bullets = bullets[start_index:]
181
+
182
+ results = []
183
+ for bullet_index, bullet_text in enumerate(bullets, start=1):
184
+ results.append({
185
+ 'bullet_index': bullet_index,
186
+ 'snippet': bullet_text.strip()
187
+ })
188
+ return results
189
+
190
+
191
+ def get_answer_with_retry(model, prompt, retries=3, delay=5):
192
+ """Call the model to get the answer based on the prompt, with retry on failure."""
193
+ for attempt in range(retries):
194
+ try:
195
+ completion = client.chat.completions.create(
196
+ model=model,
197
+ messages=[{'role': 'user', 'content': prompt}]
198
+ )
199
+ return completion.choices[0].message.content.strip()
200
+ except Exception as e:
201
+ print(f"Error on attempt {attempt + 1}: {e}")
202
+ if attempt < retries - 1:
203
+ print(f"Retrying in {delay} seconds...")
204
+ time.sleep(delay)
205
+ else:
206
+ print("Max retries reached, skipping this request.")
207
+ return None
208
+
209
+ @torch.no_grad()
210
+ def qwen_answer_and_ppl(question: str, context: str) -> Tuple[str,float]:
211
+ prompt = f"{context}\n\nQuestion: {question}\nAnswer:"
212
+ inputs = tok_qwen(prompt, return_tensors="pt").to(model_qwen.device)
213
+ gen_ids = model_qwen.generate(**inputs, max_new_tokens=30, eos_token_id=tok_qwen.eos_token_id)
214
+ ans_ids = gen_ids[0][inputs.input_ids.shape[1]:]
215
+ answer = tok_qwen.decode(ans_ids, skip_special_tokens=True).strip()
216
+ # Calculate PPL
217
+ full_ids = torch.cat([inputs.input_ids[0], ans_ids])
218
+ logits = model_qwen(full_ids.unsqueeze(0)).logits[0,:-1]
219
+ tgt = full_ids[1:]
220
+ logp = torch.log_softmax(logits, dim=-1)
221
+ sel = logp[range(len(tgt)), tgt]
222
+ ppl = math.exp(-sel.mean().item())
223
+ return answer, ppl
224
+
225
+
226
+
227
+ def extract_json_from_gpt_response(text: str) -> dict | None:
228
+ """
229
+ Finds the first JSON block inside ```json ... ``` or ``` … ``` and returns it as a dict.
230
+ """
231
+ # Try to find a ```json … ``` block first
232
+ m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
233
+ if not m:
234
+ # Fallback: any ``` … ``` block that looks like JSON
235
+ m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
236
+ if not m:
237
+ # Lastly, maybe the model just spit raw JSON without fences
238
+ m = re.search(r"(\{.*?\})", text, flags=re.DOTALL)
239
+ if not m:
240
+ return None
241
+
242
+ json_str = m.group(1)
243
+ try:
244
+ return json.loads(json_str)
245
+ except json.JSONDecodeError:
246
+ # clean up trailing commas, etc.
247
+ cleaned = re.sub(r",\s*([\]}])", r"\1", json_str)
248
+ try:
249
+ return json.loads(cleaned)
250
+ except json.JSONDecodeError:
251
+ return None
252
+ async def multi_proposal_pipeline(
253
+ question: str,
254
+ original_context: str, # directly pass in example['context']
255
+ unique_sents: List[str],
256
+ correct_answer: str,
257
+ rounds: int = 3
258
+ ) -> dict:
259
+ best = {"ppl": float("0"), "context": None, "answer": None}
260
+
261
+ for i in range(rounds):
262
+ # Construct GPT-4o prompt
263
+ numbered = "\n\n".join(f"{j+1}. {s}" for j, s in enumerate(unique_sents))
264
+ prompt = (
265
+ "You are a creative contrarian. Given the question below, and the original answer, first propose a concise alternative answer—that is, a plausible but intentionally misleading answer. "
266
+ "Followed are some sentences supporting the original answer, please rewrite them. When rewriting each sentence, modify only the parts necessary to support the antifact answer. Parts unrelated to the answer must keep their original meaning. Be sure that the modified evidence sentences are sufficient to answer the original question. Output must be strictly in the specified JSON format, with no additional text.\n"
267
+ '{\n'
268
+ ' "answer": "<your antifact answer here, just provide the answer phrase, no need for complete sentence>",\n'
269
+ ' "revised": [\n'
270
+ ' "<rewritten sentence 1>",\n'
271
+ ' "<rewritten sentence 2>",\n'
272
+ ' ...\n'
273
+ ' ]\n'
274
+ '}\n\n'
275
+ f"Question:\n{question}\n\n"
276
+ f"Original answer:\n{correct_answer}\n\n"
277
+ f"Sentences to rewrite:\n{numbered}"
278
+ )
279
+ print(f"[Proposal {i+1}] Prompt: {prompt}")
280
+ rsp = client.chat.completions.create(
281
+ model="gpt-4o", temperature=0.7,
282
+ messages=[{"role":"user","content":prompt}]
283
+ )
284
+ js = extract_json_from_gpt_response(rsp.choices[0].message.content)
285
+ if not js:
286
+ print("[Proposal {i+1}] Failed to parse JSON")
287
+ continue
288
+ revised = js["revised"] # List[str]
289
+ proposed = js["answer"] # Answer given by GPT-4o (optional record)
290
+ new_ctx = original_context
291
+ for old, new in zip(unique_sents, revised):
292
+ new_ctx = new_ctx.replace(old, new)
293
+
294
+ # Use Qwen to calculate answer & PPL
295
+ ans_i, ppl_i = qwen_answer_and_ppl(question, new_ctx)
296
+ print(f"[Proposal {i+1}] PPL = {ppl_i:.2f}")
297
+
298
+ if ppl_i > best["ppl"]:
299
+ best.update({"ppl": ppl_i, "context": new_ctx, "answer": proposed})
300
+
301
+ return best
302
+ def main():
303
+ # Parse command line arguments
304
+ parser = argparse.ArgumentParser(description="LastingBench main pipeline for context rewriting")
305
+ parser.add_argument("--output", "-o", type=str, default="output.jsonl",
306
+ help="Output JSONL file path (default: output.jsonl)")
307
+ parser.add_argument("--dataset_repo", type=str, default="THUDM/LongBench",
308
+ help="Dataset repository name (default: THUDM/LongBench)")
309
+ parser.add_argument("--dataset_subset", type=str, default="multifieldqa_en",
310
+ help="Dataset subset name (default: multifieldqa_en)")
311
+ parser.add_argument("--split", type=str, default="test",
312
+ help="Dataset split (default: test)")
313
+ parser.add_argument("--start_idx", type=int, default=0,
314
+ help="Starting index for processing (default: 0)")
315
+ parser.add_argument("--max_samples", type=int, default=-1,
316
+ help="Maximum number of samples to process (-1 for all, default: -1)")
317
+
318
+ args = parser.parse_args()
319
+
320
+ out_file = args.output
321
+ # Load dataset
322
+ longbench = load_dataset(args.dataset_repo, args.dataset_subset)[args.split]
323
+
324
+ print(f"Output file: {out_file}")
325
+ print(f"Dataset: {args.dataset_repo}/{args.dataset_subset}[{args.split}]")
326
+ print(f"Total samples: {len(longbench)}")
327
+ f1_score_total = 0
328
+ llm_judge_score_total = 0
329
+ count = 0
330
+
331
+ # Determine processing range
332
+ start_idx = args.start_idx
333
+ end_idx = len(longbench) if args.max_samples == -1 else min(start_idx + args.max_samples, len(longbench))
334
+
335
+ print(f"Processing samples from index {start_idx} to {end_idx-1}")
336
+
337
+ for idx in range(start_idx, end_idx):
338
+ example = longbench[idx]
339
+ question = example['input']
340
+ print(f"Question: {question}")
341
+ context = example['context']
342
+ correct_answer = example['answers'][0]
343
+
344
+ print(f"Processing example {idx + 1}:")
345
+ print(f"Correct Answer: {correct_answer}")
346
+
347
+ # Build prompts
348
+ prompt_with_context = build_prompt(context, question)
349
+
350
+
351
+ # Get answers using transformers pipelines
352
+
353
+
354
+ answer_with_context = get_answer_with_retry('deepseek-r1', prompt_with_context)
355
+ # Extract content after "Answer:" from answer_with_context
356
+ answer_with_context_simple = (
357
+ answer_with_context
358
+ .split("Answer:", 1)[-1] # First keep the part after Answer:
359
+ .split("Step-by-step Reasoning", 1)[0] # Then cut before Step-by-step Reasoning
360
+ .strip()
361
+ )
362
+ print(f"Answer with context: {answer_with_context_simple}")
363
+ result = judge_answer_with_api(question, correct_answer, answer_with_context_simple)
364
+ print(f"Answer judge result: {result}")
365
+ if not result:
366
+ continue
367
+
368
+
369
+ answer_with_context = remove_think_tags(answer_with_context or "")
370
+
371
+ evidence = extract_all_bullet_passages(answer_with_context)
372
+
373
+ page_contents = []
374
+ if evidence:
375
+ count += 1
376
+ for ev in evidence:
377
+ snippet = ev['snippet']
378
+ result = retriveDoc(context, snippet)
379
+ # result["context"] is a set of Document objects
380
+ page_contents += [doc.page_content for doc in result]
381
+
382
+ unique_page_contents = list(dict.fromkeys(page_contents))
383
+ aggregated_content = "\n".join(unique_page_contents)
384
+ prompt_final = (
385
+ f"Please answer the question based on the context.\nContext: {aggregated_content}.\n Question: {question}.\n"
386
+ f"Please only provide your answer. "
387
+ f"Your Answer:"
388
+ )
389
+ final_answer = get_transformers_answer(prompt_final, tokenizer1, model1)
390
+ if judge_answer_with_api(question, correct_answer, final_answer):
391
+ print("correct")
392
+ else:
393
+ print("incorrect")
394
+ result_query = retriveDoc(context, question)
395
+ page_contents += [doc.page_content for doc in result_query]
396
+ unique_page_contents = list(dict.fromkeys(page_contents))
397
+ best = asyncio.run(
398
+ multi_proposal_pipeline(
399
+ question,
400
+ context,
401
+ unique_page_contents,
402
+ correct_answer
403
+ )
404
+ )
405
+ record = {
406
+ "question": question,
407
+ "answer": best["answer"],
408
+ "context": best["context"]
409
+ }
410
+
411
+ # Append one line of JSON each loop
412
+ with open(out_file, "a", encoding="utf-8") as fout:
413
+ fout.write(json.dumps(record, ensure_ascii=False) + "\n")
414
+
415
+
416
+
417
+
418
+
419
+
420
+
421
+ if __name__ == "__main__":
422
+ main()
423
+
424
+
425
+
426
+
427
+
random_alternative_answer.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import time
3
+ import re
4
+ import os
5
+ import argparse
6
+ from datasets import load_dataset
7
+ from nltk.tokenize import sent_tokenize
8
+ from utils.util import retriveDoc,compute_best_sentence_f1
9
+ from openai import OpenAI
10
+ import asyncio, json, torch, math
11
+ from typing import List, Tuple
12
+ # Hugging Face transformers related
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
+ from utils.metrics import qa_f1_score
15
+ from utils.llmjudge import judge_answer_with_api
16
+
17
+
18
+ client = OpenAI(
19
+ base_url=os.environ.get("OPENAI_BASE_URL"),
20
+ api_key=os.environ.get("OPENAI_API_KEY")
21
+ )
22
+ # Load models using transformers
23
+
24
+ tokenizer1 = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
25
+ model1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True,device_map="cuda:0",torch_dtype=torch.bfloat16)
26
+
27
+
28
+ tok_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True)
29
+ model_qwen = AutoModelForCausalLM.from_pretrained(
30
+ "Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True,
31
+ device_map="cuda:1",torch_dtype=torch.bfloat16
32
+ ).eval()
33
+
34
+ def get_transformers_answer(prompt, tokenizer, model, max_new_tokens=100, temperature=0.7, top_p=0.9, retries=3, delay=5):
35
+ """
36
+ Use transformers model.generate method for inference with retry mechanism,
37
+ use chat template to format input, and strip the input prompt part through token-level slicing,
38
+ return the newly generated text.
39
+ """
40
+ import time
41
+ for attempt in range(retries):
42
+ try:
43
+ # Convert original prompt to message format
44
+ messages = [{"role": "user", "content": prompt}]
45
+
46
+ # Try to use chat template to format input
47
+ try:
48
+ formatted_prompt = tokenizer.apply_chat_template(
49
+ messages,
50
+ tokenize=False,
51
+ add_generation_prompt=True
52
+ )
53
+ except Exception as e:
54
+ print(f"Unable to apply chat template: {e}, falling back to basic text input")
55
+ formatted_prompt = prompt # Fall back to original prompt as input
56
+
57
+ # Encode formatted prompt as model input tensor
58
+ model_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
59
+
60
+ # Call generate, the generated id sequence contains both prompt and subsequent generated text
61
+ generated_ids = model.generate(
62
+ **model_inputs,
63
+ max_new_tokens=max_new_tokens,
64
+ temperature=temperature,
65
+ top_p=top_p
66
+ )
67
+
68
+ # Calculate the token count corresponding to the prompt
69
+ input_length = model_inputs.input_ids.shape[1]
70
+
71
+ # Strip the prompt part from the front of the output, keeping only the newly added part
72
+ output_ids = generated_ids[0][input_length:]
73
+
74
+ # Decode generated text
75
+ answer = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
76
+ return answer
77
+ except Exception as e:
78
+ print(f"Error on attempt {attempt + 1}: {e}")
79
+ if attempt < retries - 1:
80
+ print(f"Retrying in {delay} seconds...")
81
+ time.sleep(delay)
82
+ else:
83
+ print("Max retries reached, skipping this request.")
84
+ return None
85
+
86
+ def truncate_answer(answer):
87
+ """Truncate answer, only take the part before the first period"""
88
+ return answer.split('.')[0].strip() if answer else "No answer"
89
+
90
+ def write_to_log(filename, data):
91
+ """Write data to log file"""
92
+ with open(filename, 'a', encoding='utf-8') as file:
93
+ file.write(data + '\n')
94
+
95
+ def remove_think_tags(text: str) -> str:
96
+ """Remove all <think> ... </think> blocks"""
97
+ return re.sub(r'<think>(.*?)</think>', '', text, flags=re.DOTALL).strip()
98
+
99
+ def build_prompt(context: str, question: str) -> str:
100
+ prompt = (
101
+ f"Answer the question based on the given passages. The following are the passages:\n"
102
+ f"{context}\n"
103
+ f"Answer the question based on the given passages.\n"
104
+ f"Question: {question}.\n"
105
+ f"Answer:\n"
106
+ f"Please first provide your answer in the format of Answer:[Your answer]. Then provide your reasoning process step-by-step.(Only include explicit clues) "
107
+ f"At the end of each reasoning step, include a new line that specifies the key information or reference content used in that step. "
108
+ f"Please ensure that the [reference content] you include is the complete original sentence or consecutive sentences from the text. Please do not change the punctuation. Do not use ellipses inside the sentence. "
109
+ f"Follow this format:\n"
110
+ f"Answer: [Your answer]\n"
111
+ f"Step-by-step Reasoning:\n"
112
+ f"1. [Reasoning step 1]\n"
113
+ f"[replaced by your reference content]\n"
114
+ f"2. [Reasoning step 2]\n"
115
+ f"[replaced by your reference content]\n"
116
+ )
117
+ return prompt
118
+
119
+ def extract_final_bullet_passage(answer_text: str):
120
+ reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
121
+ reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
122
+ if not reasoning_match:
123
+ return None, None
124
+
125
+ reasoning_text = reasoning_match.group(1).strip()
126
+ bullet_pattern = r"(?m)^(\d+\.\s.*?)(?=(?:\n\d+\.\s)|\Z)"
127
+ bullets = re.findall(bullet_pattern, reasoning_text, flags=re.DOTALL)
128
+ if not bullets:
129
+ print("No bullet blocks found.")
130
+ return None, None
131
+
132
+ passage_pattern = re.compile(
133
+ r'(?i)(?:\*\*)?passage\s+(\d+)(?:\*\*)?\s*:\s*("([^"]*)"|(.+?))(?=\Z|\n\s*\n|$)',
134
+ flags=re.DOTALL
135
+ )
136
+
137
+ for bullet in reversed(bullets):
138
+ matches = passage_pattern.findall(bullet)
139
+ if matches:
140
+ last_match = matches[-1]
141
+ passage_number = last_match[0]
142
+ quoted_snippet = last_match[2]
143
+ non_quoted_snippet = last_match[3]
144
+ snippet = non_quoted_snippet.strip() if non_quoted_snippet.strip() else quoted_snippet.strip()
145
+ return passage_number, snippet
146
+
147
+ return None, None
148
+
149
+ def extract_all_bullet_passages(answer_text: str):
150
+ reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
151
+ reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
152
+ if not reasoning_match:
153
+ return []
154
+
155
+ reasoning_text = reasoning_match.group(1).strip()
156
+ bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
157
+ bullets = bullet_pattern.findall(reasoning_text)
158
+ if not bullets:
159
+ return []
160
+
161
+ results = []
162
+ for bullet_index, bullet_text in enumerate(bullets, start=1):
163
+ results.append({
164
+ 'bullet_index': bullet_index,
165
+ 'snippet': bullet_text.strip()
166
+ })
167
+ print(results)
168
+ return results
169
+
170
+ def extract_evidence(answer_text: str):
171
+ reasoning_pattern = r"(?i)Evidence\s*(.*)"
172
+ reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
173
+ if not reasoning_match:
174
+ return []
175
+
176
+ reasoning_text = reasoning_match.group(1).strip()
177
+
178
+ # Extract all bullet segments
179
+ bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
180
+ bullets = bullet_pattern.findall(reasoning_text)
181
+ if not bullets:
182
+ return []
183
+
184
+ # Find the index of the first bullet starting with 1.
185
+ start_index = -1
186
+ for i, bullet in enumerate(bullets):
187
+ if bullet.strip().startswith("1."):
188
+ start_index = i
189
+ break
190
+
191
+ if start_index == -1:
192
+ return [] # No valid starting bullet
193
+
194
+ # Only keep the part starting from the first valid bullet
195
+ bullets = bullets[start_index:]
196
+
197
+ results = []
198
+ for bullet_index, bullet_text in enumerate(bullets, start=1):
199
+ results.append({
200
+ 'bullet_index': bullet_index,
201
+ 'snippet': bullet_text.strip()
202
+ })
203
+ return results
204
+
205
+
206
+ def get_answer_with_retry(model, prompt, retries=3, delay=5):
207
+ """Call the model to get the answer based on the prompt, with retry on failure."""
208
+ for attempt in range(retries):
209
+ try:
210
+ completion = client.chat.completions.create(
211
+ model=model,
212
+ messages=[{'role': 'user', 'content': prompt}]
213
+ )
214
+ return completion.choices[0].message.content.strip()
215
+ except Exception as e:
216
+ print(f"Error on attempt {attempt + 1}: {e}")
217
+ if attempt < retries - 1:
218
+ print(f"Retrying in {delay} seconds...")
219
+ time.sleep(delay)
220
+ else:
221
+ print("Max retries reached, skipping this request.")
222
+ return None
223
+
224
+ def extract_json_from_gpt_response(text: str) -> dict | None:
225
+ """
226
+ Finds the first JSON block inside ```json ... ``` or ``` … ``` and returns it as a dict.
227
+ """
228
+ # Try to find a ```json … ``` block first
229
+ m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
230
+ if not m:
231
+ # Fallback: any ``` … ``` block that looks like JSON
232
+ m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
233
+ if not m:
234
+ # Lastly, maybe the model just spit raw JSON without fences
235
+ m = re.search(r"(\{.*?\})", text, flags=re.DOTALL)
236
+ if not m:
237
+ return None
238
+
239
+ json_str = m.group(1)
240
+ try:
241
+ return json.loads(json_str)
242
+ except json.JSONDecodeError:
243
+ # clean up trailing commas, etc.
244
+ cleaned = re.sub(r",\s*([\]}])", r"\1", json_str)
245
+ try:
246
+ return json.loads(cleaned)
247
+ except json.JSONDecodeError:
248
+ return None
249
+
250
+ async def random_alternative_answer(
251
+ question: str,
252
+ original_context: str,
253
+ unique_sents: List[str],
254
+ correct_answer: str
255
+ ) -> dict:
256
+ """Generate random alternative answer and modified evidence"""
257
+
258
+ # Construct GPT-4o prompt
259
+ numbered = "\n\n".join(f"{j+1}. {s}" for j, s in enumerate(unique_sents))
260
+ prompt = (
261
+ "You are a creative assistant. Given the question below and the original answer, propose a plausible alternative answer that is **different** from the original but still reasonable. "
262
+ "Then rewrite the provided sentences to support your alternative answer. When rewriting each sentence, modify only the parts necessary to support the alternative answer. "
263
+ "Parts unrelated to the answer must keep their original meaning. Be sure that the modified evidence sentences are sufficient to answer the original question. "
264
+ "Output must be strictly in the specified JSON format, with no additional text.\n"
265
+ '{\n'
266
+ ' "answer": "<your alternative answer here, just provide the answer phrase, no need for complete sentence>",\n'
267
+ ' "revised": [\n'
268
+ ' "<rewritten sentence 1>",\n'
269
+ ' "<rewritten sentence 2>",\n'
270
+ ' ...\n'
271
+ ' ]\n'
272
+ '}\n\n'
273
+ f"Question:\n{question}\n\n"
274
+ f"Original answer:\n{correct_answer}\n\n"
275
+ f"Sentences to rewrite:\n{numbered}"
276
+ )
277
+
278
+ print(f"[Alternative Answer] Generating prompt: {prompt}")
279
+
280
+ rsp = client.chat.completions.create(
281
+ model="gpt-4o", temperature=0.7,
282
+ messages=[{"role":"user","content":prompt}]
283
+ )
284
+
285
+ js = extract_json_from_gpt_response(rsp.choices[0].message.content)
286
+ if not js:
287
+ print("[Alternative Answer] Failed to parse JSON")
288
+ return {"context": original_context, "answer": "Failed to generate alternative"}
289
+
290
+ revised = js["revised"] # List[str]
291
+ alternative = js["answer"] # Alternative answer
292
+
293
+ # Create new context
294
+ new_ctx = original_context
295
+ for old, new in zip(unique_sents, revised):
296
+ new_ctx = new_ctx.replace(old, new)
297
+
298
+ return {"context": new_ctx, "answer": alternative}
299
+
300
+ def main():
301
+ # Parse command line arguments
302
+ parser = argparse.ArgumentParser(description="LastingBench random alternative answer generation")
303
+ parser.add_argument("--output", "-o", type=str, default="output_random.jsonl",
304
+ help="Output JSONL file path (default: output_random.jsonl)")
305
+ parser.add_argument("--dataset_repo", type=str, default="THUDM/LongBench",
306
+ help="Dataset repository name (default: THUDM/LongBench)")
307
+ parser.add_argument("--dataset_subset", type=str, default="hotpotqa",
308
+ help="Dataset subset name (default: hotpotqa)")
309
+ parser.add_argument("--split", type=str, default="test",
310
+ help="Dataset split (default: test)")
311
+ parser.add_argument("--start_idx", type=int, default=0,
312
+ help="Starting index for processing (default: 0)")
313
+ parser.add_argument("--max_samples", type=int, default=-1,
314
+ help="Maximum number of samples to process (-1 for all, default: -1)")
315
+
316
+ args = parser.parse_args()
317
+
318
+ out_file = args.output
319
+ # Load dataset
320
+ longbench = load_dataset(args.dataset_repo, args.dataset_subset)[args.split]
321
+
322
+ print(f"Output file: {out_file}")
323
+ print(f"Dataset: {args.dataset_repo}/{args.dataset_subset}[{args.split}]")
324
+ print(f"Total samples: {len(longbench)}")
325
+
326
+ count = 0
327
+
328
+ # Determine processing range
329
+ start_idx = args.start_idx
330
+ end_idx = len(longbench) if args.max_samples == -1 else min(start_idx + args.max_samples, len(longbench))
331
+
332
+ print(f"Processing samples from index {start_idx} to {end_idx-1}")
333
+
334
+ for idx in range(start_idx, end_idx):
335
+ example = longbench[idx]
336
+ question = example['input']
337
+ print(f"Question: {question}")
338
+ context = example['context']
339
+ correct_answer = example['answers'][0]
340
+
341
+ print(f"Processing example {idx + 1}:")
342
+ print(f"Correct Answer: {correct_answer}")
343
+
344
+ # Build prompts
345
+ prompt_with_context = build_prompt(context, question)
346
+
347
+ # Get answers using transformers pipelines
348
+ answer_with_context = get_answer_with_retry('deepseek-r1', prompt_with_context)
349
+
350
+ # Extract content after "Answer:" from answer_with_context
351
+ answer_with_context_simple = (
352
+ answer_with_context
353
+ .split("Answer:", 1)[-1] # First keep the part after Answer:
354
+ .split("Step-by-step Reasoning", 1)[0] # Then cut before Step-by-step Reasoning
355
+ .strip()
356
+ )
357
+
358
+ print(f"Answer with context: {answer_with_context_simple}")
359
+ result = judge_answer_with_api(question, correct_answer, answer_with_context_simple)
360
+ print(f"Answer judge result: {result}")
361
+
362
+ if not result:
363
+ continue
364
+
365
+ answer_with_context = remove_think_tags(answer_with_context or "")
366
+ evidence = extract_all_bullet_passages(answer_with_context)
367
+
368
+ page_contents = []
369
+ if evidence:
370
+ count += 1
371
+ for ev in evidence:
372
+ snippet = ev['snippet']
373
+ result = retriveDoc(context, snippet)
374
+ # result["context"] is a set of Document objects
375
+ page_contents += [doc.page_content for doc in result]
376
+
377
+ unique_page_contents = list(dict.fromkeys(page_contents))
378
+ aggregated_content = "\n".join(unique_page_contents)
379
+
380
+ prompt_final = (
381
+ f"Please answer the question based on the context.\nContext: {aggregated_content}.\n Question: {question}.\n"
382
+ f"Please only provide your answer. "
383
+ f"Your Answer:"
384
+ )
385
+
386
+ final_answer = get_transformers_answer(prompt_final, tokenizer1, model1)
387
+
388
+ if judge_answer_with_api(question, correct_answer, final_answer):
389
+ print("correct")
390
+ else:
391
+ print("incorrect")
392
+ result_query = retriveDoc(context, question)
393
+ page_contents += [doc.page_content for doc in result_query]
394
+
395
+ unique_page_contents = list(dict.fromkeys(page_contents))
396
+
397
+ # Generate random alternative answer instead of selecting the highest ppl answer
398
+ alternative = asyncio.run(
399
+ random_alternative_answer(
400
+ question,
401
+ context,
402
+ unique_page_contents,
403
+ correct_answer
404
+ )
405
+ )
406
+
407
+ record = {
408
+ "question": question,
409
+ "answer": alternative["answer"],
410
+ "context": alternative["context"]
411
+ }
412
+
413
+ # Append one line of JSON each loop
414
+ with open(out_file, "a", encoding="utf-8") as fout:
415
+ fout.write(json.dumps(record, ensure_ascii=False) + "\n")
416
+
417
+ if __name__ == "__main__":
418
+ main()
requirements.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core ML and NLP libraries
2
+ torch>=2.6.0
3
+ transformers>=4.51.0
4
+ datasets>=2.10.0
5
+ tokenizers>=0.13.0
6
+
7
+ # Haystack for document processing and retrieval
8
+ haystack-ai>=2.0.0
9
+
10
+ # OpenAI API client
11
+ openai == 1.84.0
12
+
13
+ # Data processing and analysis
14
+ pandas>=1.5.0
15
+ numpy>=1.24.0
16
+
17
+ # Natural language processing
18
+ nltk>=3.8.0
19
+ jieba
20
+ fuzzywuzzy
21
+ rouge
22
+ rank_bm25
23
+ langchain_text_splitters
24
+ langchain_community
25
+ langchain_openai
26
+
27
+ # Visualization
28
+ matplotlib>=3.6.0
29
+
30
+ # Async processing
31
+ asyncio-throttle>=1.0.0
32
+
33
+ # Optional: Additional ML utilities
34
+ scikit-learn>=1.2.0
35
+ tqdm>=4.64.0
36
+
37
+
38
+
training_result/training_loss_antifact_llama.csv ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Step,Loss,Epoch
2
+ 1,1.8891,1
3
+ 2,1.9053,1
4
+ 3,1.8574,1
5
+ 4,1.9114,1
6
+ 5,1.8351,1
7
+ 6,1.8981,2
8
+ 7,1.9048,2
9
+ 8,1.8649,2
10
+ 9,1.877,2
11
+ 10,1.8449,2
12
+ 11,1.8986,3
13
+ 12,1.8799,3
14
+ 13,1.8378,3
15
+ 14,1.9004,3
16
+ 15,1.8675,3
17
+ 16,1.8812,4
18
+ 17,1.8635,4
19
+ 18,1.8977,4
20
+ 19,1.8393,4
21
+ 20,1.9017,4
22
+ 21,1.8482,5
23
+ 22,1.8353,5
24
+ 23,1.8514,5
25
+ 24,1.9189,5
26
+ 25,1.8596,5
27
+ 26,1.8672,6
28
+ 27,1.8421,6
29
+ 28,1.848,6
30
+ 29,1.8762,6
31
+ 30,1.8964,6
32
+ 31,1.8663,7
33
+ 32,1.8491,7
34
+ 33,1.8637,7
35
+ 34,1.8403,7
36
+ 35,1.8842,7
37
+ 36,1.827,8
38
+ 37,1.8486,8
39
+ 38,1.8671,8
40
+ 39,1.8921,8
41
+ 40,1.7564,8
training_result/training_loss_antifact_qwen38.csv ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Step,Loss,Epoch
2
+ 1,2.3492,1
3
+ 2,2.3786,1
4
+ 3,2.314,1
5
+ 4,2.3676,1
6
+ 5,2.2953,1
7
+ 6,2.3573,2
8
+ 7,2.3483,2
9
+ 8,2.3144,2
10
+ 9,2.321,2
11
+ 10,2.301,2
12
+ 11,2.326,3
13
+ 12,2.2982,3
14
+ 13,2.2573,3
15
+ 14,2.2941,3
16
+ 15,2.2627,3
17
+ 16,2.2579,4
18
+ 17,2.2519,4
19
+ 18,2.2641,4
20
+ 19,2.2128,4
21
+ 20,2.2421,4
22
+ 21,2.1929,5
23
+ 22,2.1757,5
24
+ 23,2.1914,5
25
+ 24,2.2761,5
26
+ 25,2.2079,5
27
+ 26,2.1893,6
28
+ 27,2.1754,6
29
+ 28,2.174,6
30
+ 29,2.2038,6
31
+ 30,2.2008,6
32
+ 31,2.1921,7
33
+ 32,2.1533,7
34
+ 33,2.1783,7
35
+ 34,2.1534,7
36
+ 35,2.2158,7
37
+ 36,2.1322,8
38
+ 37,2.1752,8
39
+ 38,2.18,8
40
+ 39,2.1953,8
41
+ 40,2.0122,8
training_result/training_loss_llama.csv ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Step,Loss,Epoch
2
+ 1,1.858,1
3
+ 2,1.8676,1
4
+ 3,1.8243,1
5
+ 4,1.8815,1
6
+ 5,1.7409,1
7
+ 6,1.8664,2
8
+ 7,1.8736,2
9
+ 8,1.8304,2
10
+ 9,1.8358,2
11
+ 10,1.8119,2
12
+ 11,1.8637,3
13
+ 12,1.8456,3
14
+ 13,1.7986,3
15
+ 14,1.8643,3
16
+ 15,1.8602,3
17
+ 16,1.8451,4
18
+ 17,1.8329,4
19
+ 18,1.8589,4
20
+ 19,1.8064,4
21
+ 20,1.8571,4
22
+ 21,1.8139,5
23
+ 22,1.8059,5
24
+ 23,1.8097,5
25
+ 24,1.886,5
26
+ 25,1.8094,5
27
+ 26,1.8318,6
28
+ 27,1.8085,6
29
+ 28,1.8128,6
30
+ 29,1.842,6
31
+ 30,1.8477,6
32
+ 31,1.8348,7
33
+ 32,1.8133,7
34
+ 33,1.8263,7
35
+ 34,1.8028,7
36
+ 35,1.8589,7
37
+ 36,1.7959,8
38
+ 37,1.8054,8
39
+ 38,1.8296,8
40
+ 39,1.8629,8
41
+ 40,1.7241,8
training_result/training_loss_phi4.csv ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Step,Loss,Epoch
2
+ 1,1.8667,1
3
+ 2,1.8693,1
4
+ 3,1.8411,1
5
+ 4,1.9009,1
6
+ 5,1.7747,1
7
+ 6,1.8776,2
8
+ 7,1.8859,2
9
+ 8,1.8414,2
10
+ 9,1.8464,2
11
+ 10,1.8524,2
12
+ 11,1.8784,3
13
+ 12,1.8608,3
14
+ 13,1.8113,3
15
+ 14,1.8754,3
16
+ 15,1.8512,3
17
+ 16,1.8578,4
18
+ 17,1.8542,4
19
+ 18,1.8738,4
20
+ 19,1.8203,4
21
+ 20,1.781,4
22
+ 21,1.8297,5
23
+ 22,1.811,5
24
+ 23,1.8162,5
25
+ 24,1.9074,5
26
+ 25,1.8363,5
27
+ 26,1.8388,6
28
+ 27,1.8351,6
29
+ 28,1.8299,6
30
+ 29,1.8478,6
31
+ 30,1.8644,6
32
+ 31,1.8573,7
33
+ 32,1.8156,7
34
+ 33,1.8426,7
35
+ 34,1.824,7
36
+ 35,1.8796,7
37
+ 36,1.8153,8
38
+ 37,1.8269,8
39
+ 38,1.8404,8
40
+ 39,1.8772,8
41
+ 40,1.7365,8
training_result/training_loss_phi4_antifact.csv ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Step,Loss,Epoch
2
+ 1,1.8977,1
3
+ 2,1.9077,1
4
+ 3,1.8738,1
5
+ 4,1.9326,1
6
+ 5,1.8628,1
7
+ 6,1.9098,2
8
+ 7,1.9175,2
9
+ 8,1.876,2
10
+ 9,1.8882,2
11
+ 10,1.8819,2
12
+ 11,1.9128,3
13
+ 12,1.8947,3
14
+ 13,1.851,3
15
+ 14,1.9117,3
16
+ 15,1.8586,3
17
+ 16,1.8941,4
18
+ 17,1.8842,4
19
+ 18,1.9115,4
20
+ 19,1.8528,4
21
+ 20,1.8236,4
22
+ 21,1.8639,5
23
+ 22,1.8396,5
24
+ 23,1.8569,5
25
+ 24,1.9398,5
26
+ 25,1.8856,5
27
+ 26,1.8731,6
28
+ 27,1.8678,6
29
+ 28,1.8652,6
30
+ 29,1.8808,6
31
+ 30,1.914,6
32
+ 31,1.8884,7
33
+ 32,1.851,7
34
+ 33,1.879,7
35
+ 34,1.8604,7
36
+ 35,1.9046,7
37
+ 36,1.8455,8
38
+ 37,1.8689,8
39
+ 38,1.8771,8
40
+ 39,1.9062,8
41
+ 40,1.77,8
training_result/training_loss_qwen38.csv ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Step,Loss,Epoch
2
+ 1,2.3167,1
3
+ 2,2.3344,1
4
+ 3,2.2796,1
5
+ 4,2.3293,1
6
+ 5,2.1972,1
7
+ 6,2.3196,2
8
+ 7,2.3094,2
9
+ 8,2.278,2
10
+ 9,2.2759,2
11
+ 10,2.272,2
12
+ 11,2.2851,3
13
+ 12,2.2595,3
14
+ 13,2.2142,3
15
+ 14,2.2586,3
16
+ 15,2.2535,3
17
+ 16,2.2193,4
18
+ 17,2.2181,4
19
+ 18,2.2252,4
20
+ 19,2.1788,4
21
+ 20,2.1908,4
22
+ 21,2.1574,5
23
+ 22,2.1469,5
24
+ 23,2.1484,5
25
+ 24,2.2405,5
26
+ 25,2.1602,5
27
+ 26,2.1534,6
28
+ 27,2.1435,6
29
+ 28,2.1369,6
30
+ 29,2.1685,6
31
+ 30,2.1502,6
32
+ 31,2.1597,7
33
+ 32,2.1169,7
34
+ 33,2.1408,7
35
+ 34,2.1159,7
36
+ 35,2.1897,7
37
+ 36,2.1013,8
38
+ 37,2.1304,8
39
+ 38,2.1441,8
40
+ 39,2.1645,8
41
+ 40,1.9816,8
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .metrics import *
2
+
utils/convert.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import argparse
3
+
4
+ def main():
5
+ parser = argparse.ArgumentParser(description='Convert and merge JSONL files with question-answer mappings')
6
+ parser.add_argument('orig_path', help='Path to the original JSONL file')
7
+ parser.add_argument('out_path', help='Path to the output JSONL file')
8
+ parser.add_argument('mapping_paths', nargs='+', help='Path(s) to mapping JSONL file(s)')
9
+
10
+ args = parser.parse_args()
11
+
12
+ # Original data file paths from command line arguments
13
+ orig_path = args.orig_path
14
+ out_path = args.out_path
15
+ mapping_paths = args.mapping_paths
16
+
17
+ # Step 1: Build question -> {context, answers} mapping
18
+ mapping = {}
19
+ for mp in mapping_paths:
20
+ with open(mp, 'r', encoding='utf-8') as f_map:
21
+ for idx, line in enumerate(f_map):
22
+ obj = json.loads(line)
23
+ q = obj.get("question")
24
+ if q is None:
25
+ continue
26
+ # Ensure we get the context
27
+ ctx = obj.get("context", "")
28
+ # Some files have "answer" field, some have "answers"
29
+ raw_ans = obj.get("answers", obj.get("answer", []))
30
+ # Normalize answer(s) to list format
31
+ if isinstance(raw_ans, list):
32
+ ans = raw_ans
33
+ else:
34
+ ans = [raw_ans]
35
+ # If the same question appears in multiple mapping files, later ones will overwrite earlier ones
36
+ mapping[q] = {"context": ctx, "answers": ans}
37
+
38
+ # Step 2: Read original file, perform replacement and write output
39
+ with open(orig_path, 'r', encoding='utf-8') as f_in, \
40
+ open(out_path, 'w', encoding='utf-8') as f_out:
41
+ for line in f_in:
42
+ item = json.loads(line)
43
+ inp = item.get("input")
44
+ if inp in mapping:
45
+ item["context"] = mapping[inp]["context"]
46
+ item["answers"] = mapping[inp]["answers"]
47
+ f_out.write(json.dumps(item, ensure_ascii=False) + "\n")
48
+
49
+ print(f"Merge completed, output file: {out_path}")
50
+
51
+ if __name__ == "__main__":
52
+ main()
utils/draw.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib as mpl
4
+ import argparse
5
+
6
+
7
+ mpl.rcParams['font.family'] = 'serif'
8
+ mpl.rcParams['font.serif'] = ['Georgia']
9
+ mpl.rcParams['font.size'] = 20
10
+ mpl.rcParams['axes.titlesize']= 20
11
+ mpl.rcParams['axes.labelsize']= 18
12
+ mpl.rcParams['xtick.labelsize']=16
13
+ mpl.rcParams['ytick.labelsize']=16
14
+ # no legend, so no need to set legend.fontsize
15
+
16
+ def plot_two_loss_curves(
17
+ csv_file1,
18
+ csv_file2,
19
+ title="Loss Comparison on Qwen3-8B",
20
+ dataset1_name="Dataset1",
21
+ dataset2_name="Dataset2"
22
+ ):
23
+ # Read CSV files
24
+ df1 = pd.read_csv(csv_file1)
25
+ df2 = pd.read_csv(csv_file2)
26
+
27
+ # Check columns
28
+ for df, path in ((df1, csv_file1), (df2, csv_file2)):
29
+ if 'Step' not in df.columns or 'Loss' not in df.columns:
30
+ raise ValueError(f"Missing 'Step' or 'Loss' columns in {path}")
31
+
32
+ # Create figure
33
+ plt.figure(figsize=(12, 8))
34
+
35
+ # Plot two lines with softer colors
36
+ plt.plot(df1['Step'], df1['Loss'],
37
+ color='#1f77b4', linewidth=2.5) # steel blue
38
+ plt.plot(df2['Step'], df2['Loss'],
39
+ color='#2ca02c', linewidth=2.5) # medium sea green
40
+
41
+ # Title and labels
42
+ plt.title(title, fontweight='bold')
43
+ plt.xlabel('Steps', fontweight='bold')
44
+ plt.ylabel('Loss', fontweight='bold')
45
+
46
+ # Grid
47
+ plt.grid(True, linestyle='--', alpha=0.7)
48
+
49
+ # Layout
50
+ plt.tight_layout(pad=3.0)
51
+
52
+ # Save
53
+ plt.savefig('loss_comparison_qwen38b.svg', format='svg')
54
+ plt.savefig('loss_comparison.png', dpi=300)
55
+
56
+ # Display
57
+ plt.show()
58
+
59
+ print("Saved: loss_comparison.svg, loss_comparison.png")
60
+
61
+
62
+ def main():
63
+ parser = argparse.ArgumentParser(description='Plot comparison of two training loss curves')
64
+ parser.add_argument('csv_file1', help='Path to the first CSV file')
65
+ parser.add_argument('csv_file2', help='Path to the second CSV file')
66
+ parser.add_argument('--title', default='Training Loss Comparison', help='Title for the plot')
67
+ parser.add_argument('--dataset1-name', default='Original Dataset', help='Name for the first dataset')
68
+ parser.add_argument('--dataset2-name', default='Revised Dataset', help='Name for the second dataset')
69
+
70
+ args = parser.parse_args()
71
+
72
+ plot_two_loss_curves(
73
+ args.csv_file1,
74
+ args.csv_file2,
75
+ title=args.title,
76
+ dataset1_name=args.dataset1_name,
77
+ dataset2_name=args.dataset2_name
78
+ )
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()
utils/llmjudge.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import logging
4
+ import os
5
+
6
+ def judge_answer_with_api(question, target, answer):
7
+ from openai import OpenAI
8
+
9
+ client = OpenAI(
10
+ base_url=os.environ.get("OPENAI_BASE_URL"),
11
+ api_key=os.environ.get("OPENAI_API_KEY")
12
+ )
13
+
14
+ prompt = (
15
+ "You will be given a question, a target answer (maybe a list of all possible answers), "
16
+ "and a generated answer. Please judge whether the generated answer is correct. "
17
+ "If it is correct, return 'True'. If it is incorrect, return 'False'.\n"
18
+ f"Question: {question}\n"
19
+ f"Target Answer: {target}\n"
20
+ f"Generated Answer: {answer}\n"
21
+ "Please return only 'True' or 'False', without any other text."
22
+ )
23
+
24
+ try:
25
+ response = client.chat.completions.create(
26
+ model="gpt-4o-mini-2024-07-18",
27
+ messages=[{"role": "user", "content": prompt}],
28
+ temperature=0,
29
+ max_tokens=1
30
+ )
31
+ except Exception as e:
32
+ logging.error("API call failed: %s", str(e))
33
+ return 0 # Return default value
34
+
35
+ try:
36
+ result = response.choices[0].message.content.strip()
37
+ except (AttributeError, IndexError) as e:
38
+ logging.error("Error parsing API response: %s", str(e))
39
+ return 0
40
+
41
+ if result == "True":
42
+ return 1
43
+ elif result == "False":
44
+ return 0
45
+ else:
46
+ logging.warning("Abnormal response format: %s", result)
47
+ return 0
48
+
49
+
utils/metrics.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import string
3
+
4
+ import jieba
5
+ from fuzzywuzzy import fuzz
6
+ import difflib
7
+
8
+ from typing import List
9
+ from collections import Counter
10
+ from rouge import Rouge
11
+
12
+ def normalize_answer(s):
13
+ """Lower text and remove punctuation, articles and extra whitespace."""
14
+
15
+ def remove_articles(text):
16
+ return re.sub(r"\b(a|an|the)\b", " ", text)
17
+
18
+ def white_space_fix(text):
19
+ return " ".join(text.split())
20
+
21
+ def remove_punc(text):
22
+ exclude = set(string.punctuation)
23
+ return "".join(ch for ch in text if ch not in exclude)
24
+
25
+ def lower(text):
26
+ return text.lower()
27
+
28
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
29
+
30
+
31
+ def normalize_zh_answer(s):
32
+ """Lower text and remove punctuation, extra whitespace."""
33
+
34
+ def white_space_fix(text):
35
+ return "".join(text.split())
36
+
37
+ def remove_punc(text):
38
+ cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
39
+ all_punctuation = set(string.punctuation + cn_punctuation)
40
+ return "".join(ch for ch in text if ch not in all_punctuation)
41
+
42
+ def lower(text):
43
+ return text.lower()
44
+
45
+ return white_space_fix(remove_punc(lower(s)))
46
+
47
+ def count_score(prediction, ground_truth, **kwargs):
48
+ numbers = re.findall(r"\d+", prediction)
49
+ right_num = 0
50
+ for number in numbers:
51
+ if str(number) == str(ground_truth):
52
+ right_num += 1
53
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
54
+ return float(final_score)
55
+
56
+ def retrieval_score(prediction, ground_truth, **kwargs):
57
+ pattern = r'Paragraph (\d+)'
58
+ matches = re.findall(pattern, ground_truth)
59
+ ground_truth_id = matches[0]
60
+ numbers = re.findall(r"\d+", prediction)
61
+ right_num = 0
62
+ for number in numbers:
63
+ if str(number) == str(ground_truth_id):
64
+ right_num += 1
65
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
66
+ return float(final_score)
67
+
68
+ def retrieval_zh_score(prediction, ground_truth, **kwargs):
69
+ pattern = r'段落(\d+)'
70
+ matches = re.findall(pattern, ground_truth)
71
+ ground_truth_id = matches[0]
72
+ numbers = re.findall(r"\d+", prediction)
73
+ right_num = 0
74
+ for number in numbers:
75
+ if str(number) == str(ground_truth_id):
76
+ right_num += 1
77
+ final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
78
+ return float(final_score)
79
+
80
+ def code_sim_score(prediction, ground_truth, **kwargs):
81
+ all_lines = prediction.lstrip('\n').split('\n')
82
+ prediction = ""
83
+ for line in all_lines:
84
+ if ('`' not in line) and ('#' not in line) and ('//' not in line):
85
+ prediction = line
86
+ break
87
+ return (fuzz.ratio(prediction, ground_truth) / 100)
88
+
89
+ def classification_score(prediction, ground_truth, **kwargs):
90
+ em_match_list = []
91
+ all_classes = kwargs["all_classes"]
92
+ for class_name in all_classes:
93
+ if class_name in prediction:
94
+ em_match_list.append(class_name)
95
+ for match_term in em_match_list:
96
+ if match_term in ground_truth and match_term != ground_truth:
97
+ em_match_list.remove(match_term)
98
+ if ground_truth in em_match_list:
99
+ score = (1.0 / len(em_match_list))
100
+ else:
101
+ score = 0.0
102
+ return score
103
+
104
+ def rouge_score(prediction, ground_truth, **kwargs):
105
+ rouge = Rouge()
106
+ try:
107
+ scores = rouge.get_scores([prediction], [ground_truth], avg=True)
108
+ except:
109
+ return 0.0
110
+ return scores["rouge-l"]["f"]
111
+
112
+ def rouge_zh_score(prediction, ground_truth, **kwargs):
113
+ prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
114
+ ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
115
+ score = rouge_score(prediction, ground_truth)
116
+ return score
117
+
118
+ def f1_score(prediction, ground_truth, **kwargs):
119
+ common = Counter(prediction) & Counter(ground_truth)
120
+ num_same = sum(common.values())
121
+ if num_same == 0:
122
+ return 0
123
+ precision = 1.0 * num_same / len(prediction)
124
+ recall = 1.0 * num_same / len(ground_truth)
125
+ f1 = (2 * precision * recall) / (precision + recall)
126
+ return f1
127
+
128
+ def qa_f1_score(prediction, ground_truth, **kwargs):
129
+ normalized_prediction = normalize_answer(prediction)
130
+ normalized_ground_truth = normalize_answer(ground_truth)
131
+
132
+ prediction_tokens = normalized_prediction.split()
133
+ ground_truth_tokens = normalized_ground_truth.split()
134
+ return f1_score(prediction_tokens, ground_truth_tokens)
135
+
136
+
137
+ def qa_f1_zh_score(prediction, ground_truth, **kwargs):
138
+ prediction_tokens = list(jieba.cut(prediction, cut_all=False))
139
+ ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
140
+ prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
141
+ ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
142
+ prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
143
+ ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
144
+ return f1_score(prediction_tokens, ground_truth_tokens)
145
+
146
+ def qa_em_score(prediction, ground_truth, **kwargs):
147
+ normalized_prediction = normalize_answer(prediction)
148
+ normalized_ground_truth = normalize_answer(ground_truth)
149
+ return 1 if (normalized_prediction in normalized_ground_truth or normalized_ground_truth in normalized_prediction) else 0
150
+
151
+
152
+
utils/util.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import nltk
3
+ nltk.download('punkt_tab')
4
+ from nltk.tokenize import sent_tokenize, word_tokenize
5
+ from rank_bm25 import BM25Okapi
6
+ from langchain_text_splitters import NLTKTextSplitter
7
+ from langchain_community.vectorstores import FAISS
8
+ from langchain_openai import OpenAIEmbeddings
9
+ from collections import Counter
10
+
11
+ def replace_case_insensitive(text: str, old: str, new: str) -> str:
12
+ pattern = re.compile(re.escape(old), re.IGNORECASE)
13
+
14
+ return pattern.sub(new, text)
15
+ def get_word_list(s1):
16
+ # Separate sentences by word, Chinese by word, English by word, numbers by space
17
+ regEx = re.compile('[\W]')
18
+ res = re.compile(r"([\u4e00-\u9fa5])") # [\u4e00-\u9fa5] for Chinese
19
+
20
+ p1 = regEx.split(s1.lower())
21
+ str1_list = []
22
+ for str in p1:
23
+ if res.split(str) == None:
24
+ str1_list.append(str)
25
+ else:
26
+ ret = res.split(str)
27
+ for ch in ret:
28
+ str1_list.append(ch)
29
+
30
+ list_word1 = [w for w in str1_list if len(w.strip()) > 0]
31
+
32
+ return list_word1
33
+ def get_word_len(s1):
34
+ return len(get_word_list(s1))
35
+
36
+ regex = r'([。?!;\n.!?;]\s*)'
37
+ def retriveDoc(text,query,top_k=3):
38
+ import os
39
+ sentences = sent_tokenize(text)
40
+ embeddings = OpenAIEmbeddings(model="text-embedding-3-small", base_url=os.environ.get("OPENAI_BASE_URL"),
41
+ api_key=os.environ.get("OPENAI_API_KEY"))
42
+ # Create vector database through FAISS (built from sentence list)
43
+ vector_store = FAISS.from_texts(sentences, embeddings)
44
+
45
+ retrieved_docs = vector_store.similarity_search(query, k=top_k)
46
+ print("Retrieved sentences:", retrieved_docs)
47
+
48
+ # Return results, can adjust the return structure as needed, here returns a dictionary containing context
49
+ return retrieved_docs
50
+
51
+
52
+ def most_similar_sentence_bm25(paragraph, target_sentence):
53
+ """
54
+ Use BM25 algorithm to find the most similar sentence to target_sentence in the given paragraph,
55
+ return (most similar sentence, score).
56
+ """
57
+ # 1. First split the paragraph into a list of sentences
58
+ sentences = sent_tokenize(paragraph)
59
+
60
+ # 2. Tokenize each sentence
61
+ tokenized_sentences = [word_tokenize(sent) for sent in sentences]
62
+
63
+ # 3. Create a retrieval instance using BM25Okapi
64
+ bm25 = BM25Okapi(tokenized_sentences)
65
+
66
+ # 4. Tokenize the target sentence
67
+ target_tokens = word_tokenize(target_sentence)
68
+
69
+ # 5. Use BM25 to calculate similarity scores for each sentence
70
+ scores = bm25.get_scores(target_tokens)
71
+ # scores.shape == (len(sentences),)
72
+
73
+ # 6. Find the index of the sentence with the highest score
74
+ max_idx = scores.argmax()
75
+
76
+ # Return the most similar sentence and its score
77
+ return sentences[max_idx]
78
+
79
+
80
+ def f1_score_text(pred, gold):
81
+ pred_tokens = word_tokenize(pred)
82
+ gold_tokens = word_tokenize(gold)
83
+ common = Counter(pred_tokens) & Counter(gold_tokens)
84
+ num_same = sum(common.values())
85
+ if num_same == 0:
86
+ return 0.0
87
+ precision = num_same / len(pred_tokens)
88
+ recall = num_same / len(gold_tokens)
89
+ f1 = 2 * precision * recall / (precision + recall)
90
+ return f1
91
+
92
+ def compute_best_sentence_f1(pred_text, gold_text):
93
+ pred_sentences = sent_tokenize(pred_text)
94
+ gold_sentences = sent_tokenize(gold_text)
95
+ f1_scores = []
96
+ for pred in pred_sentences:
97
+ best_f1 = 0.0
98
+ for gold in gold_sentences:
99
+ f1 = f1_score_text(pred, gold)
100
+ if f1 > best_f1:
101
+ best_f1 = f1
102
+ f1_scores.append(best_f1)
103
+ avg_f1 = sum(f1_scores) / len(pred_sentences) if pred_sentences else 0.0
104
+ return avg_f1