Spaces:
Runtime error
Runtime error
File size: 12,112 Bytes
8bbfc12 2a5ca34 8bbfc12 f9937f7 435050e 3152d9a 8bbfc12 f9937f7 7e409fd 31e8049 7e409fd 31e8049 f9937f7 7e409fd 996d5b1 b6d1b68 703137a b6d1b68 703137a 568ee12 996d5b1 9af0f9c 9c9bc57 9af0f9c 9c9bc57 9af0f9c 9c9bc57 7e409fd 9c9bc57 afbb6c1 7e409fd 9c9bc57 9af0f9c 9c9bc57 e07066e 9c9bc57 0dd599d 7e409fd 9af0f9c 9c9bc57 9af0f9c 7e409fd f9937f7 623e388 f9937f7 623e388 f9937f7 623e388 f9937f7 8bbfc12 623e388 6e3572e 435050e 623e388 f9937f7 8f89102 f9937f7 ed60044 623e388 8bbfc12 f9937f7 63b47dd 763a8d9 63b47dd f9937f7 cad3a8c d239f9c 5a94072 d239f9c cad3a8c f9937f7 623e388 f9937f7 d787446 8bbfc12 f9937f7 993fdf4 8bbfc12 f991e85 8bbfc12 7252bf3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 | import gradio as gr
import pandas as pd
import re
from PIL import Image, ImageDraw, ImageFont
import torch
from transformers import LayoutLMv3Processor, LayoutLMv3ForQuestionAnswering, LayoutLMv3ForTokenClassification
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
# More traditional approach that works from token classification basis (not questions)
model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print_device_name = torch.cuda.get_device_name(torch.cuda.current_device())
print(f"Debug -- Using device: {device}")
print(f"Debug -- Current Device Name: {print_device_name}")
model.to(device)
labels = model.config.id2label
print(labels)
# Homemade feature extraction
def extract_features(tokens, labels):
merged_entities = []
current_date = ""
print(f"Debug -- Starting entity extraction")
#date_pattern = r"\d{1,2}/\d{1,2}/\d{2,4}" # Matches full date formats like MM/DD/YYYY or DD/MM/YYYY
#partial_date_pattern = r"\d{1,2}$|[/-]$" # Matches partial date components like "12" or "/" at the end
#date_pattern = r"\d{1,2}/\d{1,2}/\d{2,4}" # Matches full date formats like MM/DD/YYYY or DD/MM/YYYY
#partial_date_pattern = r"^\d{1,2}/?$|^[/-]$" # Matches partial date components like "12", "/", "02/", etc.
date_pattern = r"^\d{2}/\d{2}/\d{2}(\d{2})?$"
partial_date_pattern = r"^\d{1,2}/?$|^/$"
# This is a label AGNOSTIC approach
for token, label in zip(tokens, labels):
print(f"Debug -- Processing token: {token}")
# If we already have some part of a date and the next token could still be part of it, continue accumulating
if current_date and re.match(partial_date_pattern, token):
current_date += token
print(f"Debug -- Potential partial date: {current_date}")
# If the accumulated entity matches a complete date after appending this token
elif re.match(date_pattern, current_date + token):
current_date += token
merged_entities.append((current_date, 'date'))
print(f"Debug -- Complete date added: {current_date}")
current_date = "" # Reset for next entity
# If the token could start a new date (e.g., '14' could be a day or hour)
elif re.match(partial_date_pattern, token):
current_date = token
print(f"Debug -- Potentially starting a new date: {token}")
else:
# If no patterns are detected and there is any accumulated data
#if current_date:
# # Finalize accumulated partial date
# print(f"Debug -- Date finalized: {current_date}")
# merged_entities.append((current_date, 'date'))
# current_date = "" # Reset for next entity
# Append token as non-date
print(f"Debug -- Appending non-date Token: {token}")
merged_entities.append((token, 'non-date'))
# If there's any leftover accumulated date data, add it to merged_entities
if current_date:
print(f"Debug -- Dangling leftover date added: {current_date}")
merged_entities.append((current_date, 'date'))
return merged_entities
# NOTE: labels aren't being applied properly ... This is the LABEL approach
#
# Loop through tokens and labels
#for token, label in zip(tokens, labels):
# print(f"Debug -- Potentially creating date,, token: {token} label: {label}")
#
# if label == 'LABEL_1':
# # Check for partial date fragments (like '12' or '/')
# if re.match(date_pattern, current_date):
# merged_entities.append((current_date, 'date'))
# print(f"Debug -- Complete date added: {token}")
# current_date = "" # Reset for next entity
# # If the accumulated entity matches a full date
# elif re.match(partial_date_pattern, token):
# print(f"Debug -- Potentially building date: Token Start {token} After Token")
# current_date += token # Append token to the current entity
# else:
# # No partial or completed patterns are detected, but it's still LABEL_1
# # If there were any accumulated data so far
# if current_date:
# merged_entities.append((current_date, 'date'))
# print(f"Debug -- Date finalized: {current_date}")
# current_date = "" # Reset
#
# merged_entities.append((token, label))
# else:
# # These are LABEL_0, supposedly trash but keep them for now
# if current_date: # If there's a leftover date fragment, add it first
# merged_entities.append((current_date, 'date'))
# print(f"Debug -- Finalizing leftover date added: {current_date}")
# current_date = "" # Reset
#
# # Append LABEL_0 token
# print(f"Debug -- Appending LABEL_0 Token: Token Start {token} Token After")
# merged_entities.append((token, label))
#
# if current_date:
# print(f"Debug -- Dangling leftover date added: {current_date}")
# merged_entities.append((current_date, 'date'))
#
# return merged_entities
# process the image in the correct format
# extract token classifications
def parse_ticket_image(image):
# Process image
if image:
document = image.convert("RGB") if image.mode != "RGB" else image
else:
print(f"Warning - no image or malformed image!")
return pd.DataFrame()
# Encode document image
encoding = processor(document, return_tensors="pt", truncation=True)
# Move encoding to appropriate device
for k, v in encoding.items():
encoding[k] = v.to(device)
# Perform inference
outputs = model(**encoding)
# extract predictions
predictions = outputs.logits.argmax(-1).squeeze().tolist()
input_ids = encoding.input_ids.squeeze().tolist()
words = [processor.tokenizer.decode(id) for id in input_ids]
extracted_fields = []
for idx, pred in enumerate(predictions):
label = model.config.id2label[pred]
extracted_fields.append((label, words[idx]))
# apparently stands for non-entity tokens
#if label != 'LABEL_0' and '<' not in words[idx]:
# extracted_fields.append((label, words[idx]))
if len(extracted_fields) == 0:
print(f"Warning - no fields were extracted!")
return pd.DataFrame(columns=["Field", "Value"])
# Create lists for fields and values
fields = [field[0] for field in extracted_fields]
values = [field[1] for field in extracted_fields]
# Ensure both lists have the same length
min_length = min(len(fields), len(values))
fields = fields[:min_length]
values = values[:min_length]
#Homemade feature extraction
values = extract_features(values, fields)
#Ensure both lists have the same length
min_length = min(len(fields), len(values))
fields = fields[:min_length]
values = values[:min_length]
data = {
"Field": fields,
"Value": values
}
df = pd.DataFrame(data)
return df
# This is how to use questions to find answers in the document
# Less traditional approach, less flexibility, easier to implement/understand (didnt provide robust answers)
#model = LayoutLMv3ForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
#def process_question(question, document):
# #print(f"Debug - Processing Question: {question}")
#
# encoding = processor(document, question, return_tensors="pt")
# #print(f"Debug - Encoding Input IDs: {encoding.input_ids}")
#
# outputs = model(**encoding)
# #print(f"Debug - Model Outputs: {outputs}")
#
# predicted_start_idx = outputs.start_logits.argmax(-1).item()
# predicted_end_idx = outputs.end_logits.argmax(-1).item()
#
# # Check if indices are valid
# if predicted_start_idx < 0 or predicted_end_idx < 0:
# print(f"Warning - Invalid prediction indices: start={predicted_start_idx}, end={predicted_end_idx}")
# return ""
#
# answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx: predicted_end_idx + 1]
# answer = processor.tokenizer.decode(answer_tokens)
#
# return answer
# Older iteration of the code, retaining for emergencies ?
#def process_question(question, document):
# if not question or document is None:
# return None, None, None
#
# text_value = None
# predictions = run_pipeline(question, document)
#
# for i, p in enumerate(ensure_list(predictions)):
# if i == 0:
# text_value = p["answer"]
# else:
# # Keep the code around to produce multiple boxes, but only show the top
# # prediction for now
# break
#
# return text_value
#def parse_ticket_image(image, question):
# """Basically just runs through these questions for the document"""
# # Processing the image
# if image:
# try:
# if image.mode != "RGB":
# document = image.convert("RGB")
# else:
# document = image
# except Exception as e:
# traceback.print_exc()
# error = str(e)
#
#
# # Define questions you want to ask the model
#
# questions = [
# "What is the ticket number?",
# "What is the type of grain (For example: corn, soybeans, wheat)?",
# "What is the date?",
# "What is the time?",
# "What is the gross weight?",
# "What is the tare weight?",
# "What is the net weight?",
# "What is the moisture (moist) percentage?",
# "What is the damage percentage?",
# "What is the gross units?",
# "What is the dock units?",
# "What is the comment?",
# "What is the assembly number?",
# ]
#
# # Use the model to answer each question
# answers = {}
# for q in questions:
# print(f"Question: {q}")
# answer_text = process_question(q, document)
# print(f"Answer Text extracted here: {answer_text}")
# answers[q] = answer_text
#
#
# ticket_number = answers["What is the ticket number?"]
# grain_type = answers["What is the type of grain (For example: corn, soybeans, wheat)?"]
# date = answers["What is the date?"]
# time = answers["What is the time?"]
# gross_weight = answers["What is the gross weight?"]
# tare_weight = answers["What is the tare weight?"]
# net_weight = answers["What is the net weight?"]
# moisture = answers["What is the moisture (moist) percentage?"]
# damage = answers["What is the damage percentage?"]
# gross_units = answers["What is the gross units?"]
# dock_units = answers["What is the dock units?"]
# comment = answers["What is the comment?"]
# assembly_number = answers["What is the assembly number?"]
#
#
# # Create a structured format (like a table) using pandas
# data = {
# "Ticket Number": [ticket_number],
# "Grain Type": [grain_type],
# "Assembly Number": [assembly_number],
# "Date": [date],
# "Time": [time],
# "Gross Weight": [gross_weight],
# "Tare Weight": [tare_weight],
# "Net Weight": [net_weight],
# "Moisture": [moisture],
# "Damage": [damage],
# "Gross Units": [gross_units],
# "Dock Units": [dock_units],
# "Comment": [comment],
# }
# df = pd.DataFrame(data)
#
# return df
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.Interface(
fn=parse_ticket_image,
inputs=[gr.Image(label= "Upload your Grain Scale Ticket", type="pil")],
outputs=[gr.Dataframe(headers=["Field", "Value"], label="Extracted Grain Scale Ticket Data")],
)
if __name__ == "__main__":
demo.launch() |