File size: 3,077 Bytes
28f26e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from anthropic import AnthropicVertex as Anthropic
from dotenv import load_dotenv
from llm_api.utils import get_data_format, get_image_data
from .constants import EXTRACT_INFO_HUMAN_MESSAGE, EXTRACT_INFO_SYSTEM_MESSAGE,FOLLOW_SCHEMA_HUMAN_MESSAGE, FOLLOW_SCHEMA_SYSTEM_MESSAGE

load_dotenv(override=True)
client = Anthropic()

def extract_info(img_paths, schema):
    print('Extracting info via Anthropic...')
    tools = [
        {
            "name": "extract_garment_info",
            "description": "Extracts key information from the image.",
            "input_schema": schema.model_json_schema()
        }
    ]

    image_messages = [
        {
            "type": "image",
            "source": {
                "type": "base64",
                "media_type": f"image/{get_data_format(img_path)}",
                "data": get_image_data(img_path)
            } 
        } for img_path in img_paths
    ]

    system_message = [
        {
            "type": "text", 
            "text": EXTRACT_INFO_SYSTEM_MESSAGE
        }
    ]

    text_messages = [{
        "type": "text",
        "text": EXTRACT_INFO_HUMAN_MESSAGE,
    }]

    messages=[
        {
            "role": "user",
            "content": text_messages + image_messages
        }
    ]

    response = client.messages.create(
        model="claude-3-5-sonnet@20240620",
        max_tokens=2048,
        system=system_message,
        tools=tools,
        messages=messages
    )

    for content in response.content:
        if content.type == 'tool_use':
            # print(content.input)
            # print(type(content.input))
            print('Found tool_use!')
            return schema.model_validate(content.input)
    
    print('ERROR: No tool_use found!')


def follow_structure(json_info, schema):
    print('Following structure via Anthropic...')
    tools = [
        {
            "name": "extract_garment_info",
            "description": FOLLOW_SCHEMA_HUMAN_MESSAGE,
            "input_schema": schema.model_json_schema()
        }
    ]

    print('DEBUG: human message**************************')
    print(FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info))
    text_messages = [{
        "type": "text",
        "text": FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info),
    }]

    system_message = [
        {
            "type": "text", 
            "text": FOLLOW_SCHEMA_SYSTEM_MESSAGE
        }
    ]

    messages=[
        {
            "role": "user",
            "content": text_messages
        }
    ]

    response = client.messages.create(
        model="claude-3-5-sonnet@20240620",
        max_tokens=2048,
        system=system_message,
        tools=tools,
        messages=messages
    )

    for content in response.content:
        if content.type == 'tool_use':
            # print(content.input)
            # print(type(content.input))
            print('Found tool_use!***********************')
            print(content.input)
            return schema.model_validate(content.input['json_info'])
    
    print('ERROR: No tool_use found!')