vimedllm / notebook /An /new /evaluator.py
VuvanAn's picture
Upload folder using huggingface_hub
cc37925 verified
from typing import List
from chatbot import Chatbot
from ragdb import TextRAG
from prompt import *
class Evaluator:
CHATBOT = None
RETRIEVER = None
K = 4
THRESHOLD = 0.5
QA_DIR = None
SEARCH_TYPE = "similarity"
CORRECT = 0
IS_LOG = True
SUPPRESS_ERROR = False
def __init__(self,
chatbot: Chatbot,
qa_dir:str = None,
rag: TextRAG = None,
search_type:str = "bm25",
k: int = 4,
threshold: float = 0.5,
log: bool = True):
self.CHATBOT = chatbot
self.K = k
self.THRESHOLD = threshold
self.QA_DIR = qa_dir
self.SEARCH_TYPE = search_type
if (rag is not None):
self.RETRIEVER = rag
self.CORRECT = 0
self.IS_LOG = log
def _multichoice_checker(self, id: str,
question: str,
choices: List[str],
answer: str,
rag_query: str = None) -> int:
options = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
document = ""
results = None
if (self.RETRIEVER is not None and self.K != 0):
results = self.RETRIEVER.search((question + options) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE)
if results is not None:
for i,doc in enumerate(results):
if doc.page_content.find("QUESTION") != -1:
continue
document += f"\nDocument {i+1}:\n {doc.page_content}"
response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document),
suppress_error=self.SUPPRESS_ERROR)
res = response[response.lower().rfind("the answer is ") + 14]
# print(response)
for i in range(3):
if (res not in ["A", "B", "C", "D", "E"]):
response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document), suppress_error=self.SUPPRESS_ERROR)
res = response[response.lower().rfind("the answer is ") + 14]
# res = response.upper().find("THE ANSWER IS") + 13
if (self.IS_LOG):
with open("log.txt", "a", encoding="utf-8") as f:
f.write(f"ID: {id}\nQuestion: {question}\nChoices: {options}\nAnswer: {answer} {res}\nResponse: {response}\n")
f.write(f"Document: {document}\n")
with open("log_score.txt", "a", encoding="utf-8") as f:
f.write('1' if (res == answer) else '0')
if (res != answer):
with open("log_wrong.txt", "a", encoding="utf-8") as f:
f.write(f"ID: {id}\n")
return res == answer
def _answer_checker(self, id: str,
question: str,
answer: str = None,
rag_query: str = None) -> int:
document = ""
results = None
if (self.K > 0 and self.RETRIEVER is not None):
results = self.RETRIEVER.search((question) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE)
if results is not None:
for i,doc in enumerate(results):
if doc.page_content.find("QUESTION") != -1:
continue
document += f"\nDocument {i+1}:\n {doc.page_content}"
response = self.CHATBOT.chat(answer_prompt.format(question=question, document=document),
suppress_error=self.SUPPRESS_ERROR)
with open("response.txt", "a", encoding="utf-8") as f:
f.write(f"RESPONSE: {response}\n\n")
if (self.IS_LOG):
with open("log.txt", "a", encoding="utf-8") as f:
f.write(f"ID: {id}\nQuestion: {question}\nAnswer: {answer}\nResponse: {response}\n")
f.write(f"Document: {document}\n")
return 0
def eval(self, ids: List[str],
questions: List[str],
choices: List[List[str]] = None,
answers: List[str] = None,
rag_queries: List[str] = None,
max_workers: int = 4,
suppress_error: bool = False,
k:int = 0,
threshold:float = 0.5) -> float:
self.SUPPRESS_ERROR = suppress_error
self.THRESHOLD = threshold
self.K = k
if (self.IS_LOG):
with open("log_score.txt", "a", encoding="utf-8") as f:
f.write('\n')
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
def check_qa_answer(id, q, c, a, rq):
return self._multichoice_checker(id, q, c, a, rq)
def check_answer(id, q, rq):
return self._answer_checker(id, q, rq)
if choices != None:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(tqdm(executor.map(check_qa_answer, ids, questions, choices, answers, rag_queries), total=len(questions)))
else:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(tqdm(executor.map(check_answer, ids, questions, rag_queries), total=len(questions)))
self.CORRECT = sum(results)
return self.CORRECT / len(questions)
if __name__ == '__main__':
qa_dir = r"C:\Users\vuvan\Desktop\An_Plaza\ViMedLLM\Vietnamese-Medical-LLM\dataset\QA Data\random.jsonl"
evaluator = Evaluator(qa_dir = qa_dir,
chatbot = Chatbot(model_name="mistral", max_token=1),
rag = None,
search_type = "similarity",
log = True)
import json
with open(qa_dir, 'r', encoding="utf-8") as file:
data = [json.loads(line) for line in file]
questions = [item['question'] for item in data]
# answers = [item['answer'] for item in data]
# choices = [[item['A'], item['B'], item['C'], item['D'], item['E']] for item in data]
evaluator.eval(questions, max_workers=10, suppress_error=True, k=0)