Spaces:
Build error
Build error
| from pathlib import Path | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from typing import List, Dict, Optional, Tuple | |
| from storage import StorageFactory | |
| import uuid | |
| import traceback | |
| logger = logging.getLogger(__name__) | |
| class DataAggregator: | |
| def __init__(self, storage=None): | |
| self.storage = storage or StorageFactory.get_storage() | |
| self.logger = logging.getLogger(__name__) | |
| def _parse_line_data(self, lines_data: dict) -> List[dict]: | |
| """Parse line detection data with coordinate validation""" | |
| parsed_lines = [] | |
| for line in lines_data.get("lines", []): | |
| try: | |
| # Extract and validate line coordinates | |
| start_coords = line["start"]["coords"] | |
| end_coords = line["end"]["coords"] | |
| bbox = line["bbox"] | |
| # Validate coordinates | |
| if not (self._is_valid_point(start_coords) and | |
| self._is_valid_point(end_coords) and | |
| self._is_valid_bbox(bbox)): | |
| self.logger.warning(f"Invalid coordinates in line: {line['id']}") | |
| continue | |
| # Create parsed line with validated coordinates | |
| parsed_line = { | |
| "id": line["id"], | |
| "start_point": { | |
| "x": int(start_coords["x"]), | |
| "y": int(start_coords["y"]), | |
| "type": line["start"]["type"], | |
| "confidence": line["start"]["confidence"] | |
| }, | |
| "end_point": { | |
| "x": int(end_coords["x"]), | |
| "y": int(end_coords["y"]), | |
| "type": line["end"]["type"], | |
| "confidence": line["end"]["confidence"] | |
| }, | |
| "bbox": { | |
| "xmin": int(bbox["xmin"]), | |
| "ymin": int(bbox["ymin"]), | |
| "xmax": int(bbox["xmax"]), | |
| "ymax": int(bbox["ymax"]) | |
| }, | |
| "style": line["style"], | |
| "confidence": line["confidence"] | |
| } | |
| parsed_lines.append(parsed_line) | |
| except Exception as e: | |
| self.logger.error(f"Error parsing line {line.get('id')}: {str(e)}") | |
| continue | |
| return parsed_lines | |
| def _is_valid_point(self, point: dict) -> bool: | |
| """Validate point coordinates""" | |
| try: | |
| x, y = point.get("x"), point.get("y") | |
| return (isinstance(x, (int, float)) and | |
| isinstance(y, (int, float)) and | |
| 0 <= x <= 10000 and 0 <= y <= 10000) # Adjust range as needed | |
| except: | |
| return False | |
| def _is_valid_bbox(self, bbox: dict) -> bool: | |
| """Validate bbox coordinates""" | |
| try: | |
| xmin = bbox.get("xmin") | |
| ymin = bbox.get("ymin") | |
| xmax = bbox.get("xmax") | |
| ymax = bbox.get("ymax") | |
| return (isinstance(xmin, (int, float)) and | |
| isinstance(ymin, (int, float)) and | |
| isinstance(xmax, (int, float)) and | |
| isinstance(ymax, (int, float)) and | |
| xmin < xmax and ymin < ymax and | |
| 0 <= xmin <= 10000 and 0 <= ymin <= 10000 and | |
| 0 <= xmax <= 10000 and 0 <= ymax <= 10000) | |
| except: | |
| return False | |
| def _create_graph_data(self, lines: List[dict], symbols: List[dict], texts: List[dict]) -> Tuple[List[dict], List[dict]]: | |
| """Create nodes and edges for the knowledge graph following the three-step process""" | |
| nodes = [] | |
| edges = [] | |
| # Step 1: Create Object Nodes with their properties and center points | |
| # 1a. Symbol Nodes | |
| for symbol in symbols: | |
| bbox = symbol["bbox"] | |
| center_x = (bbox["xmin"] + bbox["xmax"]) / 2 | |
| center_y = (bbox["ymin"] + bbox["ymax"]) / 2 | |
| node = { | |
| "id": symbol.get("id", str(uuid.uuid4())), | |
| "type": "symbol", | |
| "category": symbol.get("category", "unknown"), | |
| "bbox": bbox, | |
| "center": {"x": center_x, "y": center_y}, | |
| "confidence": symbol.get("confidence", 1.0), | |
| "properties": { | |
| "class": symbol.get("class", ""), | |
| "equipment_type": symbol.get("equipment_type", ""), | |
| "original_label": symbol.get("original_label", ""), | |
| } | |
| } | |
| nodes.append(node) | |
| # 1b. Text Nodes | |
| for text in texts: | |
| bbox = text["bbox"] | |
| center_x = (bbox["xmin"] + bbox["xmax"]) / 2 | |
| center_y = (bbox["ymin"] + bbox["ymax"]) / 2 | |
| node = { | |
| "id": text.get("id", str(uuid.uuid4())), | |
| "type": "text", | |
| "content": text.get("text", ""), | |
| "bbox": bbox, | |
| "center": {"x": center_x, "y": center_y}, | |
| "confidence": text.get("confidence", 1.0), | |
| "properties": { | |
| "font_size": text.get("font_size"), | |
| "rotation": text.get("rotation", 0.0), | |
| "text_type": text.get("text_type", "unknown") | |
| } | |
| } | |
| nodes.append(node) | |
| # Step 2: Create Junction Nodes (T/L connections) | |
| junction_map = {} # To track junctions for edge creation | |
| for line in lines: | |
| # Check start point | |
| if line["start_point"].get("type") in ["T", "L"]: | |
| junction_id = str(uuid.uuid4()) | |
| junction_node = { | |
| "id": junction_id, | |
| "type": "junction", | |
| "junction_type": line["start_point"]["type"], | |
| "coords": { | |
| "x": line["start_point"]["x"], | |
| "y": line["start_point"]["y"] | |
| }, | |
| "properties": { | |
| "confidence": line["start_point"].get("confidence", 1.0) | |
| } | |
| } | |
| nodes.append(junction_node) | |
| junction_map[f"{line['start_point']['x']}_{line['start_point']['y']}"] = junction_id | |
| # Check end point | |
| if line["end_point"].get("type") in ["T", "L"]: | |
| junction_id = str(uuid.uuid4()) | |
| junction_node = { | |
| "id": junction_id, | |
| "type": "junction", | |
| "junction_type": line["end_point"]["type"], | |
| "coords": { | |
| "x": line["end_point"]["x"], | |
| "y": line["end_point"]["y"] | |
| }, | |
| "properties": { | |
| "confidence": line["end_point"].get("confidence", 1.0) | |
| } | |
| } | |
| nodes.append(junction_node) | |
| junction_map[f"{line['end_point']['x']}_{line['end_point']['y']}"] = junction_id | |
| # Step 3: Create Edges with connection points and topology | |
| # 3a. Line-Junction Connections | |
| for line in lines: | |
| line_id = line.get("id", str(uuid.uuid4())) | |
| start_key = f"{line['start_point']['x']}_{line['start_point']['y']}" | |
| end_key = f"{line['end_point']['x']}_{line['end_point']['y']}" | |
| # Create edge for line itself | |
| edge = { | |
| "id": line_id, | |
| "type": "line", | |
| "source": junction_map.get(start_key, str(uuid.uuid4())), | |
| "target": junction_map.get(end_key, str(uuid.uuid4())), | |
| "properties": { | |
| "style": line["style"], | |
| "confidence": line.get("confidence", 1.0), | |
| "connection_points": { | |
| "start": {"x": line["start_point"]["x"], "y": line["start_point"]["y"]}, | |
| "end": {"x": line["end_point"]["x"], "y": line["end_point"]["y"]} | |
| }, | |
| "bbox": line["bbox"] | |
| } | |
| } | |
| edges.append(edge) | |
| # 3b. Symbol-Line Connections (based on spatial proximity) | |
| for symbol in symbols: | |
| symbol_center = { | |
| "x": (symbol["bbox"]["xmin"] + symbol["bbox"]["xmax"]) / 2, | |
| "y": (symbol["bbox"]["ymin"] + symbol["bbox"]["ymax"]) / 2 | |
| } | |
| # Find connected lines based on proximity to endpoints | |
| for line in lines: | |
| # Check if line endpoints are near symbol center | |
| for point_type in ["start_point", "end_point"]: | |
| point = line[point_type] | |
| dist = ((point["x"] - symbol_center["x"])**2 + | |
| (point["y"] - symbol_center["y"])**2)**0.5 | |
| if dist < 50: # Threshold for connection, adjust as needed | |
| edge = { | |
| "id": str(uuid.uuid4()), | |
| "type": "symbol_line_connection", | |
| "source": symbol["id"], | |
| "target": line["id"], | |
| "properties": { | |
| "connection_point": {"x": point["x"], "y": point["y"]}, | |
| "connection_type": point_type, | |
| "distance": dist | |
| } | |
| } | |
| edges.append(edge) | |
| # 3c. Symbol-Text Associations (based on proximity and containment) | |
| for text in texts: | |
| text_center = { | |
| "x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2, | |
| "y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2 | |
| } | |
| for symbol in symbols: | |
| # Check if text is near or contained within symbol | |
| if (text_center["x"] >= symbol["bbox"]["xmin"] - 20 and | |
| text_center["x"] <= symbol["bbox"]["xmax"] + 20 and | |
| text_center["y"] >= symbol["bbox"]["ymin"] - 20 and | |
| text_center["y"] <= symbol["bbox"]["ymax"] + 20): | |
| edge = { | |
| "id": str(uuid.uuid4()), | |
| "type": "symbol_text_association", | |
| "source": symbol["id"], | |
| "target": text["id"], | |
| "properties": { | |
| "association_type": "label", | |
| "confidence": min(symbol.get("confidence", 1.0), | |
| text.get("confidence", 1.0)) | |
| } | |
| } | |
| edges.append(edge) | |
| # 3d. Line-Text Associations (based on proximity and alignment) | |
| for text in texts: | |
| text_center = { | |
| "x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2, | |
| "y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2 | |
| } | |
| text_bbox = text["bbox"] | |
| for line in lines: | |
| line_bbox = line["bbox"] | |
| line_center = { | |
| "x": (line_bbox["xmin"] + line_bbox["xmax"]) / 2, | |
| "y": (line_bbox["ymin"] + line_bbox["ymax"]) / 2 | |
| } | |
| # Check if text is near the line (using both center and bbox) | |
| is_nearby_horizontal = ( | |
| abs(text_center["y"] - line_center["y"]) < 30 and # Vertical proximity | |
| text_bbox["xmin"] <= line_bbox["xmax"] and | |
| text_bbox["xmax"] >= line_bbox["xmin"] | |
| ) | |
| is_nearby_vertical = ( | |
| abs(text_center["x"] - line_center["x"]) < 30 and # Horizontal proximity | |
| text_bbox["ymin"] <= line_bbox["ymax"] and | |
| text_bbox["ymax"] >= line_bbox["ymin"] | |
| ) | |
| # Determine text type and position relative to line | |
| if is_nearby_horizontal or is_nearby_vertical: | |
| text_type = text.get("text_type", "unknown").lower() | |
| # Classify the text based on content and position | |
| if any(pattern in text.get("text", "").upper() | |
| for pattern in ["-", "LINE", "PIPE"]): | |
| association_type = "line_id" | |
| else: | |
| association_type = "description" | |
| edge = { | |
| "id": str(uuid.uuid4()), | |
| "type": "line_text_association", | |
| "source": line["id"], | |
| "target": text["id"], | |
| "properties": { | |
| "association_type": association_type, | |
| "relative_position": "horizontal" if is_nearby_horizontal else "vertical", | |
| "confidence": min(line.get("confidence", 1.0), | |
| text.get("confidence", 1.0)), | |
| "distance": abs(text_center["y"] - line_center["y"]) if is_nearby_horizontal | |
| else abs(text_center["x"] - line_center["x"]) | |
| } | |
| } | |
| edges.append(edge) | |
| return nodes, edges | |
| def _validate_coordinates(self, data, data_type): | |
| """Validate coordinates in the data""" | |
| if not data: | |
| return False | |
| try: | |
| if data_type == 'line': | |
| # Check start and end points | |
| start = data.get('start_point', {}) | |
| end = data.get('end_point', {}) | |
| bbox = data.get('bbox', {}) | |
| required_fields = ['x', 'y', 'type'] | |
| if not all(field in start for field in required_fields): | |
| self.logger.warning(f"Missing required fields in start_point: {start}") | |
| return False | |
| if not all(field in end for field in required_fields): | |
| self.logger.warning(f"Missing required fields in end_point: {end}") | |
| return False | |
| # Validate bbox coordinates | |
| if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']): | |
| self.logger.warning(f"Invalid bbox format: {bbox}") | |
| return False | |
| # Check coordinate consistency | |
| if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']: | |
| self.logger.warning(f"Invalid bbox coordinates: {bbox}") | |
| return False | |
| elif data_type in ['symbol', 'text']: | |
| bbox = data.get('bbox', {}) | |
| if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']): | |
| self.logger.warning(f"Invalid {data_type} bbox format: {bbox}") | |
| return False | |
| # Check coordinate consistency | |
| if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']: | |
| self.logger.warning(f"Invalid {data_type} bbox coordinates: {bbox}") | |
| return False | |
| return True | |
| except Exception as e: | |
| self.logger.error(f"Validation error for {data_type}: {str(e)}") | |
| return False | |
| def aggregate_data(self, symbols_path: str, texts_path: str, lines_path: str) -> dict: | |
| """Aggregate detection results and create graph structure""" | |
| try: | |
| # Load line detection results | |
| lines_data = json.loads(self.storage.load_file(lines_path).decode('utf-8')) | |
| lines = self._parse_line_data(lines_data) | |
| # Load symbol detections | |
| symbols = [] | |
| if symbols_path and Path(symbols_path).exists(): | |
| symbols_data = json.loads(self.storage.load_file(symbols_path).decode('utf-8')) | |
| symbols = symbols_data.get("symbols", []) | |
| # Load text detections | |
| texts = [] | |
| if texts_path and Path(texts_path).exists(): | |
| texts_data = json.loads(self.storage.load_file(texts_path).decode('utf-8')) | |
| texts = texts_data.get("texts", []) | |
| # Create graph data | |
| nodes, edges = self._create_graph_data(lines, symbols, texts) | |
| # Combine all detections | |
| aggregated_data = { | |
| "lines": lines, | |
| "symbols": symbols, | |
| "texts": texts, | |
| "nodes": nodes, | |
| "edges": edges, | |
| "metadata": { | |
| "timestamp": datetime.now().isoformat(), | |
| "version": "2.0" | |
| } | |
| } | |
| return aggregated_data | |
| except Exception as e: | |
| logger.error(f"Error during aggregation: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| import os | |
| from pprint import pprint | |
| # Initialize the aggregator | |
| aggregator = DataAggregator() | |
| # Test paths (adjust these to match your results folder) | |
| results_dir = "results/" | |
| symbols_path = os.path.join(results_dir, "0_text_detected_symbols.json") | |
| texts_path = os.path.join(results_dir, "0_text_detected_texts.json") | |
| lines_path = os.path.join(results_dir, "0_text_detected_lines.json") | |
| try: | |
| # Aggregate the data | |
| aggregated_data = aggregator.aggregate_data( | |
| symbols_path=symbols_path, | |
| texts_path=texts_path, | |
| lines_path=lines_path | |
| ) | |
| # Save the aggregated result | |
| output_path = os.path.join(results_dir, "0_aggregated_test.json") | |
| with open(output_path, 'w') as f: | |
| json.dump(aggregated_data, f, indent=2) | |
| # Print some statistics | |
| print("\nAggregation Results:") | |
| print(f"Number of Symbols: {len(aggregated_data['symbols'])}") | |
| print(f"Number of Texts: {len(aggregated_data['texts'])}") | |
| print(f"Number of Lines: {len(aggregated_data['lines'])}") | |
| print(f"Number of Nodes: {len(aggregated_data['nodes'])}") | |
| print(f"Number of Edges: {len(aggregated_data['edges'])}") | |
| # Print sample of each type | |
| print("\nSample Node:") | |
| if aggregated_data['nodes']: | |
| pprint(aggregated_data['nodes'][0]) | |
| print("\nSample Edge:") | |
| if aggregated_data['edges']: | |
| pprint(aggregated_data['edges'][0]) | |
| print(f"\nAggregated data saved to: {output_path}") | |
| except Exception as e: | |
| print(f"Error during testing: {str(e)}") | |
| traceback.print_exc() | |