File size: 6,964 Bytes
b1e25b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import json
import time
import os
import argparse
from datasets import load_dataset
from openai import OpenAI
from tqdm import tqdm
from utils.metrics import qa_f1_score, qa_em_score # Import evaluation functions
# Configure OpenAI API
client = OpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
base_url=os.environ.get("OPENAI_BASE_URL")
)
def get_openai_response(prompt, model="gpt-4o", retries=3, delay=2):
"""Call OpenAI API to get response with retry mechanism"""
for attempt in range(retries):
try:
completion = client.chat.completions.create(
model=model,
messages=[{'role': 'user', 'content': prompt}],
max_tokens=100
)
return completion.choices[0].message.content.strip()
except Exception as e:
print(f"Attempt {attempt + 1} failed: {e}")
if attempt < retries - 1:
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
else:
print("Max retries reached. Skipping this request.")
return "Failed to get response"
def rephrase_question_api(question, model_name, rephrase_type="opposite"):
"""Use OpenAI API to rephrase question (English prompt)"""
if rephrase_type == "opposite":
prompt = f"""Please rephrase the following question to have the exact opposite meaning.
Question: {question}
Return only the rephrased question with the opposite meaning, without any explanations or other content."""
elif rephrase_type == "similar":
prompt = f"""Please rephrase the following question to be synonymous, maintaining the original meaning but using different wording:
Question: {question}
Return only the rephrased question, without any explanations or other content."""
else:
raise ValueError(f"Invalid rephrase_type: {rephrase_type}. Must be 'opposite' or 'similar'.")
return get_openai_response(prompt, model=model_name)
def answer_question_with_context_api(question, context, model_name, max_tokens_for_answer=30):
"""Use OpenAI API to answer question based on context (English prompt)"""
prompt = f"""Please answer the question based on the following context:
Context:
{context}
Question: {question}
Only output the answer, no any other text. If the answer is not in the context, please say "I don't know".
Answer:"""
try:
completion = client.chat.completions.create(
model=model_name,
messages=[{'role': 'user', 'content': prompt}],
max_tokens=max_tokens_for_answer
)
return completion.choices[0].message.content.strip()
except Exception as e:
print(f"Answer generation failed for model {model_name}: {e}")
return "Failed to get answer"
def main(args):
# Load dataset
print(f"Loading dataset {args.dataset_name}, subset {args.dataset_subset}...")
try:
dataset = load_dataset(args.dataset_name, args.dataset_subset)["test"]
print(f"Successfully loaded dataset with {len(dataset)} samples.")
except Exception as e:
print(f"Failed to load dataset: {e}")
return
em_match_count = 0 # Counter for EM matches
successfully_processed_samples = 0 # Counter for successfully processed samples
num_samples_to_process = len(dataset) if args.sample_count == -1 else min(args.sample_count, len(dataset))
print(f"Processing {num_samples_to_process} samples. Rephrasing with GPT-4o (opposite meaning). Answering with {args.model_name} (max 30 tokens for answer)...")
for i in tqdm(range(num_samples_to_process), desc="Processing samples"):
example = dataset[i]
original_question = example['input']
context = example['context']
ground_truth_answers = example['answers']
print(f"Original question: {original_question}")
# Use API to rephrase question, fixed using gpt-4o
rephrased_question = rephrase_question_api(original_question, "gpt-4o", args.rephrase_type)
print(f"Rephrased question (opposite): {rephrased_question}")
if rephrased_question == "Failed to get response" or rephrased_question == "Failed to rephrase question": # Broader check
print(f"Skipping sample {i+1} due to rephrasing failure.")
continue
# Use rephrased question and context to get answer, using args.model_name, answer length limited to 30 tokens
rephrased_answer = answer_question_with_context_api(rephrased_question, context, args.model_name, max_tokens_for_answer=30)
# print(f"Answer to rephrased question: {rephrased_answer}")
if rephrased_answer == "Failed to get answer":
print(f"Skipping sample {i+1} due to answer generation failure.")
continue
if not ground_truth_answers:
print(f"Skipping sample {i+1} due to missing ground truth answers.")
continue
successfully_processed_samples += 1
sample_had_em_match = False
for gt_ans in ground_truth_answers:
em = qa_em_score(rephrased_answer, gt_ans)
if em > 0: # EM is 1.0 for a match
sample_had_em_match = True
break
if sample_had_em_match:
em_match_count += 1
# print(f"Sample EM with original GT: {1 if sample_had_em_match else 0}")
if successfully_processed_samples > 0:
print(f"\n--- Evaluation Summary ---")
print(f"Answering Model : {args.model_name}")
print(f"Dataset : {args.dataset_name} ({args.dataset_subset})")
print(f"Successfully Processed Samples for Evaluation: {successfully_processed_samples}")
print(f"Max Answer Tokens: 30")
print(f"Count of EM with original ground truth (after rephrase): {em_match_count}")
else:
print("\nNo samples were processed adequately to provide an evaluation summary.")
print("Processing complete!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Rephrase questions to opposite meaning with GPT-4o, answer with specified OpenAI model, then count EM against original GT.")
parser.add_argument("--model_name", type=str, default="gpt-4o", help="Name of the OpenAI model to use for Answering.")
parser.add_argument("--dataset_name", type=str, default="THUDM/LongBench", help="Name of the Hugging Face dataset.")
parser.add_argument("--dataset_subset", type=str, default="2wikimqa", help="Subset of the dataset.")
parser.add_argument("--sample_count", type=int, default=-1, help="Number of samples to process. -1 for all samples.")
parser.add_argument("--rephrase_type", type=str, default="opposite", choices=["opposite", "similar"], help="Type of rephrasing: 'opposite' for opposite meaning or 'similar' for similar meaning.")
args = parser.parse_args()
main(args) |