rosemariafontana's picture
not adding as date if leftover
afbb6c1 verified
import gradio as gr
import pandas as pd
import re
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering, LayoutLMv3ForTokenClassification
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
# More traditional approach that works from token classification basis (not questions)
model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print_device_name = torch.cuda.get_device_name(torch.cuda.current_device())
print(f"Debug -- Using device: {device}")
print(f"Debug -- Current Device Name: {print_device_name}")
model.to(device)
labels = model.config.id2label
print(labels)
# Homemade feature extraction
def extract_features(tokens, labels):
merged_entities = []
current_date = ""
print(f"Debug -- Starting entity extraction")
#date_pattern = r"\d{1,2}/\d{1,2}/\d{2,4}" # Matches full date formats like MM/DD/YYYY or DD/MM/YYYY
#partial_date_pattern = r"\d{1,2}$|[/-]$" # Matches partial date components like "12" or "/" at the end
#date_pattern = r"\d{1,2}/\d{1,2}/\d{2,4}" # Matches full date formats like MM/DD/YYYY or DD/MM/YYYY
#partial_date_pattern = r"^\d{1,2}/?$|^[/-]$" # Matches partial date components like "12", "/", "02/", etc.
date_pattern = r"^\d{2}/\d{2}/\d{2}(\d{2})?$"
partial_date_pattern = r"^\d{1,2}/?$|^/$"
# This is a label AGNOSTIC approach
for token, label in zip(tokens, labels):
print(f"Debug -- Processing token: {token}")
# If we already have some part of a date and the next token could still be part of it, continue accumulating
if current_date and re.match(partial_date_pattern, token):
current_date += token
print(f"Debug -- Potential partial date: {current_date}")
# If the accumulated entity matches a complete date after appending this token
elif re.match(date_pattern, current_date + token):
current_date += token
merged_entities.append((current_date, 'date'))
print(f"Debug -- Complete date added: {current_date}")
current_date = "" # Reset for next entity
# If the token could start a new date (e.g., '14' could be a day or hour)
elif re.match(partial_date_pattern, token):
current_date = token
print(f"Debug -- Potentially starting a new date: {token}")
else:
# If no patterns are detected and there is any accumulated data
#if current_date:
# # Finalize accumulated partial date
# print(f"Debug -- Date finalized: {current_date}")
# merged_entities.append((current_date, 'date'))
# current_date = "" # Reset for next entity
# Append token as non-date
print(f"Debug -- Appending non-date Token: {token}")
merged_entities.append((token, 'non-date'))
# If there's any leftover accumulated date data, add it to merged_entities
if current_date:
print(f"Debug -- Dangling leftover date added: {current_date}")
merged_entities.append((current_date, 'date'))
return merged_entities
# NOTE: labels aren't being applied properly ... This is the LABEL approach
#
# Loop through tokens and labels
#for token, label in zip(tokens, labels):
# print(f"Debug -- Potentially creating date,, token: {token} label: {label}")
#
# if label == 'LABEL_1':
# # Check for partial date fragments (like '12' or '/')
# if re.match(date_pattern, current_date):
# merged_entities.append((current_date, 'date'))
# print(f"Debug -- Complete date added: {token}")
# current_date = "" # Reset for next entity
# # If the accumulated entity matches a full date
# elif re.match(partial_date_pattern, token):
# print(f"Debug -- Potentially building date: Token Start {token} After Token")
# current_date += token # Append token to the current entity
# else:
# # No partial or completed patterns are detected, but it's still LABEL_1
# # If there were any accumulated data so far
# if current_date:
# merged_entities.append((current_date, 'date'))
# print(f"Debug -- Date finalized: {current_date}")
# current_date = "" # Reset
#
# merged_entities.append((token, label))
# else:
# # These are LABEL_0, supposedly trash but keep them for now
# if current_date: # If there's a leftover date fragment, add it first
# merged_entities.append((current_date, 'date'))
# print(f"Debug -- Finalizing leftover date added: {current_date}")
# current_date = "" # Reset
#
# # Append LABEL_0 token
# print(f"Debug -- Appending LABEL_0 Token: Token Start {token} Token After")
# merged_entities.append((token, label))
#
# if current_date:
# print(f"Debug -- Dangling leftover date added: {current_date}")
# merged_entities.append((current_date, 'date'))
#
# return merged_entities
# process the image in the correct format
# extract token classifications
def parse_ticket_image(image):
# Process image
if image:
document = image.convert("RGB") if image.mode != "RGB" else image
else:
print(f"Warning - no image or malformed image!")
return pd.DataFrame()
# Encode document image
encoding = processor(document, return_tensors="pt", truncation=True)
# Move encoding to appropriate device
for k, v in encoding.items():
encoding[k] = v.to(device)
# Perform inference
outputs = model(**encoding)
# extract predictions
predictions = outputs.logits.argmax(-1).squeeze().tolist()
input_ids = encoding.input_ids.squeeze().tolist()
words = [processor.tokenizer.decode(id) for id in input_ids]
extracted_fields = []
for idx, pred in enumerate(predictions):
label = model.config.id2label[pred]
extracted_fields.append((label, words[idx]))
# apparently stands for non-entity tokens
#if label != 'LABEL_0' and '<' not in words[idx]:
# extracted_fields.append((label, words[idx]))
if len(extracted_fields) == 0:
print(f"Warning - no fields were extracted!")
return pd.DataFrame(columns=["Field", "Value"])
# Create lists for fields and values
fields = [field[0] for field in extracted_fields]
values = [field[1] for field in extracted_fields]
# Ensure both lists have the same length
min_length = min(len(fields), len(values))
fields = fields[:min_length]
values = values[:min_length]
#Homemade feature extraction
values = extract_features(values, fields)
#Ensure both lists have the same length
min_length = min(len(fields), len(values))
fields = fields[:min_length]
values = values[:min_length]
data = {
"Field": fields,
"Value": values
}
df = pd.DataFrame(data)
return df
# This is how to use questions to find answers in the document
# Less traditional approach, less flexibility, easier to implement/understand (didnt provide robust answers)
#model = LayoutLMv3ForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
#def process_question(question, document):
# #print(f"Debug - Processing Question: {question}")
#
# encoding = processor(document, question, return_tensors="pt")
# #print(f"Debug - Encoding Input IDs: {encoding.input_ids}")
#
# outputs = model(**encoding)
# #print(f"Debug - Model Outputs: {outputs}")
#
# predicted_start_idx = outputs.start_logits.argmax(-1).item()
# predicted_end_idx = outputs.end_logits.argmax(-1).item()
#
# # Check if indices are valid
# if predicted_start_idx < 0 or predicted_end_idx < 0:
# print(f"Warning - Invalid prediction indices: start={predicted_start_idx}, end={predicted_end_idx}")
# return ""
#
# answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx: predicted_end_idx + 1]
# answer = processor.tokenizer.decode(answer_tokens)
#
# return answer
# Older iteration of the code, retaining for emergencies ?
#def process_question(question, document):
# if not question or document is None:
# return None, None, None
#
# text_value = None
# predictions = run_pipeline(question, document)
#
# for i, p in enumerate(ensure_list(predictions)):
# if i == 0:
# text_value = p["answer"]
# else:
# # Keep the code around to produce multiple boxes, but only show the top
# # prediction for now
# break
#
# return text_value
#def parse_ticket_image(image, question):
# """Basically just runs through these questions for the document"""
# # Processing the image
# if image:
# try:
# if image.mode != "RGB":
# document = image.convert("RGB")
# else:
# document = image
# except Exception as e:
# traceback.print_exc()
# error = str(e)
#
#
# # Define questions you want to ask the model
#
# questions = [
# "What is the ticket number?",
# "What is the type of grain (For example: corn, soybeans, wheat)?",
# "What is the date?",
# "What is the time?",
# "What is the gross weight?",
# "What is the tare weight?",
# "What is the net weight?",
# "What is the moisture (moist) percentage?",
# "What is the damage percentage?",
# "What is the gross units?",
# "What is the dock units?",
# "What is the comment?",
# "What is the assembly number?",
# ]
#
# # Use the model to answer each question
# answers = {}
# for q in questions:
# print(f"Question: {q}")
# answer_text = process_question(q, document)
# print(f"Answer Text extracted here: {answer_text}")
# answers[q] = answer_text
#
#
# ticket_number = answers["What is the ticket number?"]
# grain_type = answers["What is the type of grain (For example: corn, soybeans, wheat)?"]
# date = answers["What is the date?"]
# time = answers["What is the time?"]
# gross_weight = answers["What is the gross weight?"]
# tare_weight = answers["What is the tare weight?"]
# net_weight = answers["What is the net weight?"]
# moisture = answers["What is the moisture (moist) percentage?"]
# damage = answers["What is the damage percentage?"]
# gross_units = answers["What is the gross units?"]
# dock_units = answers["What is the dock units?"]
# comment = answers["What is the comment?"]
# assembly_number = answers["What is the assembly number?"]
#
#
# # Create a structured format (like a table) using pandas
# data = {
# "Ticket Number": [ticket_number],
# "Grain Type": [grain_type],
# "Assembly Number": [assembly_number],
# "Date": [date],
# "Time": [time],
# "Gross Weight": [gross_weight],
# "Tare Weight": [tare_weight],
# "Net Weight": [net_weight],
# "Moisture": [moisture],
# "Damage": [damage],
# "Gross Units": [gross_units],
# "Dock Units": [dock_units],
# "Comment": [comment],
# }
# df = pd.DataFrame(data)
#
# return df
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.Interface(
fn=parse_ticket_image,
inputs=[gr.Image(label= "Upload your Grain Scale Ticket", type="pil")],
outputs=[gr.Dataframe(headers=["Field", "Value"], label="Extracted Grain Scale Ticket Data")],
)
if __name__ == "__main__":
demo.launch()