File size: 11,401 Bytes
ebcc7d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
from skimage.io import imread
from skimage.color import rgb2gray
import matplotlib.pyplot as plt
import numpy as np
from skimage.filters import threshold_otsu
import os
from skimage.graph import route_through_array
from heapq import heappush, heappop
from loguru import logger

def heuristic(a, b):
    """Calculate the squared distance between two points."""
    return (b[0] - a[0]) ** 2 + (b[1] - a[1]) ** 2


def get_binary(img):
    """Binarize the image using Otsu's threshold."""
    mean = np.mean(img)
    if mean == 0.0 or mean == 1.0:
        return img

    thresh = threshold_otsu(img)
    binary = img <= thresh
    binary = binary.astype(np.uint8)
    return binary


def astar(array, start, goal):
    """Perform A* algorithm to find a path from start to goal in a binary array."""
    neighbors = [(0,1),(0,-1),(1,0),(-1,0),(1,1),(1,-1),(-1,1),(-1,-1)]
    close_set = set()
    came_from = {}
    gscore = {start:0}
    fscore = {start:heuristic(start, goal)}
    oheap = []

    heappush(oheap, (fscore[start], start))

    while oheap:
        current = heappop(oheap)[1]

        if current == goal:
            data = []
            while current in came_from:
                data.append(current)
                current = came_from[current]
            return data

        close_set.add(current)
        for i, j in neighbors:
            neighbor = current[0] + i, current[1] + j
            tentative_g_score = gscore[current] + heuristic(current, neighbor)
            if 0 <= neighbor[0] < array.shape[0]:
                if 0 <= neighbor[1] < array.shape[1]:
                    if array[neighbor[0]][neighbor[1]] == 1:
                        continue
                else:
                    # array bound y walls
                    continue
            else:
                # array bound x walls
                continue

            if neighbor in close_set and tentative_g_score >= gscore.get(neighbor, 0):
                continue

            if tentative_g_score < gscore.get(neighbor, 0) or neighbor not in [i[1] for i in oheap]:
                came_from[neighbor] = current
                gscore[neighbor] = tentative_g_score
                fscore[neighbor] = tentative_g_score + heuristic(neighbor, goal)
                heappush(oheap, (fscore[neighbor], neighbor))

    return []


def preprocess_image(img, target_size):
    """Read and convert an image to grayscale."""
    try:
        if target_size is not None:
            img = img[target_size[0]:target_size[1], target_size[2]:target_size[3],:]
        if img.ndim == 3 and img.shape[2] == 4:
            img = img[..., :3]
        if img.ndim > 2:
            img = rgb2gray(img)
        return img
    except Exception as e:
        print(f"Error in preprocessing: {e}")
        return None


def horizontal_projections(sobel_image):
    """Calculate horizontal projections of the binary image."""
    return np.sum(sobel_image, axis=1)


def binarize_image(image):
    """Binarize an image using Otsu's threshold."""
    threshold = threshold_otsu(image)
    return image < threshold


def find_peak_regions(hpp, threshold):
    """Identify peak regions based on the horizontal projection profile."""
    peaks = []
    for i, hppv in enumerate(hpp):
        if hppv < threshold:
            peaks.append(i)
    return peaks


