attribution / llm_api /anthropic_vertex_api.py
thanhnt-cf's picture
update to support user schema
28f26e5
from anthropic import AnthropicVertex as Anthropic
from dotenv import load_dotenv
from llm_api.utils import get_data_format, get_image_data
from .constants import EXTRACT_INFO_HUMAN_MESSAGE, EXTRACT_INFO_SYSTEM_MESSAGE,FOLLOW_SCHEMA_HUMAN_MESSAGE, FOLLOW_SCHEMA_SYSTEM_MESSAGE
load_dotenv(override=True)
client = Anthropic()
def extract_info(img_paths, schema):
print('Extracting info via Anthropic...')
tools = [
{
"name": "extract_garment_info",
"description": "Extracts key information from the image.",
"input_schema": schema.model_json_schema()
}
]
image_messages = [
{
"type": "image",
"source": {
"type": "base64",
"media_type": f"image/{get_data_format(img_path)}",
"data": get_image_data(img_path)
}
} for img_path in img_paths
]
system_message = [
{
"type": "text",
"text": EXTRACT_INFO_SYSTEM_MESSAGE
}
]
text_messages = [{
"type": "text",
"text": EXTRACT_INFO_HUMAN_MESSAGE,
}]
messages=[
{
"role": "user",
"content": text_messages + image_messages
}
]
response = client.messages.create(
model="claude-3-5-sonnet@20240620",
max_tokens=2048,
system=system_message,
tools=tools,
messages=messages
)
for content in response.content:
if content.type == 'tool_use':
# print(content.input)
# print(type(content.input))
print('Found tool_use!')
return schema.model_validate(content.input)
print('ERROR: No tool_use found!')
def follow_structure(json_info, schema):
print('Following structure via Anthropic...')
tools = [
{
"name": "extract_garment_info",
"description": FOLLOW_SCHEMA_HUMAN_MESSAGE,
"input_schema": schema.model_json_schema()
}
]
print('DEBUG: human message**************************')
print(FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info))
text_messages = [{
"type": "text",
"text": FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info),
}]
system_message = [
{
"type": "text",
"text": FOLLOW_SCHEMA_SYSTEM_MESSAGE
}
]
messages=[
{
"role": "user",
"content": text_messages
}
]
response = client.messages.create(
model="claude-3-5-sonnet@20240620",
max_tokens=2048,
system=system_message,
tools=tools,
messages=messages
)
for content in response.content:
if content.type == 'tool_use':
# print(content.input)
# print(type(content.input))
print('Found tool_use!***********************')
print(content.input)
return schema.model_validate(content.input['json_info'])
print('ERROR: No tool_use found!')