Spaces:
Running
Running
| from collections import defaultdict | |
| from copy import deepcopy | |
| from typing import List, Dict | |
| import torch | |
| from PIL import Image | |
| from surya.model.ordering.encoderdecoder import OrderVisionEncoderDecoderModel | |
| from surya.schema import TableResult, TableCell, Bbox | |
| from surya.settings import settings | |
| from tqdm import tqdm | |
| import numpy as np | |
| from surya.model.table_rec.config import SPECIAL_TOKENS | |
| def get_batch_size(): | |
| batch_size = settings.TABLE_REC_BATCH_SIZE | |
| if batch_size is None: | |
| batch_size = 8 | |
| if settings.TORCH_DEVICE_MODEL == "mps": | |
| batch_size = 8 | |
| if settings.TORCH_DEVICE_MODEL == "cuda": | |
| batch_size = 64 | |
| return batch_size | |
| def sort_bboxes(bboxes, tolerance=1): | |
| vertical_groups = {} | |
| for block in bboxes: | |
| group_key = round(block["bbox"][1] / tolerance) * tolerance | |
| if group_key not in vertical_groups: | |
| vertical_groups[group_key] = [] | |
| vertical_groups[group_key].append(block) | |
| # Sort each group horizontally and flatten the groups into a single list | |
| sorted_page_blocks = [] | |
| for _, group in sorted(vertical_groups.items()): | |
| sorted_group = sorted(group, key=lambda x: x["bbox"][0]) | |
| sorted_page_blocks.extend(sorted_group) | |
| return sorted_page_blocks | |
| def is_rotated(rows, cols): | |
| # Determine if the table is rotated by looking at row and column width / height ratios | |
| # Rows should have a >1 ratio, cols <1 | |
| widths = sum([r.width for r in rows]) | |
| heights = sum([c.height for c in rows]) + 1 | |
| r_ratio = widths / heights | |
| widths = sum([c.width for c in cols]) | |
| heights = sum([r.height for r in cols]) + 1 | |
| c_ratio = widths / heights | |
| return r_ratio * 2 < c_ratio | |
| def batch_table_recognition(images: List, table_cells: List[List[Dict]], model: OrderVisionEncoderDecoderModel, processor, batch_size=None) -> List[TableResult]: | |
| assert all([isinstance(image, Image.Image) for image in images]) | |
| assert len(images) == len(table_cells) | |
| if batch_size is None: | |
| batch_size = get_batch_size() | |
| output_order = [] | |
| for i in tqdm(range(0, len(images), batch_size), desc="Recognizing tables"): | |
| batch_table_cells = deepcopy(table_cells[i:i+batch_size]) | |
| batch_table_cells = [sort_bboxes(page_bboxes) for page_bboxes in batch_table_cells] # Sort bboxes before passing in | |
| batch_list_bboxes = [[block["bbox"] for block in page] for page in batch_table_cells] | |
| batch_images = images[i:i+batch_size] | |
| batch_images = [image.convert("RGB") for image in batch_images] # also copies the images | |
| current_batch_size = len(batch_images) | |
| orig_sizes = [image.size for image in batch_images] | |
| model_inputs = processor(images=batch_images, boxes=deepcopy(batch_list_bboxes)) | |
| batch_pixel_values = model_inputs["pixel_values"] | |
| batch_bboxes = model_inputs["input_boxes"] | |
| batch_bbox_mask = model_inputs["input_boxes_mask"] | |
| batch_bbox_counts = model_inputs["input_boxes_counts"] | |
| batch_bboxes = torch.from_numpy(np.array(batch_bboxes, dtype=np.int32)).to(model.device) | |
| batch_bbox_mask = torch.from_numpy(np.array(batch_bbox_mask, dtype=np.int32)).to(model.device) | |
| batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device) | |
| batch_bbox_counts = torch.tensor(np.array(batch_bbox_counts), dtype=torch.long).to(model.device) | |
| # Setup inputs for the decoder | |
| batch_decoder_input = [[[model.config.decoder.bos_token_id] * 5] for _ in range(current_batch_size)] | |
| batch_decoder_input = torch.tensor(np.stack(batch_decoder_input, axis=0), dtype=torch.long, device=model.device) | |
| inference_token_count = batch_decoder_input.shape[1] | |
| max_tokens = min(batch_bbox_counts[:, 1].max().item(), settings.TABLE_REC_MAX_BOXES) | |
| decoder_position_ids = torch.ones_like(batch_decoder_input[0, :, 0], dtype=torch.int64, device=model.device).cumsum(0) - 1 | |
| model.decoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) | |
| model.text_encoder.model._setup_cache(model.config, batch_size, model.device, model.dtype) | |
| batch_predictions = [[] for _ in range(current_batch_size)] | |
| with torch.inference_mode(): | |
| encoder_hidden_states = model.encoder(pixel_values=batch_pixel_values).last_hidden_state | |
| text_encoder_hidden_states = model.text_encoder( | |
| input_boxes=batch_bboxes, | |
| input_boxes_counts=batch_bbox_counts, | |
| cache_position=None, | |
| attention_mask=batch_bbox_mask, | |
| encoder_hidden_states=encoder_hidden_states, | |
| encoder_attention_mask=None, | |
| use_cache=False | |
| ).hidden_states | |
| token_count = 0 | |
| all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) | |
| while token_count < max_tokens: | |
| is_prefill = token_count == 0 | |
| return_dict = model.decoder( | |
| input_ids=batch_decoder_input, | |
| encoder_hidden_states=text_encoder_hidden_states, | |
| cache_position=decoder_position_ids, | |
| use_cache=True, | |
| prefill=is_prefill | |
| ) | |
| decoder_position_ids = decoder_position_ids[-1:] + 1 | |
| box_logits = return_dict["bbox_logits"][:, -1, :].detach() | |
| rowcol_logits = return_dict["class_logits"][:, -1, :].detach() | |
| rowcol_preds = torch.argmax(rowcol_logits, dim=-1) | |
| box_preds = torch.argmax(box_logits, dim=-1) | |
| done = (rowcol_preds == processor.tokenizer.eos_id) | (rowcol_preds == processor.tokenizer.pad_id) | |
| done = done | |
| all_done = all_done | done | |
| if all_done.all(): | |
| break | |
| batch_decoder_input = torch.cat([box_preds.unsqueeze(1), rowcol_preds.unsqueeze(1).unsqueeze(1)], dim=-1) | |
| for j, (pred, status) in enumerate(zip(batch_decoder_input, all_done)): | |
| if not status: | |
| batch_predictions[j].append(pred[0].tolist()) | |
| token_count += inference_token_count | |
| inference_token_count = batch_decoder_input.shape[1] | |
| for j, (preds, input_cells, orig_size) in enumerate(zip(batch_predictions, batch_table_cells, orig_sizes)): | |
| img_w, img_h = orig_size | |
| width_scaler = img_w / model.config.decoder.out_box_size | |
| height_scaler = img_h / model.config.decoder.out_box_size | |
| # cx, cy to corners | |
| for i, pred in enumerate(preds): | |
| w = pred[2] / 2 | |
| h = pred[3] / 2 | |
| x1 = pred[0] - w | |
| y1 = pred[1] - h | |
| x2 = pred[0] + w | |
| y2 = pred[1] + h | |
| class_ = int(pred[4] - SPECIAL_TOKENS) | |
| preds[i] = [x1 * width_scaler, y1 * height_scaler, x2 * width_scaler, y2 * height_scaler, class_] | |
| # Get rows and columns | |
| bb_rows = [p[:4] for p in preds if p[4] == 0] | |
| bb_cols = [p[:4] for p in preds if p[4] == 1] | |
| rows = [] | |
| cols = [] | |
| for row_idx, row in enumerate(bb_rows): | |
| cell = TableCell( | |
| bbox=row, | |
| row_id=row_idx | |
| ) | |
| rows.append(cell) | |
| for col_idx, col in enumerate(bb_cols): | |
| cell = TableCell( | |
| bbox=col, | |
| col_id=col_idx, | |
| ) | |
| cols.append(cell) | |
| # Assign cells to rows/columns | |
| cells = [] | |
| for cell in input_cells: | |
| max_intersection = 0 | |
| row_pred = None | |
| for row_idx, row in enumerate(rows): | |
| intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(row) | |
| if intersection_pct > max_intersection: | |
| max_intersection = intersection_pct | |
| row_pred = row_idx | |
| max_intersection = 0 | |
| col_pred = None | |
| for col_idx, col in enumerate(cols): | |
| intersection_pct = Bbox(bbox=cell["bbox"]).intersection_pct(col) | |
| if intersection_pct > max_intersection: | |
| max_intersection = intersection_pct | |
| col_pred = col_idx | |
| cells.append( | |
| TableCell( | |
| bbox=cell["bbox"], | |
| text=cell.get("text"), | |
| row_id=row_pred, | |
| col_id=col_pred | |
| ) | |
| ) | |
| rotated = is_rotated(rows, cols) | |
| for cell in cells: | |
| if cell.row_id is None: | |
| closest_row = None | |
| closest_row_dist = None | |
| for cell2 in cells: | |
| if cell2.row_id is None: | |
| continue | |
| if rotated: | |
| cell_y_center = cell.center[0] | |
| cell2_y_center = cell2.center[0] | |
| else: | |
| cell_y_center = cell.center[1] | |
| cell2_y_center = cell2.center[1] | |
| y_dist = abs(cell_y_center - cell2_y_center) | |
| if closest_row_dist is None or y_dist < closest_row_dist: | |
| closest_row = cell2.row_id | |
| closest_row_dist = y_dist | |
| cell.row_id = closest_row | |
| if cell.col_id is None: | |
| closest_col = None | |
| closest_col_dist = None | |
| for cell2 in cells: | |
| if cell2.col_id is None: | |
| continue | |
| if rotated: | |
| cell_x_center = cell.center[1] | |
| cell2_x_center = cell2.center[1] | |
| else: | |
| cell_x_center = cell.center[0] | |
| cell2_x_center = cell2.center[0] | |
| x_dist = abs(cell2_x_center - cell_x_center) | |
| if closest_col_dist is None or x_dist < closest_col_dist: | |
| closest_col = cell2.col_id | |
| closest_col_dist = x_dist | |
| cell.col_id = closest_col | |
| result = TableResult( | |
| cells=cells, | |
| rows=rows, | |
| cols=cols, | |
| image_bbox=[0, 0, img_w, img_h], | |
| ) | |
| output_order.append(result) | |
| return output_order |