from typing import List from chatbot import Chatbot from ragdb import TextRAG from prompt import * class Evaluator: CHATBOT = None RETRIEVER = None K = 4 THRESHOLD = 0.5 QA_DIR = None SEARCH_TYPE = "similarity" CORRECT = 0 IS_LOG = True SUPPRESS_ERROR = False def __init__(self, chatbot: Chatbot, qa_dir:str = None, rag: TextRAG = None, search_type:str = "bm25", k: int = 4, threshold: float = 0.5, log: bool = True): self.CHATBOT = chatbot self.K = k self.THRESHOLD = threshold self.QA_DIR = qa_dir self.SEARCH_TYPE = search_type if (rag is not None): self.RETRIEVER = rag self.CORRECT = 0 self.IS_LOG = log def _multichoice_checker(self, id: str, question: str, choices: List[str], answer: str, rag_query: str = None) -> int: options = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]) document = "" results = None if (self.RETRIEVER is not None and self.K != 0): results = self.RETRIEVER.search((question + options) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE) if results is not None: for i,doc in enumerate(results): if doc.page_content.find("QUESTION") != -1: continue document += f"\nDocument {i+1}:\n {doc.page_content}" response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document), suppress_error=self.SUPPRESS_ERROR) res = response[response.lower().rfind("the answer is ") + 14] # print(response) for i in range(3): if (res not in ["A", "B", "C", "D", "E"]): response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document), suppress_error=self.SUPPRESS_ERROR) res = response[response.lower().rfind("the answer is ") + 14] # res = response.upper().find("THE ANSWER IS") + 13 if (self.IS_LOG): with open("log.txt", "a", encoding="utf-8") as f: f.write(f"ID: {id}\nQuestion: {question}\nChoices: {options}\nAnswer: {answer} {res}\nResponse: {response}\n") f.write(f"Document: {document}\n") with open("log_score.txt", "a", encoding="utf-8") as f: f.write('1' if (res == answer) else '0') if (res != answer): with open("log_wrong.txt", "a", encoding="utf-8") as f: f.write(f"ID: {id}\n") return res == answer def _answer_checker(self, id: str, question: str, answer: str = None, rag_query: str = None) -> int: document = "" results = None if (self.K > 0 and self.RETRIEVER is not None): results = self.RETRIEVER.search((question) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE) if results is not None: for i,doc in enumerate(results): if doc.page_content.find("QUESTION") != -1: continue document += f"\nDocument {i+1}:\n {doc.page_content}" response = self.CHATBOT.chat(answer_prompt.format(question=question, document=document), suppress_error=self.SUPPRESS_ERROR) with open("response.txt", "a", encoding="utf-8") as f: f.write(f"RESPONSE: {response}\n\n") if (self.IS_LOG): with open("log.txt", "a", encoding="utf-8") as f: f.write(f"ID: {id}\nQuestion: {question}\nAnswer: {answer}\nResponse: {response}\n") f.write(f"Document: {document}\n") return 0 def eval(self, ids: List[str], questions: List[str], choices: List[List[str]] = None, answers: List[str] = None, rag_queries: List[str] = None, max_workers: int = 4, suppress_error: bool = False, k:int = 0, threshold:float = 0.5) -> float: self.SUPPRESS_ERROR = suppress_error self.THRESHOLD = threshold self.K = k if (self.IS_LOG): with open("log_score.txt", "a", encoding="utf-8") as f: f.write('\n') from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm def check_qa_answer(id, q, c, a, rq): return self._multichoice_checker(id, q, c, a, rq) def check_answer(id, q, rq): return self._answer_checker(id, q, rq) if choices != None: with ThreadPoolExecutor(max_workers=max_workers) as executor: results = list(tqdm(executor.map(check_qa_answer, ids, questions, choices, answers, rag_queries), total=len(questions))) else: with ThreadPoolExecutor(max_workers=max_workers) as executor: results = list(tqdm(executor.map(check_answer, ids, questions, rag_queries), total=len(questions))) self.CORRECT = sum(results) return self.CORRECT / len(questions) if __name__ == '__main__': qa_dir = r"C:\Users\vuvan\Desktop\An_Plaza\ViMedLLM\Vietnamese-Medical-LLM\dataset\QA Data\random.jsonl" evaluator = Evaluator(qa_dir = qa_dir, chatbot = Chatbot(model_name="mistral", max_token=1), rag = None, search_type = "similarity", log = True) import json with open(qa_dir, 'r', encoding="utf-8") as file: data = [json.loads(line) for line in file] questions = [item['question'] for item in data] # answers = [item['answer'] for item in data] # choices = [[item['A'], item['B'], item['C'], item['D'], item['E']] for item in data] evaluator.eval(questions, max_workers=10, suppress_error=True, k=0)