File size: 6,524 Bytes
cc37925
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from typing import List
from chatbot import Chatbot
from ragdb import TextRAG
from prompt import *

class Evaluator:
    CHATBOT = None
    RETRIEVER = None
    K = 4
    THRESHOLD = 0.5
    QA_DIR = None
    SEARCH_TYPE = "similarity"
    CORRECT = 0
    IS_LOG = True
    SUPPRESS_ERROR = False

    def __init__(self, 

                 chatbot: Chatbot, 

                 qa_dir:str = None,

                 rag: TextRAG = None,

                 search_type:str = "bm25",

                 k: int = 4, 

                 threshold: float = 0.5,

                 log: bool = True):
        self.CHATBOT = chatbot
        self.K = k
        self.THRESHOLD = threshold
        self.QA_DIR = qa_dir
        self.SEARCH_TYPE = search_type
        if (rag is not None):
            self.RETRIEVER = rag
        self.CORRECT = 0
        self.IS_LOG = log

    def _multichoice_checker(self, id: str,

                             question: str, 

                             choices: List[str], 

                             answer: str,

                             rag_query: str = None) -> int:
        options = "\n".join([f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)])
        document = ""
        results = None
        if (self.RETRIEVER is not None and self.K != 0):
            results = self.RETRIEVER.search((question + options) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE)
        if results is not None:
            for i,doc in enumerate(results):
                if doc.page_content.find("QUESTION") != -1:
                    continue
                document += f"\nDocument {i+1}:\n {doc.page_content}"
        response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document),
                                    suppress_error=self.SUPPRESS_ERROR)
        res = response[response.lower().rfind("the answer is ") + 14]
        # print(response)
        for i in range(3):
            if (res not in ["A", "B", "C", "D", "E"]):
                response = self.CHATBOT.chat(multichoice_qa_prompt.format(question=question, options=options, document=document), suppress_error=self.SUPPRESS_ERROR)
                res = response[response.lower().rfind("the answer is ") + 14]
        
        
        # res = response.upper().find("THE ANSWER IS") + 13
        if (self.IS_LOG):
            with open("log.txt", "a", encoding="utf-8") as f:
                f.write(f"ID: {id}\nQuestion: {question}\nChoices: {options}\nAnswer: {answer}  {res}\nResponse: {response}\n")
                f.write(f"Document: {document}\n")
            with open("log_score.txt", "a", encoding="utf-8") as f:
                f.write('1' if (res == answer) else '0')
        if (res != answer):
            with open("log_wrong.txt", "a", encoding="utf-8") as f:
                f.write(f"ID: {id}\n")
        return res == answer
    
    def _answer_checker(self, id: str,

                        question: str,

                        answer: str = None,

                        rag_query: str = None) -> int:
        document = ""
        results = None
        if (self.K > 0 and self.RETRIEVER is not None):
            results = self.RETRIEVER.search((question) if rag_query is None else rag_query, k=self.K, threshold=self.THRESHOLD, metric=self.SEARCH_TYPE)
        
        if results is not None:
            for i,doc in enumerate(results):
                if doc.page_content.find("QUESTION") != -1:
                    continue
                document += f"\nDocument {i+1}:\n {doc.page_content}"

        response = self.CHATBOT.chat(answer_prompt.format(question=question, document=document),
                                     suppress_error=self.SUPPRESS_ERROR)
        with open("response.txt", "a", encoding="utf-8") as f:
            f.write(f"RESPONSE: {response}\n\n")
            
        if (self.IS_LOG):
            with open("log.txt", "a", encoding="utf-8") as f:
                f.write(f"ID: {id}\nQuestion: {question}\nAnswer: {answer}\nResponse: {response}\n")
                f.write(f"Document: {document}\n")
        return 0
        

    def eval(self, ids: List[str],

             questions: List[str], 

             choices: List[List[str]] = None, 

             answers: List[str] = None, 

             rag_queries: List[str] = None,

             max_workers: int = 4, 

             suppress_error: bool = False, 

             k:int = 0,

             threshold:float = 0.5) -> float:
        self.SUPPRESS_ERROR = suppress_error
        self.THRESHOLD = threshold
        self.K = k

        if (self.IS_LOG):
            with open("log_score.txt", "a", encoding="utf-8") as f:
                    f.write('\n')
                
        from concurrent.futures import ThreadPoolExecutor
        from tqdm import tqdm

        def check_qa_answer(id, q, c, a, rq):
            return self._multichoice_checker(id, q, c, a, rq)
        
        def check_answer(id, q, rq):
            return self._answer_checker(id, q, rq)

        if choices != None:
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                results = list(tqdm(executor.map(check_qa_answer, ids, questions, choices, answers, rag_queries), total=len(questions)))
        else:
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                results = list(tqdm(executor.map(check_answer, ids, questions, rag_queries), total=len(questions)))

        self.CORRECT = sum(results)
        return self.CORRECT / len(questions)

if __name__ == '__main__':


    qa_dir = r"C:\Users\vuvan\Desktop\An_Plaza\ViMedLLM\Vietnamese-Medical-LLM\dataset\QA Data\random.jsonl"
    evaluator = Evaluator(qa_dir = qa_dir,
                        chatbot = Chatbot(model_name="mistral", max_token=1),
                        rag = None,
                        search_type = "similarity",
                        log = True)
    
    import json
    with open(qa_dir, 'r', encoding="utf-8") as file:
        data = [json.loads(line) for line in file]
    questions = [item['question'] for item in data]
    # answers = [item['answer'] for item in data]
    # choices = [[item['A'], item['B'], item['C'], item['D'], item['E']] for item in data]

    evaluator.eval(questions, max_workers=10, suppress_error=True, k=0)