Spaces:
Running
Running
File size: 4,836 Bytes
8f3b56b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import json
import traceback
import json
import pickle
import requests
GPT_CACHE_FILE_PATH = 'cache.pkl'
USE_CACHE = False
class GPTRequest:
def __init__(self, model_name, temperature=0.5, tokens=300, frequency_penalty=0,
presence_penalty=0, timeout=90):
self.temperature = temperature
self.tokens = tokens
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.api_base = "https://new-llm.openai.azure.com/"
self.model_name = model_name
self.timeout = timeout
def get_cache(self, messages: list[dict]):
if not USE_CACHE:
return None
try:
with open(GPT_CACHE_FILE_PATH, 'rb') as f:
cache = pickle.load(f)
except:
cache = {}
return cache.get(json.dumps(messages), None)
def update_cache(self, messages: list[dict], response: str):
if not USE_CACHE:
return
try:
with open(GPT_CACHE_FILE_PATH, 'rb') as f:
cache = pickle.load(f)
except:
cache = {}
cache[json.dumps(messages)] = response
with open(GPT_CACHE_FILE_PATH, 'wb') as f:
pickle.dump(cache, f)
def generate(self, messages: list[dict], openai_api_key: str):
response = self.get_cache(messages)
if response:
return response
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openai_api_key}"
}
payload = {
"model": self.model_name,
"messages": messages,
"max_tokens": self.tokens,
"temperature": self.temperature,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
}
try:
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload,
timeout=self.timeout)
response = response.json()
except Exception as e:
return None
rtn = response['choices'][0]['message']['content']
self.update_cache(messages, rtn)
return rtn
llm_kwargs = dict(
model_name='gpt-3.5-turbo-1106',
presence_penalty=0.1,
tokens=3000,
temperature=0.2,
timeout=90,
)
CHATGPT = GPTRequest(**llm_kwargs)
gpt4_kwargs = dict(
model_name='gpt-4-1106-preview',
presence_penalty=0.1,
tokens=3000,
temperature=0.2,
timeout=90,
)
GPT4 = GPTRequest(**gpt4_kwargs)
def GPT_request(prompt, model_name: str, openai_api_key: str):
"""
Given a prompt and a dictionary of GPT parameters, make a request to OpenAI
server and returns the response.
ARGS:
prompt: a str prompt
gpt_parameter: a python dictionary with the keys indicating the names of
the parameter and the values indicating the parameter
values.
RETURNS:
a str of GPT-3's response.
"""
if model_name == 'gpt4':
gpt_model = GPT4
else:
gpt_model = CHATGPT
try:
resp = gpt_model.generate(messages=[{"role": "user", "content": prompt}], openai_api_key=openai_api_key)
return resp
except Exception as e:
traceback.print_exc()
return None
def generate_prompt(curr_input, prompt_lib_file):
"""
Takes in the current input (e.g. comment that you want to classifiy) and
the path to a prompt file. The prompt file contains the raw str prompt that
will be used, which contains the following substr: !<INPUT>! -- this
function replaces this substr with the actual curr_input to produce the
final promopt that will be sent to the GPT3 server.
ARGS:
curr_input: the input we want to feed in (IF THERE ARE MORE THAN ONE
INPUT, THIS CAN BE A LIST.)
prompt_lib_file: the path to the promopt file.
RETURNS:
a str prompt that will be sent to OpenAI's GPT server.
"""
if type(curr_input) == type("string"):
curr_input = [curr_input]
curr_input = [str(i) for i in curr_input]
f = open(prompt_lib_file, "r")
prompt = f.read()
f.close()
for count, i in enumerate(curr_input):
prompt = prompt.replace(f"!<INPUT {count}>!", i)
if "<commentblockmarker>###</commentblockmarker>" in prompt:
prompt = prompt.split("<commentblockmarker>###</commentblockmarker>")[1]
# return prompt.strip()
return prompt
def safe_generate_response(
prompt,
model_name="gpt4",
openai_api_key="",
func_validate=None,
func_clean_up=None,
repeat=5,
):
for _ in range(repeat):
curr_gpt_response = GPT_request(prompt, model_name, openai_api_key)
if func_validate(curr_gpt_response, prompt=prompt):
return func_clean_up(curr_gpt_response, prompt=prompt)
return None
|