attribution / llm_api /anthropic_api.py
thanhnt-cf's picture
update known data
17fbdc8 verified
from anthropic import Anthropic, APIStatusError
from dotenv import load_dotenv
from llm_api.utils import get_data_format, get_image_data
from .constants import (
EXTRACT_INFO_HUMAN_MESSAGE,
EXTRACT_INFO_SYSTEM_MESSAGE,
FOLLOW_SCHEMA_HUMAN_MESSAGE,
FOLLOW_SCHEMA_SYSTEM_MESSAGE,
)
load_dotenv(override=True)
client = Anthropic()
# claude_model = 'claude-3-5-sonnet-20240620'
claude_model = "claude-3-5-sonnet-latest"
# claude_model = 'claude-3-5-haiku-latest'
# claude_model = 'claude-3-5-haiku-20241022'
# claude_model = 'claude-3-haiku-20240307'
def extract_info(img_paths, schema, known_data=None):
print("Extracting info via Anthropic...")
tools = [
{
"name": "extract_garment_info",
"description": "Extracts key information from the image.",
"input_schema": schema.model_json_schema(),
"cache_control": {"type": "ephemeral"},
}
]
image_messages = [
{
"type": "image",
"source": {
"type": "base64",
"media_type": f"image/{get_data_format(img_path)}",
"data": get_image_data(img_path),
},
}
for img_path in img_paths
]
system_message = [{"type": "text", "text": EXTRACT_INFO_SYSTEM_MESSAGE}]
text_messages = [
{
"type": "text",
"text": EXTRACT_INFO_HUMAN_MESSAGE,
}
]
if known_data is not None:
text_messages.append(
{
"type": "text",
"text": f'\nAlso exploit the known data: \n\n"{known_data}"',
}
)
messages = [{"role": "user", "content": text_messages + image_messages}]
try:
response = client.messages.create(
model=claude_model,
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
max_tokens=2048,
system=system_message,
tools=tools,
messages=messages,
)
except APIStatusError as e:
print(f"{e=}")
return e.status_code, None
for content in response.content:
if content.type == "tool_use":
print("Found tool_use!")
return 200, schema.model_validate(content.input)
print("ERROR: No tool_use found!")
def follow_structure(json_info, schema, known_data=None):
print("Following structure via Anthropic...")
tools = [
{
"name": "extract_garment_info",
"description": FOLLOW_SCHEMA_HUMAN_MESSAGE,
"input_schema": schema.model_json_schema(),
"cache_control": {"type": "ephemeral"},
}
]
print("DEBUG: human message**************************")
print(FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info))
text_messages = [
{
"type": "text",
"text": FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info),
}
]
if known_data is not None:
text_messages.append(
{
"type": "text",
"text": f'\nAlso exploit the known data: \n\n"{known_data}"',
}
)
system_message = [{"type": "text", "text": FOLLOW_SCHEMA_SYSTEM_MESSAGE}]
messages = [{"role": "user", "content": text_messages}]
try:
response = client.messages.create(
model=claude_model,
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
max_tokens=2048,
system=system_message,
tools=tools,
messages=messages,
)
except APIStatusError as e:
print(f"{e=}")
return e.status_code, None
for content in response.content:
if content.type == "tool_use":
print("Found tool_use!***********************")
print(content.input)
return 200, schema.model_validate(content.input["json_info"])
print("ERROR: No tool_use found!")