Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from typing import Any, Dict, List, Type, Union | |
| import anthropic | |
| import weave | |
| from anthropic import APIStatusError, AsyncAnthropic | |
| from pydantic import BaseModel | |
| from app.config import get_settings | |
| from app.core import errors | |
| from app.core.errors import BadRequestError, VendorError | |
| from app.core.prompts import get_prompts | |
| from app.services.base import BaseAttributionService | |
| from app.utils.converter import product_data_to_str | |
| from app.utils.image_processing import get_data_format, get_image_data | |
| from app.utils.logger import exception_to_str, setup_logger | |
| ENV = os.getenv("ENV", "LOCAL") | |
| if ENV == "LOCAL": # local or demo | |
| weave_project_name = "cfai/attribution-exp" | |
| elif ENV == "DEV": | |
| weave_project_name = "cfai/attribution-dev" | |
| elif ENV == "UAT": | |
| weave_project_name = "cfai/attribution-uat" | |
| elif ENV == "PROD": | |
| pass | |
| if ENV != "PROD": | |
| weave.init(project_name=weave_project_name) | |
| settings = get_settings() | |
| prompts = get_prompts() | |
| logger = setup_logger(__name__) | |
| class AnthropicService(BaseAttributionService): | |
| def __init__(self): | |
| self.client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY) | |
| async def extract_attributes( | |
| self, | |
| attributes_model: Type[BaseModel], | |
| ai_model: str, | |
| img_urls: List[str], | |
| product_taxonomy: str, | |
| product_data: Dict[str, Union[str, List[str]]], | |
| pil_images: List[Any] = None, # do not remove, this is for weave | |
| img_paths: List[str] = None, | |
| ) -> Dict[str, Any]: | |
| logger.info("Extracting info via Anthropic...") | |
| tools = [ | |
| { | |
| "name": "extract_garment_info", | |
| "description": "Extracts key information from the image.", | |
| "input_schema": attributes_model.model_json_schema(), | |
| "cache_control": {"type": "ephemeral"}, | |
| } | |
| ] | |
| if img_urls is not None: | |
| image_messages = [ | |
| { | |
| "type": "image", | |
| "source": {"type": "url", "url": img_url}, | |
| } | |
| for img_url in img_urls | |
| ] | |
| elif img_paths is not None: | |
| image_messages = [ | |
| { | |
| "type": "image", | |
| "source": { | |
| "type": "base64", | |
| "media_type": f"image/{get_data_format(img_path)}", | |
| "data": get_image_data(img_path), | |
| }, | |
| } | |
| for img_path in img_paths | |
| ] | |
| else: | |
| # this is not expected, raise some errors here later. | |
| pass | |
| system_message = [{"type": "text", "text": prompts.EXTRACT_INFO_SYSTEM_MESSAGE}] | |
| text_messages = [ | |
| { | |
| "type": "text", | |
| "text": prompts.EXTRACT_INFO_HUMAN_MESSAGE.format( | |
| product_taxonomy=product_taxonomy, | |
| product_data=product_data_to_str(product_data), | |
| ), | |
| } | |
| ] | |
| messages = [{"role": "user", "content": text_messages + image_messages}] | |
| # try: | |
| try: | |
| response = await self.client.messages.create( | |
| model=ai_model, | |
| extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, | |
| max_tokens=2048, | |
| system=system_message, | |
| tools=tools, | |
| messages=messages, | |
| # temperature=0.0, | |
| # top_p=1e-45, | |
| top_k=1, | |
| ) | |
| except anthropic.BadRequestError as e: | |
| raise BadRequestError(e.message) | |
| except Exception as e: | |
| raise VendorError( | |
| errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e)) | |
| ) | |
| for content in response.content: | |
| if content.type == "tool_use": | |
| if content.input is None or not content.input: | |
| raise VendorError( | |
| errors.VENDOR_THROW_ERROR.format( | |
| error_message="content.input is None or content.input is empty" | |
| ) | |
| ) | |
| return content.input | |
| raise VendorError( | |
| errors.VENDOR_THROW_ERROR.format(error_message="No tool_use found") | |
| ) | |
| async def follow_schema(self, schema, data): | |
| logger.info("Following structure via Anthropic...") | |
| tools = [ | |
| { | |
| "name": "extract_garment_info", | |
| "description": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE, | |
| "input_schema": schema, | |
| "cache_control": {"type": "ephemeral"}, | |
| } | |
| ] | |
| text_messages = [ | |
| { | |
| "type": "text", | |
| "text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data), | |
| } | |
| ] | |
| system_message = [ | |
| {"type": "text", "text": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE} | |
| ] | |
| messages = [{"role": "user", "content": text_messages}] | |
| try: | |
| response = await self.client.messages.create( | |
| model=settings.ANTHROPIC_DEFAULT_MODEL, | |
| extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, | |
| max_tokens=2048, | |
| system=system_message, | |
| tools=tools, | |
| messages=messages, | |
| ) | |
| except Exception as e: | |
| raise VendorError( | |
| errors.VENDOR_THROW_ERROR.format(error_message=exception_to_str(e)) | |
| ) | |
| for content in response.content: | |
| if content.type == "tool_use": | |
| return content.input["json_info"] | |
| return {"status": "ERROR: no tool_use found"} | |