Spaces:
Paused
Paused
| import torch | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from io import BytesIO | |
| import torch | |
| from ultralytics import YOLO | |
| from torchvision import transforms | |
| from transformers import TableTransformerForObjectDetection | |
| from PIL import ImageDraw | |
| import numpy as np | |
| import csv | |
| import easyocr | |
| from tqdm.auto import tqdm | |
| import csv | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # new v1.1 checkpoints require no timm anymore | |
| structure_model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-structure-recognition-v1.1-all") | |
| structure_model.to(device) | |
| print("") | |
| from torchvision import transforms | |
| class MaxResize(object): | |
| def __init__(self, max_size=800): | |
| self.max_size = max_size | |
| def __call__(self, image): | |
| width, height = image.size | |
| current_max_size = max(width, height) | |
| scale = self.max_size / current_max_size | |
| resized_image = image.resize((int(round(scale*width)), int(round(scale*height)))) | |
| return resized_image | |
| detection_transform = transforms.Compose([ | |
| MaxResize(800), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # for output bounding box post-processing | |
| def box_cxcywh_to_xyxy(x): | |
| x_c, y_c, w, h = x.unbind(-1) | |
| b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
| return torch.stack(b, dim=1) | |
| def rescale_bboxes(out_bbox, size): | |
| img_w, img_h = size | |
| b = box_cxcywh_to_xyxy(out_bbox) | |
| b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
| return b | |
| # update id2label to include "no object" | |
| id2label = structure_model.config.id2label | |
| id2label[len(structure_model.config.id2label)] = "no object" | |
| def outputs_to_objects(outputs, img_size, id2label): | |
| m = outputs.logits.softmax(-1).max(-1) | |
| pred_labels = list(m.indices.detach().cpu().numpy())[0] | |
| pred_scores = list(m.values.detach().cpu().numpy())[0] | |
| pred_bboxes = outputs['pred_boxes'].detach().cpu()[0] | |
| pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] | |
| objects = [] | |
| for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): | |
| class_label = id2label[int(label)] | |
| if not class_label == 'no object': | |
| objects.append({'label': class_label, 'score': float(score), | |
| 'bbox': [float(elem) for elem in bbox]}) | |
| return objects | |
| structure_transform = transforms.Compose([ | |
| MaxResize(1000), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| def TSR(cropped_image): | |
| pixel_values = structure_transform(cropped_image).unsqueeze(0) | |
| pixel_values = pixel_values.to(device) | |
| print(pixel_values.shape) | |
| # forward pass | |
| with torch.no_grad(): | |
| outputs = structure_model(pixel_values) | |
| # update id2label to include "no object" | |
| structure_id2label = structure_model.config.id2label | |
| structure_id2label[len(structure_id2label)] = "no object" | |
| cells = outputs_to_objects(outputs, cropped_image.size, structure_id2label) | |
| #print(cells) | |
| cropped_table_visualized = cropped_image.copy() | |
| draw = ImageDraw.Draw(cropped_table_visualized) | |
| for cell in cells: | |
| draw.rectangle(cell["bbox"], outline="red") | |
| return cropped_table_visualized , cells | |
| ############# Visualizing rows and columns on cropped image | |
| ## Modified | |
| def get_cell_coordinates_by_row(table_data): | |
| # Extract rows and columns | |
| rows = [entry for entry in table_data if entry['label'] == 'table row'] | |
| columns = [entry for entry in table_data if entry['label'] == 'table column'] | |
| # Sort rows and columns by their Y and X coordinates, respectively | |
| rows.sort(key=lambda x: x['bbox'][1]) | |
| columns.sort(key=lambda x: x['bbox'][0]) | |
| # Function to find cell coordinates | |
| def find_cell_coordinates(row, column): | |
| # Use the row's Y coordinates for the cell's top and bottom | |
| cell_ymin = row['bbox'][1] | |
| cell_ymax = row['bbox'][3] # Adjust as needed for better height | |
| # Use the column's X coordinates for the cell's left and right | |
| cell_xmin = column['bbox'][0] | |
| cell_xmax = column['bbox'][2] | |
| return [cell_xmin, cell_ymin, cell_xmax, cell_ymax] | |
| # Generate cell coordinates and count cells in each row | |
| cell_coordinates = [] | |
| for row in rows: | |
| row_cells = [] | |
| for column in columns: | |
| cell_bbox = find_cell_coordinates(row, column) | |
| row_cells.append({'column': column['bbox'], 'cell': cell_bbox}) | |
| # Sort cells in the row by X coordinate | |
| row_cells.sort(key=lambda x: x['column'][0]) | |
| # Append row information to cell_coordinates | |
| cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)}) | |
| # Sort rows from top to bottom | |
| cell_coordinates.sort(key=lambda x: x['row'][1]) | |
| # Debugging: Print intermediate results | |
| #for i, row_info in enumerate(cell_coordinates): | |
| # print(f"Row {i}: {row_info['row']}, Cell Count: {row_info['cell_count']}") | |
| # for cell in row_info['cells']: | |
| # print(f" Cell Bounding Box: {cell['cell']}, Column Bounding Box: {cell['column']}") | |
| return cell_coordinates | |
| # Initialize the EasyOCR reader | |
| reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory | |
| def apply_ocr(cell_coordinates,cropped_image): | |
| # Initialize a list to store data for each row | |
| data = [] | |
| for idx, row in enumerate(tqdm(cell_coordinates)): | |
| row_text = [] # List to store text for the current row | |
| for cell in row["cells"]: | |
| # Crop cell out of the image | |
| cell_image = np.array(cropped_image.crop(cell["cell"])) | |
| # Apply OCR | |
| result = reader.readtext(np.array(cell_image)) | |
| if len(result) > 0: | |
| # Extract and join the detected text | |
| text = " ".join([x[1] for x in result]) | |
| row_text.append(text) | |
| else: | |
| row_text.append("NAN") # Append empty string if no text is detected | |
| # Append the row's text list to the data list | |
| data.append(row_text) | |
| return data | |
| # Print the extracted text for each row | |
| """for idx, row_data in enumerate(data): | |
| print(f"Row {idx + 1}: {row_data}")""" | |
| def op_csv(data): | |
| # Define the output CSV file path | |
| output_csv_file = 'extract.csv' | |
| # Write the data to a CSV file | |
| try: | |
| with open(output_csv_file, mode='w', newline='', encoding='utf-8') as file: | |
| writer = csv.writer(file) | |
| for row in data: | |
| writer.writerow(row) # Write each row individually | |
| print(f"Data successfully written to {output_csv_file}") | |
| return output_csv_file | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |