|
|
import argparse
|
|
|
from ..rag_pipeline import multichoice_qa_prompt
|
|
|
from ..rag_pipeline import ChatAssistant
|
|
|
from ..utils import paralelize, load_qa_dataset, load_prepared_retrieve_docs
|
|
|
|
|
|
from datetime import datetime
|
|
|
from typing import List, Optional
|
|
|
from langchain.schema import Document
|
|
|
|
|
|
def get_answer_from_response(llm_response: str) -> chr:
|
|
|
"""
|
|
|
Get the answer from the LLM response.
|
|
|
"""
|
|
|
return llm_response[llm_response.lower().rfind("the answer is ") + 14]
|
|
|
|
|
|
def build_multichoice_qa_prompt(question: str, options: str, document: Optional[List[Document]]) -> str:
|
|
|
"""
|
|
|
Build the prompt for the multichoice QA task.
|
|
|
"""
|
|
|
if document is not None:
|
|
|
document = '\n'.join([f"Document {i+1}:\n" + doc.page_content for i,doc in enumerate(document)])
|
|
|
|
|
|
return multichoice_qa_prompt.format(question=question, options=options, document=document)
|
|
|
|
|
|
def process_question(question, prompt, answer, id, args, llm):
|
|
|
llm_response = ""
|
|
|
for j in range(args.retries):
|
|
|
try:
|
|
|
llm_response = llm.get_response("", prompt)
|
|
|
ans = get_answer_from_response(llm_response)
|
|
|
if ans in ["A", "B", "C", "D", "E"]:
|
|
|
with open("log.txt", "a", encoding="utf-8") as f:
|
|
|
f.write(f"ID: {id}\n")
|
|
|
f.write(prompt)
|
|
|
f.write(f"LLM Response:\n{llm_response}\n")
|
|
|
f.write(f"Answer: {answer} {ans}\n\n")
|
|
|
break
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
ans = "#"
|
|
|
with open("log_score.txt", "a", encoding="utf-8") as f:
|
|
|
f.write("1" if ans == answer else "0")
|
|
|
return 1 if ans == answer else 0
|
|
|
|
|
|
def evaluate_qa(questions, prompts, answers, ids, args, llm):
|
|
|
import concurrent.futures
|
|
|
from tqdm import tqdm
|
|
|
correct = 0
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=args.max_workers) as executor:
|
|
|
futures = [executor.submit(process_question, questions[i], prompts[i], answers[i], ids[i], args, llm) for i in range(len(questions))]
|
|
|
for future in tqdm(concurrent.futures.as_completed(futures), total=len(questions)):
|
|
|
correct += future.result()
|
|
|
return correct / len(questions)
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
ids, questions, options, answers = load_qa_dataset(args.qa_file)
|
|
|
|
|
|
if ids is None:
|
|
|
raise ValueError(f"No id field in {args.qa_file}.")
|
|
|
|
|
|
if args.num_docs > 0:
|
|
|
if args.prepared_retrieve_docs_path is not None:
|
|
|
documents = load_prepared_retrieve_docs(args.prepared_retrieve_docs_path)
|
|
|
docs = [d[:args.num_docs] for i,d in enumerate(documents)]
|
|
|
else:
|
|
|
raise ValueError(f"No prepared retrieve docs found.")
|
|
|
else:
|
|
|
docs = [None]*len(questions)
|
|
|
|
|
|
prompts = [build_multichoice_qa_prompt(questions[i], options[i], docs[i]) for i in range(len(questions))]
|
|
|
|
|
|
|
|
|
llm = ChatAssistant(args.model_name, args.provider)
|
|
|
|
|
|
with open("log_score.txt", "a", encoding="utf-8") as f:
|
|
|
f.write(f"\n{datetime.now()} {args}\n")
|
|
|
|
|
|
acc = evaluate_qa(questions, prompts, answers, ids, args, llm)
|
|
|
print(f"Accuracy: {acc}")
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
parser.add_argument("--qa_file", type=str, default="dataset/QA Data/MedAB/MedABv2.jsonl")
|
|
|
parser.add_argument("--prepared_retrieve_docs_path", type=str, default="dataset/QA Data/MedAB/prepared_retrieve_docs_full.pkl")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--model_name", type=str, default="gemma3:12b")
|
|
|
parser.add_argument("--provider", type=str, default="ollama")
|
|
|
parser.add_argument("--max_workers", type=int, default=1)
|
|
|
parser.add_argument("--num_docs", type=int, default=0)
|
|
|
parser.add_argument("--retries", type=int, default=4)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--dataset_path", type=str)
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
print(f"Log:{args}")
|
|
|
|
|
|
main(args) |