from pathlib import Path import json import logging from datetime import datetime from typing import List, Dict, Optional, Tuple from storage import StorageFactory import uuid import traceback logger = logging.getLogger(__name__) class DataAggregator: def __init__(self, storage=None): self.storage = storage or StorageFactory.get_storage() self.logger = logging.getLogger(__name__) def _parse_line_data(self, lines_data: dict) -> List[dict]: """Parse line detection data with coordinate validation""" parsed_lines = [] for line in lines_data.get("lines", []): try: # Extract and validate line coordinates start_coords = line["start"]["coords"] end_coords = line["end"]["coords"] bbox = line["bbox"] # Validate coordinates if not (self._is_valid_point(start_coords) and self._is_valid_point(end_coords) and self._is_valid_bbox(bbox)): self.logger.warning(f"Invalid coordinates in line: {line['id']}") continue # Create parsed line with validated coordinates parsed_line = { "id": line["id"], "start_point": { "x": int(start_coords["x"]), "y": int(start_coords["y"]), "type": line["start"]["type"], "confidence": line["start"]["confidence"] }, "end_point": { "x": int(end_coords["x"]), "y": int(end_coords["y"]), "type": line["end"]["type"], "confidence": line["end"]["confidence"] }, "bbox": { "xmin": int(bbox["xmin"]), "ymin": int(bbox["ymin"]), "xmax": int(bbox["xmax"]), "ymax": int(bbox["ymax"]) }, "style": line["style"], "confidence": line["confidence"] } parsed_lines.append(parsed_line) except Exception as e: self.logger.error(f"Error parsing line {line.get('id')}: {str(e)}") continue return parsed_lines def _is_valid_point(self, point: dict) -> bool: """Validate point coordinates""" try: x, y = point.get("x"), point.get("y") return (isinstance(x, (int, float)) and isinstance(y, (int, float)) and 0 <= x <= 10000 and 0 <= y <= 10000) # Adjust range as needed except: return False def _is_valid_bbox(self, bbox: dict) -> bool: """Validate bbox coordinates""" try: xmin = bbox.get("xmin") ymin = bbox.get("ymin") xmax = bbox.get("xmax") ymax = bbox.get("ymax") return (isinstance(xmin, (int, float)) and isinstance(ymin, (int, float)) and isinstance(xmax, (int, float)) and isinstance(ymax, (int, float)) and xmin < xmax and ymin < ymax and 0 <= xmin <= 10000 and 0 <= ymin <= 10000 and 0 <= xmax <= 10000 and 0 <= ymax <= 10000) except: return False def _create_graph_data(self, lines: List[dict], symbols: List[dict], texts: List[dict]) -> Tuple[List[dict], List[dict]]: """Create nodes and edges for the knowledge graph following the three-step process""" nodes = [] edges = [] # Step 1: Create Object Nodes with their properties and center points # 1a. Symbol Nodes for symbol in symbols: bbox = symbol["bbox"] center_x = (bbox["xmin"] + bbox["xmax"]) / 2 center_y = (bbox["ymin"] + bbox["ymax"]) / 2 node = { "id": symbol.get("id", str(uuid.uuid4())), "type": "symbol", "category": symbol.get("category", "unknown"), "bbox": bbox, "center": {"x": center_x, "y": center_y}, "confidence": symbol.get("confidence", 1.0), "properties": { "class": symbol.get("class", ""), "equipment_type": symbol.get("equipment_type", ""), "original_label": symbol.get("original_label", ""), } } nodes.append(node) # 1b. Text Nodes for text in texts: bbox = text["bbox"] center_x = (bbox["xmin"] + bbox["xmax"]) / 2 center_y = (bbox["ymin"] + bbox["ymax"]) / 2 node = { "id": text.get("id", str(uuid.uuid4())), "type": "text", "content": text.get("text", ""), "bbox": bbox, "center": {"x": center_x, "y": center_y}, "confidence": text.get("confidence", 1.0), "properties": { "font_size": text.get("font_size"), "rotation": text.get("rotation", 0.0), "text_type": text.get("text_type", "unknown") } } nodes.append(node) # Step 2: Create Junction Nodes (T/L connections) junction_map = {} # To track junctions for edge creation for line in lines: # Check start point if line["start_point"].get("type") in ["T", "L"]: junction_id = str(uuid.uuid4()) junction_node = { "id": junction_id, "type": "junction", "junction_type": line["start_point"]["type"], "coords": { "x": line["start_point"]["x"], "y": line["start_point"]["y"] }, "properties": { "confidence": line["start_point"].get("confidence", 1.0) } } nodes.append(junction_node) junction_map[f"{line['start_point']['x']}_{line['start_point']['y']}"] = junction_id # Check end point if line["end_point"].get("type") in ["T", "L"]: junction_id = str(uuid.uuid4()) junction_node = { "id": junction_id, "type": "junction", "junction_type": line["end_point"]["type"], "coords": { "x": line["end_point"]["x"], "y": line["end_point"]["y"] }, "properties": { "confidence": line["end_point"].get("confidence", 1.0) } } nodes.append(junction_node) junction_map[f"{line['end_point']['x']}_{line['end_point']['y']}"] = junction_id # Step 3: Create Edges with connection points and topology # 3a. Line-Junction Connections for line in lines: line_id = line.get("id", str(uuid.uuid4())) start_key = f"{line['start_point']['x']}_{line['start_point']['y']}" end_key = f"{line['end_point']['x']}_{line['end_point']['y']}" # Create edge for line itself edge = { "id": line_id, "type": "line", "source": junction_map.get(start_key, str(uuid.uuid4())), "target": junction_map.get(end_key, str(uuid.uuid4())), "properties": { "style": line["style"], "confidence": line.get("confidence", 1.0), "connection_points": { "start": {"x": line["start_point"]["x"], "y": line["start_point"]["y"]}, "end": {"x": line["end_point"]["x"], "y": line["end_point"]["y"]} }, "bbox": line["bbox"] } } edges.append(edge) # 3b. Symbol-Line Connections (based on spatial proximity) for symbol in symbols: symbol_center = { "x": (symbol["bbox"]["xmin"] + symbol["bbox"]["xmax"]) / 2, "y": (symbol["bbox"]["ymin"] + symbol["bbox"]["ymax"]) / 2 } # Find connected lines based on proximity to endpoints for line in lines: # Check if line endpoints are near symbol center for point_type in ["start_point", "end_point"]: point = line[point_type] dist = ((point["x"] - symbol_center["x"])**2 + (point["y"] - symbol_center["y"])**2)**0.5 if dist < 50: # Threshold for connection, adjust as needed edge = { "id": str(uuid.uuid4()), "type": "symbol_line_connection", "source": symbol["id"], "target": line["id"], "properties": { "connection_point": {"x": point["x"], "y": point["y"]}, "connection_type": point_type, "distance": dist } } edges.append(edge) # 3c. Symbol-Text Associations (based on proximity and containment) for text in texts: text_center = { "x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2, "y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2 } for symbol in symbols: # Check if text is near or contained within symbol if (text_center["x"] >= symbol["bbox"]["xmin"] - 20 and text_center["x"] <= symbol["bbox"]["xmax"] + 20 and text_center["y"] >= symbol["bbox"]["ymin"] - 20 and text_center["y"] <= symbol["bbox"]["ymax"] + 20): edge = { "id": str(uuid.uuid4()), "type": "symbol_text_association", "source": symbol["id"], "target": text["id"], "properties": { "association_type": "label", "confidence": min(symbol.get("confidence", 1.0), text.get("confidence", 1.0)) } } edges.append(edge) # 3d. Line-Text Associations (based on proximity and alignment) for text in texts: text_center = { "x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2, "y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2 } text_bbox = text["bbox"] for line in lines: line_bbox = line["bbox"] line_center = { "x": (line_bbox["xmin"] + line_bbox["xmax"]) / 2, "y": (line_bbox["ymin"] + line_bbox["ymax"]) / 2 } # Check if text is near the line (using both center and bbox) is_nearby_horizontal = ( abs(text_center["y"] - line_center["y"]) < 30 and # Vertical proximity text_bbox["xmin"] <= line_bbox["xmax"] and text_bbox["xmax"] >= line_bbox["xmin"] ) is_nearby_vertical = ( abs(text_center["x"] - line_center["x"]) < 30 and # Horizontal proximity text_bbox["ymin"] <= line_bbox["ymax"] and text_bbox["ymax"] >= line_bbox["ymin"] ) # Determine text type and position relative to line if is_nearby_horizontal or is_nearby_vertical: text_type = text.get("text_type", "unknown").lower() # Classify the text based on content and position if any(pattern in text.get("text", "").upper() for pattern in ["-", "LINE", "PIPE"]): association_type = "line_id" else: association_type = "description" edge = { "id": str(uuid.uuid4()), "type": "line_text_association", "source": line["id"], "target": text["id"], "properties": { "association_type": association_type, "relative_position": "horizontal" if is_nearby_horizontal else "vertical", "confidence": min(line.get("confidence", 1.0), text.get("confidence", 1.0)), "distance": abs(text_center["y"] - line_center["y"]) if is_nearby_horizontal else abs(text_center["x"] - line_center["x"]) } } edges.append(edge) return nodes, edges def _validate_coordinates(self, data, data_type): """Validate coordinates in the data""" if not data: return False try: if data_type == 'line': # Check start and end points start = data.get('start_point', {}) end = data.get('end_point', {}) bbox = data.get('bbox', {}) required_fields = ['x', 'y', 'type'] if not all(field in start for field in required_fields): self.logger.warning(f"Missing required fields in start_point: {start}") return False if not all(field in end for field in required_fields): self.logger.warning(f"Missing required fields in end_point: {end}") return False # Validate bbox coordinates if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']): self.logger.warning(f"Invalid bbox format: {bbox}") return False # Check coordinate consistency if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']: self.logger.warning(f"Invalid bbox coordinates: {bbox}") return False elif data_type in ['symbol', 'text']: bbox = data.get('bbox', {}) if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']): self.logger.warning(f"Invalid {data_type} bbox format: {bbox}") return False # Check coordinate consistency if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']: self.logger.warning(f"Invalid {data_type} bbox coordinates: {bbox}") return False return True except Exception as e: self.logger.error(f"Validation error for {data_type}: {str(e)}") return False def aggregate_data(self, symbols_path: str, texts_path: str, lines_path: str) -> dict: """Aggregate detection results and create graph structure""" try: # Load line detection results lines_data = json.loads(self.storage.load_file(lines_path).decode('utf-8')) lines = self._parse_line_data(lines_data) # Load symbol detections symbols = [] if symbols_path and Path(symbols_path).exists(): symbols_data = json.loads(self.storage.load_file(symbols_path).decode('utf-8')) symbols = symbols_data.get("symbols", []) # Load text detections texts = [] if texts_path and Path(texts_path).exists(): texts_data = json.loads(self.storage.load_file(texts_path).decode('utf-8')) texts = texts_data.get("texts", []) # Create graph data nodes, edges = self._create_graph_data(lines, symbols, texts) # Combine all detections aggregated_data = { "lines": lines, "symbols": symbols, "texts": texts, "nodes": nodes, "edges": edges, "metadata": { "timestamp": datetime.now().isoformat(), "version": "2.0" } } return aggregated_data except Exception as e: logger.error(f"Error during aggregation: {str(e)}") raise if __name__ == "__main__": import os from pprint import pprint # Initialize the aggregator aggregator = DataAggregator() # Test paths (adjust these to match your results folder) results_dir = "results/" symbols_path = os.path.join(results_dir, "0_text_detected_symbols.json") texts_path = os.path.join(results_dir, "0_text_detected_texts.json") lines_path = os.path.join(results_dir, "0_text_detected_lines.json") try: # Aggregate the data aggregated_data = aggregator.aggregate_data( symbols_path=symbols_path, texts_path=texts_path, lines_path=lines_path ) # Save the aggregated result output_path = os.path.join(results_dir, "0_aggregated_test.json") with open(output_path, 'w') as f: json.dump(aggregated_data, f, indent=2) # Print some statistics print("\nAggregation Results:") print(f"Number of Symbols: {len(aggregated_data['symbols'])}") print(f"Number of Texts: {len(aggregated_data['texts'])}") print(f"Number of Lines: {len(aggregated_data['lines'])}") print(f"Number of Nodes: {len(aggregated_data['nodes'])}") print(f"Number of Edges: {len(aggregated_data['edges'])}") # Print sample of each type print("\nSample Node:") if aggregated_data['nodes']: pprint(aggregated_data['nodes'][0]) print("\nSample Edge:") if aggregated_data['edges']: pprint(aggregated_data['edges'][0]) print(f"\nAggregated data saved to: {output_path}") except Exception as e: print(f"Error during testing: {str(e)}") traceback.print_exc()