attribution / llm_api /openai_api.py
thanhnt-cf's picture
initial commit
7a38b33
raw
history blame
2.66 kB
import json
from dotenv import load_dotenv
from .utils import get_data_format, get_image_data
from openai import OpenAI, BadRequestError
from .exceptions import RefusalError
from .constants import EXTRACT_INFO_HUMAN_MESSAGE, EXTRACT_INFO_SYSTEM_MESSAGE,FOLLOW_SCHEMA_HUMAN_MESSAGE, FOLLOW_SCHEMA_SYSTEM_MESSAGE
load_dotenv(override=True)
client = OpenAI()
def extract_info(img_paths, schema):
print('Extracting info via OpenAI...')
text_content = [
{
"type": "text",
"text": EXTRACT_INFO_HUMAN_MESSAGE,
},
]
image_content = [
{
"type": "image_url",
"image_url": {
"url": f"data:image/{get_image_data(img_path)};base64,{get_image_data(img_path)}",
},
}
for img_path in img_paths
]
response = client.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{
"role": "system",
"content": EXTRACT_INFO_SYSTEM_MESSAGE,
},
{
"role": "user",
"content": text_content + image_content,
}
],
max_tokens=1000,
response_format=schema,
logprobs=True,
top_logprobs=2,
temperature=0.0,
# top_p=.0000000000000000000001
)
if response.choices[0].message.refusal:
raise RefusalError('OpenAI refused to respond to the request')
content = response.choices[0].message.content
parsed_data = json.loads(content)
model_data = schema.model_validate(parsed_data)
return model_data
def follow_structure(json_info, schema):
print('Following structure via OpenAI...')
text_content = [
{
"type": "text",
"text": FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info),
},
]
response = client.beta.chat.completions.parse(
model="gpt-4o-2024-08-06",
messages=[
{
"role": "system",
"content": FOLLOW_SCHEMA_SYSTEM_MESSAGE,
},
{
"role": "user",
"content": text_content,
}
],
max_tokens=1000,
response_format=schema,
logprobs=True,
top_logprobs=2,
temperature=0.0,
# top_p=.0000000000000000000001
)
if response.choices[0].message.refusal:
raise RefusalError('OpenAI refused to respond to the request')
content = response.choices[0].message.content
parsed_data = json.loads(content)
model_data = schema.model_validate(parsed_data)
return model_data