File size: 3,254 Bytes
26e7d52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import gradio as gr
from PIL import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import re

# Load the model and processor
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")

def extract_info(image):
    questions = [
        "What is the date of the billing?",
        "From when the period?",
        "until when the period?",
        "What is the total amount to be paid?",
        "How much electricity in Kwh was consumed during peak hours(HP)?",
        "How much electricity was consumed during off-peak hours(HC)?"
    ]
    

    # Prepare image
    image = image.convert('RGB')
    pixel_values = processor(image, return_tensors="pt").pixel_values
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    answers = {}
    
    # Generate answer for each question
    for question in questions:
        task_prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
        decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt")["input_ids"]
        outputs = model.generate(pixel_values.to(device),
                                 decoder_input_ids=decoder_input_ids.to(device),
                                 max_length=model.decoder.config.max_position_embeddings,
                                 early_stopping=True,
                                 pad_token_id=processor.tokenizer.pad_token_id,
                                 eos_token_id=processor.tokenizer.eos_token_id,
                                 use_cache=True,
                                 num_beams=1,
                                 bad_words_ids=[[processor.tokenizer.unk_token_id]],
                                 return_dict_in_generate=True,
                                 output_scores=True)
        seq = processor.batch_decode(outputs.sequences)[0]
        # Extract only the answer part from the sequence
        answer_match = re.search(r"<s_answer>(.*?)</s_answer>", seq)
        if answer_match:
            answer = answer_match.group(1).strip()
        else:
            answer = "No answer found."
        
        answers[question] = answer
    
    # Simplify the output format to remove HTML-like tags
    json_output = {
        "Billing Date": answers["What is the date of the billing?"],
        "Billing Period": f"from {answers['From when the period?']} to {answers['until when the period?']}",
        "Total Due": answers["What is the total amount to be paid?"],
        "During Peak hours (HP) Total Consumption (kWh)": answers["How much electricity in Kwh was consumed during peak hours(HP)?"],
        "During Hours Off-Peak (HC) Total Consumption (kWh)":  answers["How much electricity was consumed during off-peak hours(HC)?"]
    }
    
    return json_output

iface = gr.Interface(
    fn=extract_info,
    inputs=gr.components.Image(type="pil"),
    outputs=gr.components.JSON(label="Extraction Results"),
    title="Document Visual Question Answering with DONUT",
    description="Upload a document image and get structured information in JSON format."
)

iface.launch()