""" Utility functions for database operations. """ import os import json import logging import uuid from datetime import datetime from typing import Dict, List, Any, Optional, Union import hashlib from sqlalchemy.orm import Session from sqlalchemy import func from . import models from . import get_db, init_db logger = logging.getLogger(__name__) def initialize_database(clear_all=False): """ Initialize the database and create tables. Args: clear_all: If True, drops all existing tables before creating new ones """ if clear_all: from . import reinit_db reinit_db() logger.info("Database reinitialized (all previous data cleared)") else: from . import init_db init_db() logger.info("Database initialized (existing data preserved)") def get_knowledge_graph(session: Session, filename: str) -> Optional[models.KnowledgeGraph]: """Get a knowledge graph by filename.""" return session.query(models.KnowledgeGraph).filter_by(filename=filename).first() def save_knowledge_graph( session, filename, graph_data, trace_id=None, window_index=None, window_total=None, window_start_char=None, window_end_char=None, is_original=False, processing_run_id=None ): """ Save a knowledge graph to the database. Args: session: Database session filename: Filename to save under graph_data: Knowledge graph data trace_id: Optional ID to group knowledge graphs from the same trace window_index: Optional sequential index of window within a trace window_total: Optional total number of windows in the trace window_start_char: Optional starting character position in the original trace window_end_char: Optional ending character position in the original trace is_original: Whether this is an original knowledge graph (sets status to "created") processing_run_id: Optional ID to distinguish multiple processing runs Returns: The created KnowledgeGraph object """ from backend.database.models import KnowledgeGraph # Check if the knowledge graph already exists kg = session.query(KnowledgeGraph).filter(KnowledgeGraph.filename == filename).first() if kg: # Update the existing knowledge graph using graph_content to ensure counts are updated kg.graph_content = graph_data kg.update_timestamp = datetime.utcnow() # Update trace information if provided if trace_id is not None: kg.trace_id = trace_id if window_index is not None: kg.window_index = window_index if window_total is not None: kg.window_total = window_total if window_start_char is not None: kg.window_start_char = window_start_char if window_end_char is not None: kg.window_end_char = window_end_char if processing_run_id is not None: kg.processing_run_id = processing_run_id # Set status if is_original is True if is_original: kg.status = "created" session.add(kg) session.commit() return kg else: # Create a new knowledge graph kg = KnowledgeGraph( filename=filename, trace_id=trace_id, window_index=window_index, window_total=window_total, window_start_char=window_start_char, window_end_char=window_end_char, status="created" if is_original else None, processing_run_id=processing_run_id ) # Set graph content after creation to ensure counts are updated kg.graph_content = graph_data session.add(kg) session.commit() return kg def update_knowledge_graph_status(session: Session, kg_id: Union[int, str], status: str) -> models.KnowledgeGraph: """ Update the status of a knowledge graph. Args: session: Database session kg_id: Knowledge graph ID or filename status: New status (created, enriched, perturbed, causal) Returns: Updated knowledge graph """ # Check if kg_id is a filename or an ID if isinstance(kg_id, str): kg = session.query(models.KnowledgeGraph).filter_by(filename=kg_id).first() else: kg = session.query(models.KnowledgeGraph).filter_by(id=kg_id).first() if not kg: raise ValueError(f"Knowledge graph with ID/filename {kg_id} not found") # Update status kg.status = status session.commit() return kg def extract_entities_and_relations(session: Session, kg: models.KnowledgeGraph): """Extract entities and relations from a knowledge graph and save them to the database.""" # Get the graph data data = kg.graph_data # Skip if no data if not data: return # First, delete existing relations and entities for this knowledge graph # We need to delete relations first due to foreign key constraints session.query(models.Relation).filter_by(graph_id=kg.id).delete() session.query(models.Entity).filter_by(graph_id=kg.id).delete() session.flush() # Process entities entity_map = {} # Map entity_id to Entity instance for entity_data in data.get('entities', []): try: # Skip if no id if 'id' not in entity_data: continue entity_id = entity_data.get('id') # Create entity entity = models.Entity.from_dict(entity_data, kg.id) # Add to session session.add(entity) session.flush() # Flush to get the ID # Add to map entity_map[entity_id] = entity except Exception as e: logger.error(f"Error extracting entity {entity_data.get('id')}: {str(e)}") # Process relations for relation_data in data.get('relations', []): try: # Skip if no id, source, or target if 'id' not in relation_data or 'source' not in relation_data or 'target' not in relation_data: continue source_id = relation_data.get('source') target_id = relation_data.get('target') # Get source and target entities source_entity = entity_map.get(source_id) target_entity = entity_map.get(target_id) # Skip if source or target entity not found if not source_entity or not target_entity: logger.warning(f"Skipping relation {relation_data.get('id')}: Source or target entity not found") continue # Create relation relation = models.Relation.from_dict( relation_data, kg.id, source_entity, target_entity ) # Add to session session.add(relation) except Exception as e: logger.error(f"Error extracting relation {relation_data.get('id')}: {str(e)}") # Commit the changes session.commit() def get_test_result(session: Session, filename: str) -> Optional[Dict[str, Any]]: """ Get a test result by filename from the knowledge graph. This now returns a dictionary with test result data instead of a TestResult model. """ # Try to find a knowledge graph with this test result filename kg = session.query(models.KnowledgeGraph).filter_by(filename=filename).first() if kg and kg.content: try: data = json.loads(kg.content) if 'test_result' in data: return data['test_result'] except json.JSONDecodeError: pass # Try standard file locations standard_file_locations = [ f"datasets/test_results/{filename}", f"datasets/{filename}" ] for file_path in standard_file_locations: try: with open(file_path, 'r') as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError): continue return None def save_test_result(session: Session, filename: str, data: Dict[str, Any]) -> models.KnowledgeGraph: """ Save a test result to the database. Test results are now stored within the KnowledgeGraph content field rather than as a separate TestResult model. """ # Find or create a knowledge graph for this test result kg = session.query(models.KnowledgeGraph).filter_by(filename=filename).first() if not kg: # Create new knowledge graph for this test result kg = models.KnowledgeGraph() kg.filename = filename kg.creation_timestamp = datetime.utcnow() # Get existing content or initialize empty dict try: if kg.content: content = json.loads(kg.content) else: content = {} except json.JSONDecodeError: content = {} # Update test result data content['test_result'] = data content['test_timestamp'] = datetime.utcnow().isoformat() content['model_name'] = data.get('model', '') content['perturbation_type'] = data.get('perturbation_type', '') content['completed'] = data.get('completed', False) # Find the related knowledge graph if referenced kg_filename = data.get('knowledge_graph_filename') if kg_filename: related_kg = session.query(models.KnowledgeGraph).filter_by(filename=kg_filename).first() if related_kg: content['knowledge_graph_id'] = related_kg.id # Save updated content back to knowledge graph kg.content = json.dumps(content) # Save to database session.add(kg) session.commit() return kg def get_test_progress(session: Session, test_filename: str) -> Optional[Dict[str, Any]]: """ Get test progress by test filename. Now returns progress data as a dictionary instead of a TestProgress model. """ kg = session.query(models.KnowledgeGraph).filter_by(filename=test_filename).first() if kg and kg.content: try: content = json.loads(kg.content) if 'test_progress' in content: return content['test_progress'] except json.JSONDecodeError: pass # Try to find a progress file progress_filename = f"progress_{test_filename}" progress_path = str(PROJECT_ROOT / 'datasets' / 'test_results' / progress_filename) if os.path.exists(progress_path): try: with open(progress_path, 'r') as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError): pass return None def save_test_progress(session: Session, test_filename: str, data: Dict[str, Any]) -> models.KnowledgeGraph: """ Save test progress to the database. Test progress is now stored within the KnowledgeGraph content field rather than as a separate TestProgress model. """ # Find the knowledge graph for this test kg = session.query(models.KnowledgeGraph).filter_by(filename=test_filename).first() if not kg: # Create new knowledge graph for this test kg = models.KnowledgeGraph() kg.filename = test_filename kg.creation_timestamp = datetime.utcnow() # Get existing content or initialize empty dict try: if kg.content: content = json.loads(kg.content) else: content = {} except json.JSONDecodeError: content = {} # Initialize test_progress if it doesn't exist if 'test_progress' not in content: content['test_progress'] = {} # Update progress data if 'progress' in data: progress_data = data['progress'] content['test_progress']['status'] = progress_data.get('status', content['test_progress'].get('status')) content['test_progress']['current'] = progress_data.get('current', content['test_progress'].get('current')) content['test_progress']['total'] = progress_data.get('total', content['test_progress'].get('total')) content['test_progress']['last_tested_relation'] = progress_data.get('last_tested_relation', content['test_progress'].get('last_tested_relation')) content['test_progress']['overall_progress_percentage'] = progress_data.get('overall_progress_percentage', content['test_progress'].get('overall_progress_percentage')) content['test_progress']['current_jailbreak'] = progress_data.get('current_jailbreak', content['test_progress'].get('current_jailbreak')) else: # Direct update of progress data for key, value in data.items(): content['test_progress'][key] = value if 'timestamp' in data: try: content['test_progress']['timestamp'] = data['timestamp'] except (ValueError, TypeError): content['test_progress']['timestamp'] = datetime.utcnow().isoformat() else: content['test_progress']['timestamp'] = datetime.utcnow().isoformat() # Save updated content back to knowledge graph kg.content = json.dumps(content) # Save to database session.add(kg) session.commit() # Also save to progress file for backward compatibility try: progress_filename = f"progress_{test_filename}" progress_dir = 'datasets/test_results' os.makedirs(progress_dir, exist_ok=True) progress_path = os.path.join(progress_dir, progress_filename) with open(progress_path, 'w') as f: json.dump(content['test_progress'], f) except Exception as e: logger.warning(f"Failed to save progress file: {str(e)}") return kg def get_all_knowledge_graphs(session: Session) -> List[models.KnowledgeGraph]: """Get all knowledge graphs.""" return session.query(models.KnowledgeGraph).all() def get_all_test_results(session: Session) -> List[Dict[str, Any]]: """ Get all test results. Now returns a list of dictionaries containing test result data extracted from knowledge graphs. """ test_results = [] # Get all knowledge graphs that may contain test results knowledge_graphs = session.query(models.KnowledgeGraph).all() for kg in knowledge_graphs: if kg.content: try: content = json.loads(kg.content) if 'test_result' in content: # Add filename and ID for reference result = content['test_result'].copy() if isinstance(content['test_result'], dict) else {} result['filename'] = kg.filename result['id'] = kg.id test_results.append(result) except json.JSONDecodeError: continue return test_results def get_standard_dataset(session: Session, filename: str) -> Optional[Dict[str, Any]]: """ Get a standard dataset by filename (e.g., jailbreak techniques). First attempts to load from the database as a knowledge graph, then falls back to standard data file locations. """ # Try to get from database as a knowledge graph kg = session.query(models.KnowledgeGraph).filter_by(filename=filename).first() if kg and kg.content: try: return json.loads(kg.content) except json.JSONDecodeError: pass # If not in database, try standard file locations standard_file_locations = [ f"datasets/{filename}", # Direct in data dir f"datasets/test_results/{filename}", f"datasets/knowledge_graphs/{filename}" ] for file_path in standard_file_locations: try: with open(file_path, 'r') as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError): continue # Finally try as an absolute path try: with open(filename, 'r') as f: return json.load(f) except (FileNotFoundError, json.JSONDecodeError): pass return None def find_entity_by_id(session: Session, entity_id: str) -> Optional[models.Entity]: """ Find an entity by its ID. Args: session: Database session entity_id: Entity ID to search for Returns: Entity or None if not found """ query = session.query(models.Entity).filter_by(entity_id=entity_id) return query.first() def find_relation_by_id(session: Session, relation_id: str) -> Optional[models.Relation]: """ Find a relation by its ID. Args: session: Database session relation_id: Relation ID to search for Returns: Relation or None if not found """ query = session.query(models.Relation).filter_by(relation_id=relation_id) return query.first() def find_entities_by_type(session: Session, entity_type: str) -> List[models.Entity]: """ Find entities by type. Args: session: Database session entity_type: Entity type to search for Returns: List of entities """ query = session.query(models.Entity).filter_by(type=entity_type) return query.all() def find_relations_by_type(session: Session, relation_type: str) -> List[models.Relation]: """ Find relations by type. Args: session: Database session relation_type: Relation type to search for Returns: List of relations """ query = session.query(models.Relation).filter_by(type=relation_type) return query.all() def merge_knowledge_graphs(session: Session, output_filename: str, input_filenames: List[str]) -> Optional[models.KnowledgeGraph]: """ Merge multiple knowledge graphs into a single knowledge graph. Args: session: Database session output_filename: Output filename for the merged knowledge graph input_filenames: List of filenames of knowledge graphs to merge Returns: The merged KnowledgeGraph instance or None if error """ # Check if merged graph already exists existing_kg = get_knowledge_graph(session, output_filename) if existing_kg: logger.warning(f"Knowledge graph {output_filename} already exists. Returning existing graph.") return existing_kg # Load all input knowledge graphs knowledge_graphs = [] for filename in input_filenames: kg = get_knowledge_graph(session, filename) if not kg: logger.warning(f"Knowledge graph {filename} not found. Skipping.") continue knowledge_graphs.append(kg) if not knowledge_graphs: logger.error("No valid knowledge graphs to merge.") return None # Create a new merged knowledge graph merged_data = { "entities": [], "relations": [], "metadata": { "source_graphs": input_filenames, "creation_time": datetime.datetime.utcnow().isoformat(), "merge_method": "concatenate" } } # Keep track of entity and relation IDs to avoid duplicates entity_ids = set() relation_ids = set() # Add entities and relations from each graph for kg in knowledge_graphs: graph_data = kg.graph_data if not graph_data: logger.warning(f"Knowledge graph {kg.filename} has no data. Skipping.") continue # Process entities for entity in graph_data.get("entities", []): # Skip if no ID if "id" not in entity: continue # Skip if ID already exists if entity["id"] in entity_ids: continue # Add to merged data merged_data["entities"].append(entity) entity_ids.add(entity["id"]) # Process relations for relation in graph_data.get("relations", []): # Skip if no ID, source, or target if "id" not in relation or "source" not in relation or "target" not in relation: continue # Skip if ID already exists if relation["id"] in relation_ids: continue # Add to merged data merged_data["relations"].append(relation) relation_ids.add(relation["id"]) # Save the merged knowledge graph return save_knowledge_graph(session, output_filename, merged_data) def get_knowledge_graph_by_id(session, graph_id): """ Get a knowledge graph by its ID or filename Args: session: Database session graph_id: Either an integer ID or a string filename Returns: KnowledgeGraph object or None if not found """ try: logger.info(f"Looking up knowledge graph: {graph_id} (type: {type(graph_id)})") # Special handling for "latest" if isinstance(graph_id, str) and graph_id.lower() == "latest": logger.info("Handling 'latest' special case") kg = session.query(models.KnowledgeGraph).order_by(models.KnowledgeGraph.id.desc()).first() if kg: logger.info(f"Found latest knowledge graph with ID {kg.id} and filename {kg.filename}") return kg logger.warning("No knowledge graphs found in database") return None # Try as integer ID first if isinstance(graph_id, int) or (isinstance(graph_id, str) and graph_id.isdigit()): kg_id = int(graph_id) logger.info(f"Looking up knowledge graph by ID: {kg_id}") kg = session.query(models.KnowledgeGraph).filter(models.KnowledgeGraph.id == kg_id).first() if kg: logger.info(f"Found knowledge graph by ID {kg_id}: {kg.filename}") return kg logger.warning(f"Knowledge graph with ID {kg_id} not found") # If not found by ID or not an integer, try as filename if isinstance(graph_id, str): logger.info(f"Looking up knowledge graph by filename: {graph_id}") kg = session.query(models.KnowledgeGraph).filter(models.KnowledgeGraph.filename == graph_id).first() if kg: logger.info(f"Found knowledge graph by filename {graph_id}: ID {kg.id}") return kg logger.warning(f"Knowledge graph with filename {graph_id} not found") logger.error(f"Knowledge graph not found: {graph_id}") return None except Exception as e: logger.error(f"Error retrieving knowledge graph by ID: {str(e)}") return None def update_knowledge_graph(session: Session, filename: str, graph_data: dict) -> models.KnowledgeGraph: """ Update an existing knowledge graph with new data. Args: session: Database session filename: Filename of the knowledge graph to update graph_data: New graph data Returns: Updated KnowledgeGraph instance """ # Get the knowledge graph kg = get_knowledge_graph(session, filename) if not kg: # Create a new knowledge graph if it doesn't exist logger.info(f"Knowledge graph {filename} not found. Creating a new one.") return save_knowledge_graph(session, filename, graph_data) # Update the knowledge graph data kg.graph_data = graph_data # Update entity and relation counts if isinstance(graph_data, dict): if 'entities' in graph_data and isinstance(graph_data['entities'], list): kg.entity_count = len(graph_data['entities']) if 'relations' in graph_data and isinstance(graph_data['relations'], list): kg.relation_count = len(graph_data['relations']) # Update last modified timestamp kg.update_timestamp = datetime.utcnow() # Save to database session.add(kg) session.commit() logger.info(f"Updated knowledge graph {filename}") return kg def delete_knowledge_graph(session: Session, identifier: Union[int, str]) -> bool: """ Delete a knowledge graph and all its associated entities and relations. Args: session: Database session identifier: Knowledge graph ID or filename Returns: True if deletion was successful, False otherwise """ try: # Find the knowledge graph if isinstance(identifier, str): # Identifier is a filename kg = session.query(models.KnowledgeGraph).filter_by(filename=identifier).first() else: # Identifier is an ID kg = session.query(models.KnowledgeGraph).filter_by(id=identifier).first() if not kg: logger.warning(f"Knowledge graph with identifier {identifier} not found") return False kg_id = kg.id filename = kg.filename # Count associated entities and relations for logging entity_count = session.query(models.Entity).filter_by(graph_id=kg_id).count() relation_count = session.query(models.Relation).filter_by(graph_id=kg_id).count() # Begin transaction logger.info(f"Deleting knowledge graph {filename} (ID: {kg_id}) with {entity_count} entities and {relation_count} relations") # Due to the CASCADE setting in the relationships, deleting the knowledge graph # will automatically delete all associated entities and relations. # However, we'll delete them explicitly for clarity and to ensure proper cleanup. # Delete relations first (due to foreign key constraints) session.query(models.Relation).filter_by(graph_id=kg_id).delete() # Delete entities session.query(models.Entity).filter_by(graph_id=kg_id).delete() # Delete the knowledge graph session.delete(kg) # Commit transaction session.commit() logger.info(f"Successfully deleted knowledge graph {filename} (ID: {kg_id}) and its associated data") return True except Exception as e: # Rollback on error session.rollback() logger.error(f"Error deleting knowledge graph: {str(e)}") return False def get_trace(session: Session, trace_id: str) -> Optional[models.Trace]: """ Get a trace by its ID or filename. Args: session: Database session trace_id: Either a UUID trace_id or a filename Returns: Trace object or None if not found """ # Try as UUID trace_id first trace = session.query(models.Trace).filter_by(trace_id=trace_id).first() if trace: return trace # If not found, try as filename trace = session.query(models.Trace).filter_by(filename=trace_id).first() if trace: return trace # If not found, try as ID try: id_value = int(trace_id) trace = session.query(models.Trace).filter_by(id=id_value).first() if trace: return trace except (ValueError, TypeError): pass return None def save_trace( session: Session, content: str, filename: Optional[str] = None, title: Optional[str] = None, description: Optional[str] = None, trace_type: Optional[str] = None, trace_source: str = "user_upload", uploader: Optional[str] = None, tags: Optional[List[str]] = None, trace_metadata: Optional[Dict[str, Any]] = None ) -> models.Trace: """ Save a trace to the database. Args: session: Database session content: The content of the trace filename: Optional filename title: Optional title description: Optional description trace_type: Optional type of trace trace_source: Source of the trace (default: "user_upload") uploader: Optional name of the uploader tags: Optional list of tags trace_metadata: Optional additional metadata Returns: The created or updated Trace object """ # Generate content hash for deduplication content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest() # Check if trace already exists with this content hash existing_trace = session.query(models.Trace).filter_by(content_hash=content_hash).first() if existing_trace: logger.info(f"Trace with matching content hash already exists (ID: {existing_trace.id})") # Update fields if provided if filename: existing_trace.filename = filename if title: existing_trace.title = title if description: existing_trace.description = description if trace_type: existing_trace.trace_type = trace_type if uploader: existing_trace.uploader = uploader if tags: existing_trace.tags = tags if trace_metadata: # Merge metadata rather than replace if existing_trace.trace_metadata: existing_trace.trace_metadata.update(trace_metadata) else: existing_trace.trace_metadata = trace_metadata # Update timestamp existing_trace.update_timestamp = datetime.utcnow() session.add(existing_trace) session.commit() return existing_trace # Create new trace trace = models.Trace.from_content( content=content, filename=filename, title=title, description=description, trace_type=trace_type, trace_source=trace_source, uploader=uploader, tags=tags, trace_metadata=trace_metadata ) session.add(trace) session.commit() logger.info(f"New trace saved to database (ID: {trace.id}, trace_id: {trace.trace_id})") return trace def get_all_traces(session: Session) -> List[models.Trace]: """ Get all traces from the database. Args: session: Database session Returns: List of Trace objects """ return session.query(models.Trace).order_by(models.Trace.upload_timestamp.desc()).all() def get_traces_by_status(session: Session, status: str) -> List[models.Trace]: """ Get traces by status. Args: session: Database session status: Status to filter by Returns: List of Trace objects with the specified status """ return session.query(models.Trace).filter_by(status=status).order_by(models.Trace.upload_timestamp.desc()).all() def update_trace_status(session: Session, trace_id: str, status: str) -> models.Trace: """ Update the status of a trace. Args: session: Database session trace_id: ID of the trace to update status: New status Returns: Updated Trace object """ trace = get_trace(session, trace_id) if not trace: raise ValueError(f"Trace with ID {trace_id} not found") trace.status = status trace.update_timestamp = datetime.utcnow() session.add(trace) session.commit() return trace def update_trace_content(session: Session, trace_id: str, content: str) -> models.Trace: """ Update the content of a trace. Args: session: Database session trace_id: ID of the trace to update content: New content value Returns: Updated Trace object """ trace = get_trace(session, trace_id) if not trace: raise ValueError(f"Trace with ID {trace_id} not found") trace.content = content trace.character_count = len(content) # Recalculate turn count if needed trace.turn_count = len([line for line in content.split('\n') if line.strip()]) trace.update_timestamp = datetime.utcnow() session.add(trace) session.commit() return trace def link_knowledge_graph_to_trace( session: Session, kg_id: Union[int, str], trace_id: str, window_index: Optional[int] = None, window_total: Optional[int] = None, window_start_char: Optional[int] = None, window_end_char: Optional[int] = None ) -> models.KnowledgeGraph: """ Link a knowledge graph to a trace. Args: session: Database session kg_id: ID or filename of the knowledge graph trace_id: ID of the trace window_index: Optional index of the window within the trace window_total: Optional total number of windows window_start_char: Optional start position in the trace window_end_char: Optional end position in the trace Returns: Updated KnowledgeGraph object """ # Get the knowledge graph kg = get_knowledge_graph_by_id(session, kg_id) if not kg: raise ValueError(f"Knowledge graph with ID {kg_id} not found") # Get the trace trace = get_trace(session, trace_id) if not trace: raise ValueError(f"Trace with ID {trace_id} not found") # Update knowledge graph with trace information kg.trace_id = trace.trace_id if window_index is not None: kg.window_index = window_index if window_total is not None: kg.window_total = window_total if window_start_char is not None: kg.window_start_char = window_start_char if window_end_char is not None: kg.window_end_char = window_end_char # Update graph metadata to include trace info graph_data = kg.graph_data or {} if "metadata" not in graph_data: graph_data["metadata"] = {} graph_data["metadata"]["trace_info"] = { "trace_id": trace.trace_id, "window_index": window_index, "window_total": window_total, "linked_at": datetime.utcnow().isoformat() } kg.graph_data = graph_data session.add(kg) session.commit() return kg def get_knowledge_graphs_for_trace(session: Session, trace_id: str) -> List[models.KnowledgeGraph]: """ Get all knowledge graphs associated with a trace. Args: session: Database session trace_id: ID of the trace Returns: List of KnowledgeGraph objects linked to the trace """ trace = get_trace(session, trace_id) if not trace: raise ValueError(f"Trace with ID {trace_id} not found") return session.query(models.KnowledgeGraph).filter_by(trace_id=trace.trace_id).order_by( models.KnowledgeGraph.window_index ).all() def check_knowledge_graph_exists(session: Session, trace_id: str, is_original: bool = None) -> Optional[models.KnowledgeGraph]: """ Check if a knowledge graph exists for a trace with specific criteria. Args: session: Database session trace_id: ID of the trace is_original: If True, only return knowledge graphs with status='created' If False, only return knowledge graphs with other statuses Returns: KnowledgeGraph object if found, None otherwise """ query = session.query(models.KnowledgeGraph).filter_by(trace_id=trace_id) if is_original is True: # Original KGs have status 'created' query = query.filter_by(status='created') elif is_original is False: # Non-original KGs have other statuses query = query.filter(models.KnowledgeGraph.status != 'created') return query.first() def delete_trace(session: Session, trace_id: str, delete_related_kgs: bool = False) -> bool: """ Delete a trace from the database. Args: session: Database session trace_id: ID of the trace to delete delete_related_kgs: Whether to also delete related knowledge graphs Returns: True if successful, False otherwise """ trace = get_trace(session, trace_id) if not trace: return False # If requested, delete related knowledge graphs if delete_related_kgs: for kg in trace.knowledge_graphs: session.delete(kg) else: # Otherwise, just unlink knowledge graphs from this trace for kg in trace.knowledge_graphs: kg.trace_id = None session.add(kg) # Delete the trace session.delete(trace) session.commit() return True def get_prompt_reconstructions_for_kg(session, kg_identifier): """ Fetch all prompt reconstructions for a given knowledge graph (by ID or filename). Returns a dict mapping relation_id to reconstructed_prompt. """ from backend.database.models import KnowledgeGraph, PromptReconstruction if isinstance(kg_identifier, int): kg = session.query(KnowledgeGraph).filter_by(id=kg_identifier).first() else: kg = session.query(KnowledgeGraph).filter_by(filename=kg_identifier).first() if not kg: return {} prompt_reconstructions = session.query(PromptReconstruction).filter_by(knowledge_graph_id=kg.id).all() return {pr.relation_id: pr.reconstructed_prompt for pr in prompt_reconstructions} def get_prompt_reconstruction_for_relation(session, kg_identifier, relation_id): """ Fetch a single reconstructed prompt for a given knowledge graph and relation_id. Returns the reconstructed_prompt string or None. """ from backend.database.models import KnowledgeGraph, PromptReconstruction if isinstance(kg_identifier, int): kg = session.query(KnowledgeGraph).filter_by(id=kg_identifier).first() else: kg = session.query(KnowledgeGraph).filter_by(filename=kg_identifier).first() if not kg: return None pr = session.query(PromptReconstruction).filter_by(knowledge_graph_id=kg.id, relation_id=relation_id).first() return pr.reconstructed_prompt if pr else None def save_causal_analysis( session: Session, knowledge_graph_id: int, perturbation_set_id: str, analysis_method: str, analysis_result: dict = None, causal_score: float = None, analysis_metadata: dict = None ): """Save a causal analysis result to the database.""" from backend.database import models causal_analysis = models.CausalAnalysis( knowledge_graph_id=knowledge_graph_id, perturbation_set_id=perturbation_set_id, analysis_method=analysis_method, analysis_result=analysis_result, causal_score=causal_score, analysis_metadata=analysis_metadata ) session.add(causal_analysis) session.commit() session.refresh(causal_analysis) return causal_analysis def get_causal_analysis( session: Session, knowledge_graph_id: int, perturbation_set_id: str, analysis_method: str ) -> Optional[models.CausalAnalysis]: """ Get causal analysis results from the database. Args: session: Database session knowledge_graph_id: ID of the knowledge graph perturbation_set_id: ID of the perturbation set analysis_method: Method used for analysis Returns: CausalAnalysis object or None if not found """ return session.query(models.CausalAnalysis).filter_by( knowledge_graph_id=knowledge_graph_id, perturbation_set_id=perturbation_set_id, analysis_method=analysis_method ).first() def get_all_causal_analyses( session: Session, knowledge_graph_id: Optional[int] = None, perturbation_set_id: Optional[str] = None, analysis_method: Optional[str] = None ) -> List[models.CausalAnalysis]: """ Get all causal analysis results from the database with optional filters. Args: session: Database session knowledge_graph_id: Optional filter by knowledge graph ID perturbation_set_id: Optional filter by perturbation set ID analysis_method: Optional filter by analysis method Returns: List of CausalAnalysis objects """ query = session.query(models.CausalAnalysis) if knowledge_graph_id is not None: query = query.filter_by(knowledge_graph_id=knowledge_graph_id) if perturbation_set_id is not None: query = query.filter_by(perturbation_set_id=perturbation_set_id) if analysis_method is not None: query = query.filter_by(analysis_method=analysis_method) return query.all() def get_causal_analysis_for_perturbation(session: Session, perturbation_set_id: str) -> List[Dict[str, Any]]: """ Get all causal analysis results for a specific perturbation set. Args: session: Database session perturbation_set_id: ID of the perturbation set Returns: List of causal analysis results with their associated data """ from backend.database.models import CausalAnalysis, KnowledgeGraph, PromptReconstruction results = session.query( CausalAnalysis, KnowledgeGraph, PromptReconstruction ).join( KnowledgeGraph, CausalAnalysis.knowledge_graph_id == KnowledgeGraph.id ).join( PromptReconstruction, CausalAnalysis.prompt_reconstruction_id == PromptReconstruction.id ).filter( CausalAnalysis.perturbation_set_id == perturbation_set_id ).all() return [{ 'analysis': analysis.to_dict(), 'knowledge_graph': kg.to_dict(), 'prompt_reconstruction': { 'id': pr.id, 'relation_id': pr.relation_id, 'reconstructed_prompt': pr.reconstructed_prompt, 'dependencies': pr.dependencies } } for analysis, kg, pr in results] def get_causal_analysis_by_method(session: Session, knowledge_graph_id: int, method: str) -> List[Dict[str, Any]]: """ Get causal analysis results for a specific knowledge graph and analysis method. Args: session: Database session knowledge_graph_id: ID of the knowledge graph method: Analysis method (e.g., 'graph', 'component', 'dowhy') Returns: List of causal analysis results """ from backend.database.models import CausalAnalysis, PerturbationTest results = session.query( CausalAnalysis, PerturbationTest ).join( PerturbationTest, CausalAnalysis.perturbation_test_id == PerturbationTest.id ).filter( CausalAnalysis.knowledge_graph_id == knowledge_graph_id, CausalAnalysis.analysis_method == method ).all() return [{ 'analysis': analysis.to_dict(), 'perturbation_test': { 'id': pt.id, 'perturbation_type': pt.perturbation_type, 'test_result': pt.test_result, 'perturbation_score': pt.perturbation_score } } for analysis, pt in results] def get_causal_analysis_summary(session: Session, knowledge_graph_id: int) -> Dict[str, Any]: """ Get a summary of causal analysis results for a knowledge graph. Args: session: Database session knowledge_graph_id: ID of the knowledge graph Returns: Dictionary containing summary statistics and results by method """ from backend.database.models import CausalAnalysis from sqlalchemy import func # Get all analyses for this knowledge graph analyses = session.query(CausalAnalysis).filter_by( knowledge_graph_id=knowledge_graph_id ).all() if not analyses: return { 'total_analyses': 0, 'methods': {}, 'average_scores': {} } # Group by method method_results = {} for analysis in analyses: method = analysis.analysis_method if method not in method_results: method_results[method] = [] method_results[method].append(analysis) # Calculate statistics summary = { 'total_analyses': len(analyses), 'methods': {}, 'average_scores': {} } for method, results in method_results.items(): scores = [r.causal_score for r in results if r.causal_score is not None] summary['methods'][method] = { 'count': len(results), 'average_score': sum(scores) / len(scores) if scores else None, 'min_score': min(scores) if scores else None, 'max_score': max(scores) if scores else None } return summary # Add these functions to handle perturbation tests def save_perturbation_test(session, knowledge_graph_id: int, prompt_reconstruction_id: int, relation_id: str, perturbation_type: str, perturbation_set_id: str, test_result: dict = None, perturbation_score: float = None, test_metadata: dict = None) -> int: """ Save a perturbation test to the database. Args: session: Database session knowledge_graph_id: ID of the knowledge graph prompt_reconstruction_id: ID of the prompt reconstruction relation_id: ID of the relation perturbation_type: Type of perturbation perturbation_set_id: ID of the perturbation set test_result: Test result dictionary perturbation_score: Perturbation score test_metadata: Test metadata dictionary Returns: int: ID of the saved perturbation test """ from backend.database.models import PerturbationTest # Create new perturbation test test = PerturbationTest( knowledge_graph_id=knowledge_graph_id, prompt_reconstruction_id=prompt_reconstruction_id, relation_id=relation_id, perturbation_type=perturbation_type, perturbation_set_id=perturbation_set_id, test_result=test_result or {}, perturbation_score=perturbation_score, test_metadata=test_metadata or {} ) # Add to session and commit session.add(test) session.commit() return test.id def delete_perturbation_test(session, test_id: int) -> bool: """ Delete a perturbation test from the database. Args: session: Database session test_id: ID of the perturbation test to delete Returns: bool: True if successful, False otherwise """ from backend.database.models import PerturbationTest # Query the test test = session.query(PerturbationTest).filter_by(id=test_id).first() if test: # Delete and commit session.delete(test) session.commit() return True return False def delete_perturbation_tests_by_set(session, perturbation_set_id: str) -> int: """ Delete all perturbation tests in a set. Args: session: Database session perturbation_set_id: ID of the perturbation set Returns: int: Number of tests deleted """ from backend.database.models import PerturbationTest # Query all tests in the set tests = session.query(PerturbationTest).filter_by(perturbation_set_id=perturbation_set_id).all() # Delete all tests deleted_count = 0 for test in tests: session.delete(test) deleted_count += 1 # Commit changes session.commit() return deleted_count def get_context_document_stats(session: Session, trace_id: str) -> Dict[str, Any]: """ Get statistics about context documents for a trace. Args: session: Database session trace_id: Trace ID to get stats for Returns: Dictionary with context document statistics """ trace = get_trace(session, trace_id) if not trace or not trace.trace_metadata or "context_documents" not in trace.trace_metadata: return {"total_count": 0, "active_count": 0, "by_type": {}} docs = trace.trace_metadata["context_documents"] # Count by type by_type = {} active_count = 0 for doc in docs: doc_type = doc.get("document_type", "unknown") if doc_type not in by_type: by_type[doc_type] = 0 by_type[doc_type] += 1 if doc.get("is_active", True): active_count += 1 return { "total_count": len(docs), "active_count": active_count, "by_type": by_type } def get_context_documents_from_trace(session: Session, trace_id: str) -> List[Dict[str, Any]]: """ Get context documents from a trace's metadata. Args: session: Database session trace_id: ID of the trace Returns: List of context documents, or empty list if none found """ trace = get_trace(session, trace_id) if not trace or not trace.trace_metadata or "context_documents" not in trace.trace_metadata: return [] docs = trace.trace_metadata["context_documents"] # Filter to only active documents active_docs = [doc for doc in docs if doc.get("is_active", True)] return active_docs def get_temporal_windows_by_trace_id(session: Session, trace_id: str, processing_run_id: Optional[str] = None) -> Dict[str, Any]: """ Get all knowledge graph windows for a specific trace, ordered by window_index. Also returns the full/merged version if available. Used for temporal force-directed graph visualization. This function handles cases where KGs exist with trace_id but no trace record exists. Args: session: Database session trace_id: ID of the trace processing_run_id: Optional ID to filter by specific processing run Returns: Dict containing windowed KGs and full KG information """ logger.info(f"Looking up temporal windows for trace_id: {trace_id}") if processing_run_id: logger.info(f"Filtering by processing_run_id: {processing_run_id}") # First check if we can find the trace trace = get_trace(session, trace_id) # Get all knowledge graphs for this trace_id (even if trace record doesn't exist) query = session.query(models.KnowledgeGraph).filter(models.KnowledgeGraph.trace_id == trace_id) # Filter by processing_run_id if provided if processing_run_id: query = query.filter(models.KnowledgeGraph.processing_run_id == processing_run_id) all_kgs = query.all() logger.info(f"Found {len(all_kgs)} total knowledge graphs for trace_id {trace_id}") if processing_run_id: logger.info(f"(filtered by processing_run_id: {processing_run_id})") if not all_kgs: logger.warning(f"No knowledge graphs found for trace_id: {trace_id}") return {"windows": [], "full_kg": None, "trace_info": None} # If no trace record exists, create minimal trace info from KGs if not trace: logger.warning(f"No trace record found for trace_id {trace_id}, but {len(all_kgs)} KGs exist") trace_info = { "trace_id": trace_id, "title": f"Trace {trace_id[:8]}...", "description": "Knowledge graphs exist but no trace record found", "upload_timestamp": min([kg.creation_timestamp for kg in all_kgs if kg.creation_timestamp]) } else: logger.info(f"Found trace record {trace.trace_id}") trace_info = { "trace_id": trace.trace_id, "title": trace.title, "description": trace.description, "upload_timestamp": trace.upload_timestamp } # Separate windowed KGs from the full/merged KG windowed_kgs = [] full_kg = None for kg in all_kgs: # Full/merged KG: has window_total but no window_index, null start/end chars if (kg.window_total is not None and kg.window_index is None and kg.window_start_char is None and kg.window_end_char is None): full_kg = kg logger.info(f"Found full/merged KG: {kg.filename} with window_total={kg.window_total}") # Windowed KG: has window_index elif kg.window_index is not None: windowed_kgs.append(kg) # KG without proper window info - try to assign window_index else: logger.info(f"Found KG without proper window info: {kg.filename}") # If we don't have a full KG yet and this looks like it could be one if (full_kg is None and kg.window_total is None and kg.window_start_char is None and kg.window_end_char is None): # Check if this KG has significantly more entities than others (indicating it's merged) if kg.graph_data and len(kg.graph_data.get("entities", [])) > 10: kg.window_total = len(windowed_kgs) + 1 # Set based on current windowed KGs full_kg = kg session.add(kg) logger.info(f"Assigned {kg.filename} as full KG with window_total={kg.window_total}") # If we have windowed KGs but some are missing window_index, assign them kgs_without_index = [kg for kg in all_kgs if kg.window_index is None and kg != full_kg] if kgs_without_index and windowed_kgs: logger.info("Assigning window_index to knowledge graphs based on creation order") # Sort by creation timestamp and assign window_index starting from the highest existing + 1 max_window_index = max([kg.window_index for kg in windowed_kgs], default=-1) kgs_without_index.sort(key=lambda kg: kg.creation_timestamp or datetime.utcnow()) for i, kg in enumerate(kgs_without_index): kg.window_index = max_window_index + 1 + i session.add(kg) windowed_kgs.append(kg) session.commit() logger.info(f"Assigned window_index to {len(kgs_without_index)} knowledge graphs") # Sort windowed KGs by window_index windowed_kgs.sort(key=lambda kg: kg.window_index) logger.info(f"Found {len(windowed_kgs)} windowed KGs and {'1' if full_kg else '0'} full KG for trace_id {trace_id}") # Update entity_count and relation_count if they're 0 or None but graph_data has content updated_count = 0 for kg in windowed_kgs + ([full_kg] if full_kg else []): if kg and kg.graph_data: needs_update = False if kg.entity_count is None or kg.entity_count == 0: entities = kg.graph_data.get("entities", []) if entities: kg.entity_count = len(entities) needs_update = True if kg.relation_count is None or kg.relation_count == 0: relations = kg.graph_data.get("relations", []) if relations: kg.relation_count = len(relations) needs_update = True if needs_update: session.add(kg) updated_count += 1 if updated_count > 0: session.commit() logger.info(f"Updated entity/relation counts for {updated_count} knowledge graphs") return { "windows": windowed_kgs, "full_kg": full_kg, "trace_info": trace_info }