from skimage.io import imread from skimage.color import rgb2gray import matplotlib.pyplot as plt import numpy as np from skimage.filters import threshold_otsu import os from skimage.graph import route_through_array from heapq import heappush, heappop from loguru import logger def heuristic(a, b): """Calculate the squared distance between two points.""" return (b[0] - a[0]) ** 2 + (b[1] - a[1]) ** 2 def get_binary(img): """Binarize the image using Otsu's threshold.""" mean = np.mean(img) if mean == 0.0 or mean == 1.0: return img thresh = threshold_otsu(img) binary = img <= thresh binary = binary.astype(np.uint8) return binary def astar(array, start, goal): """Perform A* algorithm to find a path from start to goal in a binary array.""" neighbors = [(0,1),(0,-1),(1,0),(-1,0),(1,1),(1,-1),(-1,1),(-1,-1)] close_set = set() came_from = {} gscore = {start:0} fscore = {start:heuristic(start, goal)} oheap = [] heappush(oheap, (fscore[start], start)) while oheap: current = heappop(oheap)[1] if current == goal: data = [] while current in came_from: data.append(current) current = came_from[current] return data close_set.add(current) for i, j in neighbors: neighbor = current[0] + i, current[1] + j tentative_g_score = gscore[current] + heuristic(current, neighbor) if 0 <= neighbor[0] < array.shape[0]: if 0 <= neighbor[1] < array.shape[1]: if array[neighbor[0]][neighbor[1]] == 1: continue else: # array bound y walls continue else: # array bound x walls continue if neighbor in close_set and tentative_g_score >= gscore.get(neighbor, 0): continue if tentative_g_score < gscore.get(neighbor, 0) or neighbor not in [i[1] for i in oheap]: came_from[neighbor] = current gscore[neighbor] = tentative_g_score fscore[neighbor] = tentative_g_score + heuristic(neighbor, goal) heappush(oheap, (fscore[neighbor], neighbor)) return [] def preprocess_image(img, target_size): """Read and convert an image to grayscale.""" try: if target_size is not None: img = img[target_size[0]:target_size[1], target_size[2]:target_size[3],:] if img.ndim == 3 and img.shape[2] == 4: img = img[..., :3] if img.ndim > 2: img = rgb2gray(img) return img except Exception as e: print(f"Error in preprocessing: {e}") return None def horizontal_projections(sobel_image): """Calculate horizontal projections of the binary image.""" return np.sum(sobel_image, axis=1) def binarize_image(image): """Binarize an image using Otsu's threshold.""" threshold = threshold_otsu(image) return image < threshold def find_peak_regions(hpp, threshold): """Identify peak regions based on the horizontal projection profile.""" peaks = [] for i, hppv in enumerate(hpp): if hppv < threshold: peaks.append(i) return peaks def line_segmentation(image, threshold=None, min_peak_group_size=7, target_size=None, ct=0, parent_line_num=None, recursive=False, recursive_count=1, base_key="line"): """ Segment an image into lines using horizontal projections and A*. Args: image: Input image (numpy array) threshold (float, optional): Threshold for peak detection min_peak_group_size (int): Minimum size of peak groups to consider target_size (tuple, optional): Target size for image preprocessing ct (int): Counter for line numbering parent_line_num (str, optional): Parent line number for recursive segmentation recursive (bool): Whether this is a recursive call recursive_count (int): Counter for recursive segmentation numbering base_key (str): Base key for dictionary entries Returns: tuple: (segmented_images_dict, counter value, bool indicating if valid separations were found) """ segmented_images_dict = {} img = preprocess_image(image, target_size) if img is None: print(f"Failed to preprocess image") return segmented_images_dict, ct, False # Binarize image and get projections binarized_image = binarize_image(img) hpp = horizontal_projections(binarized_image) if threshold is None: threshold = (np.max(hpp) - np.min(hpp)) / 2 # Find peaks peaks = find_peak_regions(hpp, threshold) if not peaks: print(f"No peaks found in image") return segmented_images_dict, ct, False peaks_indexes = np.array(peaks).astype(int) segmented_img = np.copy(img) r, c = segmented_img.shape for ri in range(r): if ri in peaks_indexes: segmented_img[ri, :] = 0 # Group peaks diff_between_consec_numbers = np.diff(peaks_indexes) indexes_with_larger_diff = np.where(diff_between_consec_numbers > 1)[0].flatten() peak_groups = np.split(peaks_indexes, indexes_with_larger_diff + 1) peak_groups = [item for item in peak_groups if len(item) > min_peak_group_size] if not peak_groups: print(f"No valid peak groups found in image") return segmented_images_dict, ct, False binary_image = get_binary(img) segment_separating_lines = [] for sub_image_index in peak_groups: try: start_row = sub_image_index[0] end_row = sub_image_index[-1] start_row = max(0, start_row) end_row = min(binary_image.shape[0], end_row) if end_row <= start_row: continue nmap = binary_image[start_row:end_row, :] if nmap.size == 0: continue start_point = (int(nmap.shape[0] / 2), 0) end_point = (int(nmap.shape[0] / 2), nmap.shape[1] - 1) path, _ = route_through_array(nmap, start_point, end_point) path = np.array(path) + start_row segment_separating_lines.append(path) except Exception as e: print(f"Failed to process sub-image: {e}") continue if not segment_separating_lines: print(f"No valid segment separating lines found in image") return segmented_images_dict, ct, False # Separate images based on line segments seperated_images = [] for index in range(len(segment_separating_lines) - 1): try: lower_line = np.min(segment_separating_lines[index][:, 0]) upper_line = np.max(segment_separating_lines[index + 1][:, 0]) if lower_line < upper_line and upper_line <= img.shape[0]: line_image = img[lower_line:upper_line] if line_image.size > 0: seperated_images.append(line_image) except Exception as e: print(f"Failed to separate image at index {index}: {e}") continue if not seperated_images: print(f"No valid separated images found in image") return segmented_images_dict, ct, False # Calculate height threshold try: image_heights = [line_image.shape[0] for line_image in seperated_images if line_image.size > 0] if not image_heights: print(f"No valid image heights found") return segmented_images_dict, ct, False height_threshold = np.percentile(image_heights, 90) except Exception as e: print(f"Failed to calculate height threshold: {e}") return segmented_images_dict, ct, False # Process each separated image for index, line_image in enumerate(seperated_images): try: if line_image.size == 0 or line_image.shape[0] == 0 or line_image.shape[1] == 0: continue if parent_line_num is None: dict_key = f'{base_key}_{ct + 1}' else: dict_key = f'{base_key}_{recursive_count}' if index < len(seperated_images) - 1: continue segmented_images_dict[dict_key] = { "image": line_image.copy(), "transcription": f"{dict_key}" } # print(f"Added line image to dictionary with key: {dict_key}") # Handle recursive segmentation if line_image.shape[0] > height_threshold and not recursive: try: # Create recursive base key recursive_base_key = f"{base_key}_{ct + 1}" # Do recursive segmentation recursive_dict, ct, found_valid_separations = line_segmentation( line_image, threshold=threshold, min_peak_group_size=3, parent_line_num=str(ct + 1), recursive=True, ct=ct, recursive_count=1, base_key=recursive_base_key ) if found_valid_separations: del segmented_images_dict[dict_key] segmented_images_dict.update(recursive_dict) print(f"Replaced {dict_key} with recursive segmentation results") else: print(f"Keeping original image {dict_key} as no valid separations were found") except Exception as e: print(f"Failed during recursive segmentation of {dict_key}: {e}") ct += 1 if recursive: recursive_count += 1 except Exception as e: print(f"Failed to process line image at index {index}: {e}") continue logger.info(f"Total lines segment found: {len(segmented_images_dict)}") return segmented_images_dict, ct, len(seperated_images) > 0 def segment_image_to_lines(image_array, **kwargs): """ Convenience function to segment an image into lines. Args: image_array: Input image as numpy array **kwargs: Additional arguments for line_segmentation Returns: dict: Dictionary with line keys and segmented image arrays as values """ try: logger.info("Starting line segmentation...") segmented_dict, _, success = line_segmentation(image_array, **kwargs) if success: logger.info(f"Line segmentation successful.....") return segmented_dict except Exception as e: logger.error(f"Line segmentation failed: {e}") return {} # if __name__ == "__main__": # # Example usage # image_path = "./renAI-deploy/1.png" # image = imread(image_path) # segmented_lines = segment_image_to_lines(image, threshold=None, min_peak_group_size=10) # print(len(segmented_lines.values())) # for key, value in segmented_lines.items(): # print(f"{key}: {value['image'].shape}") # print(f"{key}: {value['transcription']}") # # plt.imshow(img, cmap='gray') # # plt.title(key) # # plt.show()