AgentGraph / backend /database /sample_data.py
wu981526092's picture
🔧 Fix sample data loader to support window KG files for Replay button
9ad889b
#!/usr/bin/env python
"""
Sample data loader for database initialization.
Loads curated examples of traces and knowledge graphs from JSON files for new users.
"""
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Any
logger = logging.getLogger(__name__)
# Get the directory where this file is located
CURRENT_DIR = Path(__file__).parent
SAMPLES_DIR = CURRENT_DIR / "samples"
CONFIG_FILE = SAMPLES_DIR / "samples_config.json"
class SampleDataLoader:
"""Loads sample data from JSON files."""
def __init__(self):
self._config = None
self._traces = None
self._knowledge_graphs = None
def _load_config(self) -> Dict[str, Any]:
"""Load the samples configuration."""
if self._config is None:
try:
with open(CONFIG_FILE, 'r', encoding='utf-8') as f:
self._config = json.load(f)
logger.info(f"Loaded sample data configuration from {CONFIG_FILE}")
except FileNotFoundError:
logger.error(f"Configuration file not found: {CONFIG_FILE}")
raise
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in configuration file: {e}")
raise
return self._config
def _load_trace(self, trace_file: str) -> Dict[str, Any]:
"""Load a single trace from JSON file."""
trace_path = SAMPLES_DIR / trace_file
try:
with open(trace_path, 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
logger.error(f"Trace file not found: {trace_path}")
raise
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in trace file {trace_path}: {e}")
raise
def _load_knowledge_graph(self, kg_file: str) -> Dict[str, Any]:
"""Load a single knowledge graph from JSON file."""
kg_path = SAMPLES_DIR / kg_file
try:
with open(kg_path, 'r', encoding='utf-8') as f:
return json.load(f)
except FileNotFoundError:
logger.error(f"Knowledge graph file not found: {kg_path}")
raise
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in knowledge graph file {kg_path}: {e}")
raise
def get_traces(self) -> List[Dict[str, Any]]:
"""Get all sample traces in the expected format."""
if self._traces is None:
config = self._load_config()
self._traces = []
for sample in config["samples"]:
# Load the trace data
trace_data = self._load_trace(sample["trace_file"])
# Convert to the expected format
trace_entry = {
"filename": sample["name"].replace(" ", "_").lower() + ".json",
"title": sample["name"],
"description": sample["description"],
"trace_type": sample["trace_type"],
"trace_source": sample["trace_source"],
"tags": sample["tags"],
"content": json.dumps(trace_data["content"]) # Convert content back to JSON string
}
self._traces.append(trace_entry)
logger.info(f"Loaded {len(self._traces)} sample traces")
return self._traces
def get_knowledge_graphs(self) -> List[Dict[str, Any]]:
"""Get all sample knowledge graphs in the expected format."""
if self._knowledge_graphs is None:
config = self._load_config()
self._knowledge_graphs = []
for i, sample in enumerate(config["samples"]):
# Load the main knowledge graph data
kg_data = self._load_knowledge_graph(sample["knowledge_graph_file"])
# Check if this sample supports replay (has window KGs)
supports_replay = sample.get("supports_replay", False)
window_info = sample.get("window_info", {})
# Convert main KG to the expected format
kg_entry = {
"filename": sample["knowledge_graph_file"].split("/")[-1], # Get just the filename
"trace_index": i, # Links to trace by index
"graph_data": kg_data["graph_data"]
}
# Add window metadata for final KG (window_index=None, window_total=count)
if supports_replay and window_info:
kg_entry["window_total"] = window_info.get("window_count", 0)
kg_entry["processing_run_id"] = window_info.get("processing_run_id")
logger.debug(f"Main KG {kg_entry['filename']} configured with {kg_entry['window_total']} windows")
self._knowledge_graphs.append(kg_entry)
# Load window KGs if they exist
if supports_replay and window_info.get("window_files"):
window_files = window_info["window_files"]
window_count = window_info.get("window_count", len(window_files))
processing_run_id = window_info.get("processing_run_id")
logger.debug(f"Loading {len(window_files)} window KGs for {sample['id']}")
for window_index, window_file in enumerate(window_files):
try:
# Load window KG data
window_kg_data = self._load_knowledge_graph(window_file)
# Convert window KG to the expected format
window_kg_entry = {
"filename": window_file.split("/")[-1], # Get just the filename
"trace_index": i, # Links to same trace
"graph_data": window_kg_data["graph_data"],
"window_index": window_index, # This makes it a window KG
"window_total": window_count,
"processing_run_id": processing_run_id,
"window_start_char": window_kg_data.get("window_start_char"),
"window_end_char": window_kg_data.get("window_end_char")
}
self._knowledge_graphs.append(window_kg_entry)
logger.debug(f"Loaded window KG {window_index}: {window_kg_entry['filename']}")
except Exception as e:
logger.error(f"Failed to load window KG {window_file}: {e}")
continue
logger.info(f"Loaded {len(self._knowledge_graphs)} sample knowledge graphs (including window KGs)")
return self._knowledge_graphs
def get_sample_info(self) -> Dict[str, Any]:
"""Get information about the available sample data."""
config = self._load_config()
traces = self.get_traces()
knowledge_graphs = self.get_knowledge_graphs()
# Extract unique features from all samples
all_features = set()
for sample in config["samples"]:
all_features.update(sample.get("features", []))
return {
"traces_count": len(traces),
"knowledge_graphs_count": len(knowledge_graphs),
"trace_types": list(set(t["trace_type"] for t in traces)),
"complexity_levels": list(set(sample.get("complexity", "standard") for sample in config["samples"])),
"features": list(all_features),
"description": config["metadata"]["description"],
"version": config["metadata"]["version"]
}
# Create a global loader instance
_loader = SampleDataLoader()
# Maintain backward compatibility by exposing the same interface
def get_sample_traces() -> List[Dict[str, Any]]:
"""Get sample traces (backward compatibility)."""
return _loader.get_traces()
def get_sample_knowledge_graphs() -> List[Dict[str, Any]]:
"""Get sample knowledge graphs (backward compatibility)."""
return _loader.get_knowledge_graphs()
# Legacy global variables for backward compatibility
@property
def SAMPLE_TRACES():
"""Legacy property for backward compatibility."""
return _loader.get_traces()
@property
def SAMPLE_KNOWLEDGE_GRAPHS():
"""Legacy property for backward compatibility."""
return _loader.get_knowledge_graphs()
# Make them accessible as module-level variables
import sys
current_module = sys.modules[__name__]
current_module.SAMPLE_TRACES = _loader.get_traces()
current_module.SAMPLE_KNOWLEDGE_GRAPHS = _loader.get_knowledge_graphs()
def insert_sample_data(session, force_insert=False):
"""
Insert sample traces and knowledge graphs into the database.
Args:
session: Database session
force_insert: If True, insert even if data already exists
Returns:
Dict with insertion results
"""
from backend.database.utils import save_trace, save_knowledge_graph
from backend.database.models import Trace, KnowledgeGraph
results = {
"traces_inserted": 0,
"knowledge_graphs_inserted": 0,
"skipped": 0,
"errors": []
}
# Get sample data from loader
sample_traces = _loader.get_traces()
sample_knowledge_graphs = _loader.get_knowledge_graphs()
# Check if sample data already exists
if not force_insert:
existing_sample = session.query(Trace).filter(
Trace.trace_source == "sample_data"
).first()
if existing_sample:
logger.info("Sample data already exists, skipping insertion")
results["skipped"] = len(sample_traces)
return results
try:
# Insert sample traces
trace_ids = []
for i, trace_data in enumerate(sample_traces):
try:
trace = save_trace(
session=session,
content=trace_data["content"],
filename=trace_data["filename"],
title=trace_data["title"],
description=trace_data["description"],
trace_type=trace_data["trace_type"],
trace_source=trace_data["trace_source"],
tags=trace_data["tags"]
)
trace_ids.append(trace.trace_id)
results["traces_inserted"] += 1
logger.info(f"Inserted sample trace: {trace_data['title']}")
except Exception as e:
error_msg = f"Error inserting trace {i}: {str(e)}"
logger.error(error_msg)
results["errors"].append(error_msg)
# Insert corresponding knowledge graphs
for kg_data in sample_knowledge_graphs:
try:
trace_index = kg_data["trace_index"]
if trace_index < len(trace_ids):
# Extract window information from the KG data
window_index = kg_data.get("window_index") # None for final KG, index for window KG
window_total = kg_data.get("window_total", 1) # Use provided window_total or default to 1
window_start_char = kg_data.get("window_start_char")
window_end_char = kg_data.get("window_end_char")
processing_run_id = kg_data.get("processing_run_id")
save_knowledge_graph(
session=session,
filename=kg_data["filename"],
graph_data=kg_data["graph_data"],
trace_id=trace_ids[trace_index],
window_index=window_index,
window_total=window_total,
window_start_char=window_start_char,
window_end_char=window_end_char,
processing_run_id=processing_run_id,
is_original=True
)
results["knowledge_graphs_inserted"] += 1
# Log different messages for final vs window KGs
if window_index is None:
logger.info(f"Inserted sample knowledge graph: {kg_data['filename']} (final, {window_total} windows)")
else:
logger.info(f"Inserted sample knowledge graph: {kg_data['filename']} (window {window_index})")
except Exception as e:
error_msg = f"Error inserting knowledge graph {kg_data['filename']}: {str(e)}"
logger.error(error_msg)
results["errors"].append(error_msg)
logger.info(f"Sample data insertion completed: {results}")
except Exception as e:
error_msg = f"Fatal error during sample data insertion: {str(e)}"
logger.error(error_msg)
results["errors"].append(error_msg)
raise # Re-raise to trigger rollback in calling code
return results
def get_sample_data_info():
"""
Get information about the available sample data.
Returns:
Dict with sample data statistics
"""
return _loader.get_sample_info()
# Additional utility functions for managing samples
def add_sample(sample_id: str, name: str, description: str, trace_file: str,
knowledge_graph_file: str, tags: List[str], trace_type: str = "custom",
trace_source: str = "sample_data", complexity: str = "standard",
features: List[str] = None):
"""
Add a new sample to the configuration (utility function for future use).
Args:
sample_id: Unique identifier for the sample
name: Human-readable name
description: Description of the sample
trace_file: Path to trace JSON file relative to samples directory
knowledge_graph_file: Path to KG JSON file relative to samples directory
tags: List of tags
trace_type: Type of trace
trace_source: Source of trace
complexity: Complexity level
features: List of features demonstrated
"""
# This would modify the config file - implementation depends on requirements
logger.info(f"Add sample feature called for: {sample_id}")
pass
def list_available_samples() -> List[Dict[str, Any]]:
"""List all available samples with their metadata."""
config = _loader._load_config()
return config["samples"]
if __name__ == "__main__":
# Quick test of the loader
try:
info = get_sample_data_info()
print("Sample Data Info:", json.dumps(info, indent=2))
traces = get_sample_traces()
print(f"Loaded {len(traces)} traces")
kgs = get_sample_knowledge_graphs()
print(f"Loaded {len(kgs)} knowledge graphs")
except Exception as e:
print(f"Error testing sample data loader: {e}")