Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import random | |
| import tempfile | |
| import time | |
| import vertexai | |
| from google.oauth2 import service_account | |
| from loguru import logger | |
| from vertexai.generative_models import GenerativeModel | |
| from vertexai.generative_models import Part, SafetySetting | |
| from common.enum.ai_service_error import AiServiceError | |
| from common.exceptions import AiServiceException | |
| from common.utils import encode_image_to_webp_base64 | |
| from image_processing_interface import ImageProcessingInterface | |
| class VertexAIService(ImageProcessingInterface): | |
| _instance = None | |
| def __new__(cls, *args, **kwargs): | |
| if cls._instance is None: | |
| cls._instance = super(VertexAIService, cls).__new__(cls) | |
| return cls._instance | |
| def __init__(self, json_key_path=None, json_key_env_var=None, project=None, location=None): | |
| if not hasattr(self, "_initialized"): | |
| """ | |
| Initializes the Vertex AI client. | |
| :param json_key_path: Path to the JSON file (optional) | |
| :param json_key_env_var: Environment variable name containing the JSON key (default 'GOOGLE_VERTEX_KEY') | |
| """ | |
| # logger.info("json_key_path", json_key_path) | |
| # logger.info(os.getenv('GOOGLE_VERTEX_KEY_PATH')) | |
| self.json_key_path = json_key_path or os.getenv('GOOGLE_VERTEX_KEY_PATH') | |
| self.json_key_env_var = json_key_env_var or os.getenv('GOOGLE_VERTEX_KEY') | |
| self.project = project or os.getenv('PROJECT_ID') | |
| self.location = location or os.getenv('LOCATION') | |
| # logger.info(f'json_key_path: {self.json_key_path}') | |
| # logger.info(f'json_key_env_var: {self.json_key_env_var}') | |
| # logger.info(f'project: {self.project}') | |
| # logger.info(f'location: {self.location}') | |
| self.credentials = self._authenticate_vertex_ai() | |
| vertexai.init(project=self.project, location=self.location, credentials=self.credentials) | |
| self._initialized = True | |
| logger.info('VertexAIService initialized') | |
| def _authenticate_vertex_ai(self): | |
| """ | |
| Authenticates using the JSON key from a file or environment variable. | |
| :return: Google Credentials object | |
| """ | |
| if self.json_key_path and os.path.isfile(self.json_key_path): | |
| # Authenticate using the file | |
| creds = service_account.Credentials.from_service_account_file(self.json_key_path) | |
| else: | |
| if not self.json_key_env_var: | |
| raise ValueError(f"Environment variable {json_key_env_var} is not set.") | |
| # Save JSON key to a temporary file | |
| with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
| temp_file.write(self.json_key_env_var.encode('utf-8')) | |
| temp_file_path = temp_file.name | |
| # Create credentials object | |
| creds = service_account.Credentials.from_service_account_file(temp_file_path) | |
| # Remove the temporary file after creating the credentials object | |
| os.remove(temp_file_path) | |
| return creds | |
| def process_image(self, input_image64, model_name, prompt, system="You are receipt recognizer", temperature=0.0): | |
| """ | |
| Processes the image using Vertex AI model. | |
| :param input_image64: Base64 encoded image string | |
| :param model_name: Name of the model in Vertex AI | |
| :param prompt: Text prompt to guide the model | |
| :param temperature: Temperature for controlling randomness | |
| :return: JSON response or error message | |
| """ | |
| if input_image64 is None: | |
| raise AiServiceException(AiServiceError.RETAKE_PHOTO, "No objects detected.") | |
| start_time = time.time() | |
| max_retries = 5 | |
| retries = 0 | |
| while retries < max_retries: | |
| try: | |
| # Load the model | |
| model = GenerativeModel(model_name, system_instruction=[system]) | |
| # Prepare the image part | |
| image_part = Part.from_data( | |
| mime_type="image/webp", | |
| data=input_image64 | |
| ) | |
| # Set generation configuration | |
| generation_config = { | |
| "max_output_tokens": 8192, | |
| "temperature": temperature, | |
| "response_mime_type": "application/json" | |
| } | |
| # Set safety settings | |
| safety_settings = [ | |
| SafetySetting( | |
| category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH, | |
| threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE | |
| ), | |
| SafetySetting( | |
| category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, | |
| threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE | |
| ), | |
| SafetySetting( | |
| category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, | |
| threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE | |
| ), | |
| SafetySetting( | |
| category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, | |
| threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE | |
| ), | |
| ] | |
| # Generate content using the model | |
| response = model.generate_content([image_part, prompt], | |
| generation_config=generation_config, | |
| safety_settings=safety_settings) | |
| end_time = time.time() | |
| logger.info(f"Recognition spent {end_time - start_time:.2f} seconds.") | |
| logger.debug(f"Original response: {response}") | |
| try: | |
| json_content = json.loads(response.text) | |
| formatted_json = json.dumps(json_content, indent=4) | |
| logger.debug(formatted_json) | |
| except json.JSONDecodeError: | |
| error_message = f"The receipt could not be recognized. Please retake the photo." | |
| logger.error(error_message) | |
| raise AiServiceException(AiServiceError.RETAKE_PHOTO, error_message) | |
| if not self._validate_receipt_data(json_content): | |
| error_message = f"The receipt is empty or contains no valid items. Please ensure the receipt is correctly scanned and try again" | |
| logger.error(error_message) | |
| raise AiServiceException(AiServiceError.RETAKE_PHOTO, error_message) | |
| json_content = self._add_total_price(json_content) | |
| json_content = self._add_discount_item(json_content) | |
| json_content = self._add_rounding_item(json_content) | |
| json_content['input_tokens'] = response.usage_metadata.prompt_token_count | |
| json_content['output_tokens'] = response.usage_metadata.candidates_token_count | |
| json_content['total_tokens'] = response.usage_metadata.total_token_count | |
| json_content['time'] = end_time - start_time | |
| model_input = { | |
| "system": system, | |
| "prompt": prompt | |
| } | |
| return json.dumps(json_content, indent=4), model_input | |
| except AiServiceException: | |
| raise | |
| except Exception as error: | |
| if "429" in str(error): | |
| retries += 1 | |
| wait_time = 2 ** retries + random.uniform(0, 1) | |
| logger.warning( | |
| f"Quota exceeded. Retrying in {wait_time:.2f} seconds... (Attempt {retries}/{max_retries})") | |
| time.sleep(wait_time) | |
| else: | |
| error_message = f"An error occurred during image processing: {error}" | |
| logger.error(error_message) | |
| raise AiServiceException(AiServiceError.RETAKE_PHOTO, error_message) | |
| error_message = f"Failed after {max_retries} attempts due to quota issues." | |
| logger.error(error_message) | |
| raise AiServiceException(AiServiceError.RETRY_FETCH, error_message) | |
| # Example usage | |
| if __name__ == '__main__': | |
| # Project and model details | |
| # project = "receiptsai-435817" | |
| project = "igneous-spanner-441609-h6" | |
| location = "us-central1" | |
| model_name = "gemini-1.5-flash" | |
| key_path = None | |
| # key_path = './secrets/GOOGLE_VERTEX_AI_KEY_435817.json' | |
| # Initialize the client and generate content | |
| client = VertexAIService(json_key_path=key_path, project=project, location=location) | |
| # Image processing | |
| image_path = "./examples_sl/fatlouis.png" | |
| input_image64 = encode_image_to_webp_base64(image_path) | |
| system = "You are receipt recognizer" | |
| prompt = "Read the text" | |
| with open('./common/prompt_v1.txt', 'r', encoding='utf-8') as file: | |
| prompt = file.read() | |
| result_img, model_input = client.process_image(input_image64, model_name, prompt, system, 0.0) | |
| decoded_string = result_img.encode('utf-8').decode('unicode_escape') | |
| logger.info(decoded_string) | |