|
|
from typing import List
|
|
|
from chatbot import Chatbot
|
|
|
from ragdb import TextRAG
|
|
|
from prompt import *
|
|
|
|
|
|
class Evaluator:
|
|
|
CHATBOT = None
|
|
|
RETRIEVER = None
|
|
|
K = 4
|
|
|
THRESHOLD = 0.5
|
|
|
QA_DIR = None
|
|
|
SEARCH_TYPE = "similarity"
|
|
|
CORRECT = 0
|
|
|
IS_LOG = True
|
|
|
SUPPRESS_ERROR = False
|
|
|
|
|
|
def __init__(self,
|
|
|
chatbot: Chatbot,
|
|
|
qa_dir:str = None,
|
|
|
rag: TextRAG = None,
|
|
|
search_type:str = "bm25",
|
|
|
k: int = 4,
|
|
|
threshold: float = 0.5,
|
|
|
log: bool = True):
|
|
|
self.CHATBOT = chatbot
|
|
|
self.K = k
|
|
|
self.THRESHOLD = threshold
|
|
|
self.QA_DIR = qa_dir
|
|
|
self.SEARCH_TYPE = search_type
|
|
|
if (rag is not None):
|
|
|
self.RETRIEVER = rag
|
|
|
self.CORRECT = 0
|
|
|
self.IS_LOG = log
|
|
|
|
|
|
def _multichoice_checker(self, id: str,
|
|
|
question: str,
|
|
|
choices: List[str],
|
|
|
answer: str,
|
|
|
rag_query: str = None) -> int:
|
|
|
options = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
|
|
|
document = ""
|
|
|
results = None
|
|
|
if (self.RETRIEVER is not None and self.K != 0):
|
|
|
results = self.RETRIEVER.search((question + options) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE)
|
|
|
if results is not None:
|
|
|
for i,doc in enumerate(results):
|
|
|
if doc.page_content.find("QUESTION") != -1:
|
|
|
continue
|
|
|
document += f"\nDocument {i+1}:\n {doc.page_content}"
|
|
|
response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document),
|
|
|
suppress_error=self.SUPPRESS_ERROR)
|
|
|
res = response[response.lower().rfind("the answer is ") + 14]
|
|
|
|
|
|
for i in range(3):
|
|
|
if (res not in ["A", "B", "C", "D", "E"]):
|
|
|
response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document), suppress_error=self.SUPPRESS_ERROR)
|
|
|
res = response[response.lower().rfind("the answer is ") + 14]
|
|
|
|
|
|
|
|
|
|
|
|
if (self.IS_LOG):
|
|
|
with open("log.txt", "a", encoding="utf-8") as f:
|
|
|
f.write(f"ID: {id}\nQuestion: {question}\nChoices: {options}\nAnswer: {answer} {res}\nResponse: {response}\n")
|
|
|
f.write(f"Document: {document}\n")
|
|
|
with open("log_score.txt", "a", encoding="utf-8") as f:
|
|
|
f.write('1' if (res == answer) else '0')
|
|
|
if (res != answer):
|
|
|
with open("log_wrong.txt", "a", encoding="utf-8") as f:
|
|
|
f.write(f"ID: {id}\n")
|
|
|
return res == answer
|
|
|
|
|
|
def _answer_checker(self, id: str,
|
|
|
question: str,
|
|
|
answer: str = None,
|
|
|
rag_query: str = None) -> int:
|
|
|
document = ""
|
|
|
results = None
|
|
|
if (self.K > 0 and self.RETRIEVER is not None):
|
|
|
results = self.RETRIEVER.search((question) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE)
|
|
|
|
|
|
if results is not None:
|
|
|
for i,doc in enumerate(results):
|
|
|
if doc.page_content.find("QUESTION") != -1:
|
|
|
continue
|
|
|
document += f"\nDocument {i+1}:\n {doc.page_content}"
|
|
|
|
|
|
response = self.CHATBOT.chat(answer_prompt.format(question=question, document=document),
|
|
|
suppress_error=self.SUPPRESS_ERROR)
|
|
|
with open("response.txt", "a", encoding="utf-8") as f:
|
|
|
f.write(f"RESPONSE: {response}\n\n")
|
|
|
|
|
|
if (self.IS_LOG):
|
|
|
with open("log.txt", "a", encoding="utf-8") as f:
|
|
|
f.write(f"ID: {id}\nQuestion: {question}\nAnswer: {answer}\nResponse: {response}\n")
|
|
|
f.write(f"Document: {document}\n")
|
|
|
return 0
|
|
|
|
|
|
|
|
|
def eval(self, ids: List[str],
|
|
|
questions: List[str],
|
|
|
choices: List[List[str]] = None,
|
|
|
answers: List[str] = None,
|
|
|
rag_queries: List[str] = None,
|
|
|
max_workers: int = 4,
|
|
|
suppress_error: bool = False,
|
|
|
k:int = 0,
|
|
|
threshold:float = 0.5) -> float:
|
|
|
self.SUPPRESS_ERROR = suppress_error
|
|
|
self.THRESHOLD = threshold
|
|
|
self.K = k
|
|
|
|
|
|
if (self.IS_LOG):
|
|
|
with open("log_score.txt", "a", encoding="utf-8") as f:
|
|
|
f.write('\n')
|
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
def check_qa_answer(id, q, c, a, rq):
|
|
|
return self._multichoice_checker(id, q, c, a, rq)
|
|
|
|
|
|
def check_answer(id, q, rq):
|
|
|
return self._answer_checker(id, q, rq)
|
|
|
|
|
|
if choices != None:
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
results = list(tqdm(executor.map(check_qa_answer, ids, questions, choices, answers, rag_queries), total=len(questions)))
|
|
|
else:
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
|
results = list(tqdm(executor.map(check_answer, ids, questions, rag_queries), total=len(questions)))
|
|
|
|
|
|
self.CORRECT = sum(results)
|
|
|
return self.CORRECT / len(questions)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
qa_dir = r"C:\Users\vuvan\Desktop\An_Plaza\ViMedLLM\Vietnamese-Medical-LLM\dataset\QA Data\random.jsonl"
|
|
|
evaluator = Evaluator(qa_dir = qa_dir,
|
|
|
chatbot = Chatbot(model_name="mistral", max_token=1),
|
|
|
rag = None,
|
|
|
search_type = "similarity",
|
|
|
log = True)
|
|
|
|
|
|
import json
|
|
|
with open(qa_dir, 'r', encoding="utf-8") as file:
|
|
|
data = [json.loads(line) for line in file]
|
|
|
questions = [item['question'] for item in data]
|
|
|
|
|
|
|
|
|
|
|
|
evaluator.eval(questions, max_workers=10, suppress_error=True, k=0) |