Spaces:
Running
Running
| #!/usr/bin/env python | |
| """ | |
| Sample data loader for database initialization. | |
| Loads curated examples of traces and knowledge graphs from JSON files for new users. | |
| """ | |
| import json | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Any | |
| logger = logging.getLogger(__name__) | |
| # Get the directory where this file is located | |
| CURRENT_DIR = Path(__file__).parent | |
| SAMPLES_DIR = CURRENT_DIR / "samples" | |
| CONFIG_FILE = SAMPLES_DIR / "samples_config.json" | |
| class SampleDataLoader: | |
| """Loads sample data from JSON files.""" | |
| def __init__(self): | |
| self._config = None | |
| self._traces = None | |
| self._knowledge_graphs = None | |
| def _load_config(self) -> Dict[str, Any]: | |
| """Load the samples configuration.""" | |
| if self._config is None: | |
| try: | |
| with open(CONFIG_FILE, 'r', encoding='utf-8') as f: | |
| self._config = json.load(f) | |
| logger.info(f"Loaded sample data configuration from {CONFIG_FILE}") | |
| except FileNotFoundError: | |
| logger.error(f"Configuration file not found: {CONFIG_FILE}") | |
| raise | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Invalid JSON in configuration file: {e}") | |
| raise | |
| return self._config | |
| def _load_trace(self, trace_file: str) -> Dict[str, Any]: | |
| """Load a single trace from JSON file.""" | |
| trace_path = SAMPLES_DIR / trace_file | |
| try: | |
| with open(trace_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except FileNotFoundError: | |
| logger.error(f"Trace file not found: {trace_path}") | |
| raise | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Invalid JSON in trace file {trace_path}: {e}") | |
| raise | |
| def _load_knowledge_graph(self, kg_file: str) -> Dict[str, Any]: | |
| """Load a single knowledge graph from JSON file.""" | |
| kg_path = SAMPLES_DIR / kg_file | |
| try: | |
| with open(kg_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except FileNotFoundError: | |
| logger.error(f"Knowledge graph file not found: {kg_path}") | |
| raise | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Invalid JSON in knowledge graph file {kg_path}: {e}") | |
| raise | |
| def get_traces(self) -> List[Dict[str, Any]]: | |
| """Get all sample traces in the expected format.""" | |
| if self._traces is None: | |
| config = self._load_config() | |
| self._traces = [] | |
| for sample in config["samples"]: | |
| # Load the trace data | |
| trace_data = self._load_trace(sample["trace_file"]) | |
| # Convert to the expected format | |
| trace_entry = { | |
| "filename": sample["name"].replace(" ", "_").lower() + ".json", | |
| "title": sample["name"], | |
| "description": sample["description"], | |
| "trace_type": sample["trace_type"], | |
| "trace_source": sample["trace_source"], | |
| "tags": sample["tags"], | |
| "content": json.dumps(trace_data["content"]) # Convert content back to JSON string | |
| } | |
| self._traces.append(trace_entry) | |
| logger.info(f"Loaded {len(self._traces)} sample traces") | |
| return self._traces | |
| def get_knowledge_graphs(self) -> List[Dict[str, Any]]: | |
| """Get all sample knowledge graphs in the expected format.""" | |
| if self._knowledge_graphs is None: | |
| config = self._load_config() | |
| self._knowledge_graphs = [] | |
| for i, sample in enumerate(config["samples"]): | |
| # Load the main knowledge graph data | |
| kg_data = self._load_knowledge_graph(sample["knowledge_graph_file"]) | |
| # Check if this sample supports replay (has window KGs) | |
| supports_replay = sample.get("supports_replay", False) | |
| window_info = sample.get("window_info", {}) | |
| # Convert main KG to the expected format | |
| kg_entry = { | |
| "filename": sample["knowledge_graph_file"].split("/")[-1], # Get just the filename | |
| "trace_index": i, # Links to trace by index | |
| "graph_data": kg_data["graph_data"] | |
| } | |
| # Add window metadata for final KG (window_index=None, window_total=count) | |
| if supports_replay and window_info: | |
| kg_entry["window_total"] = window_info.get("window_count", 0) | |
| kg_entry["processing_run_id"] = window_info.get("processing_run_id") | |
| logger.debug(f"Main KG {kg_entry['filename']} configured with {kg_entry['window_total']} windows") | |
| self._knowledge_graphs.append(kg_entry) | |
| # Load window KGs if they exist | |
| if supports_replay and window_info.get("window_files"): | |
| window_files = window_info["window_files"] | |
| window_count = window_info.get("window_count", len(window_files)) | |
| processing_run_id = window_info.get("processing_run_id") | |
| logger.debug(f"Loading {len(window_files)} window KGs for {sample['id']}") | |
| for window_index, window_file in enumerate(window_files): | |
| try: | |
| # Load window KG data | |
| window_kg_data = self._load_knowledge_graph(window_file) | |
| # Convert window KG to the expected format | |
| window_kg_entry = { | |
| "filename": window_file.split("/")[-1], # Get just the filename | |
| "trace_index": i, # Links to same trace | |
| "graph_data": window_kg_data["graph_data"], | |
| "window_index": window_index, # This makes it a window KG | |
| "window_total": window_count, | |
| "processing_run_id": processing_run_id, | |
| "window_start_char": window_kg_data.get("window_start_char"), | |
| "window_end_char": window_kg_data.get("window_end_char") | |
| } | |
| self._knowledge_graphs.append(window_kg_entry) | |
| logger.debug(f"Loaded window KG {window_index}: {window_kg_entry['filename']}") | |
| except Exception as e: | |
| logger.error(f"Failed to load window KG {window_file}: {e}") | |
| continue | |
| logger.info(f"Loaded {len(self._knowledge_graphs)} sample knowledge graphs (including window KGs)") | |
| return self._knowledge_graphs | |
| def get_sample_info(self) -> Dict[str, Any]: | |
| """Get information about the available sample data.""" | |
| config = self._load_config() | |
| traces = self.get_traces() | |
| knowledge_graphs = self.get_knowledge_graphs() | |
| # Extract unique features from all samples | |
| all_features = set() | |
| for sample in config["samples"]: | |
| all_features.update(sample.get("features", [])) | |
| return { | |
| "traces_count": len(traces), | |
| "knowledge_graphs_count": len(knowledge_graphs), | |
| "trace_types": list(set(t["trace_type"] for t in traces)), | |
| "complexity_levels": list(set(sample.get("complexity", "standard") for sample in config["samples"])), | |
| "features": list(all_features), | |
| "description": config["metadata"]["description"], | |
| "version": config["metadata"]["version"] | |
| } | |
| # Create a global loader instance | |
| _loader = SampleDataLoader() | |
| # Maintain backward compatibility by exposing the same interface | |
| def get_sample_traces() -> List[Dict[str, Any]]: | |
| """Get sample traces (backward compatibility).""" | |
| return _loader.get_traces() | |
| def get_sample_knowledge_graphs() -> List[Dict[str, Any]]: | |
| """Get sample knowledge graphs (backward compatibility).""" | |
| return _loader.get_knowledge_graphs() | |
| # Legacy global variables for backward compatibility | |
| def SAMPLE_TRACES(): | |
| """Legacy property for backward compatibility.""" | |
| return _loader.get_traces() | |
| def SAMPLE_KNOWLEDGE_GRAPHS(): | |
| """Legacy property for backward compatibility.""" | |
| return _loader.get_knowledge_graphs() | |
| # Make them accessible as module-level variables | |
| import sys | |
| current_module = sys.modules[__name__] | |
| current_module.SAMPLE_TRACES = _loader.get_traces() | |
| current_module.SAMPLE_KNOWLEDGE_GRAPHS = _loader.get_knowledge_graphs() | |
| def insert_sample_data(session, force_insert=False): | |
| """ | |
| Insert sample traces and knowledge graphs into the database. | |
| Args: | |
| session: Database session | |
| force_insert: If True, insert even if data already exists | |
| Returns: | |
| Dict with insertion results | |
| """ | |
| from backend.database.utils import save_trace, save_knowledge_graph | |
| from backend.database.models import Trace, KnowledgeGraph | |
| results = { | |
| "traces_inserted": 0, | |
| "knowledge_graphs_inserted": 0, | |
| "skipped": 0, | |
| "errors": [] | |
| } | |
| # Get sample data from loader | |
| sample_traces = _loader.get_traces() | |
| sample_knowledge_graphs = _loader.get_knowledge_graphs() | |
| # Check if sample data already exists | |
| if not force_insert: | |
| existing_sample = session.query(Trace).filter( | |
| Trace.trace_source == "sample_data" | |
| ).first() | |
| if existing_sample: | |
| logger.info("Sample data already exists, skipping insertion") | |
| results["skipped"] = len(sample_traces) | |
| return results | |
| try: | |
| # Insert sample traces | |
| trace_ids = [] | |
| for i, trace_data in enumerate(sample_traces): | |
| try: | |
| trace = save_trace( | |
| session=session, | |
| content=trace_data["content"], | |
| filename=trace_data["filename"], | |
| title=trace_data["title"], | |
| description=trace_data["description"], | |
| trace_type=trace_data["trace_type"], | |
| trace_source=trace_data["trace_source"], | |
| tags=trace_data["tags"] | |
| ) | |
| trace_ids.append(trace.trace_id) | |
| results["traces_inserted"] += 1 | |
| logger.info(f"Inserted sample trace: {trace_data['title']}") | |
| except Exception as e: | |
| error_msg = f"Error inserting trace {i}: {str(e)}" | |
| logger.error(error_msg) | |
| results["errors"].append(error_msg) | |
| # Insert corresponding knowledge graphs | |
| for kg_data in sample_knowledge_graphs: | |
| try: | |
| trace_index = kg_data["trace_index"] | |
| if trace_index < len(trace_ids): | |
| # Extract window information from the KG data | |
| window_index = kg_data.get("window_index") # None for final KG, index for window KG | |
| window_total = kg_data.get("window_total", 1) # Use provided window_total or default to 1 | |
| window_start_char = kg_data.get("window_start_char") | |
| window_end_char = kg_data.get("window_end_char") | |
| processing_run_id = kg_data.get("processing_run_id") | |
| save_knowledge_graph( | |
| session=session, | |
| filename=kg_data["filename"], | |
| graph_data=kg_data["graph_data"], | |
| trace_id=trace_ids[trace_index], | |
| window_index=window_index, | |
| window_total=window_total, | |
| window_start_char=window_start_char, | |
| window_end_char=window_end_char, | |
| processing_run_id=processing_run_id, | |
| is_original=True | |
| ) | |
| results["knowledge_graphs_inserted"] += 1 | |
| # Log different messages for final vs window KGs | |
| if window_index is None: | |
| logger.info(f"Inserted sample knowledge graph: {kg_data['filename']} (final, {window_total} windows)") | |
| else: | |
| logger.info(f"Inserted sample knowledge graph: {kg_data['filename']} (window {window_index})") | |
| except Exception as e: | |
| error_msg = f"Error inserting knowledge graph {kg_data['filename']}: {str(e)}" | |
| logger.error(error_msg) | |
| results["errors"].append(error_msg) | |
| logger.info(f"Sample data insertion completed: {results}") | |
| except Exception as e: | |
| error_msg = f"Fatal error during sample data insertion: {str(e)}" | |
| logger.error(error_msg) | |
| results["errors"].append(error_msg) | |
| raise # Re-raise to trigger rollback in calling code | |
| return results | |
| def get_sample_data_info(): | |
| """ | |
| Get information about the available sample data. | |
| Returns: | |
| Dict with sample data statistics | |
| """ | |
| return _loader.get_sample_info() | |
| # Additional utility functions for managing samples | |
| def add_sample(sample_id: str, name: str, description: str, trace_file: str, | |
| knowledge_graph_file: str, tags: List[str], trace_type: str = "custom", | |
| trace_source: str = "sample_data", complexity: str = "standard", | |
| features: List[str] = None): | |
| """ | |
| Add a new sample to the configuration (utility function for future use). | |
| Args: | |
| sample_id: Unique identifier for the sample | |
| name: Human-readable name | |
| description: Description of the sample | |
| trace_file: Path to trace JSON file relative to samples directory | |
| knowledge_graph_file: Path to KG JSON file relative to samples directory | |
| tags: List of tags | |
| trace_type: Type of trace | |
| trace_source: Source of trace | |
| complexity: Complexity level | |
| features: List of features demonstrated | |
| """ | |
| # This would modify the config file - implementation depends on requirements | |
| logger.info(f"Add sample feature called for: {sample_id}") | |
| pass | |
| def list_available_samples() -> List[Dict[str, Any]]: | |
| """List all available samples with their metadata.""" | |
| config = _loader._load_config() | |
| return config["samples"] | |
| if __name__ == "__main__": | |
| # Quick test of the loader | |
| try: | |
| info = get_sample_data_info() | |
| print("Sample Data Info:", json.dumps(info, indent=2)) | |
| traces = get_sample_traces() | |
| print(f"Loaded {len(traces)} traces") | |
| kgs = get_sample_knowledge_graphs() | |
| print(f"Loaded {len(kgs)} knowledge graphs") | |
| except Exception as e: | |
| print(f"Error testing sample data loader: {e}") | |