import gradio as gr from PIL import Image from transformers import DonutProcessor, VisionEncoderDecoderModel import torch import re # Load the model and processor processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa") def extract_info(image): questions = [ "What is the date of the billing?", "From when the period?", "until when the period?", "What is the total amount to be paid?", "How much electricity in Kwh was consumed during peak hours(HP)?", "How much electricity was consumed during off-peak hours(HC)?" ] # Prepare image image = image.convert('RGB') pixel_values = processor(image, return_tensors="pt").pixel_values device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) answers = {} # Generate answer for each question for question in questions: task_prompt = f"{question}" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"] outputs = model.generate(pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, output_scores=True) seq = processor.batch_decode(outputs.sequences)[0] # Extract only the answer part from the sequence answer_match = re.search(r"(.*?)", seq) if answer_match: answer = answer_match.group(1).strip() else: answer = "No answer found." answers[question] = answer # Simplify the output format to remove HTML-like tags json_output = { "Billing Date": answers["What is the date of the billing?"], "Billing Period": f"from {answers['From when the period?']} to {answers['until when the period?']}", "Total Due": answers["What is the total amount to be paid?"], "During Peak hours (HP) Total Consumption (kWh)": answers["How much electricity in Kwh was consumed during peak hours(HP)?"], "During Hours Off-Peak (HC) Total Consumption (kWh)": answers["How much electricity was consumed during off-peak hours(HC)?"] } return json_output iface = gr.Interface( fn=extract_info, inputs=gr.components.Image(type="pil"), outputs=gr.components.JSON(label="Extraction Results"), title="Document Visual Question Answering with DONUT", description="Upload a document image and get structured information in JSON format." ) iface.launch()