File size: 3,973 Bytes
17fbdc8
44e85d0
17fbdc8
7a38b33
17fbdc8
 
 
 
 
 
 
7a38b33
44e85d0
7a38b33
17fbdc8
 
 
 
 
 
7a38b33
17fbdc8
 
7a38b33
 
 
 
 
17fbdc8
7a38b33
 
 
 
 
 
 
 
 
17fbdc8
 
7a38b33
17fbdc8
7a38b33
 
17fbdc8
7a38b33
17fbdc8
7a38b33
17fbdc8
 
7a38b33
 
 
17fbdc8
 
 
 
 
 
 
 
 
 
28f26e5
 
17fbdc8
 
28f26e5
 
 
17fbdc8
28f26e5
 
17fbdc8
28f26e5
7a38b33
 
17fbdc8
 
28f26e5
7a38b33
17fbdc8
7a38b33
17fbdc8
 
 
7a38b33
 
 
 
 
17fbdc8
7a38b33
 
 
17fbdc8
7a38b33
17fbdc8
7a38b33
17fbdc8
 
7a38b33
 
 
17fbdc8
 
 
 
 
 
 
 
 
 
 
28f26e5
 
17fbdc8
 
28f26e5
 
 
17fbdc8
28f26e5
 
17fbdc8
28f26e5
7a38b33
 
17fbdc8
 
7a38b33
17fbdc8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from anthropic import Anthropic, APIStatusError
from dotenv import load_dotenv

from llm_api.utils import get_data_format, get_image_data

from .constants import (
    EXTRACT_INFO_HUMAN_MESSAGE,
    EXTRACT_INFO_SYSTEM_MESSAGE,
    FOLLOW_SCHEMA_HUMAN_MESSAGE,
    FOLLOW_SCHEMA_SYSTEM_MESSAGE,
)

load_dotenv(override=True)
client = Anthropic()
# claude_model = 'claude-3-5-sonnet-20240620'
claude_model = "claude-3-5-sonnet-latest"
# claude_model = 'claude-3-5-haiku-latest'
# claude_model = 'claude-3-5-haiku-20241022'
# claude_model = 'claude-3-haiku-20240307'


def extract_info(img_paths, schema, known_data=None):
    print("Extracting info via Anthropic...")
    tools = [
        {
            "name": "extract_garment_info",
            "description": "Extracts key information from the image.",
            "input_schema": schema.model_json_schema(),
            "cache_control": {"type": "ephemeral"},
        }
    ]

    image_messages = [
        {
            "type": "image",
            "source": {
                "type": "base64",
                "media_type": f"image/{get_data_format(img_path)}",
                "data": get_image_data(img_path),
            },
        }
        for img_path in img_paths
    ]

    system_message = [{"type": "text", "text": EXTRACT_INFO_SYSTEM_MESSAGE}]

    text_messages = [
        {
            "type": "text",
            "text": EXTRACT_INFO_HUMAN_MESSAGE,
        }
    ]

    if known_data is not None:
        text_messages.append(
            {
                "type": "text",
                "text": f'\nAlso exploit the known data: \n\n"{known_data}"',
            }
        )

    messages = [{"role": "user", "content": text_messages + image_messages}]

    try:
        response = client.messages.create(
            model=claude_model,
            extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
            max_tokens=2048,
            system=system_message,
            tools=tools,
            messages=messages,
        )
    except APIStatusError as e:
        print(f"{e=}")
        return e.status_code, None

    for content in response.content:
        if content.type == "tool_use":
            print("Found tool_use!")
            return 200, schema.model_validate(content.input)

    print("ERROR: No tool_use found!")


def follow_structure(json_info, schema, known_data=None):
    print("Following structure via Anthropic...")
    tools = [
        {
            "name": "extract_garment_info",
            "description": FOLLOW_SCHEMA_HUMAN_MESSAGE,
            "input_schema": schema.model_json_schema(),
            "cache_control": {"type": "ephemeral"},
        }
    ]

    print("DEBUG: human message**************************")
    print(FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info))
    text_messages = [
        {
            "type": "text",
            "text": FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=json_info),
        }
    ]

    if known_data is not None:
        text_messages.append(
            {
                "type": "text",
                "text": f'\nAlso exploit the known data: \n\n"{known_data}"',
            }
        )

    system_message = [{"type": "text", "text": FOLLOW_SCHEMA_SYSTEM_MESSAGE}]

    messages = [{"role": "user", "content": text_messages}]
    try:
        response = client.messages.create(
            model=claude_model,
            extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
            max_tokens=2048,
            system=system_message,
            tools=tools,
            messages=messages,
        )
    except APIStatusError as e:
        print(f"{e=}")
        return e.status_code, None

    for content in response.content:
        if content.type == "tool_use":
            print("Found tool_use!***********************")
            print(content.input)
            return 200, schema.model_validate(content.input["json_info"])

    print("ERROR: No tool_use found!")