demo2usage commited on
Commit
b648759
·
verified ·
1 Parent(s): 40be955

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -0
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import pytesseract
4
+ import torch
5
+ from transformers import LayoutLMProcessor, LayoutLMForTokenClassification
6
+ import pandas as pd
7
+ import io
8
+
9
+ # Load the processor and model
10
+ @st.cache_resource
11
+ def load_model():
12
+ processor = LayoutLMProcessor.from_pretrained("microsoft/layoutlm-base-uncased")
13
+ model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")
14
+ return processor, model
15
+
16
+ processor, model = load_model()
17
+
18
+ st.title("Document Form Field Extractor")
19
+
20
+ uploaded_file = st.file_uploader("Upload a document image", type=["png", "jpg", "jpeg"])
21
+
22
+ if uploaded_file is not None:
23
+ image = Image.open(uploaded_file).convert("RGB")
24
+ st.image(image, caption="Uploaded Document", use_column_width=True)
25
+
26
+ # OCR extraction
27
+ ocr_data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
28
+
29
+ words = []
30
+ boxes = []
31
+ for i in range(len(ocr_data["text"])):
32
+ text = ocr_data["text"][i].strip()
33
+ if text:
34
+ words.append(text)
35
+ x, y, w, h = ocr_data["left"][i], ocr_data["top"][i], ocr_data["width"][i], ocr_data["height"][i]
36
+ width, height = image.size
37
+ box = [
38
+ int(1000 * x / width),
39
+ int(1000 * y / height),
40
+ int(1000 * (x + w) / width),
41
+ int(1000 * (y + h) / height)
42
+ ]
43
+ boxes.append(box)
44
+
45
+ # Encoding
46
+ encoding = processor(images=image, words=words, boxes=boxes, return_tensors="pt", truncation=True, padding="max_length")
47
+
48
+ # Prediction
49
+ outputs = model(**encoding)
50
+ logits = outputs.logits
51
+ predictions = torch.argmax(logits, dim=2)
52
+ labels = predictions[0].tolist()
53
+ id2label = model.config.id2label
54
+
55
+ # Extract fields dynamically
56
+ fields = []
57
+ current_field = ""
58
+ current_value = ""
59
+ current_label = None
60
+
61
+ for word, label_id in zip(words, labels):
62
+ label = id2label[label_id]
63
+
64
+ if label.startswith("B-") or label.startswith("I-"):
65
+ label_type = label.split("-")[1]
66
+
67
+ if label_type != current_label:
68
+ if current_field or current_value:
69
+ fields.append((current_field.strip(), current_value.strip()))
70
+ current_field = word if label_type == "QUESTION" else ""
71
+ current_value = word if label_type == "ANSWER" else ""
72
+ current_label = label_type
73
+ else:
74
+ if label_type == "QUESTION":
75
+ current_field += " " + word
76
+ else:
77
+ current_value += " " + word
78
+ else:
79
+ if current_field or current_value:
80
+ fields.append((current_field.strip(), current_value.strip()))
81
+ current_field = ""
82
+ current_value = ""
83
+ current_label = None
84
+
85
+ if current_field or current_value:
86
+ fields.append((current_field.strip(), current_value.strip()))
87
+
88
+ # Display results
89
+ df = pd.DataFrame(fields, columns=["Field", "Value"])
90
+ st.subheader("Extracted Fields and Values")
91
+ st.dataframe(df)
92
+
93
+ # Download CSV
94
+ csv = df.to_csv(index=False)
95
+ st.download_button("Download CSV", csv, "fields.csv", "text/csv")