wu981526092's picture
🚀 Deploy AgentGraph: Complete agent monitoring and knowledge graph system
c2ea5ed
"""
Utility functions for database operations.
"""
import os
import json
import logging
import uuid
from datetime import datetime
from typing import Dict, List, Any, Optional, Union
import hashlib
from sqlalchemy.orm import Session
from sqlalchemy import func
from . import models
from . import get_db, init_db
logger = logging.getLogger(__name__)
def initialize_database(clear_all=False):
"""
Initialize the database and create tables.
Args:
clear_all: If True, drops all existing tables before creating new ones
"""
if clear_all:
from . import reinit_db
reinit_db()
logger.info("Database reinitialized (all previous data cleared)")
else:
from . import init_db
init_db()
logger.info("Database initialized (existing data preserved)")
def get_knowledge_graph(session: Session, filename: str) -> Optional[models.KnowledgeGraph]:
"""Get a knowledge graph by filename."""
return session.query(models.KnowledgeGraph).filter_by(filename=filename).first()
def save_knowledge_graph(
session,
filename,
graph_data,
trace_id=None,
window_index=None,
window_total=None,
window_start_char=None,
window_end_char=None,
is_original=False,
processing_run_id=None
):
"""
Save a knowledge graph to the database.
Args:
session: Database session
filename: Filename to save under
graph_data: Knowledge graph data
trace_id: Optional ID to group knowledge graphs from the same trace
window_index: Optional sequential index of window within a trace
window_total: Optional total number of windows in the trace
window_start_char: Optional starting character position in the original trace
window_end_char: Optional ending character position in the original trace
is_original: Whether this is an original knowledge graph (sets status to "created")
processing_run_id: Optional ID to distinguish multiple processing runs
Returns:
The created KnowledgeGraph object
"""
from backend.database.models import KnowledgeGraph
# Check if the knowledge graph already exists
kg = session.query(KnowledgeGraph).filter(KnowledgeGraph.filename == filename).first()
if kg:
# Update the existing knowledge graph using graph_content to ensure counts are updated
kg.graph_content = graph_data
kg.update_timestamp = datetime.utcnow()
# Update trace information if provided
if trace_id is not None:
kg.trace_id = trace_id
if window_index is not None:
kg.window_index = window_index
if window_total is not None:
kg.window_total = window_total
if window_start_char is not None:
kg.window_start_char = window_start_char
if window_end_char is not None:
kg.window_end_char = window_end_char
if processing_run_id is not None:
kg.processing_run_id = processing_run_id
# Set status if is_original is True
if is_original:
kg.status = "created"
session.add(kg)
session.commit()
return kg
else:
# Create a new knowledge graph
kg = KnowledgeGraph(
filename=filename,
trace_id=trace_id,
window_index=window_index,
window_total=window_total,
window_start_char=window_start_char,
window_end_char=window_end_char,
status="created" if is_original else None,
processing_run_id=processing_run_id
)
# Set graph content after creation to ensure counts are updated
kg.graph_content = graph_data
session.add(kg)
session.commit()
return kg
def update_knowledge_graph_status(session: Session, kg_id: Union[int, str], status: str) -> models.KnowledgeGraph:
"""
Update the status of a knowledge graph.
Args:
session: Database session
kg_id: Knowledge graph ID or filename
status: New status (created, enriched, perturbed, causal)
Returns:
Updated knowledge graph
"""
# Check if kg_id is a filename or an ID
if isinstance(kg_id, str):
kg = session.query(models.KnowledgeGraph).filter_by(filename=kg_id).first()
else:
kg = session.query(models.KnowledgeGraph).filter_by(id=kg_id).first()
if not kg:
raise ValueError(f"Knowledge graph with ID/filename {kg_id} not found")
# Update status
kg.status = status
session.commit()
return kg
def extract_entities_and_relations(session: Session, kg: models.KnowledgeGraph):
"""Extract entities and relations from a knowledge graph and save them to the database."""
# Get the graph data
data = kg.graph_data
# Skip if no data
if not data:
return
# First, delete existing relations and entities for this knowledge graph
# We need to delete relations first due to foreign key constraints
session.query(models.Relation).filter_by(graph_id=kg.id).delete()
session.query(models.Entity).filter_by(graph_id=kg.id).delete()
session.flush()
# Process entities
entity_map = {} # Map entity_id to Entity instance
for entity_data in data.get('entities', []):
try:
# Skip if no id
if 'id' not in entity_data:
continue
entity_id = entity_data.get('id')
# Create entity
entity = models.Entity.from_dict(entity_data, kg.id)
# Add to session
session.add(entity)
session.flush() # Flush to get the ID
# Add to map
entity_map[entity_id] = entity
except Exception as e:
logger.error(f"Error extracting entity {entity_data.get('id')}: {str(e)}")
# Process relations
for relation_data in data.get('relations', []):
try:
# Skip if no id, source, or target
if 'id' not in relation_data or 'source' not in relation_data or 'target' not in relation_data:
continue
source_id = relation_data.get('source')
target_id = relation_data.get('target')
# Get source and target entities
source_entity = entity_map.get(source_id)
target_entity = entity_map.get(target_id)
# Skip if source or target entity not found
if not source_entity or not target_entity:
logger.warning(f"Skipping relation {relation_data.get('id')}: Source or target entity not found")
continue
# Create relation
relation = models.Relation.from_dict(
relation_data,
kg.id,
source_entity,
target_entity
)
# Add to session
session.add(relation)
except Exception as e:
logger.error(f"Error extracting relation {relation_data.get('id')}: {str(e)}")
# Commit the changes
session.commit()
def get_test_result(session: Session, filename: str) -> Optional[Dict[str, Any]]:
"""
Get a test result by filename from the knowledge graph.
This now returns a dictionary with test result data instead of a TestResult model.
"""
# Try to find a knowledge graph with this test result filename
kg = session.query(models.KnowledgeGraph).filter_by(filename=filename).first()
if kg and kg.content:
try:
data = json.loads(kg.content)
if 'test_result' in data:
return data['test_result']
except json.JSONDecodeError:
pass
# Try standard file locations
standard_file_locations = [
f"datasets/test_results/{filename}",
f"datasets/{filename}"
]
for file_path in standard_file_locations:
try:
with open(file_path, 'r') as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
continue
return None
def save_test_result(session: Session, filename: str, data: Dict[str, Any]) -> models.KnowledgeGraph:
"""
Save a test result to the database.
Test results are now stored within the KnowledgeGraph content field
rather than as a separate TestResult model.
"""
# Find or create a knowledge graph for this test result
kg = session.query(models.KnowledgeGraph).filter_by(filename=filename).first()
if not kg:
# Create new knowledge graph for this test result
kg = models.KnowledgeGraph()
kg.filename = filename
kg.creation_timestamp = datetime.utcnow()
# Get existing content or initialize empty dict
try:
if kg.content:
content = json.loads(kg.content)
else:
content = {}
except json.JSONDecodeError:
content = {}
# Update test result data
content['test_result'] = data
content['test_timestamp'] = datetime.utcnow().isoformat()
content['model_name'] = data.get('model', '')
content['perturbation_type'] = data.get('perturbation_type', '')
content['completed'] = data.get('completed', False)
# Find the related knowledge graph if referenced
kg_filename = data.get('knowledge_graph_filename')
if kg_filename:
related_kg = session.query(models.KnowledgeGraph).filter_by(filename=kg_filename).first()
if related_kg:
content['knowledge_graph_id'] = related_kg.id
# Save updated content back to knowledge graph
kg.content = json.dumps(content)
# Save to database
session.add(kg)
session.commit()
return kg
def get_test_progress(session: Session, test_filename: str) -> Optional[Dict[str, Any]]:
"""
Get test progress by test filename.
Now returns progress data as a dictionary instead of a TestProgress model.
"""
kg = session.query(models.KnowledgeGraph).filter_by(filename=test_filename).first()
if kg and kg.content:
try:
content = json.loads(kg.content)
if 'test_progress' in content:
return content['test_progress']
except json.JSONDecodeError:
pass
# Try to find a progress file
progress_filename = f"progress_{test_filename}"
progress_path = str(PROJECT_ROOT / 'datasets' / 'test_results' / progress_filename)
if os.path.exists(progress_path):
try:
with open(progress_path, 'r') as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
pass
return None
def save_test_progress(session: Session, test_filename: str, data: Dict[str, Any]) -> models.KnowledgeGraph:
"""
Save test progress to the database.
Test progress is now stored within the KnowledgeGraph content field
rather than as a separate TestProgress model.
"""
# Find the knowledge graph for this test
kg = session.query(models.KnowledgeGraph).filter_by(filename=test_filename).first()
if not kg:
# Create new knowledge graph for this test
kg = models.KnowledgeGraph()
kg.filename = test_filename
kg.creation_timestamp = datetime.utcnow()
# Get existing content or initialize empty dict
try:
if kg.content:
content = json.loads(kg.content)
else:
content = {}
except json.JSONDecodeError:
content = {}
# Initialize test_progress if it doesn't exist
if 'test_progress' not in content:
content['test_progress'] = {}
# Update progress data
if 'progress' in data:
progress_data = data['progress']
content['test_progress']['status'] = progress_data.get('status', content['test_progress'].get('status'))
content['test_progress']['current'] = progress_data.get('current', content['test_progress'].get('current'))
content['test_progress']['total'] = progress_data.get('total', content['test_progress'].get('total'))
content['test_progress']['last_tested_relation'] = progress_data.get('last_tested_relation', content['test_progress'].get('last_tested_relation'))
content['test_progress']['overall_progress_percentage'] = progress_data.get('overall_progress_percentage', content['test_progress'].get('overall_progress_percentage'))
content['test_progress']['current_jailbreak'] = progress_data.get('current_jailbreak', content['test_progress'].get('current_jailbreak'))
else:
# Direct update of progress data
for key, value in data.items():
content['test_progress'][key] = value
if 'timestamp' in data:
try:
content['test_progress']['timestamp'] = data['timestamp']
except (ValueError, TypeError):
content['test_progress']['timestamp'] = datetime.utcnow().isoformat()
else:
content['test_progress']['timestamp'] = datetime.utcnow().isoformat()
# Save updated content back to knowledge graph
kg.content = json.dumps(content)
# Save to database
session.add(kg)
session.commit()
# Also save to progress file for backward compatibility
try:
progress_filename = f"progress_{test_filename}"
progress_dir = 'datasets/test_results'
os.makedirs(progress_dir, exist_ok=True)
progress_path = os.path.join(progress_dir, progress_filename)
with open(progress_path, 'w') as f:
json.dump(content['test_progress'], f)
except Exception as e:
logger.warning(f"Failed to save progress file: {str(e)}")
return kg
def get_all_knowledge_graphs(session: Session) -> List[models.KnowledgeGraph]:
"""Get all knowledge graphs."""
return session.query(models.KnowledgeGraph).all()
def get_all_test_results(session: Session) -> List[Dict[str, Any]]:
"""
Get all test results.
Now returns a list of dictionaries containing test result data
extracted from knowledge graphs.
"""
test_results = []
# Get all knowledge graphs that may contain test results
knowledge_graphs = session.query(models.KnowledgeGraph).all()
for kg in knowledge_graphs:
if kg.content:
try:
content = json.loads(kg.content)
if 'test_result' in content:
# Add filename and ID for reference
result = content['test_result'].copy() if isinstance(content['test_result'], dict) else {}
result['filename'] = kg.filename
result['id'] = kg.id
test_results.append(result)
except json.JSONDecodeError:
continue
return test_results
def get_standard_dataset(session: Session, filename: str) -> Optional[Dict[str, Any]]:
"""
Get a standard dataset by filename (e.g., jailbreak techniques).
First attempts to load from the database as a knowledge graph,
then falls back to standard data file locations.
"""
# Try to get from database as a knowledge graph
kg = session.query(models.KnowledgeGraph).filter_by(filename=filename).first()
if kg and kg.content:
try:
return json.loads(kg.content)
except json.JSONDecodeError:
pass
# If not in database, try standard file locations
standard_file_locations = [
f"datasets/{filename}", # Direct in data dir
f"datasets/test_results/{filename}",
f"datasets/knowledge_graphs/{filename}"
]
for file_path in standard_file_locations:
try:
with open(file_path, 'r') as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
continue
# Finally try as an absolute path
try:
with open(filename, 'r') as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
pass
return None
def find_entity_by_id(session: Session, entity_id: str) -> Optional[models.Entity]:
"""
Find an entity by its ID.
Args:
session: Database session
entity_id: Entity ID to search for
Returns:
Entity or None if not found
"""
query = session.query(models.Entity).filter_by(entity_id=entity_id)
return query.first()
def find_relation_by_id(session: Session, relation_id: str) -> Optional[models.Relation]:
"""
Find a relation by its ID.
Args:
session: Database session
relation_id: Relation ID to search for
Returns:
Relation or None if not found
"""
query = session.query(models.Relation).filter_by(relation_id=relation_id)
return query.first()
def find_entities_by_type(session: Session, entity_type: str) -> List[models.Entity]:
"""
Find entities by type.
Args:
session: Database session
entity_type: Entity type to search for
Returns:
List of entities
"""
query = session.query(models.Entity).filter_by(type=entity_type)
return query.all()
def find_relations_by_type(session: Session, relation_type: str) -> List[models.Relation]:
"""
Find relations by type.
Args:
session: Database session
relation_type: Relation type to search for
Returns:
List of relations
"""
query = session.query(models.Relation).filter_by(type=relation_type)
return query.all()
def merge_knowledge_graphs(session: Session, output_filename: str, input_filenames: List[str]) -> Optional[models.KnowledgeGraph]:
"""
Merge multiple knowledge graphs into a single knowledge graph.
Args:
session: Database session
output_filename: Output filename for the merged knowledge graph
input_filenames: List of filenames of knowledge graphs to merge
Returns:
The merged KnowledgeGraph instance or None if error
"""
# Check if merged graph already exists
existing_kg = get_knowledge_graph(session, output_filename)
if existing_kg:
logger.warning(f"Knowledge graph {output_filename} already exists. Returning existing graph.")
return existing_kg
# Load all input knowledge graphs
knowledge_graphs = []
for filename in input_filenames:
kg = get_knowledge_graph(session, filename)
if not kg:
logger.warning(f"Knowledge graph {filename} not found. Skipping.")
continue
knowledge_graphs.append(kg)
if not knowledge_graphs:
logger.error("No valid knowledge graphs to merge.")
return None
# Create a new merged knowledge graph
merged_data = {
"entities": [],
"relations": [],
"metadata": {
"source_graphs": input_filenames,
"creation_time": datetime.datetime.utcnow().isoformat(),
"merge_method": "concatenate"
}
}
# Keep track of entity and relation IDs to avoid duplicates
entity_ids = set()
relation_ids = set()
# Add entities and relations from each graph
for kg in knowledge_graphs:
graph_data = kg.graph_data
if not graph_data:
logger.warning(f"Knowledge graph {kg.filename} has no data. Skipping.")
continue
# Process entities
for entity in graph_data.get("entities", []):
# Skip if no ID
if "id" not in entity:
continue
# Skip if ID already exists
if entity["id"] in entity_ids:
continue
# Add to merged data
merged_data["entities"].append(entity)
entity_ids.add(entity["id"])
# Process relations
for relation in graph_data.get("relations", []):
# Skip if no ID, source, or target
if "id" not in relation or "source" not in relation or "target" not in relation:
continue
# Skip if ID already exists
if relation["id"] in relation_ids:
continue
# Add to merged data
merged_data["relations"].append(relation)
relation_ids.add(relation["id"])
# Save the merged knowledge graph
return save_knowledge_graph(session, output_filename, merged_data)
def get_knowledge_graph_by_id(session, graph_id):
"""
Get a knowledge graph by its ID or filename
Args:
session: Database session
graph_id: Either an integer ID or a string filename
Returns:
KnowledgeGraph object or None if not found
"""
try:
logger.info(f"Looking up knowledge graph: {graph_id} (type: {type(graph_id)})")
# Special handling for "latest"
if isinstance(graph_id, str) and graph_id.lower() == "latest":
logger.info("Handling 'latest' special case")
kg = session.query(models.KnowledgeGraph).order_by(models.KnowledgeGraph.id.desc()).first()
if kg:
logger.info(f"Found latest knowledge graph with ID {kg.id} and filename {kg.filename}")
return kg
logger.warning("No knowledge graphs found in database")
return None
# Try as integer ID first
if isinstance(graph_id, int) or (isinstance(graph_id, str) and graph_id.isdigit()):
kg_id = int(graph_id)
logger.info(f"Looking up knowledge graph by ID: {kg_id}")
kg = session.query(models.KnowledgeGraph).filter(models.KnowledgeGraph.id == kg_id).first()
if kg:
logger.info(f"Found knowledge graph by ID {kg_id}: {kg.filename}")
return kg
logger.warning(f"Knowledge graph with ID {kg_id} not found")
# If not found by ID or not an integer, try as filename
if isinstance(graph_id, str):
logger.info(f"Looking up knowledge graph by filename: {graph_id}")
kg = session.query(models.KnowledgeGraph).filter(models.KnowledgeGraph.filename == graph_id).first()
if kg:
logger.info(f"Found knowledge graph by filename {graph_id}: ID {kg.id}")
return kg
logger.warning(f"Knowledge graph with filename {graph_id} not found")
logger.error(f"Knowledge graph not found: {graph_id}")
return None
except Exception as e:
logger.error(f"Error retrieving knowledge graph by ID: {str(e)}")
return None
def update_knowledge_graph(session: Session, filename: str, graph_data: dict) -> models.KnowledgeGraph:
"""
Update an existing knowledge graph with new data.
Args:
session: Database session
filename: Filename of the knowledge graph to update
graph_data: New graph data
Returns:
Updated KnowledgeGraph instance
"""
# Get the knowledge graph
kg = get_knowledge_graph(session, filename)
if not kg:
# Create a new knowledge graph if it doesn't exist
logger.info(f"Knowledge graph {filename} not found. Creating a new one.")
return save_knowledge_graph(session, filename, graph_data)
# Update the knowledge graph data
kg.graph_data = graph_data
# Update entity and relation counts
if isinstance(graph_data, dict):
if 'entities' in graph_data and isinstance(graph_data['entities'], list):
kg.entity_count = len(graph_data['entities'])
if 'relations' in graph_data and isinstance(graph_data['relations'], list):
kg.relation_count = len(graph_data['relations'])
# Update last modified timestamp
kg.update_timestamp = datetime.utcnow()
# Save to database
session.add(kg)
session.commit()
logger.info(f"Updated knowledge graph {filename}")
return kg
def delete_knowledge_graph(session: Session, identifier: Union[int, str]) -> bool:
"""
Delete a knowledge graph and all its associated entities and relations.
Args:
session: Database session
identifier: Knowledge graph ID or filename
Returns:
True if deletion was successful, False otherwise
"""
try:
# Find the knowledge graph
if isinstance(identifier, str):
# Identifier is a filename
kg = session.query(models.KnowledgeGraph).filter_by(filename=identifier).first()
else:
# Identifier is an ID
kg = session.query(models.KnowledgeGraph).filter_by(id=identifier).first()
if not kg:
logger.warning(f"Knowledge graph with identifier {identifier} not found")
return False
kg_id = kg.id
filename = kg.filename
# Count associated entities and relations for logging
entity_count = session.query(models.Entity).filter_by(graph_id=kg_id).count()
relation_count = session.query(models.Relation).filter_by(graph_id=kg_id).count()
# Begin transaction
logger.info(f"Deleting knowledge graph {filename} (ID: {kg_id}) with {entity_count} entities and {relation_count} relations")
# Due to the CASCADE setting in the relationships, deleting the knowledge graph
# will automatically delete all associated entities and relations.
# However, we'll delete them explicitly for clarity and to ensure proper cleanup.
# Delete relations first (due to foreign key constraints)
session.query(models.Relation).filter_by(graph_id=kg_id).delete()
# Delete entities
session.query(models.Entity).filter_by(graph_id=kg_id).delete()
# Delete the knowledge graph
session.delete(kg)
# Commit transaction
session.commit()
logger.info(f"Successfully deleted knowledge graph {filename} (ID: {kg_id}) and its associated data")
return True
except Exception as e:
# Rollback on error
session.rollback()
logger.error(f"Error deleting knowledge graph: {str(e)}")
return False
def get_trace(session: Session, trace_id: str) -> Optional[models.Trace]:
"""
Get a trace by its ID or filename.
Args:
session: Database session
trace_id: Either a UUID trace_id or a filename
Returns:
Trace object or None if not found
"""
# Try as UUID trace_id first
trace = session.query(models.Trace).filter_by(trace_id=trace_id).first()
if trace:
return trace
# If not found, try as filename
trace = session.query(models.Trace).filter_by(filename=trace_id).first()
if trace:
return trace
# If not found, try as ID
try:
id_value = int(trace_id)
trace = session.query(models.Trace).filter_by(id=id_value).first()
if trace:
return trace
except (ValueError, TypeError):
pass
return None
def save_trace(
session: Session,
content: str,
filename: Optional[str] = None,
title: Optional[str] = None,
description: Optional[str] = None,
trace_type: Optional[str] = None,
trace_source: str = "user_upload",
uploader: Optional[str] = None,
tags: Optional[List[str]] = None,
trace_metadata: Optional[Dict[str, Any]] = None
) -> models.Trace:
"""
Save a trace to the database.
Args:
session: Database session
content: The content of the trace
filename: Optional filename
title: Optional title
description: Optional description
trace_type: Optional type of trace
trace_source: Source of the trace (default: "user_upload")
uploader: Optional name of the uploader
tags: Optional list of tags
trace_metadata: Optional additional metadata
Returns:
The created or updated Trace object
"""
# Generate content hash for deduplication
content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest()
# Check if trace already exists with this content hash
existing_trace = session.query(models.Trace).filter_by(content_hash=content_hash).first()
if existing_trace:
logger.info(f"Trace with matching content hash already exists (ID: {existing_trace.id})")
# Update fields if provided
if filename:
existing_trace.filename = filename
if title:
existing_trace.title = title
if description:
existing_trace.description = description
if trace_type:
existing_trace.trace_type = trace_type
if uploader:
existing_trace.uploader = uploader
if tags:
existing_trace.tags = tags
if trace_metadata:
# Merge metadata rather than replace
if existing_trace.trace_metadata:
existing_trace.trace_metadata.update(trace_metadata)
else:
existing_trace.trace_metadata = trace_metadata
# Update timestamp
existing_trace.update_timestamp = datetime.utcnow()
session.add(existing_trace)
session.commit()
return existing_trace
# Create new trace
trace = models.Trace.from_content(
content=content,
filename=filename,
title=title,
description=description,
trace_type=trace_type,
trace_source=trace_source,
uploader=uploader,
tags=tags,
trace_metadata=trace_metadata
)
session.add(trace)
session.commit()
logger.info(f"New trace saved to database (ID: {trace.id}, trace_id: {trace.trace_id})")
return trace
def get_all_traces(session: Session) -> List[models.Trace]:
"""
Get all traces from the database.
Args:
session: Database session
Returns:
List of Trace objects
"""
return session.query(models.Trace).order_by(models.Trace.upload_timestamp.desc()).all()
def get_traces_by_status(session: Session, status: str) -> List[models.Trace]:
"""
Get traces by status.
Args:
session: Database session
status: Status to filter by
Returns:
List of Trace objects with the specified status
"""
return session.query(models.Trace).filter_by(status=status).order_by(models.Trace.upload_timestamp.desc()).all()
def update_trace_status(session: Session, trace_id: str, status: str) -> models.Trace:
"""
Update the status of a trace.
Args:
session: Database session
trace_id: ID of the trace to update
status: New status
Returns:
Updated Trace object
"""
trace = get_trace(session, trace_id)
if not trace:
raise ValueError(f"Trace with ID {trace_id} not found")
trace.status = status
trace.update_timestamp = datetime.utcnow()
session.add(trace)
session.commit()
return trace
def update_trace_content(session: Session, trace_id: str, content: str) -> models.Trace:
"""
Update the content of a trace.
Args:
session: Database session
trace_id: ID of the trace to update
content: New content value
Returns:
Updated Trace object
"""
trace = get_trace(session, trace_id)
if not trace:
raise ValueError(f"Trace with ID {trace_id} not found")
trace.content = content
trace.character_count = len(content)
# Recalculate turn count if needed
trace.turn_count = len([line for line in content.split('\n') if line.strip()])
trace.update_timestamp = datetime.utcnow()
session.add(trace)
session.commit()
return trace
def link_knowledge_graph_to_trace(
session: Session,
kg_id: Union[int, str],
trace_id: str,
window_index: Optional[int] = None,
window_total: Optional[int] = None,
window_start_char: Optional[int] = None,
window_end_char: Optional[int] = None
) -> models.KnowledgeGraph:
"""
Link a knowledge graph to a trace.
Args:
session: Database session
kg_id: ID or filename of the knowledge graph
trace_id: ID of the trace
window_index: Optional index of the window within the trace
window_total: Optional total number of windows
window_start_char: Optional start position in the trace
window_end_char: Optional end position in the trace
Returns:
Updated KnowledgeGraph object
"""
# Get the knowledge graph
kg = get_knowledge_graph_by_id(session, kg_id)
if not kg:
raise ValueError(f"Knowledge graph with ID {kg_id} not found")
# Get the trace
trace = get_trace(session, trace_id)
if not trace:
raise ValueError(f"Trace with ID {trace_id} not found")
# Update knowledge graph with trace information
kg.trace_id = trace.trace_id
if window_index is not None:
kg.window_index = window_index
if window_total is not None:
kg.window_total = window_total
if window_start_char is not None:
kg.window_start_char = window_start_char
if window_end_char is not None:
kg.window_end_char = window_end_char
# Update graph metadata to include trace info
graph_data = kg.graph_data or {}
if "metadata" not in graph_data:
graph_data["metadata"] = {}
graph_data["metadata"]["trace_info"] = {
"trace_id": trace.trace_id,
"window_index": window_index,
"window_total": window_total,
"linked_at": datetime.utcnow().isoformat()
}
kg.graph_data = graph_data
session.add(kg)
session.commit()
return kg
def get_knowledge_graphs_for_trace(session: Session, trace_id: str) -> List[models.KnowledgeGraph]:
"""
Get all knowledge graphs associated with a trace.
Args:
session: Database session
trace_id: ID of the trace
Returns:
List of KnowledgeGraph objects linked to the trace
"""
trace = get_trace(session, trace_id)
if not trace:
raise ValueError(f"Trace with ID {trace_id} not found")
return session.query(models.KnowledgeGraph).filter_by(trace_id=trace.trace_id).order_by(
models.KnowledgeGraph.window_index
).all()
def check_knowledge_graph_exists(session: Session, trace_id: str, is_original: bool = None) -> Optional[models.KnowledgeGraph]:
"""
Check if a knowledge graph exists for a trace with specific criteria.
Args:
session: Database session
trace_id: ID of the trace
is_original: If True, only return knowledge graphs with status='created'
If False, only return knowledge graphs with other statuses
Returns:
KnowledgeGraph object if found, None otherwise
"""
query = session.query(models.KnowledgeGraph).filter_by(trace_id=trace_id)
if is_original is True:
# Original KGs have status 'created'
query = query.filter_by(status='created')
elif is_original is False:
# Non-original KGs have other statuses
query = query.filter(models.KnowledgeGraph.status != 'created')
return query.first()
def delete_trace(session: Session, trace_id: str, delete_related_kgs: bool = False) -> bool:
"""
Delete a trace from the database.
Args:
session: Database session
trace_id: ID of the trace to delete
delete_related_kgs: Whether to also delete related knowledge graphs
Returns:
True if successful, False otherwise
"""
trace = get_trace(session, trace_id)
if not trace:
return False
# If requested, delete related knowledge graphs
if delete_related_kgs:
for kg in trace.knowledge_graphs:
session.delete(kg)
else:
# Otherwise, just unlink knowledge graphs from this trace
for kg in trace.knowledge_graphs:
kg.trace_id = None
session.add(kg)
# Delete the trace
session.delete(trace)
session.commit()
return True
def get_prompt_reconstructions_for_kg(session, kg_identifier):
"""
Fetch all prompt reconstructions for a given knowledge graph (by ID or filename).
Returns a dict mapping relation_id to reconstructed_prompt.
"""
from backend.database.models import KnowledgeGraph, PromptReconstruction
if isinstance(kg_identifier, int):
kg = session.query(KnowledgeGraph).filter_by(id=kg_identifier).first()
else:
kg = session.query(KnowledgeGraph).filter_by(filename=kg_identifier).first()
if not kg:
return {}
prompt_reconstructions = session.query(PromptReconstruction).filter_by(knowledge_graph_id=kg.id).all()
return {pr.relation_id: pr.reconstructed_prompt for pr in prompt_reconstructions}
def get_prompt_reconstruction_for_relation(session, kg_identifier, relation_id):
"""
Fetch a single reconstructed prompt for a given knowledge graph and relation_id.
Returns the reconstructed_prompt string or None.
"""
from backend.database.models import KnowledgeGraph, PromptReconstruction
if isinstance(kg_identifier, int):
kg = session.query(KnowledgeGraph).filter_by(id=kg_identifier).first()
else:
kg = session.query(KnowledgeGraph).filter_by(filename=kg_identifier).first()
if not kg:
return None
pr = session.query(PromptReconstruction).filter_by(knowledge_graph_id=kg.id, relation_id=relation_id).first()
return pr.reconstructed_prompt if pr else None
def save_causal_analysis(
session: Session,
knowledge_graph_id: int,
perturbation_set_id: str,
analysis_method: str,
analysis_result: dict = None,
causal_score: float = None,
analysis_metadata: dict = None
):
"""Save a causal analysis result to the database."""
from backend.database import models
causal_analysis = models.CausalAnalysis(
knowledge_graph_id=knowledge_graph_id,
perturbation_set_id=perturbation_set_id,
analysis_method=analysis_method,
analysis_result=analysis_result,
causal_score=causal_score,
analysis_metadata=analysis_metadata
)
session.add(causal_analysis)
session.commit()
session.refresh(causal_analysis)
return causal_analysis
def get_causal_analysis(
session: Session,
knowledge_graph_id: int,
perturbation_set_id: str,
analysis_method: str
) -> Optional[models.CausalAnalysis]:
"""
Get causal analysis results from the database.
Args:
session: Database session
knowledge_graph_id: ID of the knowledge graph
perturbation_set_id: ID of the perturbation set
analysis_method: Method used for analysis
Returns:
CausalAnalysis object or None if not found
"""
return session.query(models.CausalAnalysis).filter_by(
knowledge_graph_id=knowledge_graph_id,
perturbation_set_id=perturbation_set_id,
analysis_method=analysis_method
).first()
def get_all_causal_analyses(
session: Session,
knowledge_graph_id: Optional[int] = None,
perturbation_set_id: Optional[str] = None,
analysis_method: Optional[str] = None
) -> List[models.CausalAnalysis]:
"""
Get all causal analysis results from the database with optional filters.
Args:
session: Database session
knowledge_graph_id: Optional filter by knowledge graph ID
perturbation_set_id: Optional filter by perturbation set ID
analysis_method: Optional filter by analysis method
Returns:
List of CausalAnalysis objects
"""
query = session.query(models.CausalAnalysis)
if knowledge_graph_id is not None:
query = query.filter_by(knowledge_graph_id=knowledge_graph_id)
if perturbation_set_id is not None:
query = query.filter_by(perturbation_set_id=perturbation_set_id)
if analysis_method is not None:
query = query.filter_by(analysis_method=analysis_method)
return query.all()
def get_causal_analysis_for_perturbation(session: Session, perturbation_set_id: str) -> List[Dict[str, Any]]:
"""
Get all causal analysis results for a specific perturbation set.
Args:
session: Database session
perturbation_set_id: ID of the perturbation set
Returns:
List of causal analysis results with their associated data
"""
from backend.database.models import CausalAnalysis, KnowledgeGraph, PromptReconstruction
results = session.query(
CausalAnalysis,
KnowledgeGraph,
PromptReconstruction
).join(
KnowledgeGraph,
CausalAnalysis.knowledge_graph_id == KnowledgeGraph.id
).join(
PromptReconstruction,
CausalAnalysis.prompt_reconstruction_id == PromptReconstruction.id
).filter(
CausalAnalysis.perturbation_set_id == perturbation_set_id
).all()
return [{
'analysis': analysis.to_dict(),
'knowledge_graph': kg.to_dict(),
'prompt_reconstruction': {
'id': pr.id,
'relation_id': pr.relation_id,
'reconstructed_prompt': pr.reconstructed_prompt,
'dependencies': pr.dependencies
}
} for analysis, kg, pr in results]
def get_causal_analysis_by_method(session: Session, knowledge_graph_id: int, method: str) -> List[Dict[str, Any]]:
"""
Get causal analysis results for a specific knowledge graph and analysis method.
Args:
session: Database session
knowledge_graph_id: ID of the knowledge graph
method: Analysis method (e.g., 'graph', 'component', 'dowhy')
Returns:
List of causal analysis results
"""
from backend.database.models import CausalAnalysis, PerturbationTest
results = session.query(
CausalAnalysis,
PerturbationTest
).join(
PerturbationTest,
CausalAnalysis.perturbation_test_id == PerturbationTest.id
).filter(
CausalAnalysis.knowledge_graph_id == knowledge_graph_id,
CausalAnalysis.analysis_method == method
).all()
return [{
'analysis': analysis.to_dict(),
'perturbation_test': {
'id': pt.id,
'perturbation_type': pt.perturbation_type,
'test_result': pt.test_result,
'perturbation_score': pt.perturbation_score
}
} for analysis, pt in results]
def get_causal_analysis_summary(session: Session, knowledge_graph_id: int) -> Dict[str, Any]:
"""
Get a summary of causal analysis results for a knowledge graph.
Args:
session: Database session
knowledge_graph_id: ID of the knowledge graph
Returns:
Dictionary containing summary statistics and results by method
"""
from backend.database.models import CausalAnalysis
from sqlalchemy import func
# Get all analyses for this knowledge graph
analyses = session.query(CausalAnalysis).filter_by(
knowledge_graph_id=knowledge_graph_id
).all()
if not analyses:
return {
'total_analyses': 0,
'methods': {},
'average_scores': {}
}
# Group by method
method_results = {}
for analysis in analyses:
method = analysis.analysis_method
if method not in method_results:
method_results[method] = []
method_results[method].append(analysis)
# Calculate statistics
summary = {
'total_analyses': len(analyses),
'methods': {},
'average_scores': {}
}
for method, results in method_results.items():
scores = [r.causal_score for r in results if r.causal_score is not None]
summary['methods'][method] = {
'count': len(results),
'average_score': sum(scores) / len(scores) if scores else None,
'min_score': min(scores) if scores else None,
'max_score': max(scores) if scores else None
}
return summary
# Add these functions to handle perturbation tests
def save_perturbation_test(session,
knowledge_graph_id: int,
prompt_reconstruction_id: int,
relation_id: str,
perturbation_type: str,
perturbation_set_id: str,
test_result: dict = None,
perturbation_score: float = None,
test_metadata: dict = None) -> int:
"""
Save a perturbation test to the database.
Args:
session: Database session
knowledge_graph_id: ID of the knowledge graph
prompt_reconstruction_id: ID of the prompt reconstruction
relation_id: ID of the relation
perturbation_type: Type of perturbation
perturbation_set_id: ID of the perturbation set
test_result: Test result dictionary
perturbation_score: Perturbation score
test_metadata: Test metadata dictionary
Returns:
int: ID of the saved perturbation test
"""
from backend.database.models import PerturbationTest
# Create new perturbation test
test = PerturbationTest(
knowledge_graph_id=knowledge_graph_id,
prompt_reconstruction_id=prompt_reconstruction_id,
relation_id=relation_id,
perturbation_type=perturbation_type,
perturbation_set_id=perturbation_set_id,
test_result=test_result or {},
perturbation_score=perturbation_score,
test_metadata=test_metadata or {}
)
# Add to session and commit
session.add(test)
session.commit()
return test.id
def delete_perturbation_test(session, test_id: int) -> bool:
"""
Delete a perturbation test from the database.
Args:
session: Database session
test_id: ID of the perturbation test to delete
Returns:
bool: True if successful, False otherwise
"""
from backend.database.models import PerturbationTest
# Query the test
test = session.query(PerturbationTest).filter_by(id=test_id).first()
if test:
# Delete and commit
session.delete(test)
session.commit()
return True
return False
def delete_perturbation_tests_by_set(session, perturbation_set_id: str) -> int:
"""
Delete all perturbation tests in a set.
Args:
session: Database session
perturbation_set_id: ID of the perturbation set
Returns:
int: Number of tests deleted
"""
from backend.database.models import PerturbationTest
# Query all tests in the set
tests = session.query(PerturbationTest).filter_by(perturbation_set_id=perturbation_set_id).all()
# Delete all tests
deleted_count = 0
for test in tests:
session.delete(test)
deleted_count += 1
# Commit changes
session.commit()
return deleted_count
def get_context_document_stats(session: Session, trace_id: str) -> Dict[str, Any]:
"""
Get statistics about context documents for a trace.
Args:
session: Database session
trace_id: Trace ID to get stats for
Returns:
Dictionary with context document statistics
"""
trace = get_trace(session, trace_id)
if not trace or not trace.trace_metadata or "context_documents" not in trace.trace_metadata:
return {"total_count": 0, "active_count": 0, "by_type": {}}
docs = trace.trace_metadata["context_documents"]
# Count by type
by_type = {}
active_count = 0
for doc in docs:
doc_type = doc.get("document_type", "unknown")
if doc_type not in by_type:
by_type[doc_type] = 0
by_type[doc_type] += 1
if doc.get("is_active", True):
active_count += 1
return {
"total_count": len(docs),
"active_count": active_count,
"by_type": by_type
}
def get_context_documents_from_trace(session: Session, trace_id: str) -> List[Dict[str, Any]]:
"""
Get context documents from a trace's metadata.
Args:
session: Database session
trace_id: ID of the trace
Returns:
List of context documents, or empty list if none found
"""
trace = get_trace(session, trace_id)
if not trace or not trace.trace_metadata or "context_documents" not in trace.trace_metadata:
return []
docs = trace.trace_metadata["context_documents"]
# Filter to only active documents
active_docs = [doc for doc in docs if doc.get("is_active", True)]
return active_docs
def get_temporal_windows_by_trace_id(session: Session, trace_id: str, processing_run_id: Optional[str] = None) -> Dict[str, Any]:
"""
Get all knowledge graph windows for a specific trace, ordered by window_index.
Also returns the full/merged version if available.
Used for temporal force-directed graph visualization.
This function handles cases where KGs exist with trace_id but no trace record exists.
Args:
session: Database session
trace_id: ID of the trace
processing_run_id: Optional ID to filter by specific processing run
Returns:
Dict containing windowed KGs and full KG information
"""
logger.info(f"Looking up temporal windows for trace_id: {trace_id}")
if processing_run_id:
logger.info(f"Filtering by processing_run_id: {processing_run_id}")
# First check if we can find the trace
trace = get_trace(session, trace_id)
# Get all knowledge graphs for this trace_id (even if trace record doesn't exist)
query = session.query(models.KnowledgeGraph).filter(models.KnowledgeGraph.trace_id == trace_id)
# Filter by processing_run_id if provided
if processing_run_id:
query = query.filter(models.KnowledgeGraph.processing_run_id == processing_run_id)
all_kgs = query.all()
logger.info(f"Found {len(all_kgs)} total knowledge graphs for trace_id {trace_id}")
if processing_run_id:
logger.info(f"(filtered by processing_run_id: {processing_run_id})")
if not all_kgs:
logger.warning(f"No knowledge graphs found for trace_id: {trace_id}")
return {"windows": [], "full_kg": None, "trace_info": None}
# If no trace record exists, create minimal trace info from KGs
if not trace:
logger.warning(f"No trace record found for trace_id {trace_id}, but {len(all_kgs)} KGs exist")
trace_info = {
"trace_id": trace_id,
"title": f"Trace {trace_id[:8]}...",
"description": "Knowledge graphs exist but no trace record found",
"upload_timestamp": min([kg.creation_timestamp for kg in all_kgs if kg.creation_timestamp])
}
else:
logger.info(f"Found trace record {trace.trace_id}")
trace_info = {
"trace_id": trace.trace_id,
"title": trace.title,
"description": trace.description,
"upload_timestamp": trace.upload_timestamp
}
# Separate windowed KGs from the full/merged KG
windowed_kgs = []
full_kg = None
for kg in all_kgs:
# Full/merged KG: has window_total but no window_index, null start/end chars
if (kg.window_total is not None and
kg.window_index is None and
kg.window_start_char is None and
kg.window_end_char is None):
full_kg = kg
logger.info(f"Found full/merged KG: {kg.filename} with window_total={kg.window_total}")
# Windowed KG: has window_index
elif kg.window_index is not None:
windowed_kgs.append(kg)
# KG without proper window info - try to assign window_index
else:
logger.info(f"Found KG without proper window info: {kg.filename}")
# If we don't have a full KG yet and this looks like it could be one
if (full_kg is None and
kg.window_total is None and
kg.window_start_char is None and
kg.window_end_char is None):
# Check if this KG has significantly more entities than others (indicating it's merged)
if kg.graph_data and len(kg.graph_data.get("entities", [])) > 10:
kg.window_total = len(windowed_kgs) + 1 # Set based on current windowed KGs
full_kg = kg
session.add(kg)
logger.info(f"Assigned {kg.filename} as full KG with window_total={kg.window_total}")
# If we have windowed KGs but some are missing window_index, assign them
kgs_without_index = [kg for kg in all_kgs if kg.window_index is None and kg != full_kg]
if kgs_without_index and windowed_kgs:
logger.info("Assigning window_index to knowledge graphs based on creation order")
# Sort by creation timestamp and assign window_index starting from the highest existing + 1
max_window_index = max([kg.window_index for kg in windowed_kgs], default=-1)
kgs_without_index.sort(key=lambda kg: kg.creation_timestamp or datetime.utcnow())
for i, kg in enumerate(kgs_without_index):
kg.window_index = max_window_index + 1 + i
session.add(kg)
windowed_kgs.append(kg)
session.commit()
logger.info(f"Assigned window_index to {len(kgs_without_index)} knowledge graphs")
# Sort windowed KGs by window_index
windowed_kgs.sort(key=lambda kg: kg.window_index)
logger.info(f"Found {len(windowed_kgs)} windowed KGs and {'1' if full_kg else '0'} full KG for trace_id {trace_id}")
# Update entity_count and relation_count if they're 0 or None but graph_data has content
updated_count = 0
for kg in windowed_kgs + ([full_kg] if full_kg else []):
if kg and kg.graph_data:
needs_update = False
if kg.entity_count is None or kg.entity_count == 0:
entities = kg.graph_data.get("entities", [])
if entities:
kg.entity_count = len(entities)
needs_update = True
if kg.relation_count is None or kg.relation_count == 0:
relations = kg.graph_data.get("relations", [])
if relations:
kg.relation_count = len(relations)
needs_update = True
if needs_update:
session.add(kg)
updated_count += 1
if updated_count > 0:
session.commit()
logger.info(f"Updated entity/relation counts for {updated_count} knowledge graphs")
return {
"windows": windowed_kgs,
"full_kg": full_kg,
"trace_info": trace_info
}