| import linecache | |
| import re | |
| from typing import Dict, List, Optional | |
| import openai | |
| class ChatCompletion: | |
| def __init__(self, model: str = 'gpt-3.5-turbo', | |
| api_key: Optional[str] = None, api_key_path: str = './openai_api_key'): | |
| if api_key is None: | |
| openai.api_key = api_key | |
| api_key = linecache.getline(api_key_path, 2).strip('\n') | |
| if len(api_key) == 0: | |
| raise EnvironmentError | |
| openai.api_key = api_key | |
| self.model = model | |
| self.system_messages = [] | |
| self.user_messages = [] | |
| def chat(self, msg: str, setting: Optional[str] = None, model: Optional[str] = None) -> str: | |
| if self._context_length() > 2048: | |
| self.reset() | |
| if setting is not None: | |
| if setting not in self.system_messages: | |
| self.system_messages.append(setting) | |
| if not self.user_messages or msg != self.user_messages[-1]: | |
| self.user_messages.append(msg) | |
| return self._run(model) | |
| def retry(self, model: Optional[str] = None) -> str: | |
| return self._run(model) | |
| def reset(self): | |
| self.system_messages.clear() | |
| self.user_messages.clear() | |
| def _make_message(self) -> List[Dict]: | |
| sys_messages = [{'role': 'system', 'content': msg} for msg in self.system_messages] | |
| user_messages = [{'role': 'user', 'content': msg} for msg in self.user_messages] | |
| return sys_messages + user_messages | |
| def _context_length(self) -> int: | |
| return len(''.join(self.system_messages)) + len(''.join(self.user_messages)) | |
| def _run(self, model: Optional[str] = None) -> str: | |
| if model is None: | |
| model = self.model | |
| try: | |
| response = openai.ChatCompletion.create(model=model, messages=self._make_message()) | |
| ans = response['choices'][0]['message']['content'] | |
| ans = re.sub(r'^\n+', '', ans) | |
| except openai.error.OpenAIError as e: | |
| ans = e | |
| except Exception as e: | |
| print(e) | |
| return ans | |
| def __call__(self, msg: str, setting: Optional[str] = None, model: Optional[str] = None) -> str: | |
| return self.chat(msg, setting, model) | |