from pathlib import Path import json import logging from datetime import datetime from typing import List, Dict, Optional, Tuple from storage import StorageFactory import uuid import traceback import os import cv2 import numpy as np logger = logging.getLogger(__name__) class DataAggregator: def __init__(self, storage=None): self.storage = storage or StorageFactory.get_storage() self.logger = logging.getLogger(__name__) def _parse_line_data(self, lines_data: dict) -> List[dict]: """Parse line detection data with coordinate validation""" parsed_lines = [] for line in lines_data.get("lines", []): try: # Extract and validate line coordinates start_coords = line["start"]["coords"] end_coords = line["end"]["coords"] bbox = line["bbox"] # Validate coordinates if not (self._is_valid_point(start_coords) and self._is_valid_point(end_coords) and self._is_valid_bbox(bbox)): self.logger.warning(f"Invalid coordinates in line: {line['id']}") continue # Create parsed line with validated coordinates parsed_line = { "id": line["id"], "start_point": { "x": int(start_coords["x"]), "y": int(start_coords["y"]), "type": line["start"]["type"], "confidence": line["start"]["confidence"] }, "end_point": { "x": int(end_coords["x"]), "y": int(end_coords["y"]), "type": line["end"]["type"], "confidence": line["end"]["confidence"] }, "bbox": { "xmin": int(bbox["xmin"]), "ymin": int(bbox["ymin"]), "xmax": int(bbox["xmax"]), "ymax": int(bbox["ymax"]) }, "style": line["style"], "confidence": line["confidence"] } parsed_lines.append(parsed_line) except Exception as e: self.logger.error(f"Error parsing line {line.get('id')}: {str(e)}") continue return parsed_lines def _is_valid_point(self, point: dict) -> bool: """Validate point coordinates""" try: x, y = point.get("x"), point.get("y") return (isinstance(x, (int, float)) and isinstance(y, (int, float)) and 0 <= x <= 10000 and 0 <= y <= 10000) # Adjust range as needed except: return False def _is_valid_bbox(self, bbox: dict) -> bool: """Validate bbox coordinates""" try: xmin = bbox.get("xmin") ymin = bbox.get("ymin") xmax = bbox.get("xmax") ymax = bbox.get("ymax") return (isinstance(xmin, (int, float)) and isinstance(ymin, (int, float)) and isinstance(xmax, (int, float)) and isinstance(ymax, (int, float)) and xmin < xmax and ymin < ymax and 0 <= xmin <= 10000 and 0 <= ymin <= 10000 and 0 <= xmax <= 10000 and 0 <= ymax <= 10000) except: return False def _create_graph_data(self, lines: List[dict], symbols: List[dict], texts: List[dict]) -> Tuple[List[dict], List[dict]]: """Create graph nodes and edges from detections""" nodes = [] edges = [] # Debug input data self.logger.info("Creating graph data with:") self.logger.info(f"Lines: {len(lines)}") self.logger.info(f"Symbols: {len(symbols)}") self.logger.info(f"Texts: {len(texts)}") try: # Process symbols for symbol in symbols: bbox = symbol["bbox"] # bbox is a list [x1,y1,x2,y2] nodes.append({ "id": symbol["symbol_id"], "type": "symbol", "category": symbol.get("category", ""), "label": symbol.get("label", ""), "confidence": symbol.get("confidence", 0.0), "x": (bbox[0] + bbox[2]) / 2, # Use list indices "y": (bbox[1] + bbox[3]) / 2, # Use list indices "bbox": { # Convert to dict format for consistency "xmin": bbox[0], "ymin": bbox[1], "xmax": bbox[2], "ymax": bbox[3] } }) # Process texts for text in texts: bbox = text["bbox"] # bbox is a list [x1,y1,x2,y2] nodes.append({ "id": str(uuid.uuid4()), "type": "text", "content": text.get("text", ""), "confidence": text.get("confidence", 0.0), "x": (bbox[0] + bbox[2]) / 2, # Use list indices "y": (bbox[1] + bbox[3]) / 2, # Use list indices "bbox": { # Convert to dict format for consistency "xmin": bbox[0], "ymin": bbox[1], "xmax": bbox[2], "ymax": bbox[3] } }) # Process lines (unchanged) for line in lines: edges.append({ "id": str(uuid.uuid4()), "type": "line", "start_point": line["start_point"], "end_point": line["end_point"] }) except Exception as e: self.logger.error(f"Error processing data: {str(e)}") self.logger.error("Current symbol/text being processed: %s", json.dumps(symbol if 'symbol' in locals() else text, indent=2)) raise return nodes, edges def _validate_coordinates(self, data, data_type): """Validate coordinates in the data""" if not data: return False try: if data_type == 'line': # Check start and end points start = data.get('start_point', {}) end = data.get('end_point', {}) bbox = data.get('bbox', {}) required_fields = ['x', 'y', 'type'] if not all(field in start for field in required_fields): self.logger.warning(f"Missing required fields in start_point: {start}") return False if not all(field in end for field in required_fields): self.logger.warning(f"Missing required fields in end_point: {end}") return False # Validate bbox coordinates if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']): self.logger.warning(f"Invalid bbox format: {bbox}") return False # Check coordinate consistency if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']: self.logger.warning(f"Invalid bbox coordinates: {bbox}") return False elif data_type in ['symbol', 'text']: bbox = data.get('bbox', {}) if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']): self.logger.warning(f"Invalid {data_type} bbox format: {bbox}") return False # Check coordinate consistency if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']: self.logger.warning(f"Invalid {data_type} bbox coordinates: {bbox}") return False return True except Exception as e: self.logger.error(f"Validation error for {data_type}: {str(e)}") return False def aggregate_data(self, symbols_path: str, texts_path: str, lines_path: str) -> dict: """Aggregate detection results and create graph structure""" try: # Load line detection results lines_data = json.loads(self.storage.load_file(lines_path).decode('utf-8')) lines = self._parse_line_data(lines_data) self.logger.info(f"Loaded {len(lines)} lines") # Load and debug symbol detections symbols = [] if symbols_path and Path(symbols_path).exists(): with open(symbols_path, 'r') as f: symbols_data = json.load(f) # Debug symbol data structure self.logger.info("Symbol data keys: %s", list(symbols_data.keys())) self.logger.info("First symbol in detections: %s", json.dumps(symbols_data["detections"][0], indent=2)) symbols = symbols_data.get("detections", []) self.logger.info(f"Loaded {len(symbols)} symbols from {symbols_path}") # Debug first symbol structure if symbols: self.logger.info("First symbol keys: %s", list(symbols[0].keys())) self.logger.info("First symbol bbox: %s", symbols[0]["bbox"]) # Load and debug text detections texts = [] if texts_path and Path(texts_path).exists(): with open(texts_path, 'r') as f: texts_data = json.load(f) # Debug text data structure self.logger.info("Text data keys: %s", list(texts_data.keys())) self.logger.info("First text in detections: %s", json.dumps(texts_data["detections"][0], indent=2)) texts = texts_data.get("detections", []) self.logger.info(f"Loaded {len(texts)} texts from {texts_path}") # Debug first text structure if texts: self.logger.info("First text keys: %s", list(texts[0].keys())) self.logger.info("First text bbox: %s", texts[0]["bbox"]) # Create graph data nodes, edges = self._create_graph_data(lines, symbols, texts) self.logger.info(f"Created graph with {len(nodes)} nodes and {len(edges)} edges") return { "lines": lines, "symbols": symbols, "texts": texts, "nodes": nodes, "edges": edges, "metadata": { "timestamp": datetime.now().isoformat(), "version": "2.0" } } except Exception as e: self.logger.error(f"Error during aggregation: {str(e)}") self.logger.error("Stack trace:", exc_info=True) # Add full stack trace raise def _draw_aggregated_view(self, image: np.ndarray, results: dict) -> np.ndarray: """Draw all detections on image""" annotated = image.copy() # Draw lines (green) for line in results.get('lines', []): try: cv2.line(annotated, (line['start_point']['x'], line['start_point']['y']), (line['end_point']['x'], line['end_point']['y']), (0, 255, 0), 2) except Exception as e: self.logger.warning(f"Skipping invalid line: {str(e)}") continue # Draw symbols (cyan) - Fix bbox access for symbol in results.get('symbols', []): try: bbox = symbol['bbox'] # bbox is a list [x1,y1,x2,y2], not a dict cv2.rectangle(annotated, (bbox[0], bbox[1]), # Use list indices (bbox[2], bbox[3]), # Use list indices (255, 255, 0), 2) except Exception as e: self.logger.warning(f"Skipping invalid symbol: {str(e)}") continue # Draw texts (purple) - Fix bbox access for text in results.get('texts', []): try: bbox = text['bbox'] # bbox is a list [x1,y1,x2,y2], not a dict cv2.rectangle(annotated, (bbox[0], bbox[1]), # Use list indices (bbox[2], bbox[3]), # Use list indices (128, 0, 128), 2) except Exception as e: self.logger.warning(f"Skipping invalid text: {str(e)}") continue return annotated def process_data(self, image_path: str, output_dir: str, symbols_path: str, texts_path: str, lines_path: str): try: self.logger.info(f"Processing data with:") self.logger.info(f"- Image: {image_path}") self.logger.info(f"- Symbols: {symbols_path}") self.logger.info(f"- Texts: {texts_path}") self.logger.info(f"- Lines: {lines_path}") base_name = Path(image_path).stem self.logger.info(f"Base name: {base_name}") aggregated_json = os.path.join(output_dir, f"{base_name}_aggregated.json") self.logger.info(f"Will save aggregated data to: {aggregated_json}") results = self.aggregate_data(symbols_path, texts_path, lines_path) self.logger.info("Data aggregation completed") with open(aggregated_json, 'w') as f: json.dump(results, f, indent=2) self.logger.info(f"Saved aggregated JSON to: {aggregated_json}") # Create visualization using original image image = cv2.imread(image_path) annotated = self._draw_aggregated_view(image, results) aggregated_image = os.path.join(output_dir, f"{base_name}_aggregated.png") cv2.imwrite(aggregated_image, annotated) # Return paths like other detectors return { 'success': True, 'image_path': aggregated_image, 'json_path': aggregated_json } except Exception as e: self.logger.error(f"Error in data aggregation: {str(e)}") return { 'success': False, 'error': str(e) } if __name__ == "__main__": import os from pprint import pprint # Initialize the aggregator aggregator = DataAggregator() # Test paths using actual files in results folder results_dir = "results" base_name = "002_page_1" # Input paths symbols_path = os.path.join(results_dir, f"{base_name}_detected_symbols.json") texts_path = os.path.join(results_dir, f"{base_name}_detected_texts.json") lines_path = os.path.join(results_dir, f"{base_name}_detected_lines.json") # Verify files exist print(f"\nChecking input files:") print(f"Symbols file exists: {os.path.exists(symbols_path)}") print(f"Texts file exists: {os.path.exists(texts_path)}") print(f"Lines file exists: {os.path.exists(lines_path)}") try: # Process the data print("\nProcessing data...") result = aggregator.process_data( image_path=os.path.join(results_dir, f"{base_name}.png"), output_dir=results_dir, symbols_path=symbols_path, texts_path=texts_path, lines_path=lines_path ) # Verify output files aggregated_json = os.path.join(results_dir, f"{base_name}_aggregated.json") aggregated_image = os.path.join(results_dir, f"{base_name}_aggregated.png") print("\nChecking output files:") print(f"Aggregated JSON exists: {os.path.exists(aggregated_json)}") print(f"Aggregated image exists: {os.path.exists(aggregated_image)}") # Load and print statistics from aggregated result if os.path.exists(aggregated_json): with open(aggregated_json, 'r') as f: data = json.load(f) print("\nAggregation Results:") print(f"Number of Symbols: {len(data.get('symbols', []))}") print(f"Number of Texts: {len(data.get('texts', []))}") print(f"Number of Lines: {len(data.get('lines', []))}") print(f"Number of Nodes: {len(data.get('nodes', []))}") print(f"Number of Edges: {len(data.get('edges', []))}") except Exception as e: print(f"\nError during testing: {str(e)}") traceback.print_exc()