File size: 16,892 Bytes
b1e25b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
import json
import time
import re
import os
import argparse
from datasets import load_dataset
from nltk.tokenize import sent_tokenize
from utils.util import retriveDoc,compute_best_sentence_f1
from openai import OpenAI
import asyncio, json, torch, math
from typing import List, Tuple
# Hugging Face transformers related
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from utils.metrics import qa_f1_score
from utils.llmjudge import judge_answer_with_api


client = OpenAI(
    base_url=os.environ.get("OPENAI_BASE_URL"),
    api_key=os.environ.get("OPENAI_API_KEY")
)
# Load models using transformers

tokenizer1 = AutoTokenizer.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True)
model1 = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-14B-Chat", trust_remote_code=True,device_map="cuda:0",torch_dtype=torch.bfloat16)


tok_qwen = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True)
model_qwen = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct", trust_remote_code=True,
    device_map="cuda:1",torch_dtype=torch.bfloat16
).eval()

def get_transformers_answer(prompt, tokenizer, model, max_new_tokens=100, temperature=0.7, top_p=0.9, retries=3, delay=5):
    """
    Use transformers model.generate method for inference with retry mechanism,
    use chat template to format input, and strip the input prompt part through token-level slicing,
    return the newly generated text.
    """
    import time
    for attempt in range(retries):
        try:
            # Convert original prompt to message format
            messages = [{"role": "user", "content": prompt}]
            
            # Try to use chat template to format input
            try:
                formatted_prompt = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
            except Exception as e:
                print(f"Unable to apply chat template: {e}, falling back to basic text input")
                formatted_prompt = prompt  # Fall back to original prompt as input
            
            # Encode formatted prompt as model input tensor
            model_inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
            
            # Call generate, the generated id sequence contains both prompt and subsequent generated text
            generated_ids = model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p
            )
            
            # Calculate the token count corresponding to the prompt
            input_length = model_inputs.input_ids.shape[1]
            
            # Strip the prompt part from the front of the output, keeping only the newly added part
            output_ids = generated_ids[0][input_length:]
            
            # Decode generated text
            answer = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
            return answer
        except Exception as e:
            print(f"Error on attempt {attempt + 1}: {e}")
            if attempt < retries - 1:
                print(f"Retrying in {delay} seconds...")
                time.sleep(delay)
            else:
                print("Max retries reached, skipping this request.")
                return None

def truncate_answer(answer):
    """Truncate answer, only take the part before the first period"""
    return answer.split('.')[0].strip() if answer else "No answer"

def write_to_log(filename, data):
    """Write data to log file"""
    with open(filename, 'a', encoding='utf-8') as file:
        file.write(data + '\n')

def remove_think_tags(text: str) -> str:
    """Remove all <think> ... </think> blocks"""
    return re.sub(r'<think>(.*?)</think>', '', text, flags=re.DOTALL).strip()

def build_prompt(context: str, question: str) -> str:
    prompt = (
        f"Answer the question based on the given passages. The following are the passages:\n"
        f"{context}\n"
        f"Answer the question based on the given passages.\n"
        f"Question: {question}.\n"
        f"Answer:\n"
        f"Please first provide your answer in the format of Answer:[Your answer]. Then provide your reasoning process step-by-step.(Only include explicit clues) "
        f"At the end of each reasoning step, include a new line that specifies the key information or reference content used in that step. "
        f"Please ensure that the [reference content] you include is the complete original sentence or consecutive sentences from the text. Please do not change the punctuation.  Do not use ellipses inside the sentence. "
        f"Follow this format:\n"
        f"Answer: [Your answer]\n"
        f"Step-by-step Reasoning:\n"
        f"1. [Reasoning step 1]\n"
        f"[replaced by your reference content]\n"
        f"2. [Reasoning step 2]\n"
        f"[replaced by your reference content]\n"
    )
    return prompt

def extract_final_bullet_passage(answer_text: str):
    reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
    reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
    if not reasoning_match:
        return None, None

    reasoning_text = reasoning_match.group(1).strip()
    bullet_pattern = r"(?m)^(\d+\.\s.*?)(?=(?:\n\d+\.\s)|\Z)"
    bullets = re.findall(bullet_pattern, reasoning_text, flags=re.DOTALL)
    if not bullets:
        print("No bullet blocks found.")
        return None, None

    passage_pattern = re.compile(
        r'(?i)(?:\*\*)?passage\s+(\d+)(?:\*\*)?\s*:\s*("([^"]*)"|(.+?))(?=\Z|\n\s*\n|$)',
        flags=re.DOTALL
    )
    
    for bullet in reversed(bullets):
        matches = passage_pattern.findall(bullet)
        if matches:
            last_match = matches[-1]
            passage_number = last_match[0]
            quoted_snippet = last_match[2]
            non_quoted_snippet = last_match[3]
            snippet = non_quoted_snippet.strip() if non_quoted_snippet.strip() else quoted_snippet.strip()
            return passage_number, snippet

    return None, None

