Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import io | |
| import base64 | |
| from typing import Any, Dict, List, Type, Union, Optional | |
| import google.generativeai as genai | |
| from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold # For safety settings | |
| import weave # Assuming weave is still used | |
| from pydantic import BaseModel, ValidationError # For schema validation | |
| # Assuming these utilities are in the same relative paths or accessible | |
| from app.utils.converter import product_data_to_str | |
| from app.utils.image_processing import ( | |
| get_data_format, # Assuming this returns 'jpeg', 'png' etc. | |
| get_image_base64_and_type, # Assuming this fetches URL and returns (base64_str, type_str) | |
| get_image_data, # Assuming this reads local path and returns base64_str | |
| ) | |
| from app.utils.logger import exception_to_str, setup_logger | |
| # Assuming these are correctly defined and accessible | |
| from ..config import get_settings | |
| from ..core import errors | |
| from ..core.errors import BadRequestError, VendorError # Using your custom errors | |
| from ..core.prompts import get_prompts # Assuming prompts are compatible or adapted | |
| from .base import BaseAttributionService # Assuming this base class exists | |
| # Environment and Weave setup ( 그대로 유지 ) | |
| ENV = os.getenv("ENV", "LOCAL") | |
| if ENV == "LOCAL": | |
| weave_project_name = "cfai/attribution-exp" | |
| elif ENV == "DEV": | |
| weave_project_name = "cfai/attribution-dev" | |
| elif ENV == "UAT": | |
| weave_project_name = "cfai/attribution-uat" | |
| elif ENV == "PROD": | |
| pass # No weave for PROD | |
| if ENV != "PROD": | |
| # weave.init(project_name=weave_project_name) # Assuming weave.init() is called elsewhere or if needed here | |
| print(f"Weave project name (potentially initialized elsewhere): {weave_project_name}") | |
| settings = get_settings() | |
| prompts = get_prompts() | |
| logger = setup_logger(__name__) | |
| # Configure the Gemini client | |
| try: | |
| if settings.GEMINI_API_KEY: | |
| genai.configure(api_key=settings.GEMINI_API_KEY) | |
| else: | |
| logger.error("GEMINI_API_KEY not found in settings.") | |
| # Potentially raise an error or handle this case as per application requirements | |
| except AttributeError: | |
| logger.error("Settings object does not have GEMINI_API_KEY attribute.") | |
| # Handle missing settings attribute | |
| # Define default safety settings for Gemini | |
| # Adjust these as per your application's requirements | |
| DEFAULT_SAFETY_SETTINGS = { | |
| HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
| HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, | |
| } | |
| class GeminiService(BaseAttributionService): | |
| def __init__(self, model_name: str = "gemini-2.5-flash-preview-04-17"): | |
| """ | |
| Initializes the GeminiService. | |
| Args: | |
| model_name (str): The name of the Gemini model to use. | |
| """ | |
| try: | |
| self.model = genai.GenerativeModel( | |
| model_name, | |
| safety_settings=DEFAULT_SAFETY_SETTINGS | |
| # system_instruction can be set here if a global system message is always used | |
| ) | |
| logger.info(f"GeminiService initialized with model: {model_name}") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Gemini GenerativeModel: {exception_to_str(e)}") | |
| # Depending on requirements, you might want to raise an error here | |
| # For now, we'll let it proceed, and calls will fail if model isn't initialized. | |
| self.model = None | |
| def _prepare_image_parts( | |
| self, | |
| img_urls: Optional[List[str]] = None, | |
| img_paths: Optional[List[str]] = None, | |
| pil_images: Optional[List[Any]] = None, # PIL.Image.Image objects | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Prepares image data in the format expected by Gemini API. | |
| Decodes base64 image data to bytes. | |
| Converts PIL images to bytes. | |
| """ | |
| image_parts = [] | |
| # Process image URLs | |
| if img_urls: | |
| for img_url in img_urls: | |
| try: | |
| base64_data, img_type = get_image_base64_and_type(img_url) | |
| if base64_data and img_type: | |
| # Gemini expects raw bytes, so decode base64 | |
| image_bytes = base64.b64decode(base64_data) | |
| mime_type = f"image/{img_type.lower()}" | |
| image_parts.append({"mime_type": mime_type, "data": image_bytes}) | |
| else: | |
| logger.warning(f"Could not retrieve or identify type for image URL: {img_url}") | |
| except Exception as e: | |
| logger.error(f"Error processing image URL {img_url}: {exception_to_str(e)}") | |
| # Process image paths | |
| if img_paths: | |
| for img_path in img_paths: | |
| try: | |
| base64_data = get_image_data(img_path) # Assuming this returns base64 string | |
| img_type = get_data_format(img_path) # Assuming this returns 'png', 'jpeg' | |
| if base64_data and img_type: | |
| image_bytes = base64.b64decode(base64_data) | |
| mime_type = f"image/{img_type.lower()}" | |
| image_parts.append({"mime_type": mime_type, "data": image_bytes}) | |
| else: | |
| logger.warning(f"Could not retrieve or identify type for image path: {img_path}") | |
| except Exception as e: | |
| logger.error(f"Error processing image path {img_path}: {exception_to_str(e)}") | |
| # Process PIL images | |
| if pil_images: | |
| for i, pil_image in enumerate(pil_images): | |
| try: | |
| img_format = pil_image.format or 'PNG' # Default to PNG if format is not available | |
| mime_type = f"image/{img_format.lower()}" | |
| with io.BytesIO() as img_byte_arr: | |
| pil_image.save(img_byte_arr, format=img_format) | |
| image_bytes = img_byte_arr.getvalue() | |
| image_parts.append({"mime_type": mime_type, "data": image_bytes}) | |
| except Exception as e: | |
| logger.error(f"Error processing PIL image #{i}: {exception_to_str(e)}") | |
| return image_parts | |
| # Assuming weave.op can be used as a decorator directly | |
| async def extract_attributes( | |
| self, | |
| attributes_model: Type[BaseModel], | |
| ai_model: str, # This will be the Gemini model name, e.g., "gemini-1.5-flash-latest" | |
| img_urls: Optional[List[str]] = None, | |
| product_taxonomy: str = "", | |
| product_data: Optional[Dict[str, Union[str, List[str]]]] = None, | |
| pil_images: Optional[List[Any]] = None, | |
| img_paths: Optional[List[str]] = None, | |
| ) -> Dict[str, Any]: | |
| if not self.model: | |
| raise VendorError("Gemini model not initialized.") | |
| if self.model.model_name != ai_model: # If a different model is requested for this specific call | |
| logger.info(f"Switching to model {ai_model} for this extraction request.") | |
| # Note: This creates a new model object for the call. | |
| # If this happens frequently, consider how model instances are managed. | |
| current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS) | |
| else: | |
| current_model = self.model | |
| # Construct the prompt text | |
| # Combining system and human prompts as Gemini typically takes a list of contents. | |
| # System instructions can also be part of the model's initialization. | |
| system_message = prompts.EXTRACT_INFO_SYSTEM_MESSAGE | |
| human_message = prompts.EXTRACT_INFO_HUMAN_MESSAGE.format( | |
| product_taxonomy=product_taxonomy, | |
| product_data=product_data_to_str(product_data if product_data else {}), | |
| ) | |
| full_prompt_text = f"{system_message}\n\n{human_message}" | |
| # For logging or debugging the prompt | |
| logger.info(f"Gemini Prompt Text: {full_prompt_text[:500]}...") # Log a snippet | |
| content_parts = [full_prompt_text] | |
| # Prepare image parts | |
| try: | |
| image_parts = self._prepare_image_parts(img_urls, img_paths, pil_images) | |
| content_parts.extend(image_parts) | |
| except Exception as e: | |
| logger.error(f"Failed during image preparation: {exception_to_str(e)}") | |
| raise BadRequestError(f"Image processing failed: {e}") | |
| if not image_parts and (img_urls or img_paths or pil_images): | |
| logger.warning("Image sources provided, but no image parts were successfully prepared.") | |
| # Define generation config for JSON output | |
| # Pydantic's model_json_schema() generates an OpenAPI compliant schema dictionary. | |
| try: | |
| schema_for_gemini = attributes_model.model_json_schema() | |
| except Exception as e: | |
| logger.error(f"Error generating JSON schema from Pydantic model: {exception_to_str(e)}") | |
| raise VendorError(f"Could not generate schema for attributes_model: {e}") | |
| generation_config = GenerationConfig( | |
| response_mime_type="application/json", | |
| response_schema=schema_for_gemini, # Gemini expects the schema here | |
| temperature=0.0, # For deterministic output, similar to low top_p | |
| max_output_tokens=2048, # Adjust as needed, was 1000 for OpenAI | |
| # top_p, top_k can also be set if needed | |
| ) | |
| logger.info(f"Extracting attributes via Gemini model: {current_model.model_name}...") | |
| try: | |
| response = await current_model.generate_content_async( | |
| contents=content_parts, | |
| generation_config=generation_config, | |
| # request_options={"timeout": 120} # Example: set timeout in seconds | |
| ) | |
| except Exception as e: # Catches google.api_core.exceptions and others | |
| error_message = exception_to_str(e) | |
| logger.error(f"Gemini API call failed: {error_message}") | |
| # More specific error handling for Gemini can be added here | |
| # e.g., if isinstance(e, google.api_core.exceptions.InvalidArgument): | |
| # raise BadRequestError(f"Invalid argument to Gemini: {error_message}") | |
| raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message)) | |
| # Process the response | |
| try: | |
| # Check for safety blocks or refusals | |
| if not response.candidates: | |
| # This can happen if all candidates were filtered due to safety or other reasons. | |
| block_reason_detail = "Unknown reason (no candidates)" | |
| if response.prompt_feedback and response.prompt_feedback.block_reason: | |
| block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}" | |
| if response.prompt_feedback.block_reason_message: | |
| block_reason_detail += f" - {response.prompt_feedback.block_reason_message}" | |
| logger.error(f"Gemini response was blocked or empty. {block_reason_detail}") | |
| raise VendorError(f"Gemini response blocked or empty. {block_reason_detail}") | |
| # Assuming the first candidate is the one we want | |
| candidate = response.candidates[0] | |
| if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS | |
| finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN" | |
| logger.warning(f"Gemini generation finished with reason: {finish_reason_str}") | |
| # Potentially raise error if finish reason is SAFETY, RECITATION, etc. | |
| if finish_reason_str == "SAFETY": | |
| safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings]) | |
| raise VendorError(f"Gemini content generation stopped due to safety concerns. Ratings: [{safety_ratings_str}]") | |
| if not candidate.content.parts or not candidate.content.parts[0].text: | |
| logger.error("Gemini response content is empty or not in the expected text format.") | |
| raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + " (empty response text)") | |
| response_text = candidate.content.parts[0].text | |
| # Parse and validate the JSON response using the Pydantic model | |
| parsed_data = attributes_model.model_validate_json(response_text) | |
| return parsed_data.model_dump() # Return as dict | |
| except ValidationError as ve: | |
| logger.error(f"Pydantic validation failed for Gemini response: {ve}") | |
| logger.debug(f"Invalid JSON received from Gemini: {response_text[:500]}...") # Log snippet of invalid JSON | |
| raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {ve}") | |
| except json.JSONDecodeError as je: | |
| logger.error(f"JSON decoding failed for Gemini response: {je}") | |
| logger.debug(f"Non-JSON response received: {response_text[:500]}...") | |
| raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {je}") | |
| except VendorError: # Re-raise VendorErrors | |
| raise | |
| except Exception as e: | |
| error_message = exception_to_str(e) | |
| logger.error(f"Error processing Gemini response: {error_message}") | |
| # Log the raw response text if available and an error occurred | |
| raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A" | |
| logger.debug(f"Problematic Gemini response snippet: {raw_response_snippet}") | |
| raise VendorError(f"Failed to process Gemini response: {error_message}") | |
| async def follow_schema( | |
| self, | |
| schema: Dict[str, Any], # This should be an OpenAPI schema dictionary | |
| data: Dict[str, Any], | |
| ai_model: str = "gemini-1.5-flash-latest" # Model for this specific task | |
| ) -> Dict[str, Any]: | |
| if not self.model: # Check if the main model was initialized | |
| logger.warning("Main Gemini model not initialized. Attempting to initialize a temporary one for follow_schema.") | |
| try: | |
| current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS) | |
| except Exception as e: | |
| raise VendorError(f"Failed to initialize Gemini model for follow_schema: {exception_to_str(e)}") | |
| elif self.model.model_name != ai_model: | |
| logger.info(f"Switching to model {ai_model} for this follow_schema request.") | |
| current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS) | |
| else: | |
| current_model = self.model | |
| logger.info(f"Following schema via Gemini model: {current_model.model_name}...") | |
| # Prepare the prompt | |
| # System message can be part of the model or prepended here. | |
| system_message = prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE | |
| # The human message needs to contain the data to be transformed. | |
| # Ensure `json_info` placeholder is correctly used by your prompt string. | |
| try: | |
| data_as_json_string = json.dumps(data, indent=2) | |
| except TypeError as te: | |
| logger.error(f"Could not serialize 'data' to JSON for prompt: {te}") | |
| raise BadRequestError(f"Input data for schema following is not JSON serializable: {te}") | |
| human_message = prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data_as_json_string) | |
| full_prompt_text = f"{system_message}\n\n{human_message}" | |
| content_parts = [full_prompt_text] | |
| # Define generation config for JSON output using the provided schema | |
| generation_config = GenerationConfig( | |
| response_mime_type="application/json", | |
| response_schema=schema, # The provided schema dictionary | |
| temperature=0.0, # For deterministic output | |
| max_output_tokens=2048, # Adjust as needed | |
| ) | |
| try: | |
| response = await current_model.generate_content_async( | |
| contents=content_parts, | |
| generation_config=generation_config, | |
| ) | |
| except Exception as e: | |
| error_message = exception_to_str(e) | |
| logger.error(f"Gemini API call failed for follow_schema: {error_message}") | |
| raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message)) | |
| # Process response | |
| try: | |
| if not response.candidates: | |
| block_reason_detail = "Unknown reason (no candidates)" | |
| if response.prompt_feedback and response.prompt_feedback.block_reason: | |
| block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}" | |
| logger.error(f"Gemini response was blocked or empty in follow_schema. {block_reason_detail}") | |
| # OpenAI version returned {"status": "refused"}, mimicking similar for block | |
| return {"status": "refused", "reason": block_reason_detail} | |
| candidate = response.candidates[0] | |
| if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS | |
| finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN" | |
| logger.warning(f"Gemini generation (follow_schema) finished with reason: {finish_reason_str}") | |
| if finish_reason_str == "SAFETY": | |
| safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings]) | |
| return {"status": "refused", "reason": f"Safety block. Ratings: [{safety_ratings_str}]"} | |
| if not candidate.content.parts or not candidate.content.parts[0].text: | |
| logger.error("Gemini response content (follow_schema) is empty.") | |
| # Mimic OpenAI's refusal structure or raise error | |
| return {"status": "refused", "reason": "Empty content from Gemini"} | |
| response_text = candidate.content.parts[0].text | |
| parsed_data = json.loads(response_text) # The schema is enforced by Gemini | |
| return parsed_data | |
| except json.JSONDecodeError as je: | |
| logger.error(f"JSON decoding failed for Gemini response (follow_schema): {je}") | |
| logger.debug(f"Non-JSON response received: {response_text[:500]}...") | |
| # The original code raised ValueError(errors.VENDOR_ERROR_INVALID_JSON) | |
| # Let's use VendorError for consistency if that's preferred, or ValueError | |
| raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" (follow_schema) Details: {je}") | |
| except Exception as e: | |
| error_message = exception_to_str(e) | |
| logger.error(f"Error processing Gemini response (follow_schema): {error_message}") | |
| raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A" | |
| logger.debug(f"Problematic Gemini response snippet (follow_schema): {raw_response_snippet}") | |
| raise VendorError(f"Failed to process Gemini response (follow_schema): {error_message}") | |