Spaces:
Running
Running
| import json | |
| import traceback | |
| import json | |
| import pickle | |
| import requests | |
| GPT_CACHE_FILE_PATH = 'cache.pkl' | |
| USE_CACHE = False | |
| class GPTRequest: | |
| def __init__(self, model_name, temperature=0.5, tokens=300, frequency_penalty=0, | |
| presence_penalty=0, timeout=90): | |
| self.temperature = temperature | |
| self.tokens = tokens | |
| self.frequency_penalty = frequency_penalty | |
| self.presence_penalty = presence_penalty | |
| self.api_base = "https://new-llm.openai.azure.com/" | |
| self.model_name = model_name | |
| self.timeout = timeout | |
| def get_cache(self, messages: list[dict]): | |
| if not USE_CACHE: | |
| return None | |
| try: | |
| with open(GPT_CACHE_FILE_PATH, 'rb') as f: | |
| cache = pickle.load(f) | |
| except: | |
| cache = {} | |
| return cache.get(json.dumps(messages), None) | |
| def update_cache(self, messages: list[dict], response: str): | |
| if not USE_CACHE: | |
| return | |
| try: | |
| with open(GPT_CACHE_FILE_PATH, 'rb') as f: | |
| cache = pickle.load(f) | |
| except: | |
| cache = {} | |
| cache[json.dumps(messages)] = response | |
| with open(GPT_CACHE_FILE_PATH, 'wb') as f: | |
| pickle.dump(cache, f) | |
| def generate(self, messages: list[dict], openai_api_key: str): | |
| response = self.get_cache(messages) | |
| if response: | |
| return response | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {openai_api_key}" | |
| } | |
| payload = { | |
| "model": self.model_name, | |
| "messages": messages, | |
| "max_tokens": self.tokens, | |
| "temperature": self.temperature, | |
| "frequency_penalty": self.frequency_penalty, | |
| "presence_penalty": self.presence_penalty, | |
| } | |
| try: | |
| response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload, | |
| timeout=self.timeout) | |
| response = response.json() | |
| except Exception as e: | |
| return None | |
| rtn = response['choices'][0]['message']['content'] | |
| self.update_cache(messages, rtn) | |
| return rtn | |
| llm_kwargs = dict( | |
| model_name='gpt-3.5-turbo-1106', | |
| presence_penalty=0.1, | |
| tokens=3000, | |
| temperature=0.2, | |
| timeout=90, | |
| ) | |
| CHATGPT = GPTRequest(**llm_kwargs) | |
| gpt4_kwargs = dict( | |
| model_name='gpt-4-1106-preview', | |
| presence_penalty=0.1, | |
| tokens=3000, | |
| temperature=0.2, | |
| timeout=90, | |
| ) | |
| GPT4 = GPTRequest(**gpt4_kwargs) | |
| def GPT_request(prompt, model_name: str, openai_api_key: str): | |
| """ | |
| Given a prompt and a dictionary of GPT parameters, make a request to OpenAI | |
| server and returns the response. | |
| ARGS: | |
| prompt: a str prompt | |
| gpt_parameter: a python dictionary with the keys indicating the names of | |
| the parameter and the values indicating the parameter | |
| values. | |
| RETURNS: | |
| a str of GPT-3's response. | |
| """ | |
| if model_name == 'gpt4': | |
| gpt_model = GPT4 | |
| else: | |
| gpt_model = CHATGPT | |
| try: | |
| resp = gpt_model.generate(messages=[{"role": "user", "content": prompt}], openai_api_key=openai_api_key) | |
| return resp | |
| except Exception as e: | |
| traceback.print_exc() | |
| return None | |
| def generate_prompt(curr_input, prompt_lib_file): | |
| """ | |
| Takes in the current input (e.g. comment that you want to classifiy) and | |
| the path to a prompt file. The prompt file contains the raw str prompt that | |
| will be used, which contains the following substr: !<INPUT>! -- this | |
| function replaces this substr with the actual curr_input to produce the | |
| final promopt that will be sent to the GPT3 server. | |
| ARGS: | |
| curr_input: the input we want to feed in (IF THERE ARE MORE THAN ONE | |
| INPUT, THIS CAN BE A LIST.) | |
| prompt_lib_file: the path to the promopt file. | |
| RETURNS: | |
| a str prompt that will be sent to OpenAI's GPT server. | |
| """ | |
| if type(curr_input) == type("string"): | |
| curr_input = [curr_input] | |
| curr_input = [str(i) for i in curr_input] | |
| f = open(prompt_lib_file, "r") | |
| prompt = f.read() | |
| f.close() | |
| for count, i in enumerate(curr_input): | |
| prompt = prompt.replace(f"!<INPUT {count}>!", i) | |
| if "<commentblockmarker>###</commentblockmarker>" in prompt: | |
| prompt = prompt.split("<commentblockmarker>###</commentblockmarker>")[1] | |
| # return prompt.strip() | |
| return prompt | |
| def safe_generate_response( | |
| prompt, | |
| model_name="gpt4", | |
| openai_api_key="", | |
| func_validate=None, | |
| func_clean_up=None, | |
| repeat=5, | |
| ): | |
| for _ in range(repeat): | |
| curr_gpt_response = GPT_request(prompt, model_name, openai_api_key) | |
| if func_validate(curr_gpt_response, prompt=prompt): | |
| return func_clean_up(curr_gpt_response, prompt=prompt) | |
| return None | |