def extract_all_bullet_passages(answer_text: str):
    reasoning_pattern = r"Step-by-step Reasoning:\s*(.*)"
    reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
    if not reasoning_match:
        return []

    reasoning_text = reasoning_match.group(1).strip()
    bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
    bullets = bullet_pattern.findall(reasoning_text)
    if not bullets:
        return []

    results = []
    for bullet_index, bullet_text in enumerate(bullets, start=1):
        results.append({
            'bullet_index': bullet_index,
            'snippet': bullet_text.strip()
        })
    print(results)
    return results

def extract_evidence(answer_text: str):
    reasoning_pattern = r"(?i)Evidence\s*(.*)"
    reasoning_match = re.search(reasoning_pattern, answer_text, flags=re.DOTALL)
    if not reasoning_match:
        return []

    reasoning_text = reasoning_match.group(1).strip()

    # Extract all bullet segments
    bullet_pattern = re.compile(r"^(\d+\.\s.*?)(?=^\d+\.\s|\Z)", re.MULTILINE | re.DOTALL)
    bullets = bullet_pattern.findall(reasoning_text)
    if not bullets:
        return []

    # Find the index of the first bullet starting with 1.
    start_index = -1
    for i, bullet in enumerate(bullets):
        if bullet.strip().startswith("1."):
            start_index = i
            break

    if start_index == -1:
        return []  # No valid starting bullet

    # Only keep the part starting from the first valid bullet
    bullets = bullets[start_index:]

    results = []
    for bullet_index, bullet_text in enumerate(bullets, start=1):
        results.append({
            'bullet_index': bullet_index,
            'snippet': bullet_text.strip()
        })
    return results


def get_answer_with_retry(model, prompt, retries=3, delay=5):
    """Call the model to get the answer based on the prompt, with retry on failure."""
    for attempt in range(retries):
        try:
            completion = client.chat.completions.create(
                model=model,
                messages=[{'role': 'user', 'content': prompt}]
            )
            return completion.choices[0].message.content.strip()
        except Exception as e:
            print(f"Error on attempt {attempt + 1}: {e}")
            if attempt < retries - 1:
                print(f"Retrying in {delay} seconds...")
                time.sleep(delay)
            else:
                print("Max retries reached, skipping this request.")
                return None

def extract_json_from_gpt_response(text: str) -> dict | None:
    """
    Finds the first JSON block inside ```json ... ``` or ``` … ``` and returns it as a dict.
    """
    # Try to find a ```json … ``` block first
    m = re.search(r"```json\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
    if not m:
        # Fallback: any ``` … ``` block that looks like JSON
        m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, flags=re.DOTALL)
    if not m:
        # Lastly, maybe the model just spit raw JSON without fences
        m = re.search(r"(\{.*?\})", text, flags=re.DOTALL)
    if not m:
        return None

    json_str = m.group(1)
    try:
        return json.loads(json_str)
    except json.JSONDecodeError:
        # clean up trailing commas, etc.
        cleaned = re.sub(r",\s*([\]}])", r"\1", json_str)
        try:
            return json.loads(cleaned)
        except json.JSONDecodeError:
            return None

async def random_alternative_answer(
    question: str,
    original_context: str,
    unique_sents: List[str],
    correct_answer: str
) -> dict:
    """Generate random alternative answer and modified evidence"""
    
    # Construct GPT-4o prompt
    numbered = "\n\n".join(f"{j+1}. {s}" for j, s in enumerate(unique_sents))
    prompt = (
        "You are a creative assistant. Given the question below and the original answer, propose a plausible alternative answer that is **different** from the original but still reasonable. "
        "Then rewrite the provided sentences to support your alternative answer. When rewriting each sentence, modify only the parts necessary to support the alternative answer. "
        "Parts unrelated to the answer must keep their original meaning. Be sure that the modified evidence sentences are sufficient to answer the original question. "
        "Output must be strictly in the specified JSON format, with no additional text.\n"
        '{\n'
        '  "answer": "<your alternative answer here, just provide the answer phrase, no need for complete sentence>",\n'
        '  "revised": [\n'
        '    "<rewritten sentence 1>",\n'
        '    "<rewritten sentence 2>",\n'
        '    ...\n'
        '  ]\n'
        '}\n\n'
        f"Question:\n{question}\n\n"
        f"Original answer:\n{correct_answer}\n\n"
        f"Sentences to rewrite:\n{numbered}"
    )
    
    print(f"[Alternative Answer] Generating prompt: {prompt}")
    
    rsp = client.chat.completions.create(
        model="gpt-4o", temperature=0.7,
        messages=[{"role":"user","content":prompt}]
    )
    
    js = extract_json_from_gpt_response(rsp.choices[0].message.content)
    if not js:
        print("[Alternative Answer] Failed to parse JSON")
        return {"context": original_context, "answer": "Failed to generate alternative"}
        
    revised = js["revised"]     # List[str]
    alternative = js["answer"]  # Alternative answer
    
    # Create new context
    new_ctx = original_context
    for old, new in zip(unique_sents, revised):
        new_ctx = new_ctx.replace(old, new)
    
    return {"context": new_ctx, "answer": alternative}

