RenAI / utils /line_segmentation.py
Arsh124's picture
Initial RenAI app
ebcc7d1
from skimage.io import imread
from skimage.color import rgb2gray
import matplotlib.pyplot as plt
import numpy as np
from skimage.filters import threshold_otsu
import os
from skimage.graph import route_through_array
from heapq import heappush, heappop
from loguru import logger
def heuristic(a, b):
"""Calculate the squared distance between two points."""
return (b[0] - a[0]) ** 2 + (b[1] - a[1]) ** 2
def get_binary(img):
"""Binarize the image using Otsu's threshold."""
mean = np.mean(img)
if mean == 0.0 or mean == 1.0:
return img
thresh = threshold_otsu(img)
binary = img <= thresh
binary = binary.astype(np.uint8)
return binary
def astar(array, start, goal):
"""Perform A* algorithm to find a path from start to goal in a binary array."""
neighbors = [(0,1),(0,-1),(1,0),(-1,0),(1,1),(1,-1),(-1,1),(-1,-1)]
close_set = set()
came_from = {}
gscore = {start:0}
fscore = {start:heuristic(start, goal)}
oheap = []
heappush(oheap, (fscore[start], start))
while oheap:
current = heappop(oheap)[1]
if current == goal:
data = []
while current in came_from:
data.append(current)
current = came_from[current]
return data
close_set.add(current)
for i, j in neighbors:
neighbor = current[0] + i, current[1] + j
tentative_g_score = gscore[current] + heuristic(current, neighbor)
if 0 <= neighbor[0] < array.shape[0]:
if 0 <= neighbor[1] < array.shape[1]:
if array[neighbor[0]][neighbor[1]] == 1:
continue
else:
# array bound y walls
continue
else:
# array bound x walls
continue
if neighbor in close_set and tentative_g_score >= gscore.get(neighbor, 0):
continue
if tentative_g_score < gscore.get(neighbor, 0) or neighbor not in [i[1] for i in oheap]:
came_from[neighbor] = current
gscore[neighbor] = tentative_g_score
fscore[neighbor] = tentative_g_score + heuristic(neighbor, goal)
heappush(oheap, (fscore[neighbor], neighbor))
return []
def preprocess_image(img, target_size):
"""Read and convert an image to grayscale."""
try:
if target_size is not None:
img = img[target_size[0]:target_size[1], target_size[2]:target_size[3],:]
if img.ndim == 3 and img.shape[2] == 4:
img = img[..., :3]
if img.ndim > 2:
img = rgb2gray(img)
return img
except Exception as e:
print(f"Error in preprocessing: {e}")
return None
def horizontal_projections(sobel_image):
"""Calculate horizontal projections of the binary image."""
return np.sum(sobel_image, axis=1)
def binarize_image(image):
"""Binarize an image using Otsu's threshold."""
threshold = threshold_otsu(image)
return image < threshold
def find_peak_regions(hpp, threshold):
"""Identify peak regions based on the horizontal projection profile."""
peaks = []
for i, hppv in enumerate(hpp):
if hppv < threshold:
peaks.append(i)
return peaks
def line_segmentation(image, threshold=None, min_peak_group_size=7, target_size=None,
ct=0, parent_line_num=None, recursive=False, recursive_count=1,
base_key="line"):
"""
Segment an image into lines using horizontal projections and A*.
Args:
image: Input image (numpy array)
threshold (float, optional): Threshold for peak detection
min_peak_group_size (int): Minimum size of peak groups to consider
target_size (tuple, optional): Target size for image preprocessing
ct (int): Counter for line numbering
parent_line_num (str, optional): Parent line number for recursive segmentation
recursive (bool): Whether this is a recursive call
recursive_count (int): Counter for recursive segmentation numbering
base_key (str): Base key for dictionary entries
Returns:
tuple: (segmented_images_dict, counter value, bool indicating if valid separations were found)
"""
segmented_images_dict = {}
img = preprocess_image(image, target_size)
if img is None:
print(f"Failed to preprocess image")
return segmented_images_dict, ct, False
# Binarize image and get projections
binarized_image = binarize_image(img)
hpp = horizontal_projections(binarized_image)
if threshold is None:
threshold = (np.max(hpp) - np.min(hpp)) / 2
# Find peaks
peaks = find_peak_regions(hpp, threshold)
if not peaks:
print(f"No peaks found in image")
return segmented_images_dict, ct, False
peaks_indexes = np.array(peaks).astype(int)
segmented_img = np.copy(img)
r, c = segmented_img.shape
for ri in range(r):
if ri in peaks_indexes:
segmented_img[ri, :] = 0
# Group peaks
diff_between_consec_numbers = np.diff(peaks_indexes)
indexes_with_larger_diff = np.where(diff_between_consec_numbers > 1)[0].flatten()
peak_groups = np.split(peaks_indexes, indexes_with_larger_diff + 1)
peak_groups = [item for item in peak_groups if len(item) > min_peak_group_size]
if not peak_groups:
print(f"No valid peak groups found in image")
return segmented_images_dict, ct, False
binary_image = get_binary(img)
segment_separating_lines = []
for sub_image_index in peak_groups:
try:
start_row = sub_image_index[0]
end_row = sub_image_index[-1]
start_row = max(0, start_row)
end_row = min(binary_image.shape[0], end_row)
if end_row <= start_row:
continue
nmap = binary_image[start_row:end_row, :]
if nmap.size == 0:
continue
start_point = (int(nmap.shape[0] / 2), 0)
end_point = (int(nmap.shape[0] / 2), nmap.shape[1] - 1)
path, _ = route_through_array(nmap, start_point, end_point)
path = np.array(path) + start_row
segment_separating_lines.append(path)
except Exception as e:
print(f"Failed to process sub-image: {e}")
continue
if not segment_separating_lines:
print(f"No valid segment separating lines found in image")
return segmented_images_dict, ct, False
# Separate images based on line segments
seperated_images = []
for index in range(len(segment_separating_lines) - 1):
try:
lower_line = np.min(segment_separating_lines[index][:, 0])
upper_line = np.max(segment_separating_lines[index + 1][:, 0])
if lower_line < upper_line and upper_line <= img.shape[0]:
line_image = img[lower_line:upper_line]
if line_image.size > 0:
seperated_images.append(line_image)
except Exception as e:
print(f"Failed to separate image at index {index}: {e}")
continue
if not seperated_images:
print(f"No valid separated images found in image")
return segmented_images_dict, ct, False
# Calculate height threshold
try:
image_heights = [line_image.shape[0] for line_image in seperated_images if line_image.size > 0]
if not image_heights:
print(f"No valid image heights found")
return segmented_images_dict, ct, False
height_threshold = np.percentile(image_heights, 90)
except Exception as e:
print(f"Failed to calculate height threshold: {e}")
return segmented_images_dict, ct, False
# Process each separated image
for index, line_image in enumerate(seperated_images):
try:
if line_image.size == 0 or line_image.shape[0] == 0 or line_image.shape[1] == 0:
continue
if parent_line_num is None:
dict_key = f'{base_key}_{ct + 1}'
else:
dict_key = f'{base_key}_{recursive_count}'
if index < len(seperated_images) - 1:
continue
segmented_images_dict[dict_key] = {
"image": line_image.copy(),
"transcription": f"{dict_key}"
}
# print(f"Added line image to dictionary with key: {dict_key}")
# Handle recursive segmentation
if line_image.shape[0] > height_threshold and not recursive:
try:
# Create recursive base key
recursive_base_key = f"{base_key}_{ct + 1}"
# Do recursive segmentation
recursive_dict, ct, found_valid_separations = line_segmentation(
line_image, threshold=threshold,
min_peak_group_size=3,
parent_line_num=str(ct + 1),
recursive=True,
ct=ct,
recursive_count=1,
base_key=recursive_base_key
)
if found_valid_separations:
del segmented_images_dict[dict_key]
segmented_images_dict.update(recursive_dict)
print(f"Replaced {dict_key} with recursive segmentation results")
else:
print(f"Keeping original image {dict_key} as no valid separations were found")
except Exception as e:
print(f"Failed during recursive segmentation of {dict_key}: {e}")
ct += 1
if recursive:
recursive_count += 1
except Exception as e:
print(f"Failed to process line image at index {index}: {e}")
continue
logger.info(f"Total lines segment found: {len(segmented_images_dict)}")
return segmented_images_dict, ct, len(seperated_images) > 0
def segment_image_to_lines(image_array, **kwargs):
"""
Convenience function to segment an image into lines.
Args:
image_array: Input image as numpy array
**kwargs: Additional arguments for line_segmentation
Returns:
dict: Dictionary with line keys and segmented image arrays as values
"""
try:
logger.info("Starting line segmentation...")
segmented_dict, _, success = line_segmentation(image_array, **kwargs)
if success:
logger.info(f"Line segmentation successful.....")
return segmented_dict
except Exception as e:
logger.error(f"Line segmentation failed: {e}")
return {}
# if __name__ == "__main__":
# # Example usage
# image_path = "./renAI-deploy/1.png"
# image = imread(image_path)
# segmented_lines = segment_image_to_lines(image, threshold=None, min_peak_group_size=10)
# print(len(segmented_lines.values()))
# for key, value in segmented_lines.items():
# print(f"{key}: {value['image'].shape}")
# print(f"{key}: {value['transcription']}")
# # plt.imshow(img, cmap='gray')
# # plt.title(key)
# # plt.show()