Spaces:
Running
Running
| import json | |
| import os | |
| from datetime import datetime | |
| import gradio as gr | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| from google_drive_client import GoogleDriveClient | |
| from openai_service import OpenAIService | |
| from qr_retriever import get_receipt_by_qr | |
| from utils import read_prompt_from_file, process_receipt_json, save_to_excel, \ | |
| encode_image_to_webp_base64 | |
| from vertex_ai_service import VertexAIService | |
| load_dotenv() | |
| isFullVersion = os.getenv("COLLECTION_DATA_VERSION") != "True" | |
| if isFullVersion: | |
| model_names = ["gemini-1.5-flash", "gemini-1.5-pro", "gemini-flash-experimental", "gemini-pro-experimental", "gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp-01-21", | |
| "gpt-4o-mini", "gpt-4o", "QR-processing"] | |
| else: | |
| model_names = ["gemini-1.5-flash", "gemini-1.5-pro", "gemini-flash-experimental", "gemini-pro-experimental", "gemini-2.0-flash-exp", "gemini-2.0-flash-thinking-exp-01-21", "QR-processing"] | |
| prompt_names = ["prompt_v1", "prompt_v2", "prompt_v3"] | |
| # example_list = [["./examples/" + example] for example in os.listdir("examples")] | |
| example_list_sl = [["./examples_sl/" + example] for example in os.listdir("examples_sl")] | |
| example_list_ua = [["./examples_ua/" + example] for example in os.listdir("examples_ua")] | |
| example_list_us = [["./examples_us/" + example] for example in os.listdir("examples_us")] | |
| example_list_canada = [["./examples_canada/" + example] for example in os.listdir("examples_canada")] | |
| example_france = [["./examples_france/" + example] for example in os.listdir("examples_france")] | |
| prompt_default = read_prompt_from_file("common/prompt_v1.txt") | |
| system_instruction = read_prompt_from_file("system_instruction.txt") | |
| def process_image(input_image, model_name, prompt_name, temperatura, system_instruction=None, current_prompt_text=None): | |
| # print(model_name) | |
| # print(prompt_name) | |
| # print(temperatura) | |
| # print(custom_prompt_text) | |
| if system_instruction is None: | |
| system_instruction = "" | |
| if input_image is None: | |
| return model_name, "Image not found. Load image ", "", [], [], "", gr.update(interactive=False), gr.update( | |
| interactive=False), gr.update(interactive=False), "" | |
| if prompt_name is None: | |
| prompt_name = "prompt_v1" | |
| prompt_file = f"{prompt_name}.txt" | |
| prompt = read_prompt_from_file(prompt_file) | |
| if prompt_name is None: | |
| current_prompt_text = prompt_default | |
| # if prompt_name != "custom": | |
| # prompt_file = f"{prompt_name}.txt" | |
| # prompt = read_prompt_from_file(prompt_file) | |
| # else: | |
| # if current_prompt_text is None or current_prompt_text.strip() == "": | |
| # return json.dumps({"error": "No prompt provided."}) | |
| prompt = current_prompt_text | |
| # print(prompt) | |
| print("file name:", input_image) | |
| print("model_name:", model_name) | |
| print("prompt_name:", prompt_name) | |
| print("Temperatura:", temperatura) | |
| # base64_image = encode_image_from_gradio(input_image) | |
| base64_image = encode_image_to_webp_base64(input_image) | |
| try: | |
| if model_name.startswith("QR"): | |
| try: | |
| original_json, parsed_result = get_receipt_by_qr(input_image) | |
| except Exception as e: | |
| print(e) | |
| return model_name, "Error get_receipt_by_qr", "", [], [], "", gr.update(interactive=False), gr.update( | |
| interactive=False), gr.update(interactive=False), "" | |
| print("original_json", original_json) | |
| print("receipt", parsed_result) | |
| if parsed_result: | |
| parsed_result = clean_value(parsed_result) | |
| parsed_result["sub_total_amount"] = "unknown" | |
| for key, value in parsed_result.items(): | |
| print(f"Key: {key}, Value: {value}") | |
| elif model_name.startswith("gpt"): | |
| # result = gpt_process_image(base64_image, model_name, prompt, system_instruction, temperatura) | |
| result, model_input = open_ai_client.process_image(base64_image, model_name, prompt, system_instruction, temperatura) | |
| parsed_result = json.loads(result) | |
| else: | |
| result, model_input = vertex_ai_client.process_image(base64_image, model_name, prompt, system_instruction, | |
| temperatura) | |
| parsed_result = json.loads(result) | |
| parsed_result['file_name'] = os.path.basename(input_image) | |
| result = json.dumps(parsed_result, ensure_ascii=False, indent=4) | |
| # result = result.encode('utf-8').decode('unicode_escape') | |
| print(result) | |
| except Exception as e: | |
| print(f"Exception occurred: {e}") | |
| result = json.dumps({"error": "Error processing: Check prompt or images"}) | |
| return model_name, result, "", "", "", "", gr.update(interactive=True), gr.update( | |
| interactive=True), gr.update(interactive=True), "" | |
| # print (result) | |
| try: | |
| store_info, items_table, taxs_table, message = process_receipt_json(result) | |
| print(store_info) | |
| print(items_table) | |
| except Exception as e: | |
| print(f"Exception occurred: {e}") | |
| result = json.dumps({"error": "process_receipt_json"}) | |
| return model_name, result, "", "", "", "", gr.update(interactive=False), gr.update( | |
| interactive=False), gr.update(interactive=False), "" | |
| return model_name, result, store_info, items_table, taxs_table, message, gr.update(interactive=True), gr.update( | |
| interactive=True), gr.update(interactive=True), "" | |
| def clean_value(value): | |
| if isinstance(value, list): | |
| return [clean_value(v) for v in value] | |
| elif isinstance(value, dict): | |
| return {k: clean_value(v) for k, v in value.items()} | |
| elif value is None: | |
| return "unknown" | |
| else: | |
| return value | |
| def save_flag_data(save_type, image, model_name, prompt_name, temperatura, current_prompt_text, model_output, | |
| json_output, | |
| store_info_output, items_list, comments_output, system_instruction, | |
| flagging_dir="custom_flagged_data"): | |
| save_button_update = gr.update(interactive=False) | |
| image_link, json_link, excel_link = None, None, None | |
| try: | |
| # List files in the directory | |
| try: | |
| files = [f for f in os.listdir(flagging_dir) if os.path.isfile(os.path.join(flagging_dir, f))] | |
| if files: | |
| print("Files in directory:", flagging_dir) | |
| for file in files: | |
| print(file) | |
| else: | |
| print(f"No files found in directory: {flagging_dir}") | |
| except Exception as e: | |
| print(f"Error listing files in directory: {e}") | |
| image_file_path = image | |
| print("save_type:", save_type) | |
| print("Image File Path:", image) | |
| print("prompt_name:", prompt_name) | |
| print("Model Name:", model_name) | |
| print("Result as JSON:", json_output) | |
| print("comments:", comments_output) | |
| print("system_instruction:", system_instruction) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| original_filename = os.path.basename(image_file_path) | |
| filename, file_extension = os.path.splitext(original_filename) | |
| base_filename = f"{filename}_{model_name}_{prompt_name}_{timestamp}" | |
| # Save image | |
| image_save_path = os.path.join(flagging_dir, f"{base_filename}{file_extension}") | |
| image = Image.open(image_file_path) | |
| image.save(image_save_path) | |
| if os.path.exists(image_save_path): | |
| saved_image = Image.open(image_save_path) | |
| image_size = saved_image.size | |
| print(f"Image saved at: {image_save_path}, Size: {image_size}") | |
| else: | |
| print(f"Failed to save image at: {image_save_path}") | |
| return 0 | |
| # Save result as JSON | |
| json_file_path = os.path.join(flagging_dir, f"{base_filename}.json") | |
| data_to_save = { | |
| "image_name": f"{base_filename}{file_extension}", | |
| "prompt_name": prompt_name, | |
| "system_instruction": system_instruction, | |
| "prompt": current_prompt_text, | |
| "model_name": model_name, | |
| "result_json": json_output, | |
| "comments": comments_output, | |
| "save_type": save_type | |
| } | |
| data_to_save_encode = json.dumps(data_to_save, ensure_ascii=False, indent=4) | |
| print("data_to_save_encode: ", data_to_save_encode) | |
| with open(json_file_path, 'w', encoding='utf-8') as json_file: | |
| json_file.write(data_to_save_encode) | |
| excel_file_path = os.path.join(flagging_dir, f"{base_filename}.xlsx") | |
| try: | |
| save_to_excel(json_output, excel_file_path, image_file_path) | |
| except Exception as e: | |
| print(f"Error while saving to excel: {e}") | |
| # Upload files to Google Drive | |
| google_drive_client_current = GoogleDriveClient(json_key_path='secrets/GOOGLE_SERVICE_ACCOUNT_KEY.json') | |
| if google_drive_client_current: | |
| try: | |
| image_folder_id = '10qtum6ykbGTyu7vvw59i3h1XSY3-lRpo' | |
| image_link = google_drive_client_current.upload_file(image_save_path, image_folder_id) | |
| json_link = google_drive_client_current.upload_file(json_file_path, image_folder_id) | |
| excel_link = google_drive_client_current.upload_file(excel_file_path, image_folder_id) | |
| print(f"Image uploaded to Google Drive. Link: {image_link}") | |
| print(f"JSON file uploaded to Google Drive. Link: {json_link}") | |
| print(f"Excel file uploaded to Google Drive. Link: {excel_link}") | |
| except Exception as e: | |
| print(f"Error uploading files to Google Drive: {e}") | |
| else: | |
| print(f"Error google_drive_client does not available") | |
| except Exception as e: | |
| print(f"Error while saving flag data: {e}") | |
| links = f"Image: {image_link}\nJSON: {json_link}\nExcel: {excel_link} \n shared lofder: https://drive.google.com/drive/folders/10qtum6ykbGTyu7vvw59i3h1XSY3-lRpo?usp=drive_link \n" | |
| return save_button_update, save_button_update, save_button_update, links | |
| def update_prompt_from_radio(prompt_name): | |
| if prompt_name == "prompt_v1": | |
| return read_prompt_from_file("common/prompt_v1.txt") | |
| elif prompt_name == "prompt_v2": | |
| return read_prompt_from_file("common/prompt_v2.txt") | |
| elif prompt_name == "prompt_v3": | |
| return read_prompt_from_file("common/prompt_v3.txt") | |
| else: | |
| return read_prompt_from_file("common/prompt_v1.txt") | |
| #google_drive_client = GoogleDriveClient(json_key_path='secrets/GOOGLE_SERVICE_DRIVE_KEY_435817.json') | |
| #vertex_ai_client = VertexAIService(json_key_path='secrets/GOOGLE_VERTEX_AI_KEY_435817.json') | |
| google_drive_client = GoogleDriveClient() | |
| vertex_ai_client = VertexAIService() | |
| key = None | |
| key_file_path = 'secrets/OPENAI_AI_KEY.txt' | |
| if os.path.exists(key_file_path): | |
| try: | |
| with open(key_file_path, 'r') as key_file: | |
| key = key_file.read().strip() | |
| except Exception as e: | |
| print(f"Error reading file: {e}") | |
| open_ai_client = OpenAIService(api_key=key) | |
| with gr.Blocks() as iface: | |
| gr.Markdown("# ReceiptAI") | |
| gr.Markdown("ReceiptAI") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="filepath") | |
| model_radio = gr.Radio(model_names, label="Choose model/QR-processing(Slovakia)", value=model_names[0]) | |
| prompt_radio = gr.Radio(prompt_names, label="Choose prompt", value=prompt_names[0], visible=isFullVersion) | |
| temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperatura", value=0.0, | |
| visible=isFullVersion) | |
| system_instruction = gr.Textbox(label="System Instruction", visible=isFullVersion, value=system_instruction) | |
| custom_prompt = gr.Textbox(label="prompt text", visible=isFullVersion, value=prompt_default) | |
| with gr.Row(): | |
| submit_button = gr.Button("Receipt recognizing ") | |
| with gr.Column(scale=2): | |
| model_output = gr.Textbox(label="MODEL/QR-processing(Slovakia)", lines=1, interactive=isFullVersion) | |
| json_output = gr.Textbox(label="Result as json") | |
| store_info_output = gr.Textbox(label="Store Information", lines=4) | |
| items_list = gr.Dataframe( | |
| headers=["Item Name", "Category", "Unit Price", "Quantity", "Unit", "Total Price", "Discount", | |
| "Item price with tax", "Grand Total"], | |
| label="Items List") | |
| taxes_list = gr.Dataframe( | |
| headers=["Tax Name", "%", "tax from amount", "tax", "total", "tax included"], | |
| label="Tax List") | |
| comments_output = gr.Textbox(label="Comments", visible=True, lines=4, interactive=True) | |
| with gr.Row(): | |
| save_good_button = gr.Button(value="Save as Good", interactive=False) | |
| save_average_button = gr.Button(value="Save as Average", interactive=False) | |
| save_poor_button = gr.Button(value="Save as Poor", interactive=False) | |
| file_links_output = gr.Textbox(label="File Links", interactive=False, visible=True) | |
| submit_button.click(fn=process_image, | |
| inputs=[image_input, model_radio, prompt_radio, temperature_slider, system_instruction, | |
| custom_prompt], | |
| outputs=[model_output, json_output, store_info_output, items_list, taxes_list, comments_output, | |
| save_good_button, save_average_button, save_poor_button, file_links_output]) | |
| common_inputs = [image_input, model_radio, prompt_radio, temperature_slider, custom_prompt, model_output, | |
| json_output, store_info_output, items_list, comments_output, system_instruction] | |
| def save_flag_data_wrapper(save_type, image, model_name, prompt_name, temperatura, custom_prompt, model_output, | |
| json_output, store_info_output, items_list, comments_output, system_instruction): | |
| # Ensure that `image` is a file path and not an object. | |
| image_file_path = image # Gradio returns the path as a string | |
| model_name_value = model_name # Extract selected value | |
| prompt_name_value = prompt_name # Extract selected value | |
| # The following variables should be passed as the values they hold | |
| save_good_update, save_avg_update, save_poor_update, file_links = save_flag_data( | |
| save_type, image, model_name, prompt_name, temperatura, custom_prompt, model_output, json_output, | |
| store_info_output, items_list, comments_output, system_instruction | |
| ) | |
| return save_good_update, save_avg_update, save_poor_update, file_links | |
| # Use the same common_inputs for all buttons but ensure the correct values are passed | |
| save_good_button.click( | |
| fn=lambda *args: save_flag_data_wrapper("Good", *args), | |
| inputs=common_inputs, | |
| outputs=[save_good_button, save_average_button, save_poor_button, file_links_output] | |
| ) | |
| save_average_button.click( | |
| fn=lambda *args: save_flag_data_wrapper("Average", *args), | |
| inputs=common_inputs, | |
| outputs=[save_good_button, save_average_button, save_poor_button, file_links_output] | |
| ) | |
| save_poor_button.click( | |
| fn=lambda *args: save_flag_data_wrapper("Poor", *args), | |
| inputs=common_inputs, | |
| outputs=[save_good_button, save_average_button, save_poor_button, file_links_output] | |
| ) | |
| prompt_radio.change(fn=update_prompt_from_radio, inputs=[prompt_radio], outputs=[custom_prompt]) | |
| gr.Examples(examples=example_list_sl, | |
| inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt], | |
| label="Examples for Slovakia") | |
| if isFullVersion: | |
| gr.Examples(examples=example_list_ua, | |
| inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt], | |
| label="Examples for Ukrainian") | |
| gr.Examples(examples=example_list_us, | |
| inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt], | |
| label="Examples for US") | |
| gr.Examples(examples=example_list_canada, | |
| inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt], | |
| label="Examples for Canada") | |
| gr.Examples(examples=example_france, | |
| inputs=[image_input, model_radio, prompt_radio, temperature_slider, custom_prompt], | |
| label="Examples for France") | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |