ReceiptSplitAI / app.py
valentynliubchenko
added gemini-2.0-flash-thinking-exp-01-21
cfae62c
import json
import os
from datetime import datetime
import gradio as gr
from PIL import Image
from dotenv import load_dotenv
from google_drive_client import GoogleDriveClient
from openai_service import OpenAIService
from qr_retriever import get_receipt_by_qr
from utils import read_prompt_from_file, process_receipt_json, save_to_excel, \
encode_image_to_webp_base64
from vertex_ai_service import VertexAIService
load_dotenv()
isFullVersion = os.getenv("COLLECTION_DATA_VERSION") != "True"
if isFullVersion:
model_names = ["gemini-1.5-flash", "gemini-1.5-pro", "gemini-flash-experimental", "gemini-pro-experimental", "gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp-01-21",
"gpt-4o-mini", "gpt-4o", "QR-processing"]
else:
model_names = ["gemini-1.5-flash", "gemini-1.5-pro", "gemini-flash-experimental", "gemini-pro-experimental", "gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp-01-21", "QR-processing"]
prompt_names = ["prompt_v1", "prompt_v2", "prompt_v3"]
# example_list = [["./examples/" + example] for example in os.listdir("examples")]
example_list_sl = [["./examples_sl/" + example] for example in os.listdir("examples_sl")]
example_list_ua = [["./examples_ua/" + example] for example in os.listdir("examples_ua")]
example_list_us = [["./examples_us/" + example] for example in os.listdir("examples_us")]
example_list_canada = [["./examples_canada/" + example] for example in os.listdir("examples_canada")]
example_france = [["./examples_france/" + example] for example in os.listdir("examples_france")]
prompt_default = read_prompt_from_file("common/prompt_v1.txt")
system_instruction = read_prompt_from_file("system_instruction.txt")
def process_image(input_image, model_name, prompt_name, temperatura, system_instruction=None, current_prompt_text=None):
# print(model_name)
# print(prompt_name)
# print(temperatura)
# print(custom_prompt_text)
if system_instruction is None:
system_instruction = ""
if input_image is None:
return model_name, "Image not found. Load image ", "", [], [], "", gr.update(interactive=False), gr.update(
interactive=False), gr.update(interactive=False), ""
if prompt_name is None:
prompt_name = "prompt_v1"
prompt_file = f"{prompt_name}.txt"
prompt = read_prompt_from_file(prompt_file)
if prompt_name is None:
current_prompt_text = prompt_default
# if prompt_name != "custom":
# prompt_file = f"{prompt_name}.txt"
# prompt = read_prompt_from_file(prompt_file)
# else:
# if current_prompt_text is None or current_prompt_text.strip() == "":
# return json.dumps({"error": "No prompt provided."})
prompt = current_prompt_text
# print(prompt)
print("file name:", input_image)
print("model_name:", model_name)
print("prompt_name:", prompt_name)
print("Temperatura:", temperatura)
# base64_image = encode_image_from_gradio(input_image)
base64_image = encode_image_to_webp_base64(input_image)
try:
if model_name.startswith("QR"):
try:
original_json, parsed_result = get_receipt_by_qr(input_image)
except Exception as e:
print(e)
return model_name, "Error get_receipt_by_qr", "", [], [], "", gr.update(interactive=False), gr.update(
interactive=False), gr.update(interactive=False), ""
print("original_json", original_json)
print("receipt", parsed_result)
if parsed_result:
parsed_result = clean_value(parsed_result)
parsed_result["sub_total_amount"] = "unknown"
for key, value in parsed_result.items():
print(f"Key: {key}, Value: {value}")
elif model_name.startswith("gpt"):
# result = gpt_process_image(base64_image, model_name, prompt, system_instruction, temperatura)
result, model_input = open_ai_client.process_image(base64_image, model_name, prompt, system_instruction, temperatura)
parsed_result = json.loads(result)
else:
result, model_input = vertex_ai_client.process_image(base64_image, model_name, prompt, system_instruction,
temperatura)
parsed_result = json.loads(result)
parsed_result['file_name'] = os.path.basename(input_image)
result = json.dumps(parsed_result, ensure_ascii=False, indent=4)
# result = result.encode('utf-8').decode('unicode_escape')
print(result)
except Exception as e:
print(f"Exception occurred: {e}")
result = json.dumps({"error": "Error processing: Check prompt or images"})
return model_name, result, "", "", "", "", gr.update(interactive=True), gr.update(
interactive=True), gr.update(interactive=True), ""
# print (result)
try:
store_info, items_table, taxs_table, message = process_receipt_json(result)
print(store_info)
print(items_table)
except Exception as e:
print(f"Exception occurred: {e}")
result = json.dumps({"error": "process_receipt_json"})
return model_name, result, "", "", "", "", gr.update(interactive=False), gr.update(
interactive=False), gr.update(interactive=False), ""
return model_name, result, store_info, items_table, taxs_table, message, gr.update(interactive=True), gr.update(
interactive=True), gr.update(interactive=True), ""
def clean_value(value):
if isinstance(value, list):
return [clean_value(v) for v in value]
elif isinstance(value, dict):
return {k: clean_value(v) for k, v in value.items()}
elif value is None:
return "unknown"
else:
return value
def save_flag_data(save_type, image, model_name, prompt_name, temperatura, current_prompt_text, model_output,
json_output,
store_info_output, items_list, comments_output, system_instruction,
flagging_dir="custom_flagged_data"):
save_button_update = gr.update(interactive=False)
image_link, json_link, excel_link = None, None, None
try:
# List files in the directory
try:
files = [f for f in os.listdir(flagging_dir) if os.path.isfile(os.path.join(flagging_dir, f))]
if files:
print("Files in directory:", flagging_dir)
for file in files:
print(file)
else:
print(f"No files found in directory: {flagging_dir}")
except Exception as e:
print(f"Error listing files in directory: {e}")
image_file_path = image
print("save_type:", save_type)
print("Image File Path:", image)
print("prompt_name:", prompt_name)
print("Model Name:", model_name)
print("Result as JSON:", json_output)
print("comments:", comments_output)
print("system_instruction:", system_instruction)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
original_filename = os.path.basename(image_file_path)
filename, file_extension = os.path.splitext(original_filename)
base_filename = f"{filename}_{model_name}_{prompt_name}_{timestamp}"
# Save image
image_save_path = os.path.join(flagging_dir, f"{base_filename}{file_extension}")
image = Image.open(image_file_path)
image.save(image_save_path)
if os.path.exists(image_save_path):
saved_image = Image.open(image_save_path)
image_size = saved_image.size
print(f"Image saved at: {image_save_path}, Size: {image_size}")
else:
print(f"Failed to save image at: {image_save_path}")
return 0
# Save result as JSON
json_file_path = os.path.join(flagging_dir, f"{base_filename}.json")
data_to_save = {
"image_name": f"{base_filename}{file_extension}",
"prompt_name": prompt_name,
"system_instruction": system_instruction,
"prompt": current_prompt_text,
"model_name": model_name,
"result_json": json_output,
"comments": comments_output,
"save_type": save_type
}
data_to_save_encode = json.dumps(data_to_save, ensure_ascii=False, indent=4)
print("data_to_save_encode: ", data_to_save_encode)
with open(json_file_path, 'w', encoding='utf-8') as json_file:
json_file.write(data_to_save_encode)
excel_file_path = os.path.join(flagging_dir, f"{base_filename}.xlsx")
try:
save_to_excel(json_output, excel_file_path, image_file_path)
except Exception as e:
print(f"Error while saving to excel: {e}")
# Upload files to Google Drive
google_drive_client_current = GoogleDriveClient(json_key_path='secrets/GOOGLE_SERVICE_ACCOUNT_KEY.json')
if google_drive_client_current:
try:
image_folder_id = '10qtum6ykbGTyu7vvw59i3h1XSY3-lRpo'
image_link = google_drive_client_current.upload_file(image_save_path, image_folder_id)
json_link = google_drive_client_current.upload_file(json_file_path, image_folder_id)
excel_link = google_drive_client_current.upload_file(excel_file_path, image_folder_id)
print(f"Image uploaded to Google Drive. Link: {image_link}")
print(f"JSON file uploaded to Google Drive. Link: {json_link}")
print(f"Excel file uploaded to Google Drive. Link: {excel_link}")
except Exception as e:
print(f"Error uploading files to Google Drive: {e}")
else:
print(f"Error google_drive_client does not available")
except Exception as e:
print(f"Error while saving flag data: {e}")
links = f"Image: {image_link}\nJSON: {json_link}\nExcel: {excel_link} \n shared lofder: https://drive.google.com/drive/folders/10qtum6ykbGTyu7vvw59i3h1XSY3-lRpo?usp=drive_link \n"
return save_button_update, save_button_update, save_button_update, links
def update_prompt_from_radio(prompt_name):
if prompt_name == "prompt_v1":
return read_prompt_from_file("common/prompt_v1.txt")
elif prompt_name == "prompt_v2":
return read_prompt_from_file("common/prompt_v2.txt")
elif prompt_name == "prompt_v3":
return read_prompt_from_file("common/prompt_v3.txt")
else:
return read_prompt_from_file("common/prompt_v1.txt")
#google_drive_client = GoogleDriveClient(json_key_path='secrets/GOOGLE_SERVICE_DRIVE_KEY_435817.json')
#vertex_ai_client = VertexAIService(json_key_path='secrets/GOOGLE_VERTEX_AI_KEY_435817.json')
google_drive_client = GoogleDriveClient()
vertex_ai_client = VertexAIService()
key = None
key_file_path = 'secrets/OPENAI_AI_KEY.txt'
if os.path.exists(key_file_path):
try:
with open(key_file_path, 'r') as key_file:
key = key_file.read().strip()
except Exception as e:
print(f"Error reading file: {e}")
open_ai_client = OpenAIService(api_key=key)
with gr.Blocks() as iface:
gr.Markdown("# ReceiptAI")
gr.Markdown("ReceiptAI")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="filepath")
model_radio = gr.Radio(model_names, label="Choose model/QR-processing(Slovakia)", value=model_names[0])
prompt_radio = gr.Radio(prompt_names, label="Choose prompt", value=prompt_names[0], visible=isFullVersion)
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperatura", value=0.0,
visible=isFullVersion)
system_instruction = gr.Textbox(label="System Instruction", visible=isFullVersion, value=system_instruction)
custom_prompt = gr.Textbox(label="prompt text", visible=isFullVersion, value=prompt_default)
with gr.Row():
submit_button = gr.Button("Receipt recognizing ")
with gr.Column(scale=2):
model_output = gr.Textbox(label="MODEL/QR-processing(Slovakia)", lines=1, interactive=isFullVersion)
json_output = gr.Textbox(label="Result as json")
store_info_output = gr.Textbox(label="Store Information", lines=4)
items_list = gr.Dataframe(
headers=["Item Name", "Category", "Unit Price", "Quantity", "Unit", "Total Price", "Discount",
"Item price with tax", "Grand Total"],
label="Items List")
taxes_list = gr.Dataframe(
headers=["Tax Name", "%", "tax from amount", "tax", "total", "tax included"],
label="Tax List")
comments_output = gr.Textbox(label="Comments", visible=True, lines=4, interactive=True)
with gr.Row():
save_good_button = gr.Button(value="Save as Good", interactive=False)
save_average_button = gr.Button(value="Save as Average", interactive=False)
save_poor_button = gr.Button(value="Save as Poor", interactive=False)
file_links_output = gr.Textbox(label="File Links", interactive=False, visible=True)
submit_button.click(fn=process_image,
inputs=[image_input, model_radio, prompt_radio, temperature_slider, system_instruction,
custom_prompt],
outputs=[model_output, json_output, store_info_output, items_list, taxes_list, comments_output,
save_good_button, save_average_button, save_poor_button, file_links_output])
common_inputs = [image_input, model_radio, prompt_radio, temperature_slider, custom_prompt, model_output,
json_output, store_info_output, items_list, comments_output, system_instruction]
def save_flag_data_wrapper(save_type, image, model_name, prompt_name, temperatura, custom_prompt, model_output,
json_output, store_info_output, items_list, comments_output, system_instruction):
# Ensure that `image` is a file path and not an object.
image_file_path = image # Gradio returns the path as a string
model_name_value = model_name # Extract selected value
prompt_name_value = prompt_name # Extract selected value
# The following variables should be passed as the values they hold
save_good_update, save_avg_update, save_poor_update, file_links = save_flag_data(
save_type, image, model_name, prompt_name, temperatura, custom_prompt, model_output, json_output,
store_info_output, items_list, comments_output, system_instruction
)
return save_good_update, save_avg_update, save_poor_update, file_links
# Use the same common_inputs for all buttons but ensure the correct values are passed
save_good_button.click(
fn=lambda *args: save_flag_data_wrapper("Good", *args),
inputs=common_inputs,
outputs=[save_good_button, save_average_button, save_poor_button, file_links_output]
)
save_average_button.click(
fn=lambda *args: save_flag_data_wrapper("Average", *args),
inputs=common_inputs,
outputs=[save_good_button, save_average_button, save_poor_button, file_links_output]
)
save_poor_button.click(
fn=lambda *args: save_flag_data_wrapper("Poor", *args),
inputs=common_inputs,
outputs=[save_good_button, save_average_button, save_poor_button, file_links_output]
)
prompt_radio.change(fn=update_prompt_from_radio, inputs=[prompt_radio], outputs=[custom_prompt])
gr.Examples(examples=example_list_sl,
inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt],
label="Examples for Slovakia")
if isFullVersion:
gr.Examples(examples=example_list_ua,
inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt],
label="Examples for Ukrainian")
gr.Examples(examples=example_list_us,
inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt],
label="Examples for US")
gr.Examples(examples=example_list_canada,
inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt],
label="Examples for Canada")
gr.Examples(examples=example_france,
inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt],
label="Examples for France")
iface.launch(server_name="0.0.0.0", server_port=7860)