BillsOCR / app.py
Hachem's picture
Create app.py
26e7d52 verified
import gradio as gr
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import re
# Load the model and processor
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
def extract_info(image):
questions = [
"What is the date of the billing?",
"From when the period?",
"until when the period?",
"What is the total amount to be paid?",
"How much electricity in Kwh was consumed during peak hours(HP)?",
"How much electricity was consumed during off-peak hours(HC)?"
]
# Prepare image
image = image.convert('RGB')
pixel_values = processor(image, return_tensors="pt").pixel_values
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
answers = {}
# Generate answer for each question
for question in questions:
task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
outputs = model.generate(pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
output_scores=True)
seq = processor.batch_decode(outputs.sequences)[0]
# Extract only the answer part from the sequence
answer_match = re.search(r"<s_answer>(.*?)</s_answer>", seq)
if answer_match:
answer = answer_match.group(1).strip()
else:
answer = "No answer found."
answers[question] = answer
# Simplify the output format to remove HTML-like tags
json_output = {
"Billing Date": answers["What is the date of the billing?"],
"Billing Period": f"from {answers['From when the period?']} to {answers['until when the period?']}",
"Total Due": answers["What is the total amount to be paid?"],
"During Peak hours (HP) Total Consumption (kWh)": answers["How much electricity in Kwh was consumed during peak hours(HP)?"],
"During Hours Off-Peak (HC) Total Consumption (kWh)": answers["How much electricity was consumed during off-peak hours(HC)?"]
}
return json_output
iface = gr.Interface(
fn=extract_info,
inputs=gr.components.Image(type="pil"),
outputs=gr.components.JSON(label="Extraction Results"),
title="Document Visual Question Answering with DONUT",
description="Upload a document image and get structured information in JSON format."
)
iface.launch()