| """
|
| Attribution: https://github.com/AIPI540/AIPI540-Deep-Learning-Applications/
|
|
|
| Jon Reifschneider
|
| Brinnae Bent
|
|
|
| """
|
|
|
| import streamlit as st
|
| from PIL import Image
|
| import numpy as np
|
| import os
|
| import numpy as np
|
| import pandas as pd
|
| import pandas as pd
|
| import os
|
| import json
|
| import pandas as pd
|
| import torch
|
| import numpy as np
|
| import pandas as pd
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import matplotlib.pyplot as plt
|
| from ultralytics import YOLO
|
| from PIL import Image, ImageDraw, ImageFont
|
| import numpy as np
|
| import cv2
|
| import pytesseract
|
| from PIL import ImageEnhance
|
| import numpy as np
|
| import os
|
| import json
|
| from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
|
| from datasets import load_dataset
|
| from transformers import DataCollatorForLanguageModeling
|
| from PIL import Image, ImageEnhance
|
| from io import StringIO
|
|
|
|
|
| def crop_image(model, original_image):
|
| """
|
| Crop the region of interest (table) from an image using a YOLO model.
|
|
|
| Inputs:
|
| model (YOLO): The YOLO model used for object detection.
|
| image_file (str): Path to the image file to be processed.
|
|
|
| Returns:
|
| PIL.Image: The cropped image containing the detected table.
|
| """
|
| image_array = np.array(image)
|
| results = model(image_array)
|
|
|
| for r in results:
|
| boxes = r.boxes
|
|
|
| for box in boxes:
|
| if box.cls == 3:
|
| x1, y1, x2, y2 = box.xyxy[0]
|
| x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
|
|
| table_image = original_image.crop((x1, y1, x2, y2))
|
|
|
| return table_image
|
| return
|
|
|
| def process_image(model, image):
|
| """
|
| Process the uploaded image with YOLO model and draw bounding boxes with class-specific colors.
|
|
|
| Inputs:
|
| uploaded_image (UploadedFile): The image file uploaded through Streamlit.
|
|
|
| Returns:
|
| PIL.Image: The processed image with bounding boxes and labels.
|
| """
|
| colors = {'title': (255, 0, 0),
|
| 'text': (0, 255, 0),
|
| 'figure': (0, 0, 255),
|
| 'table': (255, 255, 0),
|
| 'list': (0, 255, 255)}
|
|
|
| image_array = np.array(image)
|
| results = model(image_array)
|
|
|
| for result in results:
|
| boxes = result.boxes.cpu().numpy()
|
| for box in boxes:
|
| r = box.xyxy[0].astype(int)
|
| label = result.names[int(box.cls)]
|
| color = colors.get(label.lower(), (255, 255, 255))
|
|
|
| cv2.rectangle(image_array, r[:2], r[2:], color, 2)
|
|
|
| label_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
| top_left = (r[0], r[1] - label_size[1] - baseline)
|
| bottom_right = (r[0] + label_size[0], r[1])
|
| cv2.rectangle(image_array, top_left, bottom_right, color, cv2.FILLED)
|
| cv2.putText(image_array, label, (r[0], r[1] - baseline),
|
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1)
|
|
|
| return Image.fromarray(image_array)
|
|
|
| def improve_ocr_accuracy(img):
|
| """
|
| Preprocess the image to improve OCR accuracy.
|
|
|
| This function resizes the image, increases contrast, and applies thresholding
|
| to enhance the image for better OCR results.
|
|
|
| Inputs:
|
| img (PIL.Image): The input image to be processed.
|
|
|
| Returns:
|
| numpy.ndarray: A binary thresholded image as a numpy array.
|
| """
|
| img = img.resize((img.width * 4, img.height * 4))
|
|
|
| enhancer = ImageEnhance.Contrast(img)
|
| img = enhancer.enhance(2)
|
|
|
| _, thresh = cv2.threshold(np.array(img), 127, 255, cv2.THRESH_BINARY_INV)
|
|
|
| return thresh
|
|
|
| def ocr_core(image):
|
| """
|
| Perform OCR on the given image and process the extracted text.
|
|
|
| This function uses pytesseract to extract text from the image and then
|
| processes the extracted data to format it with appropriate line breaks
|
| and spacing.
|
|
|
| Inputs:
|
| image (numpy.ndarray): The preprocessed image as a numpy array.
|
|
|
| Returns:
|
| str: The extracted and formatted text from the image.
|
| """
|
| data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
|
| df = pd.DataFrame(data)
|
| df = df[df['conf'] != -1]
|
| df['left_diff'] = df.groupby('block_num')['left'].diff().fillna(0).astype(int)
|
| df['prev_width'] = df['width'].shift(1).fillna(0).astype(int)
|
| df['spacing'] = (df['left_diff'] - df['prev_width']).fillna(0).astype(int)
|
| df['text'] = df.apply(lambda x: '\n' + x['text'] if (x['word_num'] == 1) & (x['block_num'] != 1) else x['text'], axis=1)
|
| df['text'] = df.apply(lambda x: ',' + x['text'] if x['spacing'] > 80 else x['text'], axis=1)
|
| ocr_text = ""
|
| for text in df['text']:
|
| ocr_text += text + ' '
|
| return ocr_text
|
|
|
| def generate_csv_from_text(tokenizer, model, ocr_text):
|
| """
|
| Generate CSV text from OCR extracted text using the gpt model
|
|
|
| This function takes the OCR extracted text, processes it through a language model,
|
| and generates CSV formatted text.
|
|
|
| Inputs:
|
| tokenizer: The tokenizer for the gpt model
|
| model: The gpt model used for csv
|
| ocr_text (str): The text extracted from OCR
|
|
|
| Returns:
|
| str: The generated CSV formatted text.
|
| """
|
| inputs = tokenizer.encode(ocr_text, return_tensors='pt')
|
| outputs = model.generate(inputs, max_length=1000, num_return_sequences=1)
|
| csv_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
| return csv_text
|
|
|
| if __name__ == '__main__':
|
| pytesseract.pytesseract.tesseract_cmd = r'C:/Program Files/Tesseract-OCR/tesseract.exe'
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
| model = YOLO(os.getcwd() + '/models/trained_yolov8.pt')
|
| gpt_model = GPT2LMHeadModel.from_pretrained(os.getcwd() + '/models/gpt_model')
|
| tokenizer = GPT2Tokenizer.from_pretrained(os.getcwd() + '/models/gpt_model')
|
|
|
| st.header('''
|
| Intelligent Document Processing: Table Extraction
|
| ''')
|
|
|
| header_img = Image.open('assets/header_img.png')
|
| st.image(header_img, use_column_width=True)
|
|
|
| with st.sidebar:
|
| user_image = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
|
|
|
| if user_image is not None:
|
| st.divider()
|
| image = Image.open(user_image)
|
| st.image(image, caption='Uploaded Image', use_column_width=True)
|
|
|
| st.divider()
|
| st.subheader("Document Classes:")
|
| processed_image = process_image(model, image)
|
| st.image(processed_image, caption='Processed Image', use_column_width=True)
|
|
|
| st.divider()
|
| st.subheader("Table Cropped Image:")
|
| cropped_table = crop_image(model, image)
|
| st.image(cropped_table, caption='Cropped Table', use_column_width=True)
|
|
|
| st.divider()
|
| st.subheader("OCR Text:")
|
| improved_image = improve_ocr_accuracy(cropped_table)
|
| ocr_text = ocr_core(improved_image)
|
| st.write(ocr_text)
|
|
|
| st.divider()
|
| st.subheader("CSV Output:")
|
| csv_output = generate_csv_from_text(tokenizer,gpt_model,ocr_text)
|
| data = StringIO(csv_output)
|
| st.dataframe(pd.read_csv(data, sep=",").head())
|
|
|
|
|