ReceiptSplitAI / vertex_ai_service.py
valentynliubchenko
fixed logger
60cf71d
import json
import os
import random
import tempfile
import time
import vertexai
from google.oauth2 import service_account
from loguru import logger
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import Part, SafetySetting
from common.enum.ai_service_error import AiServiceError
from common.exceptions import AiServiceException
from common.utils import encode_image_to_webp_base64
from image_processing_interface import ImageProcessingInterface
class VertexAIService(ImageProcessingInterface):
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(VertexAIService, cls).__new__(cls)
return cls._instance
def __init__(self, json_key_path=None, json_key_env_var=None, project=None, location=None):
if not hasattr(self, "_initialized"):
"""
Initializes the Vertex AI client.
:param json_key_path: Path to the JSON file (optional)
:param json_key_env_var: Environment variable name containing the JSON key (default 'GOOGLE_VERTEX_KEY')
"""
# logger.info("json_key_path", json_key_path)
# logger.info(os.getenv('GOOGLE_VERTEX_KEY_PATH'))
self.json_key_path = json_key_path or os.getenv('GOOGLE_VERTEX_KEY_PATH')
self.json_key_env_var = json_key_env_var or os.getenv('GOOGLE_VERTEX_KEY')
self.project = project or os.getenv('PROJECT_ID')
self.location = location or os.getenv('LOCATION')
# logger.info(f'json_key_path: {self.json_key_path}')
# logger.info(f'json_key_env_var: {self.json_key_env_var}')
# logger.info(f'project: {self.project}')
# logger.info(f'location: {self.location}')
self.credentials = self._authenticate_vertex_ai()
vertexai.init(project=self.project, location=self.location, credentials=self.credentials)
self._initialized = True
logger.info('VertexAIService initialized')
def _authenticate_vertex_ai(self):
"""
Authenticates using the JSON key from a file or environment variable.
:return: Google Credentials object
"""
if self.json_key_path and os.path.isfile(self.json_key_path):
# Authenticate using the file
creds = service_account.Credentials.from_service_account_file(self.json_key_path)
else:
if not self.json_key_env_var:
raise ValueError(f"Environment variable {json_key_env_var} is not set.")
# Save JSON key to a temporary file
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_file.write(self.json_key_env_var.encode('utf-8'))
temp_file_path = temp_file.name
# Create credentials object
creds = service_account.Credentials.from_service_account_file(temp_file_path)
# Remove the temporary file after creating the credentials object
os.remove(temp_file_path)
return creds
def process_image(self, input_image64, model_name, prompt, system="You are receipt recognizer", temperature=0.0):
"""
Processes the image using Vertex AI model.
:param input_image64: Base64 encoded image string
:param model_name: Name of the model in Vertex AI
:param prompt: Text prompt to guide the model
:param temperature: Temperature for controlling randomness
:return: JSON response or error message
"""
if input_image64 is None:
raise AiServiceException(AiServiceError.RETAKE_PHOTO, "No objects detected.")
start_time = time.time()
max_retries = 5
retries = 0
while retries < max_retries:
try:
# Load the model
model = GenerativeModel(model_name, system_instruction=[system])
# Prepare the image part
image_part = Part.from_data(
mime_type="image/webp",
data=input_image64
)
# Set generation configuration
generation_config = {
"max_output_tokens": 8192,
"temperature": temperature,
"response_mime_type": "application/json"
}
# Set safety settings
safety_settings = [
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
),
SafetySetting(
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold=SafetySetting.HarmBlockThreshold.BLOCK_NONE
),
]
# Generate content using the model
response = model.generate_content([image_part, prompt],
generation_config=generation_config,
safety_settings=safety_settings)
end_time = time.time()
logger.info(f"Recognition spent {end_time - start_time:.2f} seconds.")
logger.debug(f"Original response: {response}")
try:
json_content = json.loads(response.text)
formatted_json = json.dumps(json_content, indent=4)
logger.debug(formatted_json)
except json.JSONDecodeError:
error_message = f"The receipt could not be recognized. Please retake the photo."
logger.error(error_message)
raise AiServiceException(AiServiceError.RETAKE_PHOTO, error_message)
if not self._validate_receipt_data(json_content):
error_message = f"The receipt is empty or contains no valid items. Please ensure the receipt is correctly scanned and try again"
logger.error(error_message)
raise AiServiceException(AiServiceError.RETAKE_PHOTO, error_message)
json_content = self._add_total_price(json_content)
json_content = self._add_discount_item(json_content)
json_content = self._add_rounding_item(json_content)
json_content['input_tokens'] = response.usage_metadata.prompt_token_count
json_content['output_tokens'] = response.usage_metadata.candidates_token_count
json_content['total_tokens'] = response.usage_metadata.total_token_count
json_content['time'] = end_time - start_time
model_input = {
"system": system,
"prompt": prompt
}
return json.dumps(json_content, indent=4), model_input
except AiServiceException:
raise
except Exception as error:
if "429" in str(error):
retries += 1
wait_time = 2 ** retries + random.uniform(0, 1)
logger.warning(
f"Quota exceeded. Retrying in {wait_time:.2f} seconds... (Attempt {retries}/{max_retries})")
time.sleep(wait_time)
else:
error_message = f"An error occurred during image processing: {error}"
logger.error(error_message)
raise AiServiceException(AiServiceError.RETAKE_PHOTO, error_message)
error_message = f"Failed after {max_retries} attempts due to quota issues."
logger.error(error_message)
raise AiServiceException(AiServiceError.RETRY_FETCH, error_message)
# Example usage
if __name__ == '__main__':
# Project and model details
# project = "receiptsai-435817"
project = "igneous-spanner-441609-h6"
location = "us-central1"
model_name = "gemini-1.5-flash"
key_path = None
# key_path = './secrets/GOOGLE_VERTEX_AI_KEY_435817.json'
# Initialize the client and generate content
client = VertexAIService(json_key_path=key_path, project=project, location=location)
# Image processing
image_path = "./examples_sl/fatlouis.png"
input_image64 = encode_image_to_webp_base64(image_path)
system = "You are receipt recognizer"
prompt = "Read the text"
with open('./common/prompt_v1.txt', 'r', encoding='utf-8') as file:
prompt = file.read()
result_img, model_input = client.process_image(input_image64, model_name, prompt, system, 0.0)
decoded_string = result_img.encode('utf-8').decode('unicode_escape')
logger.info(decoded_string)