| import json |
| import os |
| from typing import Any, Dict, List, Type |
|
|
| import openai |
| import weave |
| from openai import AsyncOpenAI |
| from pydantic import BaseModel |
|
|
| from app.utils.converter import product_data_to_str |
| from app.utils.image_processing import get_data_format, get_image_data |
| from app.utils.logger import setup_logger |
|
|
| from ..config import get_settings |
| from ..core import errors |
| from ..core.errors import BadRequestError, VendorError |
| from ..core.prompts import get_prompts |
| from .base import BaseAttributionService |
|
|
| deployment = os.getenv("DEPLOYMENT", "LOCAL") |
| if deployment == "LOCAL": |
| weave_project_name = "cfai/attribution-exp" |
| elif deployment == "DEV": |
| weave_project_name = "cfai/attribution-dev" |
| elif deployment == "PROD": |
| weave_project_name = "cfai/attribution-prod" |
|
|
| weave.init(project_name=weave_project_name) |
| settings = get_settings() |
| prompts = get_prompts() |
| logger = setup_logger(__name__) |
|
|
|
|
| def get_response_format(json_schema: dict[str, any]) -> dict[str, any]: |
| |
| json_schema["additionalProperties"] = False |
|
|
| |
| if "$defs" in json_schema: |
| for keys in json_schema["$defs"].keys(): |
| json_schema["$defs"][keys]["additionalProperties"] = False |
| response_format = { |
| "type": "json_schema", |
| "json_schema": {"strict": True, "name": "GarmentSchema", "schema": json_schema}, |
| } |
|
|
| return response_format |
|
|
|
|
| class OpenAIService(BaseAttributionService): |
| def __init__(self): |
| self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) |
|
|
| @weave.op |
| async def extract_attributes( |
| self, |
| attributes_model: Type[BaseModel], |
| ai_model: str, |
| img_urls: List[str], |
| product_taxonomy: str, |
| product_data: Dict[str, str], |
| pil_images: List[Any] = None, |
| img_paths: List[str] = None, |
| ) -> Dict[str, Any]: |
| logger.info("Extracting info via OpenAI...") |
| text_content = [ |
| { |
| "type": "text", |
| "text": prompts.EXTRACT_INFO_HUMAN_MESSAGE.format( |
| product_taxonomy=product_taxonomy, |
| product_data=product_data_to_str(product_data), |
| ), |
| }, |
| ] |
| if img_urls is not None: |
| image_content = [ |
| { |
| "type": "image_url", |
| "image_url": { |
| "url": img_url, |
| }, |
| } |
| for img_url in img_urls |
| ] |
| elif img_paths is not None: |
| image_content = [ |
| { |
| "type": "image_url", |
| "image_url": { |
| "url": f"data:image/{get_data_format(img_path)};base64,{get_image_data(img_path)}", |
| }, |
| } |
| for img_path in img_paths |
| ] |
|
|
| try: |
| response = await self.client.beta.chat.completions.parse( |
| model=ai_model, |
| messages=[ |
| { |
| "role": "system", |
| "content": prompts.EXTRACT_INFO_SYSTEM_MESSAGE, |
| }, |
| { |
| "role": "user", |
| "content": text_content + image_content, |
| }, |
| ], |
| max_tokens=1000, |
| response_format=attributes_model, |
| logprobs=False, |
| |
| temperature=0.0, |
| ) |
| except openai.BadRequestError as e: |
| raise BadRequestError(str(e)) |
| except Exception as e: |
| raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e))) |
|
|
| try: |
| content = response.choices[0].message.content |
| parsed_data = json.loads(content) |
| except: |
| raise VendorError(errors.VENDOR_ERROR_INVALID_JSON) |
|
|
| return parsed_data |
|
|
| @weave.op |
| async def follow_schema( |
| self, schema: Dict[str, Any], data: Dict[str, Any] |
| ) -> Dict[str, Any]: |
| logger.info("Following structure via OpenAI...") |
| text_content = [ |
| { |
| "type": "text", |
| "text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data), |
| }, |
| ] |
|
|
| try: |
| response = await self.client.beta.chat.completions.parse( |
| model="gpt-4o-2024-11-20", |
| messages=[ |
| { |
| "role": "system", |
| "content": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE, |
| }, |
| { |
| "role": "user", |
| "content": text_content, |
| }, |
| ], |
| max_tokens=1000, |
| response_format=get_response_format(schema), |
| logprobs=False, |
| |
| temperature=0.0, |
| ) |
| except Exception as e: |
| raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e))) |
|
|
| if response.choices[0].message.refusal: |
| logger.info("OpenAI refused to respond to the request") |
| return {"status": "refused"} |
|
|
| try: |
| content = response.choices[0].message.content |
| parsed_data = json.loads(content) |
| except: |
| raise ValueError(errors.VENDOR_ERROR_INVALID_JSON) |
|
|
| return parsed_data |
|
|