Spaces:
Build error
Build error
| from base_class import ChatbotEngine | |
| import os | |
| import openai | |
| import json | |
| import os | |
| import requests | |
| import tiktoken | |
| from config import MAX_TOKEN_MODEL_MAP | |
| from utils import get_filtered_keys_from_object | |
| class ChatbotWrapper: | |
| """ | |
| Wrapper of Official ChatGPT API, | |
| # base on https://github.com/ChatGPT-Hackers/revChatGPT | |
| """ | |
| def __init__( | |
| self, | |
| api_key: str, | |
| engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo", | |
| proxy: str = None, | |
| max_tokens: int = 3000, | |
| temperature: float = 0.5, | |
| top_p: float = 1.0, | |
| presence_penalty: float = 0.0, | |
| frequency_penalty: float = 0.0, | |
| reply_count: int = 1, | |
| system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally", | |
| overhead_token=96, | |
| ) -> None: | |
| """ | |
| Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys) | |
| """ | |
| self.engine = engine | |
| self.session = requests.Session() | |
| self.api_key = api_key | |
| self.system_prompt = system_prompt | |
| self.max_tokens = max_tokens | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.presence_penalty = presence_penalty | |
| self.frequency_penalty = frequency_penalty | |
| self.reply_count = reply_count | |
| self.max_limit = MAX_TOKEN_MODEL_MAP[self.engine] | |
| self.overhead_token = overhead_token | |
| if proxy: | |
| self.session.proxies = { | |
| "http": proxy, | |
| "https": proxy, | |
| } | |
| self.conversation: dict = { | |
| "default": [ | |
| { | |
| "role": "system", | |
| "content": system_prompt, | |
| }, | |
| ], | |
| } | |
| if max_tokens > self.max_limit - self.overhead_token: | |
| raise Exception( | |
| f"Max tokens cannot be greater than {self.max_limit- self.overhead_token}") | |
| if self.get_token_count("default") > self.max_tokens: | |
| raise Exception("System prompt is too long") | |
| def add_to_conversation( | |
| self, | |
| message: str, | |
| role: str, | |
| convo_id: str = "default", | |
| ) -> None: | |
| """ | |
| Add a message to the conversation | |
| """ | |
| self.conversation[convo_id].append({"role": role, "content": message}) | |
| def __truncate_conversation(self, convo_id: str = "default") -> None: | |
| """ | |
| Truncate the conversation | |
| """ | |
| # TODO: context condense with soft prompt tuning | |
| while True: | |
| if ( | |
| self.get_token_count(convo_id) > self.max_tokens | |
| and len(self.conversation[convo_id]) > 1 | |
| ): | |
| # Don't remove the first message and remove the first QA pair | |
| self.conversation[convo_id].pop(1) | |
| self.conversation[convo_id].pop(1) | |
| # TODO: optimal pop out based on similarity distance | |
| else: | |
| break | |
| # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | |
| def get_token_count(self, convo_id: str = "default") -> int: | |
| """ | |
| Get token count | |
| """ | |
| if self.engine not in ["gpt-3.5-turbo", "gpt-3.5-turbo-0301"]: | |
| raise NotImplementedError("Unsupported engine {self.engine}") | |
| encoding = tiktoken.encoding_for_model(self.engine) | |
| num_tokens = 0 | |
| for message in self.conversation[convo_id]: | |
| # every message follows <im_start>{role/name}\n{content}<im_end>\n | |
| num_tokens += 4 | |
| for key, value in message.items(): | |
| num_tokens += len(encoding.encode(value)) | |
| if key == "name": # if there's a name, the role is omitted | |
| num_tokens += 1 # role is always required and always 1 token | |
| num_tokens += 2 # every reply is primed with <im_start>assistant | |
| return num_tokens | |
| def get_max_tokens(self, convo_id: str) -> int: | |
| """ | |
| Get max tokens | |
| """ | |
| return self.max_tokens - self.get_token_count(convo_id) | |
| def ask_stream( | |
| self, | |
| prompt: str, | |
| role: str = "user", | |
| convo_id: str = "default", | |
| dynamic_system_prompt=None, | |
| **kwargs, | |
| ) -> str: | |
| """ | |
| Ask a question | |
| """ | |
| # Make conversation if it doesn't exist | |
| if convo_id not in self.conversation: | |
| self.reset(convo_id=convo_id, system_prompt=dynamic_system_prompt) | |
| # adjust system prompt | |
| assert dynamic_system_prompt is not None | |
| self.conversation[convo_id][0]["content"] = dynamic_system_prompt | |
| self.add_to_conversation(prompt, "user", convo_id=convo_id) | |
| print(" total tokens:") | |
| print(self.get_token_count(convo_id)) | |
| self.__truncate_conversation(convo_id=convo_id) | |
| # Get response | |
| response = self.session.post( | |
| os.environ.get( | |
| "API_URL") or "https://api.openai.com/v1/chat/completions", | |
| headers={ | |
| "Authorization": f"Bearer {kwargs.get('api_key', self.api_key)}"}, | |
| json={ | |
| "model": self.engine, | |
| "messages": self.conversation[convo_id], | |
| "stream": True, | |
| # kwargs | |
| "temperature": kwargs.get("temperature", self.temperature), | |
| "top_p": kwargs.get("top_p", self.top_p), | |
| "presence_penalty": kwargs.get( | |
| "presence_penalty", | |
| self.presence_penalty, | |
| ), | |
| "frequency_penalty": kwargs.get( | |
| "frequency_penalty", | |
| self.frequency_penalty, | |
| ), | |
| "n": kwargs.get("n", self.reply_count), | |
| "user": role, | |
| "max_tokens": self. get_max_tokens(convo_id=convo_id), | |
| }, | |
| stream=True, | |
| ) | |
| if response.status_code != 200: | |
| raise Exception( | |
| f"Error: {response.status_code} {response.reason} {response.text}", | |
| ) | |
| response_role: str = None | |
| full_response: str = "" | |
| for line in response.iter_lines(): | |
| if not line: | |
| continue | |
| # Remove "data: " | |
| line = line.decode("utf-8")[6:] | |
| if line == "[DONE]": | |
| break | |
| resp: dict = json.loads(line) | |
| choices = resp.get("choices") | |
| if not choices: | |
| continue | |
| delta = choices[0].get("delta") | |
| if not delta: | |
| continue | |
| if "role" in delta: | |
| response_role = delta["role"] | |
| if "content" in delta: | |
| content = delta["content"] | |
| full_response += content | |
| yield content | |
| self.add_to_conversation( | |
| full_response, response_role, convo_id=convo_id) | |
| def ask( | |
| self, | |
| prompt: str, | |
| role: str = "user", | |
| convo_id: str = "default", | |
| dynamic_system_prompt: str = None, | |
| **kwargs, | |
| ) -> str: | |
| """ | |
| Non-streaming ask | |
| """ | |
| response = self.ask_stream( | |
| prompt=prompt, | |
| role=role, | |
| convo_id=convo_id, | |
| dynamic_system_prompt=dynamic_system_prompt, | |
| **kwargs, | |
| ) | |
| full_response: str = "".join(response) | |
| return full_response | |
| def rollback(self, n: int = 1, convo_id: str = "default") -> None: | |
| """ | |
| Rollback the conversation | |
| """ | |
| for _ in range(n): | |
| self.conversation[convo_id].pop() | |
| def reset(self, convo_id: str = "default", system_prompt: str = None) -> None: | |
| """ | |
| Reset the conversation | |
| """ | |
| self.conversation[convo_id] = [ | |
| {"role": "system", "content": system_prompt or self.system_prompt}, | |
| ] | |
| def save(self, file: str, *keys: str) -> None: | |
| """ | |
| Save the Chatbot configuration to a JSON file | |
| """ | |
| with open(file, "w", encoding="utf-8") as f: | |
| json.dump( | |
| { | |
| key: self.__dict__[key] | |
| for key in get_filtered_keys_from_object(self, *keys) | |
| }, | |
| f, | |
| indent=2, | |
| # saves session.proxies dict as session | |
| default=lambda o: o.__dict__["proxies"], | |
| ) | |
| def load(self, file: str, *keys: str) -> None: | |
| """ | |
| Load the Chatbot configuration from a JSON file | |
| """ | |
| with open(file, encoding="utf-8") as f: | |
| # load json, if session is in keys, load proxies | |
| loaded_config = json.load(f) | |
| keys = get_filtered_keys_from_object(self, *keys) | |
| if "session" in keys and loaded_config["session"]: | |
| self.session.proxies = loaded_config["session"] | |
| keys = keys - {"session"} | |
| self.__dict__.update({key: loaded_config[key] for key in keys}) | |
| class OpenAIChatbot(ChatbotEngine): | |
| def __init__(self, api_key: str, | |
| engine: str = os.environ.get("GPT_ENGINE") or "gpt-3.5-turbo", | |
| proxy: str = None, | |
| max_tokens: int = 3000, | |
| temperature: float = 0.5, | |
| top_p: float = 1.0, | |
| presence_penalty: float = 0.0, | |
| frequency_penalty: float = 0.0, | |
| reply_count: int = 1, | |
| system_prompt: str = "You are ChatGPT, a large language model trained by OpenAI. Respond conversationally", | |
| overhead_token=96) -> None: | |
| openai.api_key = api_key | |
| self.api_key = api_key | |
| self.engine = engine | |
| self.proxy = proxy | |
| self.max_tokens = max_tokens | |
| self.temperature = temperature | |
| self.top_p = top_p | |
| self.presence_penalty = presence_penalty | |
| self.frequency_penalty = frequency_penalty | |
| self.reply_count = reply_count | |
| self.system_prompt = system_prompt | |
| self.bot = ChatbotWrapper( | |
| api_key=self.api_key, | |
| engine=self.engine, | |
| proxy=self.proxy, | |
| max_tokens=self.max_tokens, | |
| temperature=self.temperature, | |
| top_p=self.top_p, | |
| presence_penalty=self.presence_penalty, | |
| frequency_penalty=self.frequency_penalty, | |
| reply_count=self.reply_count, | |
| system_prompt=self.system_prompt, | |
| overhead_token=overhead_token | |
| ) | |
| self.overhead_token = overhead_token | |
| import tiktoken | |
| self.encoding = tiktoken.encoding_for_model(self.engine) | |
| def encode_length(self, text: str) -> int: | |
| return len(self.encoding.encode(text)) | |
| def query(self, questions: str, | |
| role: str = "user", | |
| convo_id: str = "default", | |
| context: str = None, | |
| **kwargs,): | |
| return self.bot.ask(prompt=questions, role=role, convo_id=convo_id, dynamic_system_prompt=context, **kwargs) | |