rosemariafontana's picture
changed image path
eaebdde verified
raw
history blame
4.44 kB
import gradio as gr
import pandas as pd
from PIL import Image
from transformers import LayoutLMv2Processor, LayoutLMv3ForQuestionAnswering
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv3-base")
model = LayoutLMv3ForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
def process_question(question, document):
encoding = processor(document, question, return_tensors="pt")
outputs = mode(**encoding)
predicted_start_idx = outputs.start_logits.argmax(-1).item()
predicted_end_idx = outputs.end_logits.argmax(-1).item()
answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx: predicted_end_idx + 1]
answer = processor.tokenizer.decode(answer_tokens)
return answer
#def process_question(question, document):
# if not question or document is None:
# return None, None, None
#
# text_value = None
# predictions = run_pipeline(question, document)
#
# for i, p in enumerate(ensure_list(predictions)):
# if i == 0:
# text_value = p["answer"]
# else:
# # Keep the code around to produce multiple boxes, but only show the top
# # prediction for now
# break
#
# return text_value
def parse_ticket_image(image, question):
"""Basically just runs through these questions for the document"""
# Processing the image
if image:
try:
document = Image.open(image).convert("RGB")
except Exception as e:
traceback.print_exc()
error = str(e)
# Define questions you want to ask the model
questions = [
{"question": "What is the ticket number?", "context": image}
]
#{"question": "What is the type of grain (For example: corn, soy, wheat)?", "context": image},
#{"question": "What is the date?", "context": image},
#{"question": "What is the time?", "context": image},
#{"question": "What is the gross weight?", "context": image},
#{"question": "What is the tare weight?", "context": image},
#{"question": "What is the net weight?", "context": image},
#{"question": "What is the moisture (moist) percentage?", "context": image},
#{"question": "What is the damage percentage?", "context": image},
#{"question": "What is the gross units?", "context": image},
#{"question": "What is the dock units?", "context": image},
#{"question": "What is the comment?", "context": image},
#{"question": "What is the assembly number?", "context": image},
#]
# Use the model to answer each question
#results = [model(q["question"], q["context"]) for q in questions]
answers = {}
for q in questions:
answer_text = process_question(q, document)
answers[q["question"]] = answer_text
# Extract answers from the results
ticket_number = answers["What is the ticket number?"]
#ticket_number = results[0][0]['answer']
#date = results[1][0]['answer']
#time = results[2][0]['answer']
#gross_weight = results[3][0]['answer']
#tare_weight = results[4][0]['answer']
#net_weight = results[5][0]['answer']
#moisture = results[6][0]['answer']
#damage = results[7][0]['answer']
#gross_units = results[8][0]['answer']
#dock_units = results[9][0]['answer']
#comment = results[10][0]['answer']
#assembly_number = results[11][0]['answer']
# Create a structured format (like a table) using pandas
data = {
"Ticket Number": [ticket_number]
}
#"Assembly Number": [assembly_number],
#"Date": [date],
#"Time": [time],
#"Gross Weight": [gross_weight],
#"Tare Weight": [tare_weight],
#"Net Weight": [net_weight],
#"Moisture": [moisture],
#"Damage": [damage],
#"Gross Units": [gross_units],
#"Dock Units": [dock_units],
#"Comment": [comment],
#}
df = pd.DataFrame(data)
return df
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.Interface(
fn=parse_ticket_image,
inputs=[gr.Image(label= "Upload your Grain Scale Ticket", type="pil")],
outputs=[gr.Dataframe(headers=["Field", "Value"], label="Extracted Grain Scale Data")],
)
if __name__ == "__main__":
demo.launch()