File size: 3,812 Bytes
7a38b33
 
 
44e85d0
7a38b33
 
 
 
 
44e85d0
7a38b33
 
 
28f26e5
7a38b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from .utils import get_data_format, get_image_data
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage, SystemMessage
from schema_structured import Garment as StructuredGarment
from schema_free import Garment as FreeGarment


load_dotenv()


anthropic_llm = ChatAnthropic(
    model="claude-3-5-sonnet-20241022",
    temperature=0,
    max_tokens=1024,
    timeout=None,
    max_retries=2,
    # other params...
)

openai_llm = ChatOpenAI(
    model="gpt-4o",
    temperature=0,
    max_tokens=None,
    timeout=None,
    max_retries=2,
    # api_key="...",  # if you prefer to pass api key in directly instaed of using env vars
    # base_url="...",
    # organization="...",
    # other params...
)


def extract_info(image_paths, provider, schema):
    if provider == 'anthropic':
        llm = anthropic_llm
    elif provider == 'openai':
        llm = openai_llm
    else:
        raise ValueError('Invalid provider')

    text_message = [{"type": "text", "text": "describe the product in the set of images"}]
    image_message = [
        {
            "type": "image_url",
            "image_url": {"url": f"data:image/{get_data_format(image_path)};base64,{get_image_data(image_path)}"},
        }
        for image_path in image_paths
    ]
    system_text= "You are a helpful assistant that describes the product in the set of images."
    human_message = HumanMessage(content=text_message + image_message)
    structured_llm = llm.with_structured_output(schema)
    if provider == 'anthropic':
        ai_msg = structured_llm.invoke([human_message], system=system_text)
    elif provider == 'openai':
        system_message = SystemMessage(content=system_text)
        ai_msg = structured_llm.invoke([system_message, human_message]) # add system message later
    
    return ai_msg


def follow_structure(json_info, provider, schema):
    if provider == 'anthropic':
        llm = anthropic_llm
    elif provider == 'openai':
        llm = openai_llm
    else:
        raise ValueError('Invalid provider')

    text_message = [{"type": "text", "text": f"Convert following attributes to structured schema. Keep all the keys and number of values. Only replace the values themselves. :\n\n{json_info}"}]
    system_text= "You are an expert at structured data extraction. You will be given an dictionary of attributes of a product and should output the its properties into the given structure."
    human_message = HumanMessage(content=text_message)
    structured_llm = llm.with_structured_output(schema)
    if provider == 'anthropic':
        ai_msg = structured_llm.invoke([human_message], system=system_text)
    elif provider == 'openai':
        system_message = SystemMessage(content=system_text)
        ai_msg = structured_llm.invoke([system_message, human_message]) # add system message later
    
    return ai_msg


if __name__ == '__main__':
    print('Running tests...')
    print("Test 1: Provider Anthropic, Schema Structure")
    info = extract_info(['1.png', '2.png'], provider='anthropic', schema=StructuredGarment)
    json_info = info.model_dump_json()
    print(json_info)
    print('Test 2: Provider OpenAI, Schema Structure')
    info = extract_info(['1.png', '2.png'], provider='openai', schema=StructuredGarment)
    json_info = info.model_dump_json()
    print(json_info)
    print('Test 3: Provider Anthropic, Schema Free')
    info = extract_info(['1.png', '2.png'], provider='anthropic', schema=FreeGarment)
    json_info = info.model_dump_json()
    print(json_info)
    print('Test 4: Provider OpenAI, Schema Free')
    info = extract_info(['1.png', '2.png'], provider='openai', schema=FreeGarment)
    json_info = info.model_dump_json()
    print(json_info)