Spaces:
Running
Running
| from typing import List, Tuple, Generator | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from surya.model.detection.model import EfficientViTForSemanticSegmentation | |
| from surya.postprocessing.heatmap import get_and_clean_boxes | |
| from surya.postprocessing.affinity import get_vertical_lines | |
| from surya.input.processing import prepare_image_detection, split_image, get_total_splits, convert_if_not_rgb | |
| from surya.schema import TextDetectionResult | |
| from surya.settings import settings | |
| from tqdm import tqdm | |
| from concurrent.futures import ProcessPoolExecutor | |
| import torch.nn.functional as F | |
| def get_batch_size(): | |
| batch_size = settings.DETECTOR_BATCH_SIZE | |
| if batch_size is None: | |
| batch_size = 8 | |
| if settings.TORCH_DEVICE_MODEL == "mps": | |
| batch_size = 8 | |
| if settings.TORCH_DEVICE_MODEL == "cuda": | |
| batch_size = 36 | |
| return batch_size | |
| def batch_detection( | |
| images: List, | |
| model: EfficientViTForSemanticSegmentation, | |
| processor, | |
| batch_size=None | |
| ) -> Generator[Tuple[List[List[np.ndarray]], List[Tuple[int, int]]], None, None]: | |
| assert all([isinstance(image, Image.Image) for image in images]) | |
| if batch_size is None: | |
| batch_size = get_batch_size() | |
| heatmap_count = model.config.num_labels | |
| orig_sizes = [image.size for image in images] | |
| splits_per_image = [get_total_splits(size, processor) for size in orig_sizes] | |
| batches = [] | |
| current_batch_size = 0 | |
| current_batch = [] | |
| for i in range(len(images)): | |
| if current_batch_size + splits_per_image[i] > batch_size: | |
| if len(current_batch) > 0: | |
| batches.append(current_batch) | |
| current_batch = [] | |
| current_batch_size = 0 | |
| current_batch.append(i) | |
| current_batch_size += splits_per_image[i] | |
| if len(current_batch) > 0: | |
| batches.append(current_batch) | |
| for batch_idx in tqdm(range(len(batches)), desc="Detecting bboxes"): | |
| batch_image_idxs = batches[batch_idx] | |
| batch_images = [images[j].convert("RGB") for j in batch_image_idxs] | |
| split_index = [] | |
| split_heights = [] | |
| image_splits = [] | |
| for image_idx, image in enumerate(batch_images): | |
| image_parts, split_height = split_image(image, processor) | |
| image_splits.extend(image_parts) | |
| split_index.extend([image_idx] * len(image_parts)) | |
| split_heights.extend(split_height) | |
| image_splits = [prepare_image_detection(image, processor) for image in image_splits] | |
| # Batch images in dim 0 | |
| batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device) | |
| with torch.inference_mode(): | |
| pred = model(pixel_values=batch) | |
| logits = pred.logits | |
| correct_shape = [processor.size["height"], processor.size["width"]] | |
| current_shape = list(logits.shape[2:]) | |
| if current_shape != correct_shape: | |
| logits = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False) | |
| logits = logits.cpu().detach().numpy().astype(np.float32) | |
| preds = [] | |
| for i, (idx, height) in enumerate(zip(split_index, split_heights)): | |
| # If our current prediction length is below the image idx, that means we have a new image | |
| # Otherwise, we need to add to the current image | |
| if len(preds) <= idx: | |
| preds.append([logits[i][k] for k in range(heatmap_count)]) | |
| else: | |
| heatmaps = preds[idx] | |
| pred_heatmaps = [logits[i][k] for k in range(heatmap_count)] | |
| if height < processor.size["height"]: | |
| # Cut off padding to get original height | |
| pred_heatmaps = [pred_heatmap[:height, :] for pred_heatmap in pred_heatmaps] | |
| for k in range(heatmap_count): | |
| heatmaps[k] = np.vstack([heatmaps[k], pred_heatmaps[k]]) | |
| preds[idx] = heatmaps | |
| yield preds, [orig_sizes[j] for j in batch_image_idxs] | |
| def parallel_get_lines(preds, orig_sizes): | |
| heatmap, affinity_map = preds | |
| heat_img = Image.fromarray((heatmap * 255).astype(np.uint8)) | |
| aff_img = Image.fromarray((affinity_map * 255).astype(np.uint8)) | |
| affinity_size = list(reversed(affinity_map.shape)) | |
| heatmap_size = list(reversed(heatmap.shape)) | |
| bboxes = get_and_clean_boxes(heatmap, heatmap_size, orig_sizes) | |
| vertical_lines = get_vertical_lines(affinity_map, affinity_size, orig_sizes) | |
| result = TextDetectionResult( | |
| bboxes=bboxes, | |
| vertical_lines=vertical_lines, | |
| heatmap=heat_img, | |
| affinity_map=aff_img, | |
| image_bbox=[0, 0, orig_sizes[0], orig_sizes[1]] | |
| ) | |
| return result | |
| def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: | |
| detection_generator = batch_detection(images, model, processor, batch_size=batch_size) | |
| results = [] | |
| max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) | |
| parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH | |
| if parallelize: | |
| with ProcessPoolExecutor(max_workers=max_workers) as executor: | |
| for preds, orig_sizes in detection_generator: | |
| batch_results = list(executor.map(parallel_get_lines, preds, orig_sizes)) | |
| results.extend(batch_results) | |
| else: | |
| for preds, orig_sizes in detection_generator: | |
| for pred, orig_size in zip(preds, orig_sizes): | |
| results.append(parallel_get_lines(pred, orig_size)) | |
| return results | |