fincatch-ocr / src /utils /decompose.py
gnlui's picture
initial
0bad002
from sklearn.cluster import DBSCAN
import numpy as np
from itertools import islice
from collections import Counter
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DBSCAN_helper:
def __init__(self, blocks):
self.blocks = blocks
def run(self):
try:
if not self.blocks:
logger.warning("No blocks provided to DBSCAN_helper")
# Return default values
self.n_clusters = 0
self.labels = np.array([])
return
# Extract features from blocks
X = np.array(
[(x0, y0, x1, y1, len(text)) for x0, y0, x1, y1, text in self.blocks]
)
# Handle empty array
if X.size == 0:
logger.warning("Empty feature array for DBSCAN")
self.n_clusters = 0
self.labels = np.array([])
return
# Configure DBSCAN with explicit parameters for better control
dbscan = DBSCAN(eps=0.5, min_samples=2, metric='euclidean')
dbscan.fit(X)
labels = dbscan.labels_
# Count the number of clusters (excluding noise points marked as -1)
unique_labels = set(labels)
if -1 in unique_labels:
unique_labels.remove(-1)
self.n_clusters = len(unique_labels)
self.labels = labels
logger.info(f"{self.n_clusters} clusters for {len(self.blocks)} blocks")
except Exception as e:
logger.error(f"Error in DBSCAN_helper: {str(e)}")
# Set default values on error
self.n_clusters = 0
self.labels = np.array([-1] * len(self.blocks)) if self.blocks else np.array([])
class Decomposer:
def __init__(self, pdf_document=None):
if not pdf_document:
raise ValueError("PDF document must be provided")
self.pdf_doc = pdf_document
def calc_rect_center(self, rect, reverse_y=False):
try:
if reverse_y:
x0, y0, x1, y1 = rect[0], -rect[1], rect[2], -rect[3]
else:
x0, y0, x1, y1 = rect
x_center = (x0 + x1) / 2
y_center = (y0 + y1) / 2
return (x_center, y_center)
except Exception as e:
logger.error(f"Error calculating rectangle center: {str(e)}")
return (0, 0) # Return default values on error
def get_rect_labels(self):
try:
rect_centers = []
rects = []
visual_label_texts = []
categorize_vectors = []
for page_idx, page in islice(enumerate(self.pdf_doc), len(self.pdf_doc)):
try:
blocks = page.get_text("blocks")
page_cnt = page_idx + 1
logger.debug(f"=== Start Page {page_cnt}: {len(blocks)} blocks ===")
block_cnt = 0
for block in blocks:
try:
block_rect = block[:4] # (x0,y0,x1,y1)
x0, y0, x1, y1 = block_rect
rects.append(block_rect)
# Handle possible encoding issues with block text
block_text = block[4]
if isinstance(block_text, bytes):
block_text = block_text.decode('utf-8', errors='ignore')
block_num = block[5]
block_cnt = block_num + 1
rect_center = self.calc_rect_center(block_rect, reverse_y=True)
rect_centers.append(rect_center)
visual_label_text = f"({page_cnt}.{block_cnt})"
visual_label_texts.append(visual_label_text)
#block_type = "text" if block[6] == 0 else "image"
categorize_vectors.append((*block_rect, block_text))
except Exception as block_error:
logger.warning(f"Error processing block {block_cnt} on page {page_cnt}: {str(block_error)}")
continue
except Exception as page_error:
logger.warning(f"Error processing page {page_idx + 1}: {str(page_error)}")
continue
if not categorize_vectors:
logger.warning("No categorize vectors generated")
return []
categorizer = DBSCAN_helper(categorize_vectors)
categorizer.run()
# Make sure the lengths match
if len(rects) != len(categorizer.labels):
logger.warning(f"Length mismatch: rects={len(rects)}, labels={len(categorizer.labels)}")
# Handle mismatch by creating default labels
if categorizer.labels.size == 0: # If labels array is empty
result = [(rect, -1) for rect in rects] # Assign all to noise (-1)
else:
# Truncate to shorter length
min_len = min(len(rects), len(categorizer.labels))
result = [(rects[i], categorizer.labels[i]) for i in range(min_len)]
return result
return [(rects[i], categorizer.labels[i]) for i in range(len(rects))]
except Exception as e:
logger.error(f"Error in get_rect_labels: {str(e)}")
return [] # Return empty result on error
def get_page_stats(self, res):
try:
if not res:
logger.warning("Empty input to get_page_stats")
return None, None, None # Handle empty input
x_counter = Counter(x for _, x in res)
y_diffs = Counter(i[3] - i[1] for i, _ in res)
# Handle empty counters
if not x_counter or not y_diffs:
logger.warning("Empty counters in get_page_stats")
return None, None, None
most_common_x = x_counter.most_common(1)[0][0]
threshold = float('inf')
min_x = float('inf')
for i, x in res:
min_x = min(i[0], min_x)
if x != most_common_x and i[0] < threshold:
threshold = i[0]
if threshold == float('inf'): # Fallback
threshold = min_x
min_y, max_y = float('inf'), -float('inf') # Changed from 0 to -inf
for i, x in res:
if x == -1 and i[0] <= threshold:
min_y = min(min_y, i[1])
max_y = max(max_y, i[-1])
single_y = y_diffs.most_common(1)[0][0] if y_diffs else 0
# Additional validity checks
if min_y == float('inf'):
min_y = None
if max_y == -float('inf'):
max_y = None
# Ensure single_y is positive
single_y = abs(single_y) if single_y else 0
return min_y, max_y, single_y
except Exception as e:
logger.error(f"Error in get_page_stats: {str(e)}")
return None, None, None # Return default values on error
def run(self):
try:
rect_labels = self.get_rect_labels()
stats = self.get_page_stats(rect_labels)
logger.info(f"Page stats: min_y={stats[0]}, max_y={stats[1]}, single_y={stats[2]}")
return stats
except Exception as e:
logger.error(f"Error in Decomposer.run: {str(e)}")
return None, None, None # Return default values on error