Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from typing import List, Dict, Optional, Tuple | |
| from storage import StorageFactory | |
| import uuid | |
| import traceback | |
| import os | |
| import cv2 | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class DataAggregator: | |
| def __init__(self, storage=None): | |
| self.storage = storage or StorageFactory.get_storage() | |
| self.logger = logging.getLogger(__name__) | |
| def _parse_line_data(self, lines_data: dict) -> List[dict]: | |
| """Parse line detection data with coordinate validation""" | |
| parsed_lines = [] | |
| for line in lines_data.get("lines", []): | |
| try: | |
| # Extract and validate line coordinates | |
| start_coords = line["start"]["coords"] | |
| end_coords = line["end"]["coords"] | |
| bbox = line["bbox"] | |
| # Validate coordinates | |
| if not (self._is_valid_point(start_coords) and | |
| self._is_valid_point(end_coords) and | |
| self._is_valid_bbox(bbox)): | |
| self.logger.warning(f"Invalid coordinates in line: {line['id']}") | |
| continue | |
| # Create parsed line with validated coordinates | |
| parsed_line = { | |
| "id": line["id"], | |
| "start_point": { | |
| "x": int(start_coords["x"]), | |
| "y": int(start_coords["y"]), | |
| "type": line["start"]["type"], | |
| "confidence": line["start"]["confidence"] | |
| }, | |
| "end_point": { | |
| "x": int(end_coords["x"]), | |
| "y": int(end_coords["y"]), | |
| "type": line["end"]["type"], | |
| "confidence": line["end"]["confidence"] | |
| }, | |
| "bbox": { | |
| "xmin": int(bbox["xmin"]), | |
| "ymin": int(bbox["ymin"]), | |
| "xmax": int(bbox["xmax"]), | |
| "ymax": int(bbox["ymax"]) | |
| }, | |
| "style": line["style"], | |
| "confidence": line["confidence"] | |
| } | |
| parsed_lines.append(parsed_line) | |
| except Exception as e: | |
| self.logger.error(f"Error parsing line {line.get('id')}: {str(e)}") | |
| continue | |
| return parsed_lines | |
| def _is_valid_point(self, point: dict) -> bool: | |
| """Validate point coordinates""" | |
| try: | |
| x, y = point.get("x"), point.get("y") | |
| return (isinstance(x, (int, float)) and | |
| isinstance(y, (int, float)) and | |
| 0 <= x <= 10000 and 0 <= y <= 10000) # Adjust range as needed | |
| except: | |
| return False | |
| def _is_valid_bbox(self, bbox: dict) -> bool: | |
| """Validate bbox coordinates""" | |
| try: | |
| xmin = bbox.get("xmin") | |
| ymin = bbox.get("ymin") | |
| xmax = bbox.get("xmax") | |
| ymax = bbox.get("ymax") | |
| return (isinstance(xmin, (int, float)) and | |
| isinstance(ymin, (int, float)) and | |
| isinstance(xmax, (int, float)) and | |
| isinstance(ymax, (int, float)) and | |
| xmin < xmax and ymin < ymax and | |
| 0 <= xmin <= 10000 and 0 <= ymin <= 10000 and | |
| 0 <= xmax <= 10000 and 0 <= ymax <= 10000) | |
| except: | |
| return False | |
| def _create_graph_data(self, lines: List[dict], symbols: List[dict], texts: List[dict]) -> Tuple[List[dict], List[dict]]: | |
| """Create graph nodes and edges from detections""" | |
| nodes = [] | |
| edges = [] | |
| # Debug input data | |
| self.logger.info("Creating graph data with:") | |
| self.logger.info(f"Lines: {len(lines)}") | |
| self.logger.info(f"Symbols: {len(symbols)}") | |
| self.logger.info(f"Texts: {len(texts)}") | |
| try: | |
| # Process symbols | |
| for symbol in symbols: | |
| bbox = symbol["bbox"] # bbox is a list [x1,y1,x2,y2] | |
| nodes.append({ | |
| "id": symbol["symbol_id"], | |
| "type": "symbol", | |
| "category": symbol.get("category", ""), | |
| "label": symbol.get("label", ""), | |
| "confidence": symbol.get("confidence", 0.0), | |
| "x": (bbox[0] + bbox[2]) / 2, # Use list indices | |
| "y": (bbox[1] + bbox[3]) / 2, # Use list indices | |
| "bbox": { # Convert to dict format for consistency | |
| "xmin": bbox[0], | |
| "ymin": bbox[1], | |
| "xmax": bbox[2], | |
| "ymax": bbox[3] | |
| } | |
| }) | |
| # Process texts | |
| for text in texts: | |
| bbox = text["bbox"] # bbox is a list [x1,y1,x2,y2] | |
| nodes.append({ | |
| "id": str(uuid.uuid4()), | |
| "type": "text", | |
| "content": text.get("text", ""), | |
| "confidence": text.get("confidence", 0.0), | |
| "x": (bbox[0] + bbox[2]) / 2, # Use list indices | |
| "y": (bbox[1] + bbox[3]) / 2, # Use list indices | |
| "bbox": { # Convert to dict format for consistency | |
| "xmin": bbox[0], | |
| "ymin": bbox[1], | |
| "xmax": bbox[2], | |
| "ymax": bbox[3] | |
| } | |
| }) | |
| # Process lines (unchanged) | |
| for line in lines: | |
| edges.append({ | |
| "id": str(uuid.uuid4()), | |
| "type": "line", | |
| "start_point": line["start_point"], | |
| "end_point": line["end_point"] | |
| }) | |
| except Exception as e: | |
| self.logger.error(f"Error processing data: {str(e)}") | |
| self.logger.error("Current symbol/text being processed: %s", | |
| json.dumps(symbol if 'symbol' in locals() else text, indent=2)) | |
| raise | |
| return nodes, edges | |
| def _validate_coordinates(self, data, data_type): | |
| """Validate coordinates in the data""" | |
| if not data: | |
| return False | |
| try: | |
| if data_type == 'line': | |
| # Check start and end points | |
| start = data.get('start_point', {}) | |
| end = data.get('end_point', {}) | |
| bbox = data.get('bbox', {}) | |
| required_fields = ['x', 'y', 'type'] | |
| if not all(field in start for field in required_fields): | |
| self.logger.warning(f"Missing required fields in start_point: {start}") | |
| return False | |
| if not all(field in end for field in required_fields): | |
| self.logger.warning(f"Missing required fields in end_point: {end}") | |
| return False | |
| # Validate bbox coordinates | |
| if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']): | |
| self.logger.warning(f"Invalid bbox format: {bbox}") | |
| return False | |
| # Check coordinate consistency | |
| if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']: | |
| self.logger.warning(f"Invalid bbox coordinates: {bbox}") | |
| return False | |
| elif data_type in ['symbol', 'text']: | |
| bbox = data.get('bbox', {}) | |
| if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']): | |
| self.logger.warning(f"Invalid {data_type} bbox format: {bbox}") | |
| return False | |
| # Check coordinate consistency | |
| if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']: | |
| self.logger.warning(f"Invalid {data_type} bbox coordinates: {bbox}") | |
| return False | |
| return True | |
| except Exception as e: | |
| self.logger.error(f"Validation error for {data_type}: {str(e)}") | |
| return False | |
| def aggregate_data(self, symbols_path: str, texts_path: str, lines_path: str) -> dict: | |
| """Aggregate detection results and create graph structure""" | |
| try: | |
| # Load line detection results | |
| lines_data = json.loads(self.storage.load_file(lines_path).decode('utf-8')) | |
| lines = self._parse_line_data(lines_data) | |
| self.logger.info(f"Loaded {len(lines)} lines") | |
| # Load and debug symbol detections | |
| symbols = [] | |
| if symbols_path and Path(symbols_path).exists(): | |
| with open(symbols_path, 'r') as f: | |
| symbols_data = json.load(f) | |
| # Debug symbol data structure | |
| self.logger.info("Symbol data keys: %s", list(symbols_data.keys())) | |
| self.logger.info("First symbol in detections: %s", | |
| json.dumps(symbols_data["detections"][0], indent=2)) | |
| symbols = symbols_data.get("detections", []) | |
| self.logger.info(f"Loaded {len(symbols)} symbols from {symbols_path}") | |
| # Debug first symbol structure | |
| if symbols: | |
| self.logger.info("First symbol keys: %s", list(symbols[0].keys())) | |
| self.logger.info("First symbol bbox: %s", symbols[0]["bbox"]) | |
| # Load and debug text detections | |
| texts = [] | |
| if texts_path and Path(texts_path).exists(): | |
| with open(texts_path, 'r') as f: | |
| texts_data = json.load(f) | |
| # Debug text data structure | |
| self.logger.info("Text data keys: %s", list(texts_data.keys())) | |
| self.logger.info("First text in detections: %s", | |
| json.dumps(texts_data["detections"][0], indent=2)) | |
| texts = texts_data.get("detections", []) | |
| self.logger.info(f"Loaded {len(texts)} texts from {texts_path}") | |
| # Debug first text structure | |
| if texts: | |
| self.logger.info("First text keys: %s", list(texts[0].keys())) | |
| self.logger.info("First text bbox: %s", texts[0]["bbox"]) | |
| # Create graph data | |
| nodes, edges = self._create_graph_data(lines, symbols, texts) | |
| self.logger.info(f"Created graph with {len(nodes)} nodes and {len(edges)} edges") | |
| return { | |
| "lines": lines, | |
| "symbols": symbols, | |
| "texts": texts, | |
| "nodes": nodes, | |
| "edges": edges, | |
| "metadata": { | |
| "timestamp": datetime.now().isoformat(), | |
| "version": "2.0" | |
| } | |
| } | |
| except Exception as e: | |
| self.logger.error(f"Error during aggregation: {str(e)}") | |
| self.logger.error("Stack trace:", exc_info=True) # Add full stack trace | |
| raise | |
| def _draw_aggregated_view(self, image: np.ndarray, results: dict) -> np.ndarray: | |
| """Draw all detections on image""" | |
| annotated = image.copy() | |
| # Draw lines (green) | |
| for line in results.get('lines', []): | |
| try: | |
| cv2.line(annotated, | |
| (line['start_point']['x'], line['start_point']['y']), | |
| (line['end_point']['x'], line['end_point']['y']), | |
| (0, 255, 0), 2) | |
| except Exception as e: | |
| self.logger.warning(f"Skipping invalid line: {str(e)}") | |
| continue | |
| # Draw symbols (cyan) - Fix bbox access | |
| for symbol in results.get('symbols', []): | |
| try: | |
| bbox = symbol['bbox'] | |
| # bbox is a list [x1,y1,x2,y2], not a dict | |
| cv2.rectangle(annotated, | |
| (bbox[0], bbox[1]), # Use list indices | |
| (bbox[2], bbox[3]), # Use list indices | |
| (255, 255, 0), 2) | |
| except Exception as e: | |
| self.logger.warning(f"Skipping invalid symbol: {str(e)}") | |
| continue | |
| # Draw texts (purple) - Fix bbox access | |
| for text in results.get('texts', []): | |
| try: | |
| bbox = text['bbox'] | |
| # bbox is a list [x1,y1,x2,y2], not a dict | |
| cv2.rectangle(annotated, | |
| (bbox[0], bbox[1]), # Use list indices | |
| (bbox[2], bbox[3]), # Use list indices | |
| (128, 0, 128), 2) | |
| except Exception as e: | |
| self.logger.warning(f"Skipping invalid text: {str(e)}") | |
| continue | |
| return annotated | |
| def process_data(self, image_path: str, output_dir: str, symbols_path: str, texts_path: str, lines_path: str): | |
| try: | |
| self.logger.info(f"Processing data with:") | |
| self.logger.info(f"- Image: {image_path}") | |
| self.logger.info(f"- Symbols: {symbols_path}") | |
| self.logger.info(f"- Texts: {texts_path}") | |
| self.logger.info(f"- Lines: {lines_path}") | |
| base_name = Path(image_path).stem | |
| self.logger.info(f"Base name: {base_name}") | |
| aggregated_json = os.path.join(output_dir, f"{base_name}_aggregated.json") | |
| self.logger.info(f"Will save aggregated data to: {aggregated_json}") | |
| results = self.aggregate_data(symbols_path, texts_path, lines_path) | |
| self.logger.info("Data aggregation completed") | |
| with open(aggregated_json, 'w') as f: | |
| json.dump(results, f, indent=2) | |
| self.logger.info(f"Saved aggregated JSON to: {aggregated_json}") | |
| # Create visualization using original image | |
| image = cv2.imread(image_path) | |
| annotated = self._draw_aggregated_view(image, results) | |
| aggregated_image = os.path.join(output_dir, f"{base_name}_aggregated.png") | |
| cv2.imwrite(aggregated_image, annotated) | |
| # Return paths like other detectors | |
| return { | |
| 'success': True, | |
| 'image_path': aggregated_image, | |
| 'json_path': aggregated_json | |
| } | |
| except Exception as e: | |
| self.logger.error(f"Error in data aggregation: {str(e)}") | |
| return { | |
| 'success': False, | |
| 'error': str(e) | |
| } | |
| if __name__ == "__main__": | |
| import os | |
| from pprint import pprint | |
| # Initialize the aggregator | |
| aggregator = DataAggregator() | |
| # Test paths using actual files in results folder | |
| results_dir = "results" | |
| base_name = "002_page_1" | |
| # Input paths | |
| symbols_path = os.path.join(results_dir, f"{base_name}_detected_symbols.json") | |
| texts_path = os.path.join(results_dir, f"{base_name}_detected_texts.json") | |
| lines_path = os.path.join(results_dir, f"{base_name}_detected_lines.json") | |
| # Verify files exist | |
| print(f"\nChecking input files:") | |
| print(f"Symbols file exists: {os.path.exists(symbols_path)}") | |
| print(f"Texts file exists: {os.path.exists(texts_path)}") | |
| print(f"Lines file exists: {os.path.exists(lines_path)}") | |
| try: | |
| # Process the data | |
| print("\nProcessing data...") | |
| result = aggregator.process_data( | |
| image_path=os.path.join(results_dir, f"{base_name}.png"), | |
| output_dir=results_dir, | |
| symbols_path=symbols_path, | |
| texts_path=texts_path, | |
| lines_path=lines_path | |
| ) | |
| # Verify output files | |
| aggregated_json = os.path.join(results_dir, f"{base_name}_aggregated.json") | |
| aggregated_image = os.path.join(results_dir, f"{base_name}_aggregated.png") | |
| print("\nChecking output files:") | |
| print(f"Aggregated JSON exists: {os.path.exists(aggregated_json)}") | |
| print(f"Aggregated image exists: {os.path.exists(aggregated_image)}") | |
| # Load and print statistics from aggregated result | |
| if os.path.exists(aggregated_json): | |
| with open(aggregated_json, 'r') as f: | |
| data = json.load(f) | |
| print("\nAggregation Results:") | |
| print(f"Number of Symbols: {len(data.get('symbols', []))}") | |
| print(f"Number of Texts: {len(data.get('texts', []))}") | |
| print(f"Number of Lines: {len(data.get('lines', []))}") | |
| print(f"Number of Nodes: {len(data.get('nodes', []))}") | |
| print(f"Number of Edges: {len(data.get('edges', []))}") | |
| except Exception as e: | |
| print(f"\nError during testing: {str(e)}") | |
| traceback.print_exc() | |