File size: 5,916 Bytes
b0c0df0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import copy as cp
import os
import string
import time
from concurrent.futures import ThreadPoolExecutor

import pandas as pd
import requests
from loguru import logger as eval_logger
from tqdm import tqdm


class HRBenchEval:
    API_TYPE = os.getenv("API_TYPE", "openai")

    if API_TYPE == "openai":
        API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
        API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
        headers = {
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": "application/json",
        }
    elif API_TYPE == "azure":
        API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
        API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
        headers = {
            "api-key": API_KEY,
            "Content-Type": "application/json",
        }

    def __init__(self, api_key, gpt_model="gpt-3.5-turbo", max_workers=12):
        self.api_key = api_key
        self.gpt_model = gpt_model
        self.max_workers = max_workers

    def _post_request(self, payload):
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }
        response = requests.post(self.API_URL, headers=headers, json=payload, timeout=30)
        response.raise_for_status()
        return response.json()

    def can_infer_option(self, answer, choices):
        verbose = os.environ.get("VERBOSE", 0)
        # Choices is a dictionary
        if "Failed to obtain answer via API" in answer:
            return False

        reject_to_answer = ["Sorry, I can't help with images of people yet.", "I can't process this file.", "I'm sorry, but without the image provided", "Cannot determine the answer"]
        for err in reject_to_answer:
            if err in answer:
                return "Z"

        def count_choice(splits, choices, prefix="", suffix=""):
            cnt = 0
            for c in choices:
                if prefix + c + suffix in splits:
                    cnt += 1
            return cnt

        answer_mod = cp.copy(answer)
        chars = ".()[],:;!*#{}"
        for c in chars:
            answer_mod = answer_mod.replace(c, " ")

        splits = [x.strip() for x in answer_mod.split()]
        count = count_choice(splits, choices)

        if count == 1:
            for ch in choices:
                if "A" in splits and len(splits) > 3 and verbose:
                    return False
                if ch in splits:
                    return ch
        elif count == 0 and count_choice(splits, {"Z", ""}) == 1:
            return "Z"
        return False

    def can_infer_text(self, answer, choices):
        answer = answer.lower()
        assert isinstance(choices, dict)
        for k in choices:
            assert k in string.ascii_uppercase
            choices[k] = str(choices[k]).lower()
        cands = []
        for k in choices:
            if choices[k] in answer:
                cands.append(k)
        if len(cands) == 1:
            return cands[0]
        return False

    def can_infer(self, answer, choices):
        answer = str(answer)
        copt = self.can_infer_option(answer, choices)
        return copt if copt else self.can_infer_text(answer, choices)

    def get_chat_response(self, data, temperature=0, max_tokens=256, patience=10, sleep_time=0):
        question = data["question"]
        options = data["options"]
        prediction = data["prediction"]
        ret = self.can_infer(prediction, options)
        if ret:
            data["gpt_prediction"] = ret
            return data

        prompt = self.build_prompt(question, options, prediction)
        messages = [
            {"role": "user", "content": prompt},
        ]
        payload = {"model": self.gpt_model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "n": 1}

        while patience > 0:
            patience -= 1
            try:
                response = self._post_request(payload)
                prediction = response["choices"][0]["message"]["content"].strip()
                if prediction and prediction != "" and "Failed to obtain answer via API" not in prediction:
                    ret = self.can_infer(prediction, options)
                    data["gpt_prediction"] = ret
                    return data

            except Exception as e:
                # some model may output repetitive answer, which ChatGPT will throw an error.
                eval_logger.error(e)

                if sleep_time > 0:
                    time.sleep(sleep_time)
        return data

    def build_prompt(self, question, options, prediction):
        options_prompt = ""
        for key, item in options.items():
            options_prompt += f"{key}. {item}\n"
        tmpl = (
            "You are an AI assistant who will help me to match "
            "an answer with several options of a single-choice question. "
            "You are provided with a question, several options, and an answer, "
            "and you need to find which option is most similar to the answer. "
            "If the meaning of all options are significantly different from the answer, output Z. "
            "Your should output a single uppercase character in A, B, C, D (if they are valid options), and Z. \n"
            "Example 1: \n"
            "Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n"
            "Answer: a cute teddy bear\nYour output: A\n"
            "Example 2: \n"
            "Question: What is the main object in image?\nOptions: A. teddy bear B. rabbit C. cat D. dog\n"
            "Answer: Spider\nYour output: Z\n"
            "Example 3: \n"
            "Question: {}\nOptions: {}\nAnswer: {}\nYour output: "
        )
        return tmpl.format(question, options_prompt, prediction)