import gradio as gr import pandas as pd import re from PIL import Image, ImageDraw, ImageFont import torch from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering, LayoutLMv3ForTokenClassification processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base") # More traditional approach that works from token classification basis (not questions) model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print_device_name = torch.cuda.get_device_name(torch.cuda.current_device()) print(f"Debug -- Using device: {device}") print(f"Debug -- Current Device Name: {print_device_name}") model.to(device) labels = model.config.id2label print(labels) # Homemade feature extraction def extract_features(tokens, labels): merged_entities = [] current_date = "" print(f"Debug -- Starting entity extraction") #date_pattern = r"\d{1,2}/\d{1,2}/\d{2,4}" # Matches full date formats like MM/DD/YYYY or DD/MM/YYYY #partial_date_pattern = r"\d{1,2}$|[/-]$" # Matches partial date components like "12" or "/" at the end #date_pattern = r"\d{1,2}/\d{1,2}/\d{2,4}" # Matches full date formats like MM/DD/YYYY or DD/MM/YYYY #partial_date_pattern = r"^\d{1,2}/?$|^[/-]$" # Matches partial date components like "12", "/", "02/", etc. date_pattern = r"^\d{2}/\d{2}/\d{2}(\d{2})?$" partial_date_pattern = r"^\d{1,2}/?$|^/$" # This is a label AGNOSTIC approach for token, label in zip(tokens, labels): print(f"Debug -- Processing token: {token}") # If we already have some part of a date and the next token could still be part of it, continue accumulating if current_date and re.match(partial_date_pattern, token): current_date += token print(f"Debug -- Potential partial date: {current_date}") # If the accumulated entity matches a complete date after appending this token elif re.match(date_pattern, current_date + token): current_date += token merged_entities.append((current_date, 'date')) print(f"Debug -- Complete date added: {current_date}") current_date = "" # Reset for next entity # If the token could start a new date (e.g., '14' could be a day or hour) elif re.match(partial_date_pattern, token): current_date = token print(f"Debug -- Potentially starting a new date: {token}") else: # If no patterns are detected and there is any accumulated data #if current_date: # # Finalize accumulated partial date # print(f"Debug -- Date finalized: {current_date}") # merged_entities.append((current_date, 'date')) # current_date = "" # Reset for next entity # Append token as non-date print(f"Debug -- Appending non-date Token: {token}") merged_entities.append((token, 'non-date')) # If there's any leftover accumulated date data, add it to merged_entities if current_date: print(f"Debug -- Dangling leftover date added: {current_date}") merged_entities.append((current_date, 'date')) return merged_entities # NOTE: labels aren't being applied properly ... This is the LABEL approach # # Loop through tokens and labels #for token, label in zip(tokens, labels): # print(f"Debug -- Potentially creating date,, token: {token} label: {label}") # # if label == 'LABEL_1': # # Check for partial date fragments (like '12' or '/') # if re.match(date_pattern, current_date): # merged_entities.append((current_date, 'date')) # print(f"Debug -- Complete date added: {token}") # current_date = "" # Reset for next entity # # If the accumulated entity matches a full date # elif re.match(partial_date_pattern, token): # print(f"Debug -- Potentially building date: Token Start {token} After Token") # current_date += token # Append token to the current entity # else: # # No partial or completed patterns are detected, but it's still LABEL_1 # # If there were any accumulated data so far # if current_date: # merged_entities.append((current_date, 'date')) # print(f"Debug -- Date finalized: {current_date}") # current_date = "" # Reset # # merged_entities.append((token, label)) # else: # # These are LABEL_0, supposedly trash but keep them for now # if current_date: # If there's a leftover date fragment, add it first # merged_entities.append((current_date, 'date')) # print(f"Debug -- Finalizing leftover date added: {current_date}") # current_date = "" # Reset # # # Append LABEL_0 token # print(f"Debug -- Appending LABEL_0 Token: Token Start {token} Token After") # merged_entities.append((token, label)) # # if current_date: # print(f"Debug -- Dangling leftover date added: {current_date}") # merged_entities.append((current_date, 'date')) # # return merged_entities # process the image in the correct format # extract token classifications def parse_ticket_image(image): # Process image if image: document = image.convert("RGB") if image.mode != "RGB" else image else: print(f"Warning - no image or malformed image!") return pd.DataFrame() # Encode document image encoding = processor(document, return_tensors="pt", truncation=True) # Move encoding to appropriate device for k, v in encoding.items(): encoding[k] = v.to(device) # Perform inference outputs = model(**encoding) # extract predictions predictions = outputs.logits.argmax(-1).squeeze().tolist() input_ids = encoding.input_ids.squeeze().tolist() words = [processor.tokenizer.decode(id) for id in input_ids] extracted_fields = [] for idx, pred in enumerate(predictions): label = model.config.id2label[pred] extracted_fields.append((label, words[idx])) # apparently stands for non-entity tokens #if label != 'LABEL_0' and '<' not in words[idx]: # extracted_fields.append((label, words[idx])) if len(extracted_fields) == 0: print(f"Warning - no fields were extracted!") return pd.DataFrame(columns=["Field", "Value"]) # Create lists for fields and values fields = [field[0] for field in extracted_fields] values = [field[1] for field in extracted_fields] # Ensure both lists have the same length min_length = min(len(fields), len(values)) fields = fields[:min_length] values = values[:min_length] #Homemade feature extraction values = extract_features(values, fields) #Ensure both lists have the same length min_length = min(len(fields), len(values)) fields = fields[:min_length] values = values[:min_length] data = { "Field": fields, "Value": values } df = pd.DataFrame(data) return df # This is how to use questions to find answers in the document # Less traditional approach, less flexibility, easier to implement/understand (didnt provide robust answers) #model = LayoutLMv3ForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base") #def process_question(question, document): # #print(f"Debug - Processing Question: {question}") # # encoding = processor(document, question, return_tensors="pt") # #print(f"Debug - Encoding Input IDs: {encoding.input_ids}") # # outputs = model(**encoding) # #print(f"Debug - Model Outputs: {outputs}") # # predicted_start_idx = outputs.start_logits.argmax(-1).item() # predicted_end_idx = outputs.end_logits.argmax(-1).item() # # # Check if indices are valid # if predicted_start_idx < 0 or predicted_end_idx < 0: # print(f"Warning - Invalid prediction indices: start={predicted_start_idx}, end={predicted_end_idx}") # return "" # # answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx: predicted_end_idx + 1] # answer = processor.tokenizer.decode(answer_tokens) # # return answer # Older iteration of the code, retaining for emergencies ? #def process_question(question, document): # if not question or document is None: # return None, None, None # # text_value = None # predictions = run_pipeline(question, document) # # for i, p in enumerate(ensure_list(predictions)): # if i == 0: # text_value = p["answer"] # else: # # Keep the code around to produce multiple boxes, but only show the top # # prediction for now # break # # return text_value #def parse_ticket_image(image, question): # """Basically just runs through these questions for the document""" # # Processing the image # if image: # try: # if image.mode != "RGB": # document = image.convert("RGB") # else: # document = image # except Exception as e: # traceback.print_exc() # error = str(e) # # # # Define questions you want to ask the model # # questions = [ # "What is the ticket number?", # "What is the type of grain (For example: corn, soybeans, wheat)?", # "What is the date?", # "What is the time?", # "What is the gross weight?", # "What is the tare weight?", # "What is the net weight?", # "What is the moisture (moist) percentage?", # "What is the damage percentage?", # "What is the gross units?", # "What is the dock units?", # "What is the comment?", # "What is the assembly number?", # ] # # # Use the model to answer each question # answers = {} # for q in questions: # print(f"Question: {q}") # answer_text = process_question(q, document) # print(f"Answer Text extracted here: {answer_text}") # answers[q] = answer_text # # # ticket_number = answers["What is the ticket number?"] # grain_type = answers["What is the type of grain (For example: corn, soybeans, wheat)?"] # date = answers["What is the date?"] # time = answers["What is the time?"] # gross_weight = answers["What is the gross weight?"] # tare_weight = answers["What is the tare weight?"] # net_weight = answers["What is the net weight?"] # moisture = answers["What is the moisture (moist) percentage?"] # damage = answers["What is the damage percentage?"] # gross_units = answers["What is the gross units?"] # dock_units = answers["What is the dock units?"] # comment = answers["What is the comment?"] # assembly_number = answers["What is the assembly number?"] # # # # Create a structured format (like a table) using pandas # data = { # "Ticket Number": [ticket_number], # "Grain Type": [grain_type], # "Assembly Number": [assembly_number], # "Date": [date], # "Time": [time], # "Gross Weight": [gross_weight], # "Tare Weight": [tare_weight], # "Net Weight": [net_weight], # "Moisture": [moisture], # "Damage": [damage], # "Gross Units": [gross_units], # "Dock Units": [dock_units], # "Comment": [comment], # } # df = pd.DataFrame(data) # # return df """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ demo = gr.Interface( fn=parse_ticket_image, inputs=[gr.Image(label= "Upload your Grain Scale Ticket", type="pil")], outputs=[gr.Dataframe(headers=["Field", "Value"], label="Extracted Grain Scale Ticket Data")], ) if __name__ == "__main__": demo.launch()