File size: 4,363 Bytes
58da50a
 
673de00
b78d5b2
673de00
52f2655
 
 
 
 
673de00
 
 
52f2655
58da50a
 
 
53bd0b4
71c0668
 
52f2655
71c0668
52f2655
71c0668
52f2655
71c0668
52f2655
98165dc
 
 
52f2655
58da50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba8e34d
1f9d350
ba8e34d
1f9d350
ba8e34d
52f2655
 
 
 
 
ba8e34d
 
52f2655
 
 
1f9d350
52f2655
 
 
 
 
 
 
 
 
 
 
 
 
71c0668
1f9d350
71c0668
1f9d350
71c0668
 
925c77b
71c0668
 
 
 
 
 
 
 
 
 
 
26e5540
 
1f9d350
26e5540
 
71c0668
 
 
bbad6e4
71c0668
 
98165dc
02ae70c
98165dc
71c0668
ba8e34d
1f9d350
 
673de00
52f2655
1f9d350
52f2655
 
925c77b
52f2655
 
 
673de00
 
 
52f2655
53bd0b4
58da50a
 
 
 
 
1f9d350
 
7998fe0
58da50a
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import logging
import os
import os
from abc import ABC, abstractmethod
from gradio_client import Client

import openai
import tiktoken
from transformers import T5Tokenizer

HUB_TOKEN = os.getenv("HUB_TOKEN")
FLAN_HF_SPACE="jerpint/i-like-flan"


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class TextCompletion(ABC):
    @abstractmethod
    def get_token_count(prompt: str) -> int:
        ...

    @abstractmethod
    def complete(prompt: str):
        ...

    def get_score_multiplier(self) -> float:
        return 1.0


class DebugCompleter(TextCompletion):
    """This is only meant for testing/debugging purposes."""

    def __init__(self, model):
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model("text-davinci-003")
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        if "$CORRECT$" in prompt:
            # returns the correct prompt
            return "I have been PWNED"
        elif "$ERROR$" in prompt:
            # simulate an error
            raise RuntimeError(f"User raised an error explicitly in {self.model}")
        else:
            return "Some incorrect completion"


class GPTCompletor(TextCompletion):
    def __init__(self, openai_api_key, model):
        self.openai_api_key = openai_api_key
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model("text-davinci-003")
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        # Set the api key
        openai.api_key = self.openai_api_key

        # Call the API to generate a response
        response = openai.Completion.create(
            engine=self.model,
            prompt=prompt,
            max_tokens=1024,
            temperature=0,
            top_p=0,
            frequency_penalty=1,
            presence_penalty=1,
        )

        # Get the response text
        response_text = response["choices"][0]["text"]
        return response_text


class ChatGPTCompletor(TextCompletion):
    def __init__(self, openai_api_key, model):
        self.openai_api_key = openai_api_key
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model(self.model)
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        # Set the api key
        openai.api_key = self.openai_api_key

        messages = [
            {"role": "user", "content": prompt},
        ]

        # Call the API to generate a response
        response = openai.ChatCompletion.create(
            messages=messages,
            model=self.model,
            temperature=0,
        )

        # Get the response text
        response_text = response["choices"][0]["message"]["content"]

        return response_text

    def get_score_multiplier(self) -> float:
        return 2.0


class FlanT5Completor(TextCompletion):
    def __init__(self, model: str):
        self.model = model
        self.client = Client(src=FLAN_HF_SPACE, hf_token=HUB_TOKEN)

    def get_token_count(self, prompt: str) -> int:
        tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")

        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        token_count = len(input_ids[0])
        return token_count

    def complete(self, prompt: str) -> str:
        response_text = self.client.predict(prompt)
        return response_text


completers = {
    "text-davinci-003": GPTCompletor,
    "FlanT5-XXL": FlanT5Completor,
    "gpt-3.5-turbo": ChatGPTCompletor,
}


def get_completer(model: str, openai_api_key: str = ""):
    logger.info(f"Loading completer for {model=}")
    if model in ["text-davinci-003", "gpt-3.5-turbo"]:
        completer = completers[model](model=model, openai_api_key=openai_api_key)
    elif model == "FlanT5-XXL":
        completer = completers[model](model=model)
    elif model == "gpt-debug":
        # Not included in completers because we dont want it showing in the app
        logger.warning("Using debugging completer...")
        completer = DebugCompleter(model=model)
    else:
        raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}")

    return completer