Spaces:
Sleeping
Sleeping
| from anthropic import AnthropicVertex as Anthropic | |
| from dotenv import load_dotenv | |
| from llm_api.utils import get_data_format, get_image_data | |
| from .constants import EXTRACT_INFO_HUMAN_MESSAGE, EXTRACT_INFO_SYSTEM_MESSAGE,FOLLOW_SCHEMA_HUMAN_MESSAGE, FOLLOW_SCHEMA_SYSTEM_MESSAGE | |
| load_dotenv(override=True) | |
| client = Anthropic() | |
| def extract_info(img_paths, schema): | |
| print('Extracting info via Anthropic...') | |
| tools = [ | |
| { | |
| "name": "extract_garment_info", | |
| "description": "Extracts key information from the image.", | |
| "input_schema": schema.model_json_schema() | |
| } | |
| ] | |
| image_messages = [ | |
| { | |
| "type": "image", | |
| "source": { | |
| "type": "base64", | |
| "media_type": f"image/{get_data_format(img_path)}", | |
| "data": get_image_data(img_path) | |
| } | |
| } for img_path in img_paths | |
| ] | |
| system_message = [ | |
| { | |
| "type": "text", | |
| "text": EXTRACT_INFO_SYSTEM_MESSAGE | |
| } | |
| ] | |
| text_messages = [{ | |
| "type": "text", | |
| "text": EXTRACT_INFO_HUMAN_MESSAGE, | |
| }] | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": text_messages + image_messages | |
| } | |
| ] | |
| response = client.messages.create( | |
| model="claude-3-5-sonnet@20240620", | |
| max_tokens=2048, | |
| system=system_message, | |
| tools=tools, | |
| messages=messages | |
| ) | |
| for content in response.content: | |
| if content.type == 'tool_use': | |
| # print(content.input) | |
| # print(type(content.input)) | |
| print('Found tool_use!') | |
| return schema.model_validate(content.input) | |
| print('ERROR: No tool_use found!') | |
| def follow_structure(json_info, schema): | |
| print('Following structure via Anthropic...') | |
| tools = [ | |
| { | |
| "name": "extract_garment_info", | |
| "description": FOLLOW_SCHEMA_HUMAN_MESSAGE, | |
| "input_schema": schema.model_json_schema() | |
| } | |
| ] | |
| print('DEBUG: human message**************************') | |
| print(FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info)) | |
| text_messages = [{ | |
| "type": "text", | |
| "text": FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info), | |
| }] | |
| system_message = [ | |
| { | |
| "type": "text", | |
| "text": FOLLOW_SCHEMA_SYSTEM_MESSAGE | |
| } | |
| ] | |
| messages=[ | |
| { | |
| "role": "user", | |
| "content": text_messages | |
| } | |
| ] | |
| response = client.messages.create( | |
| model="claude-3-5-sonnet@20240620", | |
| max_tokens=2048, | |
| system=system_message, | |
| tools=tools, | |
| messages=messages | |
| ) | |
| for content in response.content: | |
| if content.type == 'tool_use': | |
| # print(content.input) | |
| # print(type(content.input)) | |
| print('Found tool_use!***********************') | |
| print(content.input) | |
| return schema.model_validate(content.input['json_info']) | |
| print('ERROR: No tool_use found!') |