Spaces:
Sleeping
Sleeping
| from anthropic import Anthropic, APIStatusError | |
| from dotenv import load_dotenv | |
| from llm_api.utils import get_data_format, get_image_data | |
| from .constants import ( | |
| EXTRACT_INFO_HUMAN_MESSAGE, | |
| EXTRACT_INFO_SYSTEM_MESSAGE, | |
| FOLLOW_SCHEMA_HUMAN_MESSAGE, | |
| FOLLOW_SCHEMA_SYSTEM_MESSAGE, | |
| ) | |
| load_dotenv(override=True) | |
| client = Anthropic() | |
| # claude_model = 'claude-3-5-sonnet-20240620' | |
| claude_model = "claude-3-5-sonnet-latest" | |
| # claude_model = 'claude-3-5-haiku-latest' | |
| # claude_model = 'claude-3-5-haiku-20241022' | |
| # claude_model = 'claude-3-haiku-20240307' | |
| def extract_info(img_paths, schema, known_data=None): | |
| print("Extracting info via Anthropic...") | |
| tools = [ | |
| { | |
| "name": "extract_garment_info", | |
| "description": "Extracts key information from the image.", | |
| "input_schema": schema.model_json_schema(), | |
| "cache_control": {"type": "ephemeral"}, | |
| } | |
| ] | |
| image_messages = [ | |
| { | |
| "type": "image", | |
| "source": { | |
| "type": "base64", | |
| "media_type": f"image/{get_data_format(img_path)}", | |
| "data": get_image_data(img_path), | |
| }, | |
| } | |
| for img_path in img_paths | |
| ] | |
| system_message = [{"type": "text", "text": EXTRACT_INFO_SYSTEM_MESSAGE}] | |
| text_messages = [ | |
| { | |
| "type": "text", | |
| "text": EXTRACT_INFO_HUMAN_MESSAGE, | |
| } | |
| ] | |
| if known_data is not None: | |
| text_messages.append( | |
| { | |
| "type": "text", | |
| "text": f'\nAlso exploit the known data: \n\n"{known_data}"', | |
| } | |
| ) | |
| messages = [{"role": "user", "content": text_messages + image_messages}] | |
| try: | |
| response = client.messages.create( | |
| model=claude_model, | |
| extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, | |
| max_tokens=2048, | |
| system=system_message, | |
| tools=tools, | |
| messages=messages, | |
| ) | |
| except APIStatusError as e: | |
| print(f"{e=}") | |
| return e.status_code, None | |
| for content in response.content: | |
| if content.type == "tool_use": | |
| print("Found tool_use!") | |
| return 200, schema.model_validate(content.input) | |
| print("ERROR: No tool_use found!") | |
| def follow_structure(json_info, schema, known_data=None): | |
| print("Following structure via Anthropic...") | |
| tools = [ | |
| { | |
| "name": "extract_garment_info", | |
| "description": FOLLOW_SCHEMA_HUMAN_MESSAGE, | |
| "input_schema": schema.model_json_schema(), | |
| "cache_control": {"type": "ephemeral"}, | |
| } | |
| ] | |
| print("DEBUG: human message**************************") | |
| print(FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info)) | |
| text_messages = [ | |
| { | |
| "type": "text", | |
| "text": FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info), | |
| } | |
| ] | |
| if known_data is not None: | |
| text_messages.append( | |
| { | |
| "type": "text", | |
| "text": f'\nAlso exploit the known data: \n\n"{known_data}"', | |
| } | |
| ) | |
| system_message = [{"type": "text", "text": FOLLOW_SCHEMA_SYSTEM_MESSAGE}] | |
| messages = [{"role": "user", "content": text_messages}] | |
| try: | |
| response = client.messages.create( | |
| model=claude_model, | |
| extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, | |
| max_tokens=2048, | |
| system=system_message, | |
| tools=tools, | |
| messages=messages, | |
| ) | |
| except APIStatusError as e: | |
| print(f"{e=}") | |
| return e.status_code, None | |
| for content in response.content: | |
| if content.type == "tool_use": | |
| print("Found tool_use!***********************") | |
| print(content.input) | |
| return 200, schema.model_validate(content.input["json_info"]) | |
| print("ERROR: No tool_use found!") | |