Spaces:
Sleeping
Sleeping
File size: 4,363 Bytes
58da50a 673de00 b78d5b2 673de00 52f2655 673de00 52f2655 58da50a 53bd0b4 71c0668 52f2655 71c0668 52f2655 71c0668 52f2655 71c0668 52f2655 98165dc 52f2655 58da50a ba8e34d 1f9d350 ba8e34d 1f9d350 ba8e34d 52f2655 ba8e34d 52f2655 1f9d350 52f2655 71c0668 1f9d350 71c0668 1f9d350 71c0668 925c77b 71c0668 26e5540 1f9d350 26e5540 71c0668 bbad6e4 71c0668 98165dc 02ae70c 98165dc 71c0668 ba8e34d 1f9d350 673de00 52f2655 1f9d350 52f2655 925c77b 52f2655 673de00 52f2655 53bd0b4 58da50a 1f9d350 7998fe0 58da50a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import logging
import os
import os
from abc import ABC, abstractmethod
from gradio_client import Client
import openai
import tiktoken
from transformers import T5Tokenizer
HUB_TOKEN = os.getenv("HUB_TOKEN")
FLAN_HF_SPACE="jerpint/i-like-flan"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextCompletion(ABC):
@abstractmethod
def get_token_count(prompt: str) -> int:
...
@abstractmethod
def complete(prompt: str):
...
def get_score_multiplier(self) -> float:
return 1.0
class DebugCompleter(TextCompletion):
"""This is only meant for testing/debugging purposes."""
def __init__(self, model):
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model("text-davinci-003")
return len(encoding.encode(prompt))
def complete(self, prompt: str):
if "$CORRECT$" in prompt:
# returns the correct prompt
return "I have been PWNED"
elif "$ERROR$" in prompt:
# simulate an error
raise RuntimeError(f"User raised an error explicitly in {self.model}")
else:
return "Some incorrect completion"
class GPTCompletor(TextCompletion):
def __init__(self, openai_api_key, model):
self.openai_api_key = openai_api_key
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model("text-davinci-003")
return len(encoding.encode(prompt))
def complete(self, prompt: str):
# Set the api key
openai.api_key = self.openai_api_key
# Call the API to generate a response
response = openai.Completion.create(
engine=self.model,
prompt=prompt,
max_tokens=1024,
temperature=0,
top_p=0,
frequency_penalty=1,
presence_penalty=1,
)
# Get the response text
response_text = response["choices"][0]["text"]
return response_text
class ChatGPTCompletor(TextCompletion):
def __init__(self, openai_api_key, model):
self.openai_api_key = openai_api_key
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model(self.model)
return len(encoding.encode(prompt))
def complete(self, prompt: str):
# Set the api key
openai.api_key = self.openai_api_key
messages = [
{"role": "user", "content": prompt},
]
# Call the API to generate a response
response = openai.ChatCompletion.create(
messages=messages,
model=self.model,
temperature=0,
)
# Get the response text
response_text = response["choices"][0]["message"]["content"]
return response_text
def get_score_multiplier(self) -> float:
return 2.0
class FlanT5Completor(TextCompletion):
def __init__(self, model: str):
self.model = model
self.client = Client(src=FLAN_HF_SPACE, hf_token=HUB_TOKEN)
def get_token_count(self, prompt: str) -> int:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
token_count = len(input_ids[0])
return token_count
def complete(self, prompt: str) -> str:
response_text = self.client.predict(prompt)
return response_text
completers = {
"text-davinci-003": GPTCompletor,
"FlanT5-XXL": FlanT5Completor,
"gpt-3.5-turbo": ChatGPTCompletor,
}
def get_completer(model: str, openai_api_key: str = ""):
logger.info(f"Loading completer for {model=}")
if model in ["text-davinci-003", "gpt-3.5-turbo"]:
completer = completers[model](model=model, openai_api_key=openai_api_key)
elif model == "FlanT5-XXL":
completer = completers[model](model=model)
elif model == "gpt-debug":
# Not included in completers because we dont want it showing in the app
logger.warning("Using debugging completer...")
completer = DebugCompleter(model=model)
else:
raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}")
return completer
|