Spaces:
Sleeping
Sleeping
| from .utils import get_data_format, get_image_data | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_openai import ChatOpenAI | |
| from dotenv import load_dotenv | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from schema_structured import Garment as StructuredGarment | |
| from schema_free import Garment as FreeGarment | |
| load_dotenv() | |
| anthropic_llm = ChatAnthropic( | |
| model="claude-3-5-sonnet-20241022", | |
| temperature=0, | |
| max_tokens=1024, | |
| timeout=None, | |
| max_retries=2, | |
| # other params... | |
| ) | |
| openai_llm = ChatOpenAI( | |
| model="gpt-4o", | |
| temperature=0, | |
| max_tokens=None, | |
| timeout=None, | |
| max_retries=2, | |
| # api_key="...", # if you prefer to pass api key in directly instaed of using env vars | |
| # base_url="...", | |
| # organization="...", | |
| # other params... | |
| ) | |
| def extract_info(image_paths, provider, schema): | |
| if provider == 'anthropic': | |
| llm = anthropic_llm | |
| elif provider == 'openai': | |
| llm = openai_llm | |
| else: | |
| raise ValueError('Invalid provider') | |
| text_message = [{"type": "text", "text": "describe the product in the set of images"}] | |
| image_message = [ | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/{get_data_format(image_path)};base64,{get_image_data(image_path)}"}, | |
| } | |
| for image_path in image_paths | |
| ] | |
| system_text= "You are a helpful assistant that describes the product in the set of images." | |
| human_message = HumanMessage(content=text_message + image_message) | |
| structured_llm = llm.with_structured_output(schema) | |
| if provider == 'anthropic': | |
| ai_msg = structured_llm.invoke([human_message], system=system_text) | |
| elif provider == 'openai': | |
| system_message = SystemMessage(content=system_text) | |
| ai_msg = structured_llm.invoke([system_message, human_message]) # add system message later | |
| return ai_msg | |
| def follow_structure(json_info, provider, schema): | |
| if provider == 'anthropic': | |
| llm = anthropic_llm | |
| elif provider == 'openai': | |
| llm = openai_llm | |
| else: | |
| raise ValueError('Invalid provider') | |
| text_message = [{"type": "text", "text": f"Convert following attributes to structured schema. Keep all the keys and number of values. Only replace the values themselves. :\n\n{json_info}"}] | |
| system_text= "You are an expert at structured data extraction. You will be given an dictionary of attributes of a product and should output the its properties into the given structure." | |
| human_message = HumanMessage(content=text_message) | |
| structured_llm = llm.with_structured_output(schema) | |
| if provider == 'anthropic': | |
| ai_msg = structured_llm.invoke([human_message], system=system_text) | |
| elif provider == 'openai': | |
| system_message = SystemMessage(content=system_text) | |
| ai_msg = structured_llm.invoke([system_message, human_message]) # add system message later | |
| return ai_msg | |
| if __name__ == '__main__': | |
| print('Running tests...') | |
| print("Test 1: Provider Anthropic, Schema Structure") | |
| info = extract_info(['1.png', '2.png'], provider='anthropic', schema=StructuredGarment) | |
| json_info = info.model_dump_json() | |
| print(json_info) | |
| print('Test 2: Provider OpenAI, Schema Structure') | |
| info = extract_info(['1.png', '2.png'], provider='openai', schema=StructuredGarment) | |
| json_info = info.model_dump_json() | |
| print(json_info) | |
| print('Test 3: Provider Anthropic, Schema Free') | |
| info = extract_info(['1.png', '2.png'], provider='anthropic', schema=FreeGarment) | |
| json_info = info.model_dump_json() | |
| print(json_info) | |
| print('Test 4: Provider OpenAI, Schema Free') | |
| info = extract_info(['1.png', '2.png'], provider='openai', schema=FreeGarment) | |
| json_info = info.model_dump_json() | |
| print(json_info) |