File size: 3,246 Bytes
b648759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from PIL import Image
import pytesseract
import torch
from transformers import LayoutLMProcessor, LayoutLMForTokenClassification
import pandas as pd
import io

# Load the processor and model
@st.cache_resource
def load_model():
    processor = LayoutLMProcessor.from_pretrained("microsoft/layoutlm-base-uncased")
    model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")
    return processor, model

processor, model = load_model()

st.title("Document Form Field Extractor")

uploaded_file = st.file_uploader("Upload a document image", type=["png", "jpg", "jpeg"])

if uploaded_file is not None:
    image = Image.open(uploaded_file).convert("RGB")
    st.image(image, caption="Uploaded Document", use_column_width=True)

    # OCR extraction
    ocr_data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)

    words = []
    boxes = []
    for i in range(len(ocr_data["text"])):
        text = ocr_data["text"][i].strip()
        if text:
            words.append(text)
            x, y, w, h = ocr_data["left"][i], ocr_data["top"][i], ocr_data["width"][i], ocr_data["height"][i]
            width, height = image.size
            box = [
                int(1000 * x / width),
                int(1000 * y / height),
                int(1000 * (x + w) / width),
                int(1000 * (y + h) / height)
            ]
            boxes.append(box)

    # Encoding
    encoding = processor(images=image, words=words, boxes=boxes, return_tensors="pt", truncation=True, padding="max_length")

    # Prediction
    outputs = model(**encoding)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=2)
    labels = predictions[0].tolist()
    id2label = model.config.id2label

    # Extract fields dynamically
    fields = []
    current_field = ""
    current_value = ""
    current_label = None

    for word, label_id in zip(words, labels):
        label = id2label[label_id]

        if label.startswith("B-") or label.startswith("I-"):
            label_type = label.split("-")[1]

            if label_type != current_label:
                if current_field or current_value:
                    fields.append((current_field.strip(), current_value.strip()))
                current_field = word if label_type == "QUESTION" else ""
                current_value = word if label_type == "ANSWER" else ""
                current_label = label_type
            else:
                if label_type == "QUESTION":
                    current_field += " " + word
                else:
                    current_value += " " + word
        else:
            if current_field or current_value:
                fields.append((current_field.strip(), current_value.strip()))
            current_field = ""
            current_value = ""
            current_label = None

    if current_field or current_value:
        fields.append((current_field.strip(), current_value.strip()))

    # Display results
    df = pd.DataFrame(fields, columns=["Field", "Value"])
    st.subheader("Extracted Fields and Values")
    st.dataframe(df)

    # Download CSV
    csv = df.to_csv(index=False)
    st.download_button("Download CSV", csv, "fields.csv", "text/csv")