Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pandas as pd | |
| import re | |
| from PIL import Image, ImageDraw, ImageFont | |
| import torch | |
| from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering, LayoutLMv3ForTokenClassification | |
| processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base") | |
| # More traditional approach that works from token classification basis (not questions) | |
| model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print_device_name = torch.cuda.get_device_name(torch.cuda.current_device()) | |
| print(f"Debug -- Using device: {device}") | |
| print(f"Debug -- Current Device Name: {print_device_name}") | |
| model.to(device) | |
| labels = model.config.id2label | |
| print(labels) | |
| # Homemade feature extraction | |
| def extract_features(tokens, labels): | |
| merged_entities = [] | |
| current_date = "" | |
| print(f"Debug -- Starting entity extraction") | |
| #date_pattern = r"\d{1,2}/\d{1,2}/\d{2,4}" # Matches full date formats like MM/DD/YYYY or DD/MM/YYYY | |
| #partial_date_pattern = r"\d{1,2}$|[/-]$" # Matches partial date components like "12" or "/" at the end | |
| #date_pattern = r"\d{1,2}/\d{1,2}/\d{2,4}" # Matches full date formats like MM/DD/YYYY or DD/MM/YYYY | |
| #partial_date_pattern = r"^\d{1,2}/?$|^[/-]$" # Matches partial date components like "12", "/", "02/", etc. | |
| date_pattern = r"^\d{2}/\d{2}/\d{2}(\d{2})?$" | |
| partial_date_pattern = r"^\d{1,2}/?$|^/$" | |
| # This is a label AGNOSTIC approach | |
| for token, label in zip(tokens, labels): | |
| print(f"Debug -- Processing token: {token}") | |
| # If we already have some part of a date and the next token could still be part of it, continue accumulating | |
| if current_date and re.match(partial_date_pattern, token): | |
| current_date += token | |
| print(f"Debug -- Potential partial date: {current_date}") | |
| # If the accumulated entity matches a complete date after appending this token | |
| elif re.match(date_pattern, current_date + token): | |
| current_date += token | |
| merged_entities.append((current_date, 'date')) | |
| print(f"Debug -- Complete date added: {current_date}") | |
| current_date = "" # Reset for next entity | |
| # If the token could start a new date (e.g., '14' could be a day or hour) | |
| elif re.match(partial_date_pattern, token): | |
| current_date = token | |
| print(f"Debug -- Potentially starting a new date: {token}") | |
| else: | |
| # If no patterns are detected and there is any accumulated data | |
| #if current_date: | |
| # # Finalize accumulated partial date | |
| # print(f"Debug -- Date finalized: {current_date}") | |
| # merged_entities.append((current_date, 'date')) | |
| # current_date = "" # Reset for next entity | |
| # Append token as non-date | |
| print(f"Debug -- Appending non-date Token: {token}") | |
| merged_entities.append((token, 'non-date')) | |
| # If there's any leftover accumulated date data, add it to merged_entities | |
| if current_date: | |
| print(f"Debug -- Dangling leftover date added: {current_date}") | |
| merged_entities.append((current_date, 'date')) | |
| return merged_entities | |
| # NOTE: labels aren't being applied properly ... This is the LABEL approach | |
| # | |
| # Loop through tokens and labels | |
| #for token, label in zip(tokens, labels): | |
| # print(f"Debug -- Potentially creating date,, token: {token} label: {label}") | |
| # | |
| # if label == 'LABEL_1': | |
| # # Check for partial date fragments (like '12' or '/') | |
| # if re.match(date_pattern, current_date): | |
| # merged_entities.append((current_date, 'date')) | |
| # print(f"Debug -- Complete date added: {token}") | |
| # current_date = "" # Reset for next entity | |
| # # If the accumulated entity matches a full date | |
| # elif re.match(partial_date_pattern, token): | |
| # print(f"Debug -- Potentially building date: Token Start {token} After Token") | |
| # current_date += token # Append token to the current entity | |
| # else: | |
| # # No partial or completed patterns are detected, but it's still LABEL_1 | |
| # # If there were any accumulated data so far | |
| # if current_date: | |
| # merged_entities.append((current_date, 'date')) | |
| # print(f"Debug -- Date finalized: {current_date}") | |
| # current_date = "" # Reset | |
| # | |
| # merged_entities.append((token, label)) | |
| # else: | |
| # # These are LABEL_0, supposedly trash but keep them for now | |
| # if current_date: # If there's a leftover date fragment, add it first | |
| # merged_entities.append((current_date, 'date')) | |
| # print(f"Debug -- Finalizing leftover date added: {current_date}") | |
| # current_date = "" # Reset | |
| # | |
| # # Append LABEL_0 token | |
| # print(f"Debug -- Appending LABEL_0 Token: Token Start {token} Token After") | |
| # merged_entities.append((token, label)) | |
| # | |
| # if current_date: | |
| # print(f"Debug -- Dangling leftover date added: {current_date}") | |
| # merged_entities.append((current_date, 'date')) | |
| # | |
| # return merged_entities | |
| # process the image in the correct format | |
| # extract token classifications | |
| def parse_ticket_image(image): | |
| # Process image | |
| if image: | |
| document = image.convert("RGB") if image.mode != "RGB" else image | |
| else: | |
| print(f"Warning - no image or malformed image!") | |
| return pd.DataFrame() | |
| # Encode document image | |
| encoding = processor(document, return_tensors="pt", truncation=True) | |
| # Move encoding to appropriate device | |
| for k, v in encoding.items(): | |
| encoding[k] = v.to(device) | |
| # Perform inference | |
| outputs = model(**encoding) | |
| # extract predictions | |
| predictions = outputs.logits.argmax(-1).squeeze().tolist() | |
| input_ids = encoding.input_ids.squeeze().tolist() | |
| words = [processor.tokenizer.decode(id) for id in input_ids] | |
| extracted_fields = [] | |
| for idx, pred in enumerate(predictions): | |
| label = model.config.id2label[pred] | |
| extracted_fields.append((label, words[idx])) | |
| # apparently stands for non-entity tokens | |
| #if label != 'LABEL_0' and '<' not in words[idx]: | |
| # extracted_fields.append((label, words[idx])) | |
| if len(extracted_fields) == 0: | |
| print(f"Warning - no fields were extracted!") | |
| return pd.DataFrame(columns=["Field", "Value"]) | |
| # Create lists for fields and values | |
| fields = [field[0] for field in extracted_fields] | |
| values = [field[1] for field in extracted_fields] | |
| # Ensure both lists have the same length | |
| min_length = min(len(fields), len(values)) | |
| fields = fields[:min_length] | |
| values = values[:min_length] | |
| #Homemade feature extraction | |
| values = extract_features(values, fields) | |
| #Ensure both lists have the same length | |
| min_length = min(len(fields), len(values)) | |
| fields = fields[:min_length] | |
| values = values[:min_length] | |
| data = { | |
| "Field": fields, | |
| "Value": values | |
| } | |
| df = pd.DataFrame(data) | |
| return df | |
| # This is how to use questions to find answers in the document | |
| # Less traditional approach, less flexibility, easier to implement/understand (didnt provide robust answers) | |
| #model = LayoutLMv3ForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base") | |
| #def process_question(question, document): | |
| # #print(f"Debug - Processing Question: {question}") | |
| # | |
| # encoding = processor(document, question, return_tensors="pt") | |
| # #print(f"Debug - Encoding Input IDs: {encoding.input_ids}") | |
| # | |
| # outputs = model(**encoding) | |
| # #print(f"Debug - Model Outputs: {outputs}") | |
| # | |
| # predicted_start_idx = outputs.start_logits.argmax(-1).item() | |
| # predicted_end_idx = outputs.end_logits.argmax(-1).item() | |
| # | |
| # # Check if indices are valid | |
| # if predicted_start_idx < 0 or predicted_end_idx < 0: | |
| # print(f"Warning - Invalid prediction indices: start={predicted_start_idx}, end={predicted_end_idx}") | |
| # return "" | |
| # | |
| # answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx: predicted_end_idx + 1] | |
| # answer = processor.tokenizer.decode(answer_tokens) | |
| # | |
| # return answer | |
| # Older iteration of the code, retaining for emergencies ? | |
| #def process_question(question, document): | |
| # if not question or document is None: | |
| # return None, None, None | |
| # | |
| # text_value = None | |
| # predictions = run_pipeline(question, document) | |
| # | |
| # for i, p in enumerate(ensure_list(predictions)): | |
| # if i == 0: | |
| # text_value = p["answer"] | |
| # else: | |
| # # Keep the code around to produce multiple boxes, but only show the top | |
| # # prediction for now | |
| # break | |
| # | |
| # return text_value | |
| #def parse_ticket_image(image, question): | |
| # """Basically just runs through these questions for the document""" | |
| # # Processing the image | |
| # if image: | |
| # try: | |
| # if image.mode != "RGB": | |
| # document = image.convert("RGB") | |
| # else: | |
| # document = image | |
| # except Exception as e: | |
| # traceback.print_exc() | |
| # error = str(e) | |
| # | |
| # | |
| # # Define questions you want to ask the model | |
| # | |
| # questions = [ | |
| # "What is the ticket number?", | |
| # "What is the type of grain (For example: corn, soybeans, wheat)?", | |
| # "What is the date?", | |
| # "What is the time?", | |
| # "What is the gross weight?", | |
| # "What is the tare weight?", | |
| # "What is the net weight?", | |
| # "What is the moisture (moist) percentage?", | |
| # "What is the damage percentage?", | |
| # "What is the gross units?", | |
| # "What is the dock units?", | |
| # "What is the comment?", | |
| # "What is the assembly number?", | |
| # ] | |
| # | |
| # # Use the model to answer each question | |
| # answers = {} | |
| # for q in questions: | |
| # print(f"Question: {q}") | |
| # answer_text = process_question(q, document) | |
| # print(f"Answer Text extracted here: {answer_text}") | |
| # answers[q] = answer_text | |
| # | |
| # | |
| # ticket_number = answers["What is the ticket number?"] | |
| # grain_type = answers["What is the type of grain (For example: corn, soybeans, wheat)?"] | |
| # date = answers["What is the date?"] | |
| # time = answers["What is the time?"] | |
| # gross_weight = answers["What is the gross weight?"] | |
| # tare_weight = answers["What is the tare weight?"] | |
| # net_weight = answers["What is the net weight?"] | |
| # moisture = answers["What is the moisture (moist) percentage?"] | |
| # damage = answers["What is the damage percentage?"] | |
| # gross_units = answers["What is the gross units?"] | |
| # dock_units = answers["What is the dock units?"] | |
| # comment = answers["What is the comment?"] | |
| # assembly_number = answers["What is the assembly number?"] | |
| # | |
| # | |
| # # Create a structured format (like a table) using pandas | |
| # data = { | |
| # "Ticket Number": [ticket_number], | |
| # "Grain Type": [grain_type], | |
| # "Assembly Number": [assembly_number], | |
| # "Date": [date], | |
| # "Time": [time], | |
| # "Gross Weight": [gross_weight], | |
| # "Tare Weight": [tare_weight], | |
| # "Net Weight": [net_weight], | |
| # "Moisture": [moisture], | |
| # "Damage": [damage], | |
| # "Gross Units": [gross_units], | |
| # "Dock Units": [dock_units], | |
| # "Comment": [comment], | |
| # } | |
| # df = pd.DataFrame(data) | |
| # | |
| # return df | |
| """ | |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface | |
| """ | |
| demo = gr.Interface( | |
| fn=parse_ticket_image, | |
| inputs=[gr.Image(label= "Upload your Grain Scale Ticket", type="pil")], | |
| outputs=[gr.Dataframe(headers=["Field", "Value"], label="Extracted Grain Scale Ticket Data")], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |