Spaces:
Sleeping
Sleeping
| import base64 | |
| import os | |
| import re | |
| from io import BytesIO | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| import json | |
| from langchain.schema.output_parser import OutputParserException | |
| from PIL import Image | |
| from openpyxl import load_workbook | |
| from openpyxl.utils import get_column_letter | |
| import categories | |
| from categories import Category | |
| from main import process_image, process_pdf | |
| from forex_python.converter import CurrencyRates | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| PDF_IFRAME = """ | |
| <div style="border-radius: 10px; width: 100%; overflow: hidden;"> | |
| <iframe | |
| src="data:application/pdf;base64,{0}" | |
| width="100%" | |
| height="400" | |
| type="application/pdf"> | |
| </iframe> | |
| </div>""" | |
| hf_writer_normal = gr.HuggingFaceDatasetSaver( | |
| HF_TOKEN, "automatic-reimbursement-tool-demo", separate_dirs=False | |
| ) | |
| hf_writer_incorrect = gr.HuggingFaceDatasetSaver( | |
| HF_TOKEN, "automatic-reimbursement-tool-demo-incorrect", separate_dirs=False | |
| ) | |
| # with open("examples/example1.pdf", "rb") as pdf_file: | |
| # base64_pdf = base64.b64encode(pdf_file.read()) | |
| # example_paths = [] | |
| # current_file_path = None | |
| # def ignore_examples(function): | |
| # def new_function(*args, **kwargs): | |
| # global example_paths, current_file_path | |
| # if current_file_path not in example_paths: | |
| # return function(*args, **kwargs) | |
| def display_file(input_files): | |
| global current_file_paths | |
| # Initialize the list of current file paths | |
| current_file_paths = [file.name for file in input_files] | |
| if not input_files: | |
| return gr.HTML.update(visible=False), gr.Image.update(visible=False) | |
| # Check if there's any PDF file among the uploaded files | |
| pdf_base64 = None | |
| for input_file in input_files: | |
| if input_file.name.endswith(".pdf"): | |
| with open(input_file.name, "rb") as pdf_file: | |
| pdf_base64 = base64.b64encode(pdf_file.read()).decode() | |
| break # Assuming only one PDF is present | |
| if pdf_base64: | |
| return gr.HTML.update(PDF_IFRAME.format(pdf_base64), visible=True), gr.Image.update(visible=False) | |
| else: | |
| # You can choose to display the first image in the list or handle multiple images differently | |
| image = Image.open(input_files[0].name) | |
| return gr.HTML.update(visible=False), gr.Image.update(image, visible=True) | |
| def show_intermediate_outputs(show_intermediate): | |
| if show_intermediate: | |
| return gr.Accordion.update(visible=True) | |
| else: | |
| return gr.Accordion.update(visible=False) | |
| def show_share_contact(share_result): | |
| return gr.Textbox.update(visible=share_result) | |
| def clear_inputs(): | |
| return gr.File.update(value=None) | |
| def clear_outputs(input_file): | |
| if input_file: | |
| return None, None, None, None | |
| def extract_text(input_file): | |
| """Takes the input file and updates the extracted text""" | |
| if not input_file: | |
| gr.Error("Please upload a file to continue!") | |
| return gr.Textbox.update() | |
| # Send change to preprocessed image or to extracted text | |
| if input_file.name.endswith(".pdf"): | |
| text = process_pdf(Path(input_file.name), extract_only=True) | |
| else: | |
| text = process_image(Path(input_file.name), extract_only=True) | |
| return text | |
| def find_currency_symbol(text): | |
| currency_symbols = { | |
| 'USD': ['$', 'US$', 'US Dollar', 'United States Dollar'], | |
| 'EUR': ['€', 'Euro'], | |
| 'GBP': ['£', 'British Pound', 'Pound Sterling'], | |
| 'JPY': ['¥', 'Japanese Yen'], | |
| 'AUD': ['A$', 'AU$', 'Australian Dollar'], | |
| 'CAD': ['C$', 'CA$', 'Canadian Dollar'], | |
| 'CHF': ['Swiss Franc'], | |
| 'CNY': ['CN¥', 'Chinese Yuan', 'Renminbi'], | |
| 'HKD': ['HK$', 'Hong Kong Dollar'], | |
| 'NZD': ['NZ$', 'New Zealand Dollar'], | |
| 'SEK': ['Swedish Krona'], | |
| 'KRW': ['₩', 'South Korean Won'], | |
| 'SGD': ['S$', 'Singapore Dollar'], | |
| 'NOK': ['Norwegian Krone'], | |
| 'MXN': ['Mexican Peso'], | |
| 'INR': ['₹', 'Indian Rupee'], | |
| 'RUB': ['₽', 'Russian Ruble'], | |
| 'ZAR': ['South African Rand'], | |
| 'BRL': ['R$', 'Brazilian Real'], | |
| } | |
| detected_currency = None | |
| for currency, symbols in currency_symbols.items(): | |
| for symbol in symbols: | |
| if symbol in text: | |
| detected_currency = currency | |
| break | |
| if detected_currency: | |
| break | |
| return detected_currency | |
| def get_exchange_rate_to_inr(currency): | |
| c = CurrencyRates() | |
| if currency == 'INR' or currency == None: | |
| return 1 | |
| else: | |
| try: | |
| exchange_rate = c.get_rate(currency, 'INR') | |
| return exchange_rate | |
| except: | |
| return None | |
| def categorize_text(text): | |
| """Takes the extracted text and updates the category""" | |
| category = categories.categorize_text(text) | |
| return category | |
| def query(category, text): | |
| """Takes the extracted text and category and updates the chatbot in two steps: | |
| 1. Construct a prompt | |
| 2. Generate a response | |
| """ | |
| #category = Category[category] | |
| chain = categories.category_modules[category].chain | |
| formatted_prompt = chain.prompt.format_prompt( | |
| text=text, | |
| format_instructions=chain.output_parser.get_format_instructions(), | |
| ) | |
| question = f"" | |
| if len(formatted_prompt.messages) > 1: | |
| question += f"**System:**\n{formatted_prompt.messages[0].content}" | |
| question += f"\n\n**Human:**\n{formatted_prompt.messages[-1].content}" | |
| yield gr.Chatbot.update([[question, "Generating..."]]) | |
| result = chain.generate( | |
| input_list=[ | |
| { | |
| "text": text, | |
| "format_instructions": chain.output_parser.get_format_instructions(), | |
| } | |
| ] | |
| ) | |
| answer = result.generations[0][0].text | |
| yield gr.Chatbot.update([[question, answer]]) | |
| PARSING_REGEXP = r"\*\*System:\*\*\n([\s\S]+)\n\n\*\*Human:\*\*\n([\s\S]+)" | |
| def parse(category, chatbot): | |
| """Takes the chatbot prompt and response and updates the extracted information""" | |
| global PARSING_REGEXP | |
| chatbot_responses = [] | |
| for response in chatbot: | |
| chatbot_responses.append(response[1]) | |
| if not chatbot_responses: | |
| # Handle the case when there are no chatbot responses | |
| return {"status": "No responses available"} | |
| answer = chatbot_responses[-1] | |
| # try: | |
| # answer = next(chatbot)[1] | |
| # except StopIteration: | |
| # answer = "" | |
| if category not in Category.__members__: | |
| # Handle the case when an invalid category is provided | |
| answer="test" | |
| #category = Category[category] | |
| chain = categories.category_modules[category].chain | |
| yield {"status": "Parsing response..."} | |
| try: | |
| information = chain.output_parser.parse(answer) | |
| information = information.json() if information else {} | |
| except OutputParserException as e: | |
| information = { | |
| "details": str(e), | |
| "output": e.llm_output, | |
| } | |
| yield information | |
| def activate_flags(): | |
| return gr.Button.update(interactive=True), gr.Button.update(interactive=True) | |
| def deactivate_flags(): | |
| return gr.Button.update(interactive=False), gr.Button.update(interactive=False) | |
| def flag_if_shared(flag_method): | |
| def proxy(share_result, request: gr.Request, *args, **kwargs): | |
| if share_result: | |
| return flag_method(request, *args, **kwargs) | |
| return proxy | |
| def save_df_to_excel_with_autowidth(df, filename): | |
| # Save DataFrame to Excel without any formatting | |
| df.to_excel(filename, index=False, engine='openpyxl') | |
| # Open the Excel file with openpyxl to adjust column widths | |
| book = load_workbook(filename) | |
| sheet = book.active | |
| # Loop through columns and adjust the width based on max length in each column | |
| for column in sheet.columns: | |
| max_length = 0 | |
| column = [cell for cell in column] | |
| for cell in column: | |
| try: | |
| if len(str(cell.value)) > max_length: | |
| max_length = len(cell.value) | |
| except: | |
| pass | |
| adjusted_width = (max_length + 2) # adding a little extra space | |
| sheet.column_dimensions[get_column_letter(column[0].column)].width = adjusted_width | |
| # Save the changes back to the Excel file | |
| book.save(filename) | |
| def process_and_output_files(input_files): | |
| data = [] | |
| total_amount = 0 | |
| item_no = 1 | |
| for file in input_files: | |
| # Extract and categorize text for each file | |
| text = extract_text(file) | |
| currency = find_currency_symbol(text) | |
| category = categorize_text(text) | |
| chatbot_response = query(category, text) # Convert the generator to a list | |
| #parsed_info = parse(category, chatbot_response) | |
| chats=list(chatbot_response) | |
| # Append the relevant data for this file to the output_data list | |
| # data.append( | |
| # #"File Name": file.name, | |
| # #"Extracted Text": text, | |
| # #"Category": category, | |
| # #"Chatbot Response": chatbot_response, # Access the first element as a list | |
| # #"trial" : chats, | |
| # chats[1]["value"][0][1] , | |
| # ) | |
| exchange_rate = get_exchange_rate_to_inr(currency) | |
| exchange_rate = float("{:.2f}".format(exchange_rate)) | |
| response_dict = json.loads(chats[1]["value"][0][1]) | |
| if category.name == "TRAVEL_CAB" : | |
| # Extract the relevant data | |
| extracted_data = { | |
| "S.No.": item_no, | |
| "Nature of Expenditure": response_dict.get("summary"), | |
| "Billing Date": response_dict.get("issue_date"), | |
| "Bill/Invoice No.": "NA", | |
| "Amount(Rs.)": response_dict.get("total") * exchange_rate, | |
| } | |
| else: | |
| extracted_data = { | |
| "S.No.": item_no, | |
| "Nature of Expenditure": response_dict.get("summary"), | |
| "Billing Date": response_dict.get("issue_date"), | |
| "Bill/Invoice No.": response_dict.get("uids"), | |
| "Amount(Rs.)": response_dict.get("total") * exchange_rate | |
| } | |
| total_amount+=response_dict.get("total") * exchange_rate | |
| # Append the relevant data for this file to the data list | |
| data.append(extracted_data) | |
| item_no=item_no+1 | |
| total_data = { | |
| "S.No.": "", | |
| "Nature of Expenditure": "Total Amount", | |
| "Billing Date": "", | |
| "Bill/Invoice No.": "", | |
| "Amount(Rs.)": total_amount | |
| } | |
| data.append(total_data) | |
| string_data = [] | |
| for item in data: | |
| string_item = {key: str(value) for key, value in item.items()} | |
| string_data.append(string_item) | |
| df = pd.DataFrame(string_data) | |
| filename = "output.xlsx" | |
| save_df_to_excel_with_autowidth(df, filename) | |
| table_html = df.to_html(classes="table table-bordered", index=True) | |
| scrollable_table = f'<div style="overflow-x: auto;">{table_html}</div>' | |
| return scrollable_table, filename | |
| #return data | |
| with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page: | |
| gr.Markdown("<center><h1>Automatic Reimbursement Tool Demo</h1></center>") | |
| gr.Markdown("<h2>Description</h2>") | |
| gr.Markdown( | |
| "The reimbursement filing process can be time-consuming and cumbersome, causing " | |
| "frustration for faculty members and finance departments. Our project aims to " | |
| "automate the information extraction involved in the process by feeding " | |
| "extracted text to language models such as ChatGPT. This demo showcases the " | |
| "categorization and extraction parts of the pipeline. Categorization is done " | |
| "to identify the relevant details associated with the text, after which " | |
| "extraction is done for those details using a language model." | |
| ) | |
| gr.Markdown("<h2>Try it out!</h2>") | |
| with gr.Box() as demo: | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| gr.HTML( | |
| '<div><center style="color:rgb(200, 200, 200);">Input</center></div>' | |
| ) | |
| pdf_preview = gr.HTML(label="Preview", show_label=True, visible=False) | |
| image_preview = gr.Image( | |
| label="Preview", show_label=True, visible=False, height=350 | |
| ) | |
| input_file = gr.File( | |
| label="Input receipt", | |
| show_label=True, | |
| type="file", | |
| file_count="multiple", | |
| file_types=["image", ".pdf"], | |
| ) | |
| input_file.change( | |
| display_file, input_file, [pdf_preview, image_preview] | |
| ) | |
| with gr.Row(): | |
| clear = gr.Button("Clear", variant="secondary") | |
| submit_button = gr.Button("Submit", variant="primary") | |
| show_intermediate = gr.Checkbox( | |
| False, | |
| label="Show intermediate outputs", | |
| info="There are several intermediate steps in the process such as " | |
| "preprocessing, OCR, chatbot interaction. You can choose to " | |
| "show their results here.", | |
| visible=False, # Shortcut for removal | |
| ) | |
| share_result = gr.Checkbox( | |
| True, | |
| label="Share results", | |
| info="Sharing your result with us will help us improve this tool.", | |
| interactive=True, | |
| ) | |
| contact = gr.Textbox( | |
| type="email", | |
| label="Contact", | |
| interactive=True, | |
| placeholder="Enter your email address", | |
| info="Optionally, enter your email address to allow us to contact " | |
| "you regarding your result.", | |
| visible=True, | |
| ) | |
| share_result.change(show_share_contact, share_result, [contact]) | |
| with gr.Column(variant="panel"): | |
| gr.HTML( | |
| '<div><center style="color:rgb(200, 200, 200);">Output</center></div>' | |
| ) | |
| category = gr.Dropdown( | |
| value=None, | |
| choices=Category.__members__.keys(), | |
| label=f"Recognized category ({', '.join(Category.__members__.keys())})", | |
| show_label=True, | |
| interactive=False, | |
| ) | |
| intermediate_outputs = gr.Accordion( | |
| "Intermediate outputs", open=True, visible=False | |
| ) | |
| with intermediate_outputs: | |
| extracted_text = gr.Textbox( | |
| label="Extracted text", | |
| show_label=True, | |
| max_lines=5, | |
| show_copy_button=True, | |
| lines=5, | |
| interactive=False, | |
| ) | |
| chatbot = gr.Chatbot( | |
| None, | |
| label="Chatbot interaction", | |
| show_label=True, | |
| interactive=False, | |
| height=240, | |
| ) | |
| #information = gr.JSON(label="Extracted information") | |
| table_display = gr.HTML(label="Table Display") | |
| excel_download = gr.File(label="Download Excel", type="file") | |
| with gr.Row(): | |
| flag_incorrect_button = gr.Button( | |
| "Flag as incorrect", variant="stop", interactive=True | |
| ) | |
| flag_irrelevant_button = gr.Button( | |
| "Flag as irrelevant", variant="stop", interactive=True | |
| ) | |
| show_intermediate.change( | |
| show_intermediate_outputs, show_intermediate, [intermediate_outputs] | |
| ) | |
| clear.click(clear_inputs, None, [input_file]).then( | |
| deactivate_flags, | |
| None, | |
| [flag_incorrect_button, flag_irrelevant_button], | |
| ) | |
| hf_writer_normal.setup( | |
| [input_file, extracted_text, category, chatbot, table_display, contact], | |
| flagging_dir="flagged", | |
| ) | |
| flag_method = gr.flagging.FlagMethod( | |
| hf_writer_normal, "", "", visual_feedback=False | |
| ) | |
| submit_button.click( | |
| clear_outputs, | |
| [input_file], | |
| [extracted_text, category, chatbot, table_display], | |
| ).then( | |
| process_and_output_files, | |
| [input_file], | |
| [table_display, excel_download], # Adding excel_download here | |
| ).then( | |
| flag_if_shared(flag_method), | |
| [ | |
| share_result, | |
| input_file, | |
| extracted_text, | |
| category, | |
| chatbot, | |
| table_display, | |
| contact, | |
| ], | |
| None, | |
| preprocess=False, | |
| ) | |
| hf_writer_incorrect.setup( | |
| [input_file, extracted_text, category, chatbot, table_display, contact], | |
| flagging_dir="flagged_incorrect", | |
| ) | |
| flag_incorrect_method = gr.flagging.FlagMethod( | |
| hf_writer_incorrect, | |
| "Flag as incorrect", | |
| "Incorrect", | |
| visual_feedback=True, | |
| ) | |
| flag_incorrect_button.click( | |
| lambda: gr.Button.update(value="Saving...", interactive=False), | |
| None, | |
| flag_incorrect_button, | |
| queue=False, | |
| ) | |
| flag_incorrect_button.click( | |
| flag_incorrect_method, | |
| inputs=[ | |
| input_file, | |
| extracted_text, | |
| category, | |
| chatbot, | |
| table_display, | |
| contact, | |
| ], | |
| outputs=[flag_incorrect_button], | |
| preprocess=False, | |
| queue=False, | |
| ) | |
| flag_irrelevant_method = gr.flagging.FlagMethod( | |
| hf_writer_incorrect, | |
| "Flag as irrelevant", | |
| "Irrelevant", | |
| visual_feedback=True, | |
| ) | |
| flag_irrelevant_button.click( | |
| lambda: gr.Button.update(value="Saving...", interactive=False), | |
| None, | |
| flag_irrelevant_button, | |
| queue=False, | |
| ) | |
| flag_irrelevant_button.click( | |
| flag_irrelevant_method, | |
| inputs=[ | |
| input_file, | |
| extracted_text, | |
| category, | |
| chatbot, | |
| table_display, | |
| contact, | |
| ], | |
| outputs=[flag_irrelevant_button], | |
| preprocess=False, | |
| queue=False, | |
| ) | |
| page.queue( | |
| concurrency_count=20, | |
| max_size=1, | |
| ) | |
| page.launch(show_api=True, show_error=True, debug=True) | |