File size: 6,524 Bytes
cc37925 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from typing import List
from chatbot import Chatbot
from ragdb import TextRAG
from prompt import *
class Evaluator:
CHATBOT = None
RETRIEVER = None
K = 4
THRESHOLD = 0.5
QA_DIR = None
SEARCH_TYPE = "similarity"
CORRECT = 0
IS_LOG = True
SUPPRESS_ERROR = False
def __init__(self,
chatbot: Chatbot,
qa_dir:str = None,
rag: TextRAG = None,
search_type:str = "bm25",
k: int = 4,
threshold: float = 0.5,
log: bool = True):
self.CHATBOT = chatbot
self.K = k
self.THRESHOLD = threshold
self.QA_DIR = qa_dir
self.SEARCH_TYPE = search_type
if (rag is not None):
self.RETRIEVER = rag
self.CORRECT = 0
self.IS_LOG = log
def _multichoice_checker(self, id: str,
question: str,
choices: List[str],
answer: str,
rag_query: str = None) -> int:
options = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
document = ""
results = None
if (self.RETRIEVER is not None and self.K != 0):
results = self.RETRIEVER.search((question + options) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE)
if results is not None:
for i,doc in enumerate(results):
if doc.page_content.find("QUESTION") != -1:
continue
document += f"\nDocument {i+1}:\n {doc.page_content}"
response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document),
suppress_error=self.SUPPRESS_ERROR)
res = response[response.lower().rfind("the answer is ") + 14]
# print(response)
for i in range(3):
if (res not in ["A", "B", "C", "D", "E"]):
response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document), suppress_error=self.SUPPRESS_ERROR)
res = response[response.lower().rfind("the answer is ") + 14]
# res = response.upper().find("THE ANSWER IS") + 13
if (self.IS_LOG):
with open("log.txt", "a", encoding="utf-8") as f:
f.write(f"ID: {id}\nQuestion: {question}\nChoices: {options}\nAnswer: {answer} {res}\nResponse: {response}\n")
f.write(f"Document: {document}\n")
with open("log_score.txt", "a", encoding="utf-8") as f:
f.write('1' if (res == answer) else '0')
if (res != answer):
with open("log_wrong.txt", "a", encoding="utf-8") as f:
f.write(f"ID: {id}\n")
return res == answer
def _answer_checker(self, id: str,
question: str,
answer: str = None,
rag_query: str = None) -> int:
document = ""
results = None
if (self.K > 0 and self.RETRIEVER is not None):
results = self.RETRIEVER.search((question) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE)
if results is not None:
for i,doc in enumerate(results):
if doc.page_content.find("QUESTION") != -1:
continue
document += f"\nDocument {i+1}:\n {doc.page_content}"
response = self.CHATBOT.chat(answer_prompt.format(question=question, document=document),
suppress_error=self.SUPPRESS_ERROR)
with open("response.txt", "a", encoding="utf-8") as f:
f.write(f"RESPONSE: {response}\n\n")
if (self.IS_LOG):
with open("log.txt", "a", encoding="utf-8") as f:
f.write(f"ID: {id}\nQuestion: {question}\nAnswer: {answer}\nResponse: {response}\n")
f.write(f"Document: {document}\n")
return 0
def eval(self, ids: List[str],
questions: List[str],
choices: List[List[str]] = None,
answers: List[str] = None,
rag_queries: List[str] = None,
max_workers: int = 4,
suppress_error: bool = False,
k:int = 0,
threshold:float = 0.5) -> float:
self.SUPPRESS_ERROR = suppress_error
self.THRESHOLD = threshold
self.K = k
if (self.IS_LOG):
with open("log_score.txt", "a", encoding="utf-8") as f:
f.write('\n')
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
def check_qa_answer(id, q, c, a, rq):
return self._multichoice_checker(id, q, c, a, rq)
def check_answer(id, q, rq):
return self._answer_checker(id, q, rq)
if choices != None:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(tqdm(executor.map(check_qa_answer, ids, questions, choices, answers, rag_queries), total=len(questions)))
else:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
results = list(tqdm(executor.map(check_answer, ids, questions, rag_queries), total=len(questions)))
self.CORRECT = sum(results)
return self.CORRECT / len(questions)
if __name__ == '__main__':
qa_dir = r"C:\Users\vuvan\Desktop\An_Plaza\ViMedLLM\Vietnamese-Medical-LLM\dataset\QA Data\random.jsonl"
evaluator = Evaluator(qa_dir = qa_dir,
chatbot = Chatbot(model_name="mistral", max_token=1),
rag = None,
search_type = "similarity",
log = True)
import json
with open(qa_dir, 'r', encoding="utf-8") as file:
data = [json.loads(line) for line in file]
questions = [item['question'] for item in data]
# answers = [item['answer'] for item in data]
# choices = [[item['A'], item['B'], item['C'], item['D'], item['E']] for item in data]
evaluator.eval(questions, max_workers=10, suppress_error=True, k=0) |