File size: 16,892 Bytes
b1e25b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 |
import json
import time
import re
import os
import argparse
from datasets import load_dataset
from nltk.tokenize import sent_tokenize
from utils.util import retriveDoc,compute_best_sentence_f1
from openai import OpenAI
import asyncio, json, torch, math
from typing import List, Tuple
# Hugging Face transformers related
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from utils.metrics import qa_f1_score
from utils.llmjudge import judge_answer_with_api
client = OpenAI(
base_url=os.environ.get("OPENAI_BASE_URL"),
api_key=os.environ.get("OPENAI_API_KEY")
)
# Load models using transformers
tokenizer1 = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
model1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True,device_map="cuda:0",torch_dtype=torch.bfloat16)
tok_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True)
model_qwen = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True,
device_map="cuda:1",torch_dtype=torch.bfloat16
).eval()
def get_transformers_answer(prompt, tokenizer, model, max_new_tokens=100, temperature=0.7, top_p=0.9, retries=3, delay=5):
"""
Use transformers model.generate method for inference with retry mechanism,
use chat template to format input, and strip the input prompt part through token-level slicing,
return the newly generated text.
"""
import time
for attempt in range(retries):
try:
# Convert original prompt to message format
messages = [{"role": "user", "content": prompt}]
# Try to use chat template to format input
try:
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
except Exception as e:
print(f"Unable to apply chat template: {e}, falling back to basic text input")
formatted_prompt = prompt # Fall back to original prompt as input
# Encode formatted prompt as model input tensor
model_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
# Call generate, the generated id sequence contains both prompt and subsequent generated text
generated_ids = model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p
)
# Calculate the token count corresponding to the prompt
input_length = model_inputs.input_ids.shape[1]
# Strip the prompt part from the front of the output, keeping only the newly added part
output_ids = generated_ids[0][input_length:]
# Decode generated text
answer = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
return answer
except Exception as e:
print(f"Error on attempt {attempt + 1}: {e}")
if attempt < retries - 1:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
else:
print("Max retries reached, skipping this request.")
return None
def truncate_answer(answer):
"""Truncate answer, only take the part before the first period"""
return answer.split('.')[0].strip() if answer else "No answer"
def write_to_log(filename, data):
"""Write data to log file"""
with open(filename, 'a', encoding='utf-8') as file:
file.write(data + '\n')
def remove_think_tags(text: str) -> str:
"""Remove all <think> ... </think> blocks"""
return re.sub(r'<think>(.*?)</think>', '', text, flags=re.DOTALL).strip()
def build_prompt(context: str, question: str) -> str:
prompt = (
f"Answer the question based on the given passages. The following are the passages:\n"
f"{context}\n"
f"Answer the question based on the given passages.\n"
f"Question: {question}.\n"
f"Answer:\n"
f"Please first provide your answer in the format of Answer:[Your answer]. Then provide your reasoning process step-by-step.(Only include explicit clues) "
f"At the end of each reasoning step, include a new line that specifies the key information or reference content used in that step. "
f"Please ensure that the [reference content] you include is the complete original sentence or consecutive sentences from the text. Please do not change the punctuation. Do not use ellipses inside the sentence. "
f"Follow this format:\n"
f"Answer: [Your answer]\n"
f"Step-by-step Reasoning:\n"
f"1. [Reasoning step 1]\n"
f"[replaced by your reference content]\n"
f"2. [Reasoning step 2]\n"
f"[replaced by your reference content]\n"
)
return prompt
def extract_final_bullet_passage(answer_text: str):
reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
if not reasoning_match:
return None, None
reasoning_text = reasoning_match.group(1).strip()
bullet_pattern = r"(?m)^(\d+\.\s.*?)(?=(?:\n\d+\.\s)|\Z)"
bullets = re.findall(bullet_pattern, reasoning_text, flags=re.DOTALL)
if not bullets:
print("No bullet blocks found.")
return None, None
passage_pattern = re.compile(
r'(?i)(?:\*\*)?passage\s+(\d+)(?:\*\*)?\s*:\s*("([^"]*)"|(.+?))(?=\Z|\n\s*\n|$)',
flags=re.DOTALL
)
for bullet in reversed(bullets):
matches = passage_pattern.findall(bullet)
if matches:
last_match = matches[-1]
passage_number = last_match[0]
quoted_snippet = last_match[2]
non_quoted_snippet = last_match[3]
snippet = non_quoted_snippet.strip() if non_quoted_snippet.strip() else quoted_snippet.strip()
return passage_number, snippet
return None, None
def extract_all_bullet_passages(answer_text: str):
reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
if not reasoning_match:
return []
reasoning_text = reasoning_match.group(1).strip()
bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
bullets = bullet_pattern.findall(reasoning_text)
if not bullets:
return []
results = []
for bullet_index, bullet_text in enumerate(bullets, start=1):
results.append({
'bullet_index': bullet_index,
'snippet': bullet_text.strip()
})
print(results)
return results
def extract_evidence(answer_text: str):
reasoning_pattern = r"(?i)Evidence\s*(.*)"
reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
if not reasoning_match:
return []
reasoning_text = reasoning_match.group(1).strip()
# Extract all bullet segments
bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
bullets = bullet_pattern.findall(reasoning_text)
if not bullets:
return []
# Find the index of the first bullet starting with 1.
start_index = -1
for i, bullet in enumerate(bullets):
if bullet.strip().startswith("1."):
start_index = i
break
if start_index == -1:
return [] # No valid starting bullet
# Only keep the part starting from the first valid bullet
bullets = bullets[start_index:]
results = []
for bullet_index, bullet_text in enumerate(bullets, start=1):
results.append({
'bullet_index': bullet_index,
'snippet': bullet_text.strip()
})
return results
def get_answer_with_retry(model, prompt, retries=3, delay=5):
"""Call the model to get the answer based on the prompt, with retry on failure."""
for attempt in range(retries):
try:
completion = client.chat.completions.create(
model=model,
messages=[{'role': 'user', 'content': prompt}]
)
return completion.choices[0].message.content.strip()
except Exception as e:
print(f"Error on attempt {attempt + 1}: {e}")
if attempt < retries - 1:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
else:
print("Max retries reached, skipping this request.")
return None
def extract_json_from_gpt_response(text: str) -> dict | None:
"""
Finds the first JSON block inside ```json ... ``` or ``` … ``` and returns it as a dict.
"""
# Try to find a ```json … ``` block first
m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
if not m:
# Fallback: any ``` … ``` block that looks like JSON
m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
if not m:
# Lastly, maybe the model just spit raw JSON without fences
m = re.search(r"(\{.*?\})", text, flags=re.DOTALL)
if not m:
return None
json_str = m.group(1)
try:
return json.loads(json_str)
except json.JSONDecodeError:
# clean up trailing commas, etc.
cleaned = re.sub(r",\s*([\]}])", r"\1", json_str)
try:
return json.loads(cleaned)
except json.JSONDecodeError:
return None
async def random_alternative_answer(
question: str,
original_context: str,
unique_sents: List[str],
correct_answer: str
) -> dict:
"""Generate random alternative answer and modified evidence"""
# Construct GPT-4o prompt
numbered = "\n\n".join(f"{j+1}. {s}" for j, s in enumerate(unique_sents))
prompt = (
"You are a creative assistant. Given the question below and the original answer, propose a plausible alternative answer that is **different** from the original but still reasonable. "
"Then rewrite the provided sentences to support your alternative answer. When rewriting each sentence, modify only the parts necessary to support the alternative answer. "
"Parts unrelated to the answer must keep their original meaning. Be sure that the modified evidence sentences are sufficient to answer the original question. "
"Output must be strictly in the specified JSON format, with no additional text.\n"
'{\n'
' "answer": "<your alternative answer here, just provide the answer phrase, no need for complete sentence>",\n'
' "revised": [\n'
' "<rewritten sentence 1>",\n'
' "<rewritten sentence 2>",\n'
' ...\n'
' ]\n'
'}\n\n'
f"Question:\n{question}\n\n"
f"Original answer:\n{correct_answer}\n\n"
f"Sentences to rewrite:\n{numbered}"
)
print(f"[Alternative Answer] Generating prompt: {prompt}")
rsp = client.chat.completions.create(
model="gpt-4o", temperature=0.7,
messages=[{"role":"user","content":prompt}]
)
js = extract_json_from_gpt_response(rsp.choices[0].message.content)
if not js:
print("[Alternative Answer] Failed to parse JSON")
return {"context": original_context, "answer": "Failed to generate alternative"}
revised = js["revised"] # List[str]
alternative = js["answer"] # Alternative answer
# Create new context
new_ctx = original_context
for old, new in zip(unique_sents, revised):
new_ctx = new_ctx.replace(old, new)
return {"context": new_ctx, "answer": alternative}
def main():
# Parse command line arguments
parser = argparse.ArgumentParser(description="LastingBench random alternative answer generation")
parser.add_argument("--output", "-o", type=str, default="output_random.jsonl",
help="Output JSONL file path (default: output_random.jsonl)")
parser.add_argument("--dataset_repo", type=str, default="THUDM/LongBench",
help="Dataset repository name (default: THUDM/LongBench)")
parser.add_argument("--dataset_subset", type=str, default="hotpotqa",
help="Dataset subset name (default: hotpotqa)")
parser.add_argument("--split", type=str, default="test",
help="Dataset split (default: test)")
parser.add_argument("--start_idx", type=int, default=0,
help="Starting index for processing (default: 0)")
parser.add_argument("--max_samples", type=int, default=-1,
help="Maximum number of samples to process (-1 for all, default: -1)")
args = parser.parse_args()
out_file = args.output
# Load dataset
longbench = load_dataset(args.dataset_repo, args.dataset_subset)[args.split]
print(f"Output file: {out_file}")
print(f"Dataset: {args.dataset_repo}/{args.dataset_subset}[{args.split}]")
print(f"Total samples: {len(longbench)}")
count = 0
# Determine processing range
start_idx = args.start_idx
end_idx = len(longbench) if args.max_samples == -1 else min(start_idx + args.max_samples, len(longbench))
print(f"Processing samples from index {start_idx} to {end_idx-1}")
for idx in range(start_idx, end_idx):
example = longbench[idx]
question = example['input']
print(f"Question: {question}")
context = example['context']
correct_answer = example['answers'][0]
print(f"Processing example {idx + 1}:")
print(f"Correct Answer: {correct_answer}")
# Build prompts
prompt_with_context = build_prompt(context, question)
# Get answers using transformers pipelines
answer_with_context = get_answer_with_retry('deepseek-r1', prompt_with_context)
# Extract content after "Answer:" from answer_with_context
answer_with_context_simple = (
answer_with_context
.split("Answer:", 1)[-1] # First keep the part after Answer:
.split("Step-by-step Reasoning", 1)[0] # Then cut before Step-by-step Reasoning
.strip()
)
print(f"Answer with context: {answer_with_context_simple}")
result = judge_answer_with_api(question, correct_answer, answer_with_context_simple)
print(f"Answer judge result: {result}")
if not result:
continue
answer_with_context = remove_think_tags(answer_with_context or "")
evidence = extract_all_bullet_passages(answer_with_context)
page_contents = []
if evidence:
count += 1
for ev in evidence:
snippet = ev['snippet']
result = retriveDoc(context, snippet)
# result["context"] is a set of Document objects
page_contents += [doc.page_content for doc in result]
unique_page_contents = list(dict.fromkeys(page_contents))
aggregated_content = "\n".join(unique_page_contents)
prompt_final = (
f"Please answer the question based on the context.\nContext: {aggregated_content}.\n Question: {question}.\n"
f"Please only provide your answer. "
f"Your Answer:"
)
final_answer = get_transformers_answer(prompt_final, tokenizer1, model1)
if judge_answer_with_api(question, correct_answer, final_answer):
print("correct")
else:
print("incorrect")
result_query = retriveDoc(context, question)
page_contents += [doc.page_content for doc in result_query]
unique_page_contents = list(dict.fromkeys(page_contents))
# Generate random alternative answer instead of selecting the highest ppl answer
alternative = asyncio.run(
random_alternative_answer(
question,
context,
unique_page_contents,
correct_answer
)
)
record = {
"question": question,
"answer": alternative["answer"],
"context": alternative["context"]
}
# Append one line of JSON each loop
with open(out_file, "a", encoding="utf-8") as fout:
fout.write(json.dumps(record, ensure_ascii=False) + "\n")
if __name__ == "__main__":
main() |