File size: 4,060 Bytes
27d9242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bfdad0
 
27d9242
5bfdad0
 
27d9242
5bfdad0
 
 
 
27d9242
5bfdad0
 
 
27d9242
5bfdad0
 
 
 
 
 
 
 
 
27d9242
 
df4ac89
 
 
 
 
4559f18
 
 
df4ac89
4559f18
df4ac89
 
27d9242
df4ac89
 
27d9242
 
 
 
 
 
 
5533eed
 
 
 
 
 
 
2205e98
 
 
 
 
 
5533eed
 
2205e98
5533eed
 
2205e98
 
 
 
 
 
 
 
 
 
 
 
5533eed
2205e98
 
 
 
5533eed
2205e98
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import time
from doctr.file_utils import is_tf_available
import numpy as np
import cv2

if is_tf_available():
    import tensorflow as tf
    from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
else:
    import torch
    from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor

class OCRModel:
    def __init__(self):
        self.predictor = None
        self.device = None
        self._init_backend()

    def _init_backend(self):
        # First check if TF is available
        if is_tf_available():
            import tensorflow as tf
            from backend.tensorflow import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
            self.DET_ARCHS = DET_ARCHS
            self.RECO_ARCHS = RECO_ARCHS
            if any(tf.config.experimental.list_physical_devices("gpu")):
                self.device = tf.device("/gpu:0")
            else:
                self.device = tf.device("/cpu:0")
        else:
            # Only import torch if TF is not available
            import torch
            from backend.pytorch import DET_ARCHS, RECO_ARCHS, forward_image, load_predictor
            self.DET_ARCHS = DET_ARCHS
            self.RECO_ARCHS = RECO_ARCHS
            if torch.cuda.is_available():
                self.device = torch.device("cuda:0")
            else:
                self.device = torch.device("cpu")

    def load_model(self, det_arch, reco_arch, **kwargs):
        model_params = {
            "det_arch": det_arch,
            "reco_arch": reco_arch,
            "assume_straight_pages": kwargs.get("assume_straight_pages", True),
            "straighten_pages": kwargs.get("straighten_pages", False),
            "export_as_straight_boxes": kwargs.get("export_as_straight_boxes", False),
            "disable_page_orientation": kwargs.get("disable_page_orientation", False),
            "disable_crop_orientation": kwargs.get("disable_crop_orientation", False),
            "bin_thresh": kwargs.get("bin_thresh", 0.3),
            "box_thresh": kwargs.get("box_thresh", 0.1)
        }
        
        self.predictor = load_predictor(
            **model_params,
            device=self.device
        )

    def process_page(self, page):
        seg_map = forward_image(self.predictor, page, self.device)
        seg_map = np.squeeze(seg_map)
        seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
        out = self.predictor([page])
        return seg_map, out
    
    def get_reconstructed_page(self, out, page):
        """Get reconstructed page from OCR output"""
        page_export = out.pages[0].export()
        img = out.pages[0].synthesize()
        
        # Validate if blocks exist
        if not page_export["blocks"]:
            return img  # Return full image if no blocks detected
        
        # Initialize with image dimensions instead of infinity
        x_min, y_min = img.shape[1], img.shape[0]
        x_max, y_max = 0, 0
        
        valid_coords = False
        for block in page_export["blocks"]:
            coords = np.array(block["geometry"])
            if coords.size > 0:  # Check if coordinates exist
                valid_coords = True
                x_min = min(x_min, max(0, np.nanmin(coords[:, 0]) * img.shape[1]))
                y_min = min(y_min, max(0, np.nanmin(coords[:, 1]) * img.shape[0]))
                x_max = max(x_max, min(img.shape[1], np.nanmax(coords[:, 0]) * img.shape[1]))
                y_max = max(y_max, min(img.shape[0], np.nanmax(coords[:, 1]) * img.shape[0]))
        
        # Return full image if no valid coordinates found
        if not valid_coords or x_min >= x_max or y_min >= y_max:
            return img
            
        # Add margins and ensure bounds
        margin = 10
        x_min = max(0, int(x_min - margin))
        y_min = max(0, int(y_min - margin))
        x_max = min(img.shape[1], int(x_max + margin))
        y_max = min(img.shape[0], int(y_max + margin))
        
        return img[y_min:y_max, x_min:x_max]