| from typing import List, Union, Optional, Literal |
| import dataclasses |
|
|
| from tenacity import ( |
| retry, |
| stop_after_attempt, |
| wait_random_exponential, |
| ) |
| import openai |
|
|
| MessageRole = Literal["system", "user", "assistant"] |
|
|
|
|
| @dataclasses.dataclass() |
| class Message(): |
| role: MessageRole |
| content: str |
|
|
|
|
| def message_to_str(message: Message) -> str: |
| return f"{message.role}: {message.content}" |
|
|
|
|
| def messages_to_str(messages: List[Message]) -> str: |
| return "\n".join([message_to_str(message) for message in messages]) |
|
|
|
|
| @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) |
| def gpt_completion( |
| model: str, |
| prompt: str, |
| max_tokens: int = 1024, |
| stop_strs: Optional[List[str]] = None, |
| temperature: float = 0.0, |
| num_comps=1, |
| ) -> Union[List[str], str]: |
| response = openai.Completion.create( |
| model=model, |
| prompt=prompt, |
| temperature=temperature, |
| max_tokens=max_tokens, |
| top_p=1, |
| frequency_penalty=0.0, |
| presence_penalty=0.0, |
| stop=stop_strs, |
| n=num_comps, |
| ) |
| if num_comps == 1: |
| return response.choices[0].text |
|
|
| return [choice.text for choice in response.choices] |
|
|
|
|
| @retry(wait=wait_random_exponential(min=1, max=180), stop=stop_after_attempt(6)) |
| def gpt_chat( |
| model: str, |
| messages: List, |
| max_tokens: int = 1024, |
| temperature: float = 0.0, |
| num_comps=1, |
| ) -> Union[List[str], str]: |
| try: |
| response = openai.ChatCompletion.create( |
| model=model, |
| messages=[dataclasses.asdict(message) for message in messages], |
| max_tokens=max_tokens, |
| temperature=temperature, |
| top_p=1, |
| frequency_penalty=0.0, |
| presence_penalty=0.0, |
| n=num_comps, |
| ) |
| if num_comps == 1: |
| return response.choices[0].message.content |
| return [choice.message.content for choice in response.choices] |
|
|
| except Exception as e: |
| print(f"An error occurred while calling OpenAI: {e}") |
| raise |
|
|
| class ModelBase(): |
| def __init__(self, name: str): |
| self.name = name |
| self.is_chat = False |
|
|
| def __repr__(self) -> str: |
| return f'{self.name}' |
|
|
| def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: |
| raise NotImplementedError |
|
|
| def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]: |
| raise NotImplementedError |
|
|
|
|
| class GPTChat(ModelBase): |
| def __init__(self, model_name: str): |
| self.name = model_name |
| self.is_chat = True |
|
|
| def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: |
| return gpt_chat(self.name, messages, max_tokens, temperature, num_comps) |
|
|
|
|
| class GPT4(GPTChat): |
| def __init__(self): |
| super().__init__("gpt-4") |
|
|
|
|
| class GPT35(GPTChat): |
| def __init__(self): |
| super().__init__("gpt-3.5-turbo") |
|
|
|
|
| class GPTDavinci(ModelBase): |
| def __init__(self, model_name: str): |
| self.name = model_name |
|
|
| def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]: |
| return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps) |