import json import time import re import os import argparse from datasets import load_dataset from nltk.tokenize import sent_tokenize from utils.util import retriveDoc,compute_best_sentence_f1 from openai import OpenAI import asyncio, json, torch, math from typing import List, Tuple # Hugging Face transformers related from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from utils.metrics import qa_f1_score from utils.llmjudge import judge_answer_with_api client = OpenAI( base_url=os.environ.get("OPENAI_BASE_URL"), api_key=os.environ.get("OPENAI_API_KEY") ) # Load models using transformers tokenizer1 = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True) model1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True,device_map="cuda:0",torch_dtype=torch.bfloat16) tok_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True) model_qwen = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True, device_map="cuda:1",torch_dtype=torch.bfloat16 ).eval() def get_transformers_answer(prompt, tokenizer, model, max_new_tokens=100, temperature=0.7, top_p=0.9, retries=3, delay=5): """ Use transformers model.generate method for inference with retry mechanism, use chat template to format input, and strip the input prompt part through token-level slicing, return the newly generated text. """ import time for attempt in range(retries): try: # Convert original prompt to message format messages = [{"role": "user", "content": prompt}] # Try to use chat template to format input try: formatted_prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) except Exception as e: print(f"Unable to apply chat template: {e}, falling back to basic text input") formatted_prompt = prompt # Fall back to original prompt as input # Encode formatted prompt as model input tensor model_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device) # Call generate, the generated id sequence contains both prompt and subsequent generated text generated_ids = model.generate( **model_inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p ) # Calculate the token count corresponding to the prompt input_length = model_inputs.input_ids.shape[1] # Strip the prompt part from the front of the output, keeping only the newly added part output_ids = generated_ids[0][input_length:] # Decode generated text answer = tokenizer.decode(output_ids, skip_special_tokens=True).strip() return answer except Exception as e: print(f"Error on attempt {attempt + 1}: {e}") if attempt < retries - 1: print(f"Retrying in {delay} seconds...") time.sleep(delay) else: print("Max retries reached, skipping this request.") return None def truncate_answer(answer): """Truncate answer, only take the part before the first period""" return answer.split('.')[0].strip() if answer else "No answer" def write_to_log(filename, data): """Write data to log file""" with open(filename, 'a', encoding='utf-8') as file: file.write(data + '\n') def remove_think_tags(text: str) -> str: """Remove all ... blocks""" return re.sub(r'(.*?)', '', text, flags=re.DOTALL).strip() def build_prompt(context: str, question: str) -> str: prompt = ( f"Answer the question based on the given passages. The following are the passages:\n" f"{context}\n" f"Answer the question based on the given passages.\n" f"Question: {question}.\n" f"Answer:\n" f"Please first provide your answer in the format of Answer:[Your answer]. Then provide your reasoning process step-by-step.(Only include explicit clues) " f"At the end of each reasoning step, include a new line that specifies the key information or reference content used in that step. " f"Please ensure that the [reference content] you include is the complete original sentence or consecutive sentences from the text. Please do not change the punctuation. Do not use ellipses inside the sentence. " f"Follow this format:\n" f"Answer: [Your answer]\n" f"Step-by-step Reasoning:\n" f"1. [Reasoning step 1]\n" f"[replaced by your reference content]\n" f"2. [Reasoning step 2]\n" f"[replaced by your reference content]\n" ) return prompt def extract_final_bullet_passage(answer_text: str): reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)" reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL) if not reasoning_match: return None, None reasoning_text = reasoning_match.group(1).strip() bullet_pattern = r"(?m)^(\d+\.\s.*?)(?=(?:\n\d+\.\s)|\Z)" bullets = re.findall(bullet_pattern, reasoning_text, flags=re.DOTALL) if not bullets: print("No bullet blocks found.") return None, None passage_pattern = re.compile( r'(?i)(?:\*\*)?passage\s+(\d+)(?:\*\*)?\s*:\s*("([^"]*)"|(.+?))(?=\Z|\n\s*\n|$)', flags=re.DOTALL ) for bullet in reversed(bullets): matches = passage_pattern.findall(bullet) if matches: last_match = matches[-1] passage_number = last_match[0] quoted_snippet = last_match[2] non_quoted_snippet = last_match[3] snippet = non_quoted_snippet.strip() if non_quoted_snippet.strip() else quoted_snippet.strip() return passage_number, snippet return None, None def extract_all_bullet_passages(answer_text: str): reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)" reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL) if not reasoning_match: return [] reasoning_text = reasoning_match.group(1).strip() bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL) bullets = bullet_pattern.findall(reasoning_text) if not bullets: return [] results = [] for bullet_index, bullet_text in enumerate(bullets, start=1): results.append({ 'bullet_index': bullet_index, 'snippet': bullet_text.strip() }) print(results) return results def extract_evidence(answer_text: str): reasoning_pattern = r"(?i)Evidence\s*(.*)" reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL) if not reasoning_match: return [] reasoning_text = reasoning_match.group(1).strip() # Extract all bullet segments bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL) bullets = bullet_pattern.findall(reasoning_text) if not bullets: return [] # Find the index of the first bullet starting with 1. start_index = -1 for i, bullet in enumerate(bullets): if bullet.strip().startswith("1."): start_index = i break if start_index == -1: return [] # No valid starting bullet # Only keep the part starting from the first valid bullet bullets = bullets[start_index:] results = [] for bullet_index, bullet_text in enumerate(bullets, start=1): results.append({ 'bullet_index': bullet_index, 'snippet': bullet_text.strip() }) return results def get_answer_with_retry(model, prompt, retries=3, delay=5): """Call the model to get the answer based on the prompt, with retry on failure.""" for attempt in range(retries): try: completion = client.chat.completions.create( model=model, messages=[{'role': 'user', 'content': prompt}] ) return completion.choices[0].message.content.strip() except Exception as e: print(f"Error on attempt {attempt + 1}: {e}") if attempt < retries - 1: print(f"Retrying in {delay} seconds...") time.sleep(delay) else: print("Max retries reached, skipping this request.") return None def extract_json_from_gpt_response(text: str) -> dict | None: """ Finds the first JSON block inside ```json ... ``` or ``` … ``` and returns it as a dict. """ # Try to find a ```json … ``` block first m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.DOTALL) if not m: # Fallback: any ``` … ``` block that looks like JSON m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL) if not m: # Lastly, maybe the model just spit raw JSON without fences m = re.search(r"(\{.*?\})", text, flags=re.DOTALL) if not m: return None json_str = m.group(1) try: return json.loads(json_str) except json.JSONDecodeError: # clean up trailing commas, etc. cleaned = re.sub(r",\s*([\]}])", r"\1", json_str) try: return json.loads(cleaned) except json.JSONDecodeError: return None async def random_alternative_answer( question: str, original_context: str, unique_sents: List[str], correct_answer: str ) -> dict: """Generate random alternative answer and modified evidence""" # Construct GPT-4o prompt numbered = "\n\n".join(f"{j+1}. {s}" for j, s in enumerate(unique_sents)) prompt = ( "You are a creative assistant. Given the question below and the original answer, propose a plausible alternative answer that is **different** from the original but still reasonable. " "Then rewrite the provided sentences to support your alternative answer. When rewriting each sentence, modify only the parts necessary to support the alternative answer. " "Parts unrelated to the answer must keep their original meaning. Be sure that the modified evidence sentences are sufficient to answer the original question. " "Output must be strictly in the specified JSON format, with no additional text.\n" '{\n' ' "answer": "",\n' ' "revised": [\n' ' "",\n' ' "",\n' ' ...\n' ' ]\n' '}\n\n' f"Question:\n{question}\n\n" f"Original answer:\n{correct_answer}\n\n" f"Sentences to rewrite:\n{numbered}" ) print(f"[Alternative Answer] Generating prompt: {prompt}") rsp = client.chat.completions.create( model="gpt-4o", temperature=0.7, messages=[{"role":"user","content":prompt}] ) js = extract_json_from_gpt_response(rsp.choices[0].message.content) if not js: print("[Alternative Answer] Failed to parse JSON") return {"context": original_context, "answer": "Failed to generate alternative"} revised = js["revised"] # List[str] alternative = js["answer"] # Alternative answer # Create new context new_ctx = original_context for old, new in zip(unique_sents, revised): new_ctx = new_ctx.replace(old, new) return {"context": new_ctx, "answer": alternative} def main(): # Parse command line arguments parser = argparse.ArgumentParser(description="LastingBench random alternative answer generation") parser.add_argument("--output", "-o", type=str, default="output_random.jsonl", help="Output JSONL file path (default: output_random.jsonl)") parser.add_argument("--dataset_repo", type=str, default="THUDM/LongBench", help="Dataset repository name (default: THUDM/LongBench)") parser.add_argument("--dataset_subset", type=str, default="hotpotqa", help="Dataset subset name (default: hotpotqa)") parser.add_argument("--split", type=str, default="test", help="Dataset split (default: test)") parser.add_argument("--start_idx", type=int, default=0, help="Starting index for processing (default: 0)") parser.add_argument("--max_samples", type=int, default=-1, help="Maximum number of samples to process (-1 for all, default: -1)") args = parser.parse_args() out_file = args.output # Load dataset longbench = load_dataset(args.dataset_repo, args.dataset_subset)[args.split] print(f"Output file: {out_file}") print(f"Dataset: {args.dataset_repo}/{args.dataset_subset}[{args.split}]") print(f"Total samples: {len(longbench)}") count = 0 # Determine processing range start_idx = args.start_idx end_idx = len(longbench) if args.max_samples == -1 else min(start_idx + args.max_samples, len(longbench)) print(f"Processing samples from index {start_idx} to {end_idx-1}") for idx in range(start_idx, end_idx): example = longbench[idx] question = example['input'] print(f"Question: {question}") context = example['context'] correct_answer = example['answers'][0] print(f"Processing example {idx + 1}:") print(f"Correct Answer: {correct_answer}") # Build prompts prompt_with_context = build_prompt(context, question) # Get answers using transformers pipelines answer_with_context = get_answer_with_retry('deepseek-r1', prompt_with_context) # Extract content after "Answer:" from answer_with_context answer_with_context_simple = ( answer_with_context .split("Answer:", 1)[-1] # First keep the part after Answer: .split("Step-by-step Reasoning", 1)[0] # Then cut before Step-by-step Reasoning .strip() ) print(f"Answer with context: {answer_with_context_simple}") result = judge_answer_with_api(question, correct_answer, answer_with_context_simple) print(f"Answer judge result: {result}") if not result: continue answer_with_context = remove_think_tags(answer_with_context or "") evidence = extract_all_bullet_passages(answer_with_context) page_contents = [] if evidence: count += 1 for ev in evidence: snippet = ev['snippet'] result = retriveDoc(context, snippet) # result["context"] is a set of Document objects page_contents += [doc.page_content for doc in result] unique_page_contents = list(dict.fromkeys(page_contents)) aggregated_content = "\n".join(unique_page_contents) prompt_final = ( f"Please answer the question based on the context.\nContext: {aggregated_content}.\n Question: {question}.\n" f"Please only provide your answer. " f"Your Answer:" ) final_answer = get_transformers_answer(prompt_final, tokenizer1, model1) if judge_answer_with_api(question, correct_answer, final_answer): print("correct") else: print("incorrect") result_query = retriveDoc(context, question) page_contents += [doc.page_content for doc in result_query] unique_page_contents = list(dict.fromkeys(page_contents)) # Generate random alternative answer instead of selecting the highest ppl answer alternative = asyncio.run( random_alternative_answer( question, context, unique_page_contents, correct_answer ) ) record = { "question": question, "answer": alternative["answer"], "context": alternative["context"] } # Append one line of JSON each loop with open(out_file, "a", encoding="utf-8") as fout: fout.write(json.dumps(record, ensure_ascii=False) + "\n") if __name__ == "__main__": main()