import json import os import random import tempfile import time import vertexai from google.oauth2 import service_account from loguru import logger from vertexai.generative_models import GenerativeModel from vertexai.generative_models import Part, SafetySetting from common.enum.ai_service_error import AiServiceError from common.exceptions import AiServiceException from common.utils import encode_image_to_webp_base64 from image_processing_interface import ImageProcessingInterface class VertexAIService(ImageProcessingInterface): _instance = None def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super(VertexAIService, cls).__new__(cls) return cls._instance def __init__(self, json_key_path=None, json_key_env_var=None, project=None, location=None): if not hasattr(self, "_initialized"): """ Initializes the Vertex AI client. :param json_key_path: Path to the JSON file (optional) :param json_key_env_var: Environment variable name containing the JSON key (default 'GOOGLE_VERTEX_KEY') """ # logger.info("json_key_path", json_key_path) # logger.info(os.getenv('GOOGLE_VERTEX_KEY_PATH')) self.json_key_path = json_key_path or os.getenv('GOOGLE_VERTEX_KEY_PATH') self.json_key_env_var = json_key_env_var or os.getenv('GOOGLE_VERTEX_KEY') self.project = project or os.getenv('PROJECT_ID') self.location = location or os.getenv('LOCATION') # logger.info(f'json_key_path: {self.json_key_path}') # logger.info(f'json_key_env_var: {self.json_key_env_var}') # logger.info(f'project: {self.project}') # logger.info(f'location: {self.location}') self.credentials = self._authenticate_vertex_ai() vertexai.init(project=self.project, location=self.location, credentials=self.credentials) self._initialized = True logger.info('VertexAIService initialized') def _authenticate_vertex_ai(self): """ Authenticates using the JSON key from a file or environment variable. :return: Google Credentials object """ if self.json_key_path and os.path.isfile(self.json_key_path): # Authenticate using the file creds = service_account.Credentials.from_service_account_file(self.json_key_path) else: if not self.json_key_env_var: raise ValueError(f"Environment variable {json_key_env_var} is not set.") # Save JSON key to a temporary file with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(self.json_key_env_var.encode('utf-8')) temp_file_path = temp_file.name # Create credentials object creds = service_account.Credentials.from_service_account_file(temp_file_path) # Remove the temporary file after creating the credentials object os.remove(temp_file_path) return creds def process_image(self, input_image64, model_name, prompt, system="You are receipt recognizer", temperature=0.0): """ Processes the image using Vertex AI model. :param input_image64: Base64 encoded image string :param model_name: Name of the model in Vertex AI :param prompt: Text prompt to guide the model :param temperature: Temperature for controlling randomness :return: JSON response or error message """ if input_image64 is None: raise AiServiceException(AiServiceError.RETAKE_PHOTO, "No objects detected.") start_time = time.time() max_retries = 5 retries = 0 while retries < max_retries: try: # Load the model model = GenerativeModel(model_name, system_instruction=[system]) # Prepare the image part image_part = Part.from_data( mime_type="image/webp", data=input_image64 ) # Set generation configuration generation_config = { "max_output_tokens": 8192, "temperature": temperature, "response_mime_type": "application/json" } # Set safety settings safety_settings = [ SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE ), SafetySetting( category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE ), ] # Generate content using the model response = model.generate_content([image_part, prompt], generation_config=generation_config, safety_settings=safety_settings) end_time = time.time() logger.info(f"Recognition spent {end_time - start_time:.2f} seconds.") logger.debug(f"Original response: {response}") try: json_content = json.loads(response.text) formatted_json = json.dumps(json_content, indent=4) logger.debug(formatted_json) except json.JSONDecodeError: error_message = f"The receipt could not be recognized. Please retake the photo." logger.error(error_message) raise AiServiceException(AiServiceError.RETAKE_PHOTO, error_message) if not self._validate_receipt_data(json_content): error_message = f"The receipt is empty or contains no valid items. Please ensure the receipt is correctly scanned and try again" logger.error(error_message) raise AiServiceException(AiServiceError.RETAKE_PHOTO, error_message) json_content = self._add_total_price(json_content) json_content = self._add_discount_item(json_content) json_content = self._add_rounding_item(json_content) json_content['input_tokens'] = response.usage_metadata.prompt_token_count json_content['output_tokens'] = response.usage_metadata.candidates_token_count json_content['total_tokens'] = response.usage_metadata.total_token_count json_content['time'] = end_time - start_time model_input = { "system": system, "prompt": prompt } return json.dumps(json_content, indent=4), model_input except AiServiceException: raise except Exception as error: if "429" in str(error): retries += 1 wait_time = 2 ** retries + random.uniform(0, 1) logger.warning( f"Quota exceeded. Retrying in {wait_time:.2f} seconds... (Attempt {retries}/{max_retries})") time.sleep(wait_time) else: error_message = f"An error occurred during image processing: {error}" logger.error(error_message) raise AiServiceException(AiServiceError.RETAKE_PHOTO, error_message) error_message = f"Failed after {max_retries} attempts due to quota issues." logger.error(error_message) raise AiServiceException(AiServiceError.RETRY_FETCH, error_message) # Example usage if __name__ == '__main__': # Project and model details # project = "receiptsai-435817" project = "igneous-spanner-441609-h6" location = "us-central1" model_name = "gemini-1.5-flash" key_path = None # key_path = './secrets/GOOGLE_VERTEX_AI_KEY_435817.json' # Initialize the client and generate content client = VertexAIService(json_key_path=key_path, project=project, location=location) # Image processing image_path = "./examples_sl/fatlouis.png" input_image64 = encode_image_to_webp_base64(image_path) system = "You are receipt recognizer" prompt = "Read the text" with open('./common/prompt_v1.txt', 'r', encoding='utf-8') as file: prompt = file.read() result_img, model_input = client.process_image(input_image64, model_name, prompt, system, 0.0) decoded_string = result_img.encode('utf-8').decode('unicode_escape') logger.info(decoded_string)