|
|
from typing import List |
|
|
import openai |
|
|
from tenacity import retry, stop_after_attempt, wait_random_exponential |
|
|
|
|
|
from models.Base import BaseModel |
|
|
|
|
|
|
|
|
class OpenAIModel(BaseModel): |
|
|
def __init__(self, |
|
|
model_id="GPT4o", |
|
|
model_api_version='2024-06-01', |
|
|
api_key=None): |
|
|
assert api_key is not None, "no api key is provided." |
|
|
self.model_id = model_id |
|
|
self.model_api_version = model_api_version |
|
|
|
|
|
url = 'https://llm-api.amd.com' |
|
|
headers = { |
|
|
'Ocp-Apim-Subscription-Key': api_key |
|
|
} |
|
|
model_api_version = '2024-06-01' |
|
|
|
|
|
|
|
|
self.client = openai.AzureOpenAI( |
|
|
api_key='dummy', |
|
|
api_version=self.model_api_version, |
|
|
base_url=url, |
|
|
default_headers=headers |
|
|
) |
|
|
self.client.base_url = '{0}/openai/deployments/{1}'.format(url, self.model_id) |
|
|
|
|
|
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(5)) |
|
|
def generate(self, |
|
|
messages: List, |
|
|
temperature=0, |
|
|
presence_penalty=0, |
|
|
frequency_penalty=0, |
|
|
max_tokens=5000) -> str: |
|
|
response = self.client.chat.completions.create( |
|
|
model=self.model_id, |
|
|
messages=messages, |
|
|
temperature=temperature, |
|
|
n=1, |
|
|
stream=False, |
|
|
stop=None, |
|
|
max_tokens=max_tokens, |
|
|
presence_penalty=presence_penalty, |
|
|
frequency_penalty=frequency_penalty, |
|
|
logit_bias=None, |
|
|
user=None |
|
|
) |
|
|
|
|
|
|
|
|
if not response or not hasattr(response, 'choices') or len(response.choices) == 0: |
|
|
raise ValueError("No response choices returned from the API.") |
|
|
|
|
|
return response.choices[0].message.content |
|
|
|