Spaces:
Sleeping
Sleeping
| from sklearn.cluster import DBSCAN | |
| import numpy as np | |
| from itertools import islice | |
| from collections import Counter | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DBSCAN_helper: | |
| def __init__(self, blocks): | |
| self.blocks = blocks | |
| def run(self): | |
| try: | |
| if not self.blocks: | |
| logger.warning("No blocks provided to DBSCAN_helper") | |
| # Return default values | |
| self.n_clusters = 0 | |
| self.labels = np.array([]) | |
| return | |
| # Extract features from blocks | |
| X = np.array( | |
| [(x0, y0, x1, y1, len(text)) for x0, y0, x1, y1, text in self.blocks] | |
| ) | |
| # Handle empty array | |
| if X.size == 0: | |
| logger.warning("Empty feature array for DBSCAN") | |
| self.n_clusters = 0 | |
| self.labels = np.array([]) | |
| return | |
| # Configure DBSCAN with explicit parameters for better control | |
| dbscan = DBSCAN(eps=0.5, min_samples=2, metric='euclidean') | |
| dbscan.fit(X) | |
| labels = dbscan.labels_ | |
| # Count the number of clusters (excluding noise points marked as -1) | |
| unique_labels = set(labels) | |
| if -1 in unique_labels: | |
| unique_labels.remove(-1) | |
| self.n_clusters = len(unique_labels) | |
| self.labels = labels | |
| logger.info(f"{self.n_clusters} clusters for {len(self.blocks)} blocks") | |
| except Exception as e: | |
| logger.error(f"Error in DBSCAN_helper: {str(e)}") | |
| # Set default values on error | |
| self.n_clusters = 0 | |
| self.labels = np.array([-1] * len(self.blocks)) if self.blocks else np.array([]) | |
| class Decomposer: | |
| def __init__(self, pdf_document=None): | |
| if not pdf_document: | |
| raise ValueError("PDF document must be provided") | |
| self.pdf_doc = pdf_document | |
| def calc_rect_center(self, rect, reverse_y=False): | |
| try: | |
| if reverse_y: | |
| x0, y0, x1, y1 = rect[0], -rect[1], rect[2], -rect[3] | |
| else: | |
| x0, y0, x1, y1 = rect | |
| x_center = (x0 + x1) / 2 | |
| y_center = (y0 + y1) / 2 | |
| return (x_center, y_center) | |
| except Exception as e: | |
| logger.error(f"Error calculating rectangle center: {str(e)}") | |
| return (0, 0) # Return default values on error | |
| def get_rect_labels(self): | |
| try: | |
| rect_centers = [] | |
| rects = [] | |
| visual_label_texts = [] | |
| categorize_vectors = [] | |
| for page_idx, page in islice(enumerate(self.pdf_doc), len(self.pdf_doc)): | |
| try: | |
| blocks = page.get_text("blocks") | |
| page_cnt = page_idx + 1 | |
| logger.debug(f"=== Start Page {page_cnt}: {len(blocks)} blocks ===") | |
| block_cnt = 0 | |
| for block in blocks: | |
| try: | |
| block_rect = block[:4] # (x0,y0,x1,y1) | |
| x0, y0, x1, y1 = block_rect | |
| rects.append(block_rect) | |
| # Handle possible encoding issues with block text | |
| block_text = block[4] | |
| if isinstance(block_text, bytes): | |
| block_text = block_text.decode('utf-8', errors='ignore') | |
| block_num = block[5] | |
| block_cnt = block_num + 1 | |
| rect_center = self.calc_rect_center(block_rect, reverse_y=True) | |
| rect_centers.append(rect_center) | |
| visual_label_text = f"({page_cnt}.{block_cnt})" | |
| visual_label_texts.append(visual_label_text) | |
| #block_type = "text" if block[6] == 0 else "image" | |
| categorize_vectors.append((*block_rect, block_text)) | |
| except Exception as block_error: | |
| logger.warning(f"Error processing block {block_cnt} on page {page_cnt}: {str(block_error)}") | |
| continue | |
| except Exception as page_error: | |
| logger.warning(f"Error processing page {page_idx + 1}: {str(page_error)}") | |
| continue | |
| if not categorize_vectors: | |
| logger.warning("No categorize vectors generated") | |
| return [] | |
| categorizer = DBSCAN_helper(categorize_vectors) | |
| categorizer.run() | |
| # Make sure the lengths match | |
| if len(rects) != len(categorizer.labels): | |
| logger.warning(f"Length mismatch: rects={len(rects)}, labels={len(categorizer.labels)}") | |
| # Handle mismatch by creating default labels | |
| if categorizer.labels.size == 0: # If labels array is empty | |
| result = [(rect, -1) for rect in rects] # Assign all to noise (-1) | |
| else: | |
| # Truncate to shorter length | |
| min_len = min(len(rects), len(categorizer.labels)) | |
| result = [(rects[i], categorizer.labels[i]) for i in range(min_len)] | |
| return result | |
| return [(rects[i], categorizer.labels[i]) for i in range(len(rects))] | |
| except Exception as e: | |
| logger.error(f"Error in get_rect_labels: {str(e)}") | |
| return [] # Return empty result on error | |
| def get_page_stats(self, res): | |
| try: | |
| if not res: | |
| logger.warning("Empty input to get_page_stats") | |
| return None, None, None # Handle empty input | |
| x_counter = Counter(x for _, x in res) | |
| y_diffs = Counter(i[3] - i[1] for i, _ in res) | |
| # Handle empty counters | |
| if not x_counter or not y_diffs: | |
| logger.warning("Empty counters in get_page_stats") | |
| return None, None, None | |
| most_common_x = x_counter.most_common(1)[0][0] | |
| threshold = float('inf') | |
| min_x = float('inf') | |
| for i, x in res: | |
| min_x = min(i[0], min_x) | |
| if x != most_common_x and i[0] < threshold: | |
| threshold = i[0] | |
| if threshold == float('inf'): # Fallback | |
| threshold = min_x | |
| min_y, max_y = float('inf'), -float('inf') # Changed from 0 to -inf | |
| for i, x in res: | |
| if x == -1 and i[0] <= threshold: | |
| min_y = min(min_y, i[1]) | |
| max_y = max(max_y, i[-1]) | |
| single_y = y_diffs.most_common(1)[0][0] if y_diffs else 0 | |
| # Additional validity checks | |
| if min_y == float('inf'): | |
| min_y = None | |
| if max_y == -float('inf'): | |
| max_y = None | |
| # Ensure single_y is positive | |
| single_y = abs(single_y) if single_y else 0 | |
| return min_y, max_y, single_y | |
| except Exception as e: | |
| logger.error(f"Error in get_page_stats: {str(e)}") | |
| return None, None, None # Return default values on error | |
| def run(self): | |
| try: | |
| rect_labels = self.get_rect_labels() | |
| stats = self.get_page_stats(rect_labels) | |
| logger.info(f"Page stats: min_y={stats[0]}, max_y={stats[1]}, single_y={stats[2]}") | |
| return stats | |
| except Exception as e: | |
| logger.error(f"Error in Decomposer.run: {str(e)}") | |
| return None, None, None # Return default values on error |