def main():
    # Parse command line arguments
    parser = argparse.ArgumentParser(description="LastingBench random alternative answer generation")
    parser.add_argument("--output", "-o", type=str, default="output_random.jsonl", 
                       help="Output JSONL file path (default: output_random.jsonl)")
    parser.add_argument("--dataset_repo", type=str, default="THUDM/LongBench",
                       help="Dataset repository name (default: THUDM/LongBench)")
    parser.add_argument("--dataset_subset", type=str, default="hotpotqa",
                       help="Dataset subset name (default: hotpotqa)")
    parser.add_argument("--split", type=str, default="test",
                       help="Dataset split (default: test)")
    parser.add_argument("--start_idx", type=int, default=0,
                       help="Starting index for processing (default: 0)")
    parser.add_argument("--max_samples", type=int, default=-1,
                       help="Maximum number of samples to process (-1 for all, default: -1)")
    
    args = parser.parse_args()
    
    out_file = args.output
    # Load dataset
    longbench = load_dataset(args.dataset_repo, args.dataset_subset)[args.split]
    
    print(f"Output file: {out_file}")
    print(f"Dataset: {args.dataset_repo}/{args.dataset_subset}[{args.split}]")
    print(f"Total samples: {len(longbench)}")
    
    count = 0
    
    # Determine processing range
    start_idx = args.start_idx
    end_idx = len(longbench) if args.max_samples == -1 else min(start_idx + args.max_samples, len(longbench))
    
    print(f"Processing samples from index {start_idx} to {end_idx-1}")
    
    for idx in range(start_idx, end_idx):
        example = longbench[idx]
        question = example['input']
        print(f"Question: {question}")
        context = example['context']
        correct_answer = example['answers'][0]

        print(f"Processing example {idx + 1}:")
        print(f"Correct Answer: {correct_answer}")

        # Build prompts
        prompt_with_context = build_prompt(context, question)

        # Get answers using transformers pipelines
        answer_with_context = get_answer_with_retry('deepseek-r1', prompt_with_context) 
        
        # Extract content after "Answer:" from answer_with_context
        answer_with_context_simple = (
            answer_with_context
            .split("Answer:", 1)[-1]          # First keep the part after Answer:
            .split("Step-by-step Reasoning", 1)[0]  # Then cut before Step-by-step Reasoning
            .strip()
        )
        
        print(f"Answer with context: {answer_with_context_simple}") 
        result = judge_answer_with_api(question, correct_answer, answer_with_context_simple)
        print(f"Answer judge result: {result}")
        
        if not result:
            continue

        answer_with_context = remove_think_tags(answer_with_context or "")
        evidence = extract_all_bullet_passages(answer_with_context)

        page_contents = []
        if evidence:
            count += 1
            for ev in evidence:
                snippet = ev['snippet']
                result = retriveDoc(context, snippet)
                # result["context"] is a set of Document objects
                page_contents += [doc.page_content for doc in result]
            
            unique_page_contents = list(dict.fromkeys(page_contents))
            aggregated_content = "\n".join(unique_page_contents)
            
            prompt_final = (
                f"Please answer the question based on the context.\nContext: {aggregated_content}.\n Question: {question}.\n"
                f"Please only provide your answer. "
                f"Your Answer:"
            )
            
            final_answer = get_transformers_answer(prompt_final, tokenizer1, model1)
            
            if judge_answer_with_api(question, correct_answer, final_answer):
                print("correct")
            else:
                print("incorrect")
                result_query = retriveDoc(context, question)
                page_contents += [doc.page_content for doc in result_query]
                
            unique_page_contents = list(dict.fromkeys(page_contents))
            
            # Generate random alternative answer instead of selecting the highest ppl answer
            alternative = asyncio.run(
                random_alternative_answer(
                    question,
                    context,
                    unique_page_contents,
                    correct_answer
                )
            )
            
            record = {
                "question": question,
                "answer": alternative["answer"],
                "context": alternative["context"]
            }

            # Append one line of JSON each loop
            with open(out_file, "a", encoding="utf-8") as fout:
                fout.write(json.dumps(record, ensure_ascii=False) + "\n")

if __name__ == "__main__":
    main()