def line_segmentation(image, threshold=None, min_peak_group_size=7, target_size=None, 
                     ct=0, parent_line_num=None, recursive=False, recursive_count=1, 
                     base_key="line"):
    """
    Segment an image into lines using horizontal projections and A*.

    Args:
        image: Input image (numpy array)
        threshold (float, optional): Threshold for peak detection
        min_peak_group_size (int): Minimum size of peak groups to consider
        target_size (tuple, optional): Target size for image preprocessing
        ct (int): Counter for line numbering
        parent_line_num (str, optional): Parent line number for recursive segmentation
        recursive (bool): Whether this is a recursive call
        recursive_count (int): Counter for recursive segmentation numbering
        base_key (str): Base key for dictionary entries

    Returns:
        tuple: (segmented_images_dict, counter value, bool indicating if valid separations were found)
    """
    segmented_images_dict = {}
    
    img = preprocess_image(image, target_size)
    if img is None:
        print(f"Failed to preprocess image")
        return segmented_images_dict, ct, False

    # Binarize image and get projections
    binarized_image = binarize_image(img)
    hpp = horizontal_projections(binarized_image)

    if threshold is None:
        threshold = (np.max(hpp) - np.min(hpp)) / 2

    # Find peaks
    peaks = find_peak_regions(hpp, threshold)
    if not peaks:
        print(f"No peaks found in image")
        return segmented_images_dict, ct, False

    peaks_indexes = np.array(peaks).astype(int)

    segmented_img = np.copy(img)
    r, c = segmented_img.shape
    for ri in range(r):
        if ri in peaks_indexes:
            segmented_img[ri, :] = 0

    # Group peaks
    diff_between_consec_numbers = np.diff(peaks_indexes)
    indexes_with_larger_diff = np.where(diff_between_consec_numbers > 1)[0].flatten()
    peak_groups = np.split(peaks_indexes, indexes_with_larger_diff + 1)
    peak_groups = [item for item in peak_groups if len(item) > min_peak_group_size]

    if not peak_groups:
        print(f"No valid peak groups found in image")
        return segmented_images_dict, ct, False

    binary_image = get_binary(img)
    segment_separating_lines = []

    for sub_image_index in peak_groups:
        try:
            start_row = sub_image_index[0]
            end_row = sub_image_index[-1]

            start_row = max(0, start_row)
            end_row = min(binary_image.shape[0], end_row)

            if end_row <= start_row:
                continue

            nmap = binary_image[start_row:end_row, :]

            if nmap.size == 0:
                continue

            start_point = (int(nmap.shape[0] / 2), 0)
            end_point = (int(nmap.shape[0] / 2), nmap.shape[1] - 1)

            path, _ = route_through_array(nmap, start_point, end_point)
            path = np.array(path) + start_row
            segment_separating_lines.append(path)
        except Exception as e:
            print(f"Failed to process sub-image: {e}")
            continue

    if not segment_separating_lines:
        print(f"No valid segment separating lines found in image")
        return segmented_images_dict, ct, False

    # Separate images based on line segments
    seperated_images = []
    for index in range(len(segment_separating_lines) - 1):
        try:
            lower_line = np.min(segment_separating_lines[index][:, 0])
            upper_line = np.max(segment_separating_lines[index + 1][:, 0])

            if lower_line < upper_line and upper_line <= img.shape[0]:
                line_image = img[lower_line:upper_line]
                if line_image.size > 0: 
                    seperated_images.append(line_image)
        except Exception as e:
            print(f"Failed to separate image at index {index}: {e}")
            continue

    if not seperated_images:
        print(f"No valid separated images found in image")
        return segmented_images_dict, ct, False

    # Calculate height threshold
    try:
        image_heights = [line_image.shape[0] for line_image in seperated_images if line_image.size > 0]
        if not image_heights:
            print(f"No valid image heights found")
            return segmented_images_dict, ct, False
        height_threshold = np.percentile(image_heights, 90)
    except Exception as e:
        print(f"Failed to calculate height threshold: {e}")
        return segmented_images_dict, ct, False

    # Process each separated image
    for index, line_image in enumerate(seperated_images):
        try:
            if line_image.size == 0 or line_image.shape[0] == 0 or line_image.shape[1] == 0:
                continue

            if parent_line_num is None:
                dict_key = f'{base_key}_{ct + 1}'
            else:
                dict_key = f'{base_key}_{recursive_count}'
                if index < len(seperated_images) - 1:
                    continue

            segmented_images_dict[dict_key] = {
                "image": line_image.copy(),
                "transcription": f"{dict_key}"
            }
            
            # print(f"Added line image to dictionary with key: {dict_key}")

            # Handle recursive segmentation
            if line_image.shape[0] > height_threshold and not recursive:
                try:
                    # Create recursive base key
                    recursive_base_key = f"{base_key}_{ct + 1}"

                    # Do recursive segmentation
                    recursive_dict, ct, found_valid_separations = line_segmentation(
                        line_image, threshold=threshold,
                        min_peak_group_size=3,
                        parent_line_num=str(ct + 1),
                        recursive=True,
                        ct=ct,
                        recursive_count=1,
                        base_key=recursive_base_key
                    )

                    if found_valid_separations:
                        del segmented_images_dict[dict_key]
                        segmented_images_dict.update(recursive_dict)
                        print(f"Replaced {dict_key} with recursive segmentation results")
                    else:
                        print(f"Keeping original image {dict_key} as no valid separations were found")

                except Exception as e:
                    print(f"Failed during recursive segmentation of {dict_key}: {e}")

            ct += 1
            if recursive:
                recursive_count += 1

        except Exception as e:
            print(f"Failed to process line image at index {index}: {e}")
            continue
    logger.info(f"Total lines segment found: {len(segmented_images_dict)}")
    return segmented_images_dict, ct, len(seperated_images) > 0


def segment_image_to_lines(image_array, **kwargs):
    """
    Convenience function to segment an image into lines.
    
    Args:
        image_array: Input image as numpy array
        **kwargs: Additional arguments for line_segmentation
    
    Returns:
        dict: Dictionary with line keys and segmented image arrays as values
    """
    try:

        logger.info("Starting line segmentation...")
        segmented_dict, _, success = line_segmentation(image_array, **kwargs)
        if success:
            logger.info(f"Line segmentation successful.....")
    
        return segmented_dict
    except Exception as e:
        logger.error(f"Line segmentation failed: {e}")
        return {}
    
# if __name__ == "__main__":
#     # Example usage
#     image_path = "./renAI-deploy/1.png"
#     image = imread(image_path)
#     segmented_lines = segment_image_to_lines(image, threshold=None, min_peak_group_size=10)
    

#     print(len(segmented_lines.values()))

#     for key, value in segmented_lines.items():
#         print(f"{key}: {value['image'].shape}")
#         print(f"{key}: {value['transcription']}")
#         # plt.imshow(img, cmap='gray')
#         # plt.title(key)
#         # plt.show()