Spaces:
Running
Running
| """ | |
| Utility functions for database operations. | |
| """ | |
| import os | |
| import json | |
| import logging | |
| import uuid | |
| from datetime import datetime | |
| from typing import Dict, List, Any, Optional, Union | |
| import hashlib | |
| from sqlalchemy.orm import Session | |
| from sqlalchemy import func | |
| from . import models | |
| from . import get_db, init_db | |
| logger = logging.getLogger(__name__) | |
| def initialize_database(clear_all=False): | |
| """ | |
| Initialize the database and create tables. | |
| Args: | |
| clear_all: If True, drops all existing tables before creating new ones | |
| """ | |
| if clear_all: | |
| from . import reinit_db | |
| reinit_db() | |
| logger.info("Database reinitialized (all previous data cleared)") | |
| else: | |
| from . import init_db | |
| init_db() | |
| logger.info("Database initialized (existing data preserved)") | |
| def get_knowledge_graph(session: Session, filename: str) -> Optional[models.KnowledgeGraph]: | |
| """Get a knowledge graph by filename.""" | |
| return session.query(models.KnowledgeGraph).filter_by(filename=filename).first() | |
| def save_knowledge_graph( | |
| session, | |
| filename, | |
| graph_data, | |
| trace_id=None, | |
| window_index=None, | |
| window_total=None, | |
| window_start_char=None, | |
| window_end_char=None, | |
| is_original=False, | |
| processing_run_id=None | |
| ): | |
| """ | |
| Save a knowledge graph to the database. | |
| Args: | |
| session: Database session | |
| filename: Filename to save under | |
| graph_data: Knowledge graph data | |
| trace_id: Optional ID to group knowledge graphs from the same trace | |
| window_index: Optional sequential index of window within a trace | |
| window_total: Optional total number of windows in the trace | |
| window_start_char: Optional starting character position in the original trace | |
| window_end_char: Optional ending character position in the original trace | |
| is_original: Whether this is an original knowledge graph (sets status to "created") | |
| processing_run_id: Optional ID to distinguish multiple processing runs | |
| Returns: | |
| The created KnowledgeGraph object | |
| """ | |
| from backend.database.models import KnowledgeGraph | |
| # Check if the knowledge graph already exists | |
| kg = session.query(KnowledgeGraph).filter(KnowledgeGraph.filename == filename).first() | |
| if kg: | |
| # Update the existing knowledge graph using graph_content to ensure counts are updated | |
| kg.graph_content = graph_data | |
| kg.update_timestamp = datetime.utcnow() | |
| # Update trace information if provided | |
| if trace_id is not None: | |
| kg.trace_id = trace_id | |
| if window_index is not None: | |
| kg.window_index = window_index | |
| if window_total is not None: | |
| kg.window_total = window_total | |
| if window_start_char is not None: | |
| kg.window_start_char = window_start_char | |
| if window_end_char is not None: | |
| kg.window_end_char = window_end_char | |
| if processing_run_id is not None: | |
| kg.processing_run_id = processing_run_id | |
| # Set status if is_original is True | |
| if is_original: | |
| kg.status = "created" | |
| session.add(kg) | |
| session.commit() | |
| return kg | |
| else: | |
| # Create a new knowledge graph | |
| kg = KnowledgeGraph( | |
| filename=filename, | |
| trace_id=trace_id, | |
| window_index=window_index, | |
| window_total=window_total, | |
| window_start_char=window_start_char, | |
| window_end_char=window_end_char, | |
| status="created" if is_original else None, | |
| processing_run_id=processing_run_id | |
| ) | |
| # Set graph content after creation to ensure counts are updated | |
| kg.graph_content = graph_data | |
| session.add(kg) | |
| session.commit() | |
| return kg | |
| def update_knowledge_graph_status(session: Session, kg_id: Union[int, str], status: str) -> models.KnowledgeGraph: | |
| """ | |
| Update the status of a knowledge graph. | |
| Args: | |
| session: Database session | |
| kg_id: Knowledge graph ID or filename | |
| status: New status (created, enriched, perturbed, causal) | |
| Returns: | |
| Updated knowledge graph | |
| """ | |
| # Check if kg_id is a filename or an ID | |
| if isinstance(kg_id, str): | |
| kg = session.query(models.KnowledgeGraph).filter_by(filename=kg_id).first() | |
| else: | |
| kg = session.query(models.KnowledgeGraph).filter_by(id=kg_id).first() | |
| if not kg: | |
| raise ValueError(f"Knowledge graph with ID/filename {kg_id} not found") | |
| # Update status | |
| kg.status = status | |
| session.commit() | |
| return kg | |
| def extract_entities_and_relations(session: Session, kg: models.KnowledgeGraph): | |
| """Extract entities and relations from a knowledge graph and save them to the database.""" | |
| # Get the graph data | |
| data = kg.graph_data | |
| # Skip if no data | |
| if not data: | |
| return | |
| # First, delete existing relations and entities for this knowledge graph | |
| # We need to delete relations first due to foreign key constraints | |
| session.query(models.Relation).filter_by(graph_id=kg.id).delete() | |
| session.query(models.Entity).filter_by(graph_id=kg.id).delete() | |
| session.flush() | |
| # Process entities | |
| entity_map = {} # Map entity_id to Entity instance | |
| for entity_data in data.get('entities', []): | |
| try: | |
| # Skip if no id | |
| if 'id' not in entity_data: | |
| continue | |
| entity_id = entity_data.get('id') | |
| # Create entity | |
| entity = models.Entity.from_dict(entity_data, kg.id) | |
| # Add to session | |
| session.add(entity) | |
| session.flush() # Flush to get the ID | |
| # Add to map | |
| entity_map[entity_id] = entity | |
| except Exception as e: | |
| logger.error(f"Error extracting entity {entity_data.get('id')}: {str(e)}") | |
| # Process relations | |
| for relation_data in data.get('relations', []): | |
| try: | |
| # Skip if no id, source, or target | |
| if 'id' not in relation_data or 'source' not in relation_data or 'target' not in relation_data: | |
| continue | |
| source_id = relation_data.get('source') | |
| target_id = relation_data.get('target') | |
| # Get source and target entities | |
| source_entity = entity_map.get(source_id) | |
| target_entity = entity_map.get(target_id) | |
| # Skip if source or target entity not found | |
| if not source_entity or not target_entity: | |
| logger.warning(f"Skipping relation {relation_data.get('id')}: Source or target entity not found") | |
| continue | |
| # Create relation | |
| relation = models.Relation.from_dict( | |
| relation_data, | |
| kg.id, | |
| source_entity, | |
| target_entity | |
| ) | |
| # Add to session | |
| session.add(relation) | |
| except Exception as e: | |
| logger.error(f"Error extracting relation {relation_data.get('id')}: {str(e)}") | |
| # Commit the changes | |
| session.commit() | |
| def get_test_result(session: Session, filename: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get a test result by filename from the knowledge graph. | |
| This now returns a dictionary with test result data instead of a TestResult model. | |
| """ | |
| # Try to find a knowledge graph with this test result filename | |
| kg = session.query(models.KnowledgeGraph).filter_by(filename=filename).first() | |
| if kg and kg.content: | |
| try: | |
| data = json.loads(kg.content) | |
| if 'test_result' in data: | |
| return data['test_result'] | |
| except json.JSONDecodeError: | |
| pass | |
| # Try standard file locations | |
| standard_file_locations = [ | |
| f"datasets/test_results/{filename}", | |
| f"datasets/{filename}" | |
| ] | |
| for file_path in standard_file_locations: | |
| try: | |
| with open(file_path, 'r') as f: | |
| return json.load(f) | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| continue | |
| return None | |
| def save_test_result(session: Session, filename: str, data: Dict[str, Any]) -> models.KnowledgeGraph: | |
| """ | |
| Save a test result to the database. | |
| Test results are now stored within the KnowledgeGraph content field | |
| rather than as a separate TestResult model. | |
| """ | |
| # Find or create a knowledge graph for this test result | |
| kg = session.query(models.KnowledgeGraph).filter_by(filename=filename).first() | |
| if not kg: | |
| # Create new knowledge graph for this test result | |
| kg = models.KnowledgeGraph() | |
| kg.filename = filename | |
| kg.creation_timestamp = datetime.utcnow() | |
| # Get existing content or initialize empty dict | |
| try: | |
| if kg.content: | |
| content = json.loads(kg.content) | |
| else: | |
| content = {} | |
| except json.JSONDecodeError: | |
| content = {} | |
| # Update test result data | |
| content['test_result'] = data | |
| content['test_timestamp'] = datetime.utcnow().isoformat() | |
| content['model_name'] = data.get('model', '') | |
| content['perturbation_type'] = data.get('perturbation_type', '') | |
| content['completed'] = data.get('completed', False) | |
| # Find the related knowledge graph if referenced | |
| kg_filename = data.get('knowledge_graph_filename') | |
| if kg_filename: | |
| related_kg = session.query(models.KnowledgeGraph).filter_by(filename=kg_filename).first() | |
| if related_kg: | |
| content['knowledge_graph_id'] = related_kg.id | |
| # Save updated content back to knowledge graph | |
| kg.content = json.dumps(content) | |
| # Save to database | |
| session.add(kg) | |
| session.commit() | |
| return kg | |
| def get_test_progress(session: Session, test_filename: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get test progress by test filename. | |
| Now returns progress data as a dictionary instead of a TestProgress model. | |
| """ | |
| kg = session.query(models.KnowledgeGraph).filter_by(filename=test_filename).first() | |
| if kg and kg.content: | |
| try: | |
| content = json.loads(kg.content) | |
| if 'test_progress' in content: | |
| return content['test_progress'] | |
| except json.JSONDecodeError: | |
| pass | |
| # Try to find a progress file | |
| progress_filename = f"progress_{test_filename}" | |
| progress_path = str(PROJECT_ROOT / 'datasets' / 'test_results' / progress_filename) | |
| if os.path.exists(progress_path): | |
| try: | |
| with open(progress_path, 'r') as f: | |
| return json.load(f) | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| pass | |
| return None | |
| def save_test_progress(session: Session, test_filename: str, data: Dict[str, Any]) -> models.KnowledgeGraph: | |
| """ | |
| Save test progress to the database. | |
| Test progress is now stored within the KnowledgeGraph content field | |
| rather than as a separate TestProgress model. | |
| """ | |
| # Find the knowledge graph for this test | |
| kg = session.query(models.KnowledgeGraph).filter_by(filename=test_filename).first() | |
| if not kg: | |
| # Create new knowledge graph for this test | |
| kg = models.KnowledgeGraph() | |
| kg.filename = test_filename | |
| kg.creation_timestamp = datetime.utcnow() | |
| # Get existing content or initialize empty dict | |
| try: | |
| if kg.content: | |
| content = json.loads(kg.content) | |
| else: | |
| content = {} | |
| except json.JSONDecodeError: | |
| content = {} | |
| # Initialize test_progress if it doesn't exist | |
| if 'test_progress' not in content: | |
| content['test_progress'] = {} | |
| # Update progress data | |
| if 'progress' in data: | |
| progress_data = data['progress'] | |
| content['test_progress']['status'] = progress_data.get('status', content['test_progress'].get('status')) | |
| content['test_progress']['current'] = progress_data.get('current', content['test_progress'].get('current')) | |
| content['test_progress']['total'] = progress_data.get('total', content['test_progress'].get('total')) | |
| content['test_progress']['last_tested_relation'] = progress_data.get('last_tested_relation', content['test_progress'].get('last_tested_relation')) | |
| content['test_progress']['overall_progress_percentage'] = progress_data.get('overall_progress_percentage', content['test_progress'].get('overall_progress_percentage')) | |
| content['test_progress']['current_jailbreak'] = progress_data.get('current_jailbreak', content['test_progress'].get('current_jailbreak')) | |
| else: | |
| # Direct update of progress data | |
| for key, value in data.items(): | |
| content['test_progress'][key] = value | |
| if 'timestamp' in data: | |
| try: | |
| content['test_progress']['timestamp'] = data['timestamp'] | |
| except (ValueError, TypeError): | |
| content['test_progress']['timestamp'] = datetime.utcnow().isoformat() | |
| else: | |
| content['test_progress']['timestamp'] = datetime.utcnow().isoformat() | |
| # Save updated content back to knowledge graph | |
| kg.content = json.dumps(content) | |
| # Save to database | |
| session.add(kg) | |
| session.commit() | |
| # Also save to progress file for backward compatibility | |
| try: | |
| progress_filename = f"progress_{test_filename}" | |
| progress_dir = 'datasets/test_results' | |
| os.makedirs(progress_dir, exist_ok=True) | |
| progress_path = os.path.join(progress_dir, progress_filename) | |
| with open(progress_path, 'w') as f: | |
| json.dump(content['test_progress'], f) | |
| except Exception as e: | |
| logger.warning(f"Failed to save progress file: {str(e)}") | |
| return kg | |
| def get_all_knowledge_graphs(session: Session) -> List[models.KnowledgeGraph]: | |
| """Get all knowledge graphs.""" | |
| return session.query(models.KnowledgeGraph).all() | |
| def get_all_test_results(session: Session) -> List[Dict[str, Any]]: | |
| """ | |
| Get all test results. | |
| Now returns a list of dictionaries containing test result data | |
| extracted from knowledge graphs. | |
| """ | |
| test_results = [] | |
| # Get all knowledge graphs that may contain test results | |
| knowledge_graphs = session.query(models.KnowledgeGraph).all() | |
| for kg in knowledge_graphs: | |
| if kg.content: | |
| try: | |
| content = json.loads(kg.content) | |
| if 'test_result' in content: | |
| # Add filename and ID for reference | |
| result = content['test_result'].copy() if isinstance(content['test_result'], dict) else {} | |
| result['filename'] = kg.filename | |
| result['id'] = kg.id | |
| test_results.append(result) | |
| except json.JSONDecodeError: | |
| continue | |
| return test_results | |
| def get_standard_dataset(session: Session, filename: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get a standard dataset by filename (e.g., jailbreak techniques). | |
| First attempts to load from the database as a knowledge graph, | |
| then falls back to standard data file locations. | |
| """ | |
| # Try to get from database as a knowledge graph | |
| kg = session.query(models.KnowledgeGraph).filter_by(filename=filename).first() | |
| if kg and kg.content: | |
| try: | |
| return json.loads(kg.content) | |
| except json.JSONDecodeError: | |
| pass | |
| # If not in database, try standard file locations | |
| standard_file_locations = [ | |
| f"datasets/{filename}", # Direct in data dir | |
| f"datasets/test_results/{filename}", | |
| f"datasets/knowledge_graphs/{filename}" | |
| ] | |
| for file_path in standard_file_locations: | |
| try: | |
| with open(file_path, 'r') as f: | |
| return json.load(f) | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| continue | |
| # Finally try as an absolute path | |
| try: | |
| with open(filename, 'r') as f: | |
| return json.load(f) | |
| except (FileNotFoundError, json.JSONDecodeError): | |
| pass | |
| return None | |
| def find_entity_by_id(session: Session, entity_id: str) -> Optional[models.Entity]: | |
| """ | |
| Find an entity by its ID. | |
| Args: | |
| session: Database session | |
| entity_id: Entity ID to search for | |
| Returns: | |
| Entity or None if not found | |
| """ | |
| query = session.query(models.Entity).filter_by(entity_id=entity_id) | |
| return query.first() | |
| def find_relation_by_id(session: Session, relation_id: str) -> Optional[models.Relation]: | |
| """ | |
| Find a relation by its ID. | |
| Args: | |
| session: Database session | |
| relation_id: Relation ID to search for | |
| Returns: | |
| Relation or None if not found | |
| """ | |
| query = session.query(models.Relation).filter_by(relation_id=relation_id) | |
| return query.first() | |
| def find_entities_by_type(session: Session, entity_type: str) -> List[models.Entity]: | |
| """ | |
| Find entities by type. | |
| Args: | |
| session: Database session | |
| entity_type: Entity type to search for | |
| Returns: | |
| List of entities | |
| """ | |
| query = session.query(models.Entity).filter_by(type=entity_type) | |
| return query.all() | |
| def find_relations_by_type(session: Session, relation_type: str) -> List[models.Relation]: | |
| """ | |
| Find relations by type. | |
| Args: | |
| session: Database session | |
| relation_type: Relation type to search for | |
| Returns: | |
| List of relations | |
| """ | |
| query = session.query(models.Relation).filter_by(type=relation_type) | |
| return query.all() | |
| def merge_knowledge_graphs(session: Session, output_filename: str, input_filenames: List[str]) -> Optional[models.KnowledgeGraph]: | |
| """ | |
| Merge multiple knowledge graphs into a single knowledge graph. | |
| Args: | |
| session: Database session | |
| output_filename: Output filename for the merged knowledge graph | |
| input_filenames: List of filenames of knowledge graphs to merge | |
| Returns: | |
| The merged KnowledgeGraph instance or None if error | |
| """ | |
| # Check if merged graph already exists | |
| existing_kg = get_knowledge_graph(session, output_filename) | |
| if existing_kg: | |
| logger.warning(f"Knowledge graph {output_filename} already exists. Returning existing graph.") | |
| return existing_kg | |
| # Load all input knowledge graphs | |
| knowledge_graphs = [] | |
| for filename in input_filenames: | |
| kg = get_knowledge_graph(session, filename) | |
| if not kg: | |
| logger.warning(f"Knowledge graph {filename} not found. Skipping.") | |
| continue | |
| knowledge_graphs.append(kg) | |
| if not knowledge_graphs: | |
| logger.error("No valid knowledge graphs to merge.") | |
| return None | |
| # Create a new merged knowledge graph | |
| merged_data = { | |
| "entities": [], | |
| "relations": [], | |
| "metadata": { | |
| "source_graphs": input_filenames, | |
| "creation_time": datetime.datetime.utcnow().isoformat(), | |
| "merge_method": "concatenate" | |
| } | |
| } | |
| # Keep track of entity and relation IDs to avoid duplicates | |
| entity_ids = set() | |
| relation_ids = set() | |
| # Add entities and relations from each graph | |
| for kg in knowledge_graphs: | |
| graph_data = kg.graph_data | |
| if not graph_data: | |
| logger.warning(f"Knowledge graph {kg.filename} has no data. Skipping.") | |
| continue | |
| # Process entities | |
| for entity in graph_data.get("entities", []): | |
| # Skip if no ID | |
| if "id" not in entity: | |
| continue | |
| # Skip if ID already exists | |
| if entity["id"] in entity_ids: | |
| continue | |
| # Add to merged data | |
| merged_data["entities"].append(entity) | |
| entity_ids.add(entity["id"]) | |
| # Process relations | |
| for relation in graph_data.get("relations", []): | |
| # Skip if no ID, source, or target | |
| if "id" not in relation or "source" not in relation or "target" not in relation: | |
| continue | |
| # Skip if ID already exists | |
| if relation["id"] in relation_ids: | |
| continue | |
| # Add to merged data | |
| merged_data["relations"].append(relation) | |
| relation_ids.add(relation["id"]) | |
| # Save the merged knowledge graph | |
| return save_knowledge_graph(session, output_filename, merged_data) | |
| def get_knowledge_graph_by_id(session, graph_id): | |
| """ | |
| Get a knowledge graph by its ID or filename | |
| Args: | |
| session: Database session | |
| graph_id: Either an integer ID or a string filename | |
| Returns: | |
| KnowledgeGraph object or None if not found | |
| """ | |
| try: | |
| logger.info(f"Looking up knowledge graph: {graph_id} (type: {type(graph_id)})") | |
| # Special handling for "latest" | |
| if isinstance(graph_id, str) and graph_id.lower() == "latest": | |
| logger.info("Handling 'latest' special case") | |
| kg = session.query(models.KnowledgeGraph).order_by(models.KnowledgeGraph.id.desc()).first() | |
| if kg: | |
| logger.info(f"Found latest knowledge graph with ID {kg.id} and filename {kg.filename}") | |
| return kg | |
| logger.warning("No knowledge graphs found in database") | |
| return None | |
| # Try as integer ID first | |
| if isinstance(graph_id, int) or (isinstance(graph_id, str) and graph_id.isdigit()): | |
| kg_id = int(graph_id) | |
| logger.info(f"Looking up knowledge graph by ID: {kg_id}") | |
| kg = session.query(models.KnowledgeGraph).filter(models.KnowledgeGraph.id == kg_id).first() | |
| if kg: | |
| logger.info(f"Found knowledge graph by ID {kg_id}: {kg.filename}") | |
| return kg | |
| logger.warning(f"Knowledge graph with ID {kg_id} not found") | |
| # If not found by ID or not an integer, try as filename | |
| if isinstance(graph_id, str): | |
| logger.info(f"Looking up knowledge graph by filename: {graph_id}") | |
| kg = session.query(models.KnowledgeGraph).filter(models.KnowledgeGraph.filename == graph_id).first() | |
| if kg: | |
| logger.info(f"Found knowledge graph by filename {graph_id}: ID {kg.id}") | |
| return kg | |
| logger.warning(f"Knowledge graph with filename {graph_id} not found") | |
| logger.error(f"Knowledge graph not found: {graph_id}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error retrieving knowledge graph by ID: {str(e)}") | |
| return None | |
| def update_knowledge_graph(session: Session, filename: str, graph_data: dict) -> models.KnowledgeGraph: | |
| """ | |
| Update an existing knowledge graph with new data. | |
| Args: | |
| session: Database session | |
| filename: Filename of the knowledge graph to update | |
| graph_data: New graph data | |
| Returns: | |
| Updated KnowledgeGraph instance | |
| """ | |
| # Get the knowledge graph | |
| kg = get_knowledge_graph(session, filename) | |
| if not kg: | |
| # Create a new knowledge graph if it doesn't exist | |
| logger.info(f"Knowledge graph {filename} not found. Creating a new one.") | |
| return save_knowledge_graph(session, filename, graph_data) | |
| # Update the knowledge graph data | |
| kg.graph_data = graph_data | |
| # Update entity and relation counts | |
| if isinstance(graph_data, dict): | |
| if 'entities' in graph_data and isinstance(graph_data['entities'], list): | |
| kg.entity_count = len(graph_data['entities']) | |
| if 'relations' in graph_data and isinstance(graph_data['relations'], list): | |
| kg.relation_count = len(graph_data['relations']) | |
| # Update last modified timestamp | |
| kg.update_timestamp = datetime.utcnow() | |
| # Save to database | |
| session.add(kg) | |
| session.commit() | |
| logger.info(f"Updated knowledge graph {filename}") | |
| return kg | |
| def delete_knowledge_graph(session: Session, identifier: Union[int, str]) -> bool: | |
| """ | |
| Delete a knowledge graph and all its associated entities and relations. | |
| Args: | |
| session: Database session | |
| identifier: Knowledge graph ID or filename | |
| Returns: | |
| True if deletion was successful, False otherwise | |
| """ | |
| try: | |
| # Find the knowledge graph | |
| if isinstance(identifier, str): | |
| # Identifier is a filename | |
| kg = session.query(models.KnowledgeGraph).filter_by(filename=identifier).first() | |
| else: | |
| # Identifier is an ID | |
| kg = session.query(models.KnowledgeGraph).filter_by(id=identifier).first() | |
| if not kg: | |
| logger.warning(f"Knowledge graph with identifier {identifier} not found") | |
| return False | |
| kg_id = kg.id | |
| filename = kg.filename | |
| # Count associated entities and relations for logging | |
| entity_count = session.query(models.Entity).filter_by(graph_id=kg_id).count() | |
| relation_count = session.query(models.Relation).filter_by(graph_id=kg_id).count() | |
| # Begin transaction | |
| logger.info(f"Deleting knowledge graph {filename} (ID: {kg_id}) with {entity_count} entities and {relation_count} relations") | |
| # Due to the CASCADE setting in the relationships, deleting the knowledge graph | |
| # will automatically delete all associated entities and relations. | |
| # However, we'll delete them explicitly for clarity and to ensure proper cleanup. | |
| # Delete relations first (due to foreign key constraints) | |
| session.query(models.Relation).filter_by(graph_id=kg_id).delete() | |
| # Delete entities | |
| session.query(models.Entity).filter_by(graph_id=kg_id).delete() | |
| # Delete the knowledge graph | |
| session.delete(kg) | |
| # Commit transaction | |
| session.commit() | |
| logger.info(f"Successfully deleted knowledge graph {filename} (ID: {kg_id}) and its associated data") | |
| return True | |
| except Exception as e: | |
| # Rollback on error | |
| session.rollback() | |
| logger.error(f"Error deleting knowledge graph: {str(e)}") | |
| return False | |
| def get_trace(session: Session, trace_id: str) -> Optional[models.Trace]: | |
| """ | |
| Get a trace by its ID or filename. | |
| Args: | |
| session: Database session | |
| trace_id: Either a UUID trace_id or a filename | |
| Returns: | |
| Trace object or None if not found | |
| """ | |
| # Try as UUID trace_id first | |
| trace = session.query(models.Trace).filter_by(trace_id=trace_id).first() | |
| if trace: | |
| return trace | |
| # If not found, try as filename | |
| trace = session.query(models.Trace).filter_by(filename=trace_id).first() | |
| if trace: | |
| return trace | |
| # If not found, try as ID | |
| try: | |
| id_value = int(trace_id) | |
| trace = session.query(models.Trace).filter_by(id=id_value).first() | |
| if trace: | |
| return trace | |
| except (ValueError, TypeError): | |
| pass | |
| return None | |
| def save_trace( | |
| session: Session, | |
| content: str, | |
| filename: Optional[str] = None, | |
| title: Optional[str] = None, | |
| description: Optional[str] = None, | |
| trace_type: Optional[str] = None, | |
| trace_source: str = "user_upload", | |
| uploader: Optional[str] = None, | |
| tags: Optional[List[str]] = None, | |
| trace_metadata: Optional[Dict[str, Any]] = None | |
| ) -> models.Trace: | |
| """ | |
| Save a trace to the database. | |
| Args: | |
| session: Database session | |
| content: The content of the trace | |
| filename: Optional filename | |
| title: Optional title | |
| description: Optional description | |
| trace_type: Optional type of trace | |
| trace_source: Source of the trace (default: "user_upload") | |
| uploader: Optional name of the uploader | |
| tags: Optional list of tags | |
| trace_metadata: Optional additional metadata | |
| Returns: | |
| The created or updated Trace object | |
| """ | |
| # Generate content hash for deduplication | |
| content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest() | |
| # Check if trace already exists with this content hash | |
| existing_trace = session.query(models.Trace).filter_by(content_hash=content_hash).first() | |
| if existing_trace: | |
| logger.info(f"Trace with matching content hash already exists (ID: {existing_trace.id})") | |
| # Update fields if provided | |
| if filename: | |
| existing_trace.filename = filename | |
| if title: | |
| existing_trace.title = title | |
| if description: | |
| existing_trace.description = description | |
| if trace_type: | |
| existing_trace.trace_type = trace_type | |
| if uploader: | |
| existing_trace.uploader = uploader | |
| if tags: | |
| existing_trace.tags = tags | |
| if trace_metadata: | |
| # Merge metadata rather than replace | |
| if existing_trace.trace_metadata: | |
| existing_trace.trace_metadata.update(trace_metadata) | |
| else: | |
| existing_trace.trace_metadata = trace_metadata | |
| # Update timestamp | |
| existing_trace.update_timestamp = datetime.utcnow() | |
| session.add(existing_trace) | |
| session.commit() | |
| return existing_trace | |
| # Create new trace | |
| trace = models.Trace.from_content( | |
| content=content, | |
| filename=filename, | |
| title=title, | |
| description=description, | |
| trace_type=trace_type, | |
| trace_source=trace_source, | |
| uploader=uploader, | |
| tags=tags, | |
| trace_metadata=trace_metadata | |
| ) | |
| session.add(trace) | |
| session.commit() | |
| logger.info(f"New trace saved to database (ID: {trace.id}, trace_id: {trace.trace_id})") | |
| return trace | |
| def get_all_traces(session: Session) -> List[models.Trace]: | |
| """ | |
| Get all traces from the database. | |
| Args: | |
| session: Database session | |
| Returns: | |
| List of Trace objects | |
| """ | |
| return session.query(models.Trace).order_by(models.Trace.upload_timestamp.desc()).all() | |
| def get_traces_by_status(session: Session, status: str) -> List[models.Trace]: | |
| """ | |
| Get traces by status. | |
| Args: | |
| session: Database session | |
| status: Status to filter by | |
| Returns: | |
| List of Trace objects with the specified status | |
| """ | |
| return session.query(models.Trace).filter_by(status=status).order_by(models.Trace.upload_timestamp.desc()).all() | |
| def update_trace_status(session: Session, trace_id: str, status: str) -> models.Trace: | |
| """ | |
| Update the status of a trace. | |
| Args: | |
| session: Database session | |
| trace_id: ID of the trace to update | |
| status: New status | |
| Returns: | |
| Updated Trace object | |
| """ | |
| trace = get_trace(session, trace_id) | |
| if not trace: | |
| raise ValueError(f"Trace with ID {trace_id} not found") | |
| trace.status = status | |
| trace.update_timestamp = datetime.utcnow() | |
| session.add(trace) | |
| session.commit() | |
| return trace | |
| def update_trace_content(session: Session, trace_id: str, content: str) -> models.Trace: | |
| """ | |
| Update the content of a trace. | |
| Args: | |
| session: Database session | |
| trace_id: ID of the trace to update | |
| content: New content value | |
| Returns: | |
| Updated Trace object | |
| """ | |
| trace = get_trace(session, trace_id) | |
| if not trace: | |
| raise ValueError(f"Trace with ID {trace_id} not found") | |
| trace.content = content | |
| trace.character_count = len(content) | |
| # Recalculate turn count if needed | |
| trace.turn_count = len([line for line in content.split('\n') if line.strip()]) | |
| trace.update_timestamp = datetime.utcnow() | |
| session.add(trace) | |
| session.commit() | |
| return trace | |
| def link_knowledge_graph_to_trace( | |
| session: Session, | |
| kg_id: Union[int, str], | |
| trace_id: str, | |
| window_index: Optional[int] = None, | |
| window_total: Optional[int] = None, | |
| window_start_char: Optional[int] = None, | |
| window_end_char: Optional[int] = None | |
| ) -> models.KnowledgeGraph: | |
| """ | |
| Link a knowledge graph to a trace. | |
| Args: | |
| session: Database session | |
| kg_id: ID or filename of the knowledge graph | |
| trace_id: ID of the trace | |
| window_index: Optional index of the window within the trace | |
| window_total: Optional total number of windows | |
| window_start_char: Optional start position in the trace | |
| window_end_char: Optional end position in the trace | |
| Returns: | |
| Updated KnowledgeGraph object | |
| """ | |
| # Get the knowledge graph | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| raise ValueError(f"Knowledge graph with ID {kg_id} not found") | |
| # Get the trace | |
| trace = get_trace(session, trace_id) | |
| if not trace: | |
| raise ValueError(f"Trace with ID {trace_id} not found") | |
| # Update knowledge graph with trace information | |
| kg.trace_id = trace.trace_id | |
| if window_index is not None: | |
| kg.window_index = window_index | |
| if window_total is not None: | |
| kg.window_total = window_total | |
| if window_start_char is not None: | |
| kg.window_start_char = window_start_char | |
| if window_end_char is not None: | |
| kg.window_end_char = window_end_char | |
| # Update graph metadata to include trace info | |
| graph_data = kg.graph_data or {} | |
| if "metadata" not in graph_data: | |
| graph_data["metadata"] = {} | |
| graph_data["metadata"]["trace_info"] = { | |
| "trace_id": trace.trace_id, | |
| "window_index": window_index, | |
| "window_total": window_total, | |
| "linked_at": datetime.utcnow().isoformat() | |
| } | |
| kg.graph_data = graph_data | |
| session.add(kg) | |
| session.commit() | |
| return kg | |
| def get_knowledge_graphs_for_trace(session: Session, trace_id: str) -> List[models.KnowledgeGraph]: | |
| """ | |
| Get all knowledge graphs associated with a trace. | |
| Args: | |
| session: Database session | |
| trace_id: ID of the trace | |
| Returns: | |
| List of KnowledgeGraph objects linked to the trace | |
| """ | |
| trace = get_trace(session, trace_id) | |
| if not trace: | |
| raise ValueError(f"Trace with ID {trace_id} not found") | |
| return session.query(models.KnowledgeGraph).filter_by(trace_id=trace.trace_id).order_by( | |
| models.KnowledgeGraph.window_index | |
| ).all() | |
| def check_knowledge_graph_exists(session: Session, trace_id: str, is_original: bool = None) -> Optional[models.KnowledgeGraph]: | |
| """ | |
| Check if a knowledge graph exists for a trace with specific criteria. | |
| Args: | |
| session: Database session | |
| trace_id: ID of the trace | |
| is_original: If True, only return knowledge graphs with status='created' | |
| If False, only return knowledge graphs with other statuses | |
| Returns: | |
| KnowledgeGraph object if found, None otherwise | |
| """ | |
| query = session.query(models.KnowledgeGraph).filter_by(trace_id=trace_id) | |
| if is_original is True: | |
| # Original KGs have status 'created' | |
| query = query.filter_by(status='created') | |
| elif is_original is False: | |
| # Non-original KGs have other statuses | |
| query = query.filter(models.KnowledgeGraph.status != 'created') | |
| return query.first() | |
| def delete_trace(session: Session, trace_id: str, delete_related_kgs: bool = False) -> bool: | |
| """ | |
| Delete a trace from the database. | |
| Args: | |
| session: Database session | |
| trace_id: ID of the trace to delete | |
| delete_related_kgs: Whether to also delete related knowledge graphs | |
| Returns: | |
| True if successful, False otherwise | |
| """ | |
| trace = get_trace(session, trace_id) | |
| if not trace: | |
| return False | |
| # If requested, delete related knowledge graphs | |
| if delete_related_kgs: | |
| for kg in trace.knowledge_graphs: | |
| session.delete(kg) | |
| else: | |
| # Otherwise, just unlink knowledge graphs from this trace | |
| for kg in trace.knowledge_graphs: | |
| kg.trace_id = None | |
| session.add(kg) | |
| # Delete the trace | |
| session.delete(trace) | |
| session.commit() | |
| return True | |
| def get_prompt_reconstructions_for_kg(session, kg_identifier): | |
| """ | |
| Fetch all prompt reconstructions for a given knowledge graph (by ID or filename). | |
| Returns a dict mapping relation_id to reconstructed_prompt. | |
| """ | |
| from backend.database.models import KnowledgeGraph, PromptReconstruction | |
| if isinstance(kg_identifier, int): | |
| kg = session.query(KnowledgeGraph).filter_by(id=kg_identifier).first() | |
| else: | |
| kg = session.query(KnowledgeGraph).filter_by(filename=kg_identifier).first() | |
| if not kg: | |
| return {} | |
| prompt_reconstructions = session.query(PromptReconstruction).filter_by(knowledge_graph_id=kg.id).all() | |
| return {pr.relation_id: pr.reconstructed_prompt for pr in prompt_reconstructions} | |
| def get_prompt_reconstruction_for_relation(session, kg_identifier, relation_id): | |
| """ | |
| Fetch a single reconstructed prompt for a given knowledge graph and relation_id. | |
| Returns the reconstructed_prompt string or None. | |
| """ | |
| from backend.database.models import KnowledgeGraph, PromptReconstruction | |
| if isinstance(kg_identifier, int): | |
| kg = session.query(KnowledgeGraph).filter_by(id=kg_identifier).first() | |
| else: | |
| kg = session.query(KnowledgeGraph).filter_by(filename=kg_identifier).first() | |
| if not kg: | |
| return None | |
| pr = session.query(PromptReconstruction).filter_by(knowledge_graph_id=kg.id, relation_id=relation_id).first() | |
| return pr.reconstructed_prompt if pr else None | |
| def save_causal_analysis( | |
| session: Session, | |
| knowledge_graph_id: int, | |
| perturbation_set_id: str, | |
| analysis_method: str, | |
| analysis_result: dict = None, | |
| causal_score: float = None, | |
| analysis_metadata: dict = None | |
| ): | |
| """Save a causal analysis result to the database.""" | |
| from backend.database import models | |
| causal_analysis = models.CausalAnalysis( | |
| knowledge_graph_id=knowledge_graph_id, | |
| perturbation_set_id=perturbation_set_id, | |
| analysis_method=analysis_method, | |
| analysis_result=analysis_result, | |
| causal_score=causal_score, | |
| analysis_metadata=analysis_metadata | |
| ) | |
| session.add(causal_analysis) | |
| session.commit() | |
| session.refresh(causal_analysis) | |
| return causal_analysis | |
| def get_causal_analysis( | |
| session: Session, | |
| knowledge_graph_id: int, | |
| perturbation_set_id: str, | |
| analysis_method: str | |
| ) -> Optional[models.CausalAnalysis]: | |
| """ | |
| Get causal analysis results from the database. | |
| Args: | |
| session: Database session | |
| knowledge_graph_id: ID of the knowledge graph | |
| perturbation_set_id: ID of the perturbation set | |
| analysis_method: Method used for analysis | |
| Returns: | |
| CausalAnalysis object or None if not found | |
| """ | |
| return session.query(models.CausalAnalysis).filter_by( | |
| knowledge_graph_id=knowledge_graph_id, | |
| perturbation_set_id=perturbation_set_id, | |
| analysis_method=analysis_method | |
| ).first() | |
| def get_all_causal_analyses( | |
| session: Session, | |
| knowledge_graph_id: Optional[int] = None, | |
| perturbation_set_id: Optional[str] = None, | |
| analysis_method: Optional[str] = None | |
| ) -> List[models.CausalAnalysis]: | |
| """ | |
| Get all causal analysis results from the database with optional filters. | |
| Args: | |
| session: Database session | |
| knowledge_graph_id: Optional filter by knowledge graph ID | |
| perturbation_set_id: Optional filter by perturbation set ID | |
| analysis_method: Optional filter by analysis method | |
| Returns: | |
| List of CausalAnalysis objects | |
| """ | |
| query = session.query(models.CausalAnalysis) | |
| if knowledge_graph_id is not None: | |
| query = query.filter_by(knowledge_graph_id=knowledge_graph_id) | |
| if perturbation_set_id is not None: | |
| query = query.filter_by(perturbation_set_id=perturbation_set_id) | |
| if analysis_method is not None: | |
| query = query.filter_by(analysis_method=analysis_method) | |
| return query.all() | |
| def get_causal_analysis_for_perturbation(session: Session, perturbation_set_id: str) -> List[Dict[str, Any]]: | |
| """ | |
| Get all causal analysis results for a specific perturbation set. | |
| Args: | |
| session: Database session | |
| perturbation_set_id: ID of the perturbation set | |
| Returns: | |
| List of causal analysis results with their associated data | |
| """ | |
| from backend.database.models import CausalAnalysis, KnowledgeGraph, PromptReconstruction | |
| results = session.query( | |
| CausalAnalysis, | |
| KnowledgeGraph, | |
| PromptReconstruction | |
| ).join( | |
| KnowledgeGraph, | |
| CausalAnalysis.knowledge_graph_id == KnowledgeGraph.id | |
| ).join( | |
| PromptReconstruction, | |
| CausalAnalysis.prompt_reconstruction_id == PromptReconstruction.id | |
| ).filter( | |
| CausalAnalysis.perturbation_set_id == perturbation_set_id | |
| ).all() | |
| return [{ | |
| 'analysis': analysis.to_dict(), | |
| 'knowledge_graph': kg.to_dict(), | |
| 'prompt_reconstruction': { | |
| 'id': pr.id, | |
| 'relation_id': pr.relation_id, | |
| 'reconstructed_prompt': pr.reconstructed_prompt, | |
| 'dependencies': pr.dependencies | |
| } | |
| } for analysis, kg, pr in results] | |
| def get_causal_analysis_by_method(session: Session, knowledge_graph_id: int, method: str) -> List[Dict[str, Any]]: | |
| """ | |
| Get causal analysis results for a specific knowledge graph and analysis method. | |
| Args: | |
| session: Database session | |
| knowledge_graph_id: ID of the knowledge graph | |
| method: Analysis method (e.g., 'graph', 'component', 'dowhy') | |
| Returns: | |
| List of causal analysis results | |
| """ | |
| from backend.database.models import CausalAnalysis, PerturbationTest | |
| results = session.query( | |
| CausalAnalysis, | |
| PerturbationTest | |
| ).join( | |
| PerturbationTest, | |
| CausalAnalysis.perturbation_test_id == PerturbationTest.id | |
| ).filter( | |
| CausalAnalysis.knowledge_graph_id == knowledge_graph_id, | |
| CausalAnalysis.analysis_method == method | |
| ).all() | |
| return [{ | |
| 'analysis': analysis.to_dict(), | |
| 'perturbation_test': { | |
| 'id': pt.id, | |
| 'perturbation_type': pt.perturbation_type, | |
| 'test_result': pt.test_result, | |
| 'perturbation_score': pt.perturbation_score | |
| } | |
| } for analysis, pt in results] | |
| def get_causal_analysis_summary(session: Session, knowledge_graph_id: int) -> Dict[str, Any]: | |
| """ | |
| Get a summary of causal analysis results for a knowledge graph. | |
| Args: | |
| session: Database session | |
| knowledge_graph_id: ID of the knowledge graph | |
| Returns: | |
| Dictionary containing summary statistics and results by method | |
| """ | |
| from backend.database.models import CausalAnalysis | |
| from sqlalchemy import func | |
| # Get all analyses for this knowledge graph | |
| analyses = session.query(CausalAnalysis).filter_by( | |
| knowledge_graph_id=knowledge_graph_id | |
| ).all() | |
| if not analyses: | |
| return { | |
| 'total_analyses': 0, | |
| 'methods': {}, | |
| 'average_scores': {} | |
| } | |
| # Group by method | |
| method_results = {} | |
| for analysis in analyses: | |
| method = analysis.analysis_method | |
| if method not in method_results: | |
| method_results[method] = [] | |
| method_results[method].append(analysis) | |
| # Calculate statistics | |
| summary = { | |
| 'total_analyses': len(analyses), | |
| 'methods': {}, | |
| 'average_scores': {} | |
| } | |
| for method, results in method_results.items(): | |
| scores = [r.causal_score for r in results if r.causal_score is not None] | |
| summary['methods'][method] = { | |
| 'count': len(results), | |
| 'average_score': sum(scores) / len(scores) if scores else None, | |
| 'min_score': min(scores) if scores else None, | |
| 'max_score': max(scores) if scores else None | |
| } | |
| return summary | |
| # Add these functions to handle perturbation tests | |
| def save_perturbation_test(session, | |
| knowledge_graph_id: int, | |
| prompt_reconstruction_id: int, | |
| relation_id: str, | |
| perturbation_type: str, | |
| perturbation_set_id: str, | |
| test_result: dict = None, | |
| perturbation_score: float = None, | |
| test_metadata: dict = None) -> int: | |
| """ | |
| Save a perturbation test to the database. | |
| Args: | |
| session: Database session | |
| knowledge_graph_id: ID of the knowledge graph | |
| prompt_reconstruction_id: ID of the prompt reconstruction | |
| relation_id: ID of the relation | |
| perturbation_type: Type of perturbation | |
| perturbation_set_id: ID of the perturbation set | |
| test_result: Test result dictionary | |
| perturbation_score: Perturbation score | |
| test_metadata: Test metadata dictionary | |
| Returns: | |
| int: ID of the saved perturbation test | |
| """ | |
| from backend.database.models import PerturbationTest | |
| # Create new perturbation test | |
| test = PerturbationTest( | |
| knowledge_graph_id=knowledge_graph_id, | |
| prompt_reconstruction_id=prompt_reconstruction_id, | |
| relation_id=relation_id, | |
| perturbation_type=perturbation_type, | |
| perturbation_set_id=perturbation_set_id, | |
| test_result=test_result or {}, | |
| perturbation_score=perturbation_score, | |
| test_metadata=test_metadata or {} | |
| ) | |
| # Add to session and commit | |
| session.add(test) | |
| session.commit() | |
| return test.id | |
| def delete_perturbation_test(session, test_id: int) -> bool: | |
| """ | |
| Delete a perturbation test from the database. | |
| Args: | |
| session: Database session | |
| test_id: ID of the perturbation test to delete | |
| Returns: | |
| bool: True if successful, False otherwise | |
| """ | |
| from backend.database.models import PerturbationTest | |
| # Query the test | |
| test = session.query(PerturbationTest).filter_by(id=test_id).first() | |
| if test: | |
| # Delete and commit | |
| session.delete(test) | |
| session.commit() | |
| return True | |
| return False | |
| def delete_perturbation_tests_by_set(session, perturbation_set_id: str) -> int: | |
| """ | |
| Delete all perturbation tests in a set. | |
| Args: | |
| session: Database session | |
| perturbation_set_id: ID of the perturbation set | |
| Returns: | |
| int: Number of tests deleted | |
| """ | |
| from backend.database.models import PerturbationTest | |
| # Query all tests in the set | |
| tests = session.query(PerturbationTest).filter_by(perturbation_set_id=perturbation_set_id).all() | |
| # Delete all tests | |
| deleted_count = 0 | |
| for test in tests: | |
| session.delete(test) | |
| deleted_count += 1 | |
| # Commit changes | |
| session.commit() | |
| return deleted_count | |
| def get_context_document_stats(session: Session, trace_id: str) -> Dict[str, Any]: | |
| """ | |
| Get statistics about context documents for a trace. | |
| Args: | |
| session: Database session | |
| trace_id: Trace ID to get stats for | |
| Returns: | |
| Dictionary with context document statistics | |
| """ | |
| trace = get_trace(session, trace_id) | |
| if not trace or not trace.trace_metadata or "context_documents" not in trace.trace_metadata: | |
| return {"total_count": 0, "active_count": 0, "by_type": {}} | |
| docs = trace.trace_metadata["context_documents"] | |
| # Count by type | |
| by_type = {} | |
| active_count = 0 | |
| for doc in docs: | |
| doc_type = doc.get("document_type", "unknown") | |
| if doc_type not in by_type: | |
| by_type[doc_type] = 0 | |
| by_type[doc_type] += 1 | |
| if doc.get("is_active", True): | |
| active_count += 1 | |
| return { | |
| "total_count": len(docs), | |
| "active_count": active_count, | |
| "by_type": by_type | |
| } | |
| def get_context_documents_from_trace(session: Session, trace_id: str) -> List[Dict[str, Any]]: | |
| """ | |
| Get context documents from a trace's metadata. | |
| Args: | |
| session: Database session | |
| trace_id: ID of the trace | |
| Returns: | |
| List of context documents, or empty list if none found | |
| """ | |
| trace = get_trace(session, trace_id) | |
| if not trace or not trace.trace_metadata or "context_documents" not in trace.trace_metadata: | |
| return [] | |
| docs = trace.trace_metadata["context_documents"] | |
| # Filter to only active documents | |
| active_docs = [doc for doc in docs if doc.get("is_active", True)] | |
| return active_docs | |
| def get_temporal_windows_by_trace_id(session: Session, trace_id: str, processing_run_id: Optional[str] = None) -> Dict[str, Any]: | |
| """ | |
| Get all knowledge graph windows for a specific trace, ordered by window_index. | |
| Also returns the full/merged version if available. | |
| Used for temporal force-directed graph visualization. | |
| This function handles cases where KGs exist with trace_id but no trace record exists. | |
| Args: | |
| session: Database session | |
| trace_id: ID of the trace | |
| processing_run_id: Optional ID to filter by specific processing run | |
| Returns: | |
| Dict containing windowed KGs and full KG information | |
| """ | |
| logger.info(f"Looking up temporal windows for trace_id: {trace_id}") | |
| if processing_run_id: | |
| logger.info(f"Filtering by processing_run_id: {processing_run_id}") | |
| # First check if we can find the trace | |
| trace = get_trace(session, trace_id) | |
| # Get all knowledge graphs for this trace_id (even if trace record doesn't exist) | |
| query = session.query(models.KnowledgeGraph).filter(models.KnowledgeGraph.trace_id == trace_id) | |
| # Filter by processing_run_id if provided | |
| if processing_run_id: | |
| query = query.filter(models.KnowledgeGraph.processing_run_id == processing_run_id) | |
| all_kgs = query.all() | |
| logger.info(f"Found {len(all_kgs)} total knowledge graphs for trace_id {trace_id}") | |
| if processing_run_id: | |
| logger.info(f"(filtered by processing_run_id: {processing_run_id})") | |
| if not all_kgs: | |
| logger.warning(f"No knowledge graphs found for trace_id: {trace_id}") | |
| return {"windows": [], "full_kg": None, "trace_info": None} | |
| # If no trace record exists, create minimal trace info from KGs | |
| if not trace: | |
| logger.warning(f"No trace record found for trace_id {trace_id}, but {len(all_kgs)} KGs exist") | |
| trace_info = { | |
| "trace_id": trace_id, | |
| "title": f"Trace {trace_id[:8]}...", | |
| "description": "Knowledge graphs exist but no trace record found", | |
| "upload_timestamp": min([kg.creation_timestamp for kg in all_kgs if kg.creation_timestamp]) | |
| } | |
| else: | |
| logger.info(f"Found trace record {trace.trace_id}") | |
| trace_info = { | |
| "trace_id": trace.trace_id, | |
| "title": trace.title, | |
| "description": trace.description, | |
| "upload_timestamp": trace.upload_timestamp | |
| } | |
| # Separate windowed KGs from the full/merged KG | |
| windowed_kgs = [] | |
| full_kg = None | |
| for kg in all_kgs: | |
| # Full/merged KG: has window_total but no window_index, null start/end chars | |
| if (kg.window_total is not None and | |
| kg.window_index is None and | |
| kg.window_start_char is None and | |
| kg.window_end_char is None): | |
| full_kg = kg | |
| logger.info(f"Found full/merged KG: {kg.filename} with window_total={kg.window_total}") | |
| # Windowed KG: has window_index | |
| elif kg.window_index is not None: | |
| windowed_kgs.append(kg) | |
| # KG without proper window info - try to assign window_index | |
| else: | |
| logger.info(f"Found KG without proper window info: {kg.filename}") | |
| # If we don't have a full KG yet and this looks like it could be one | |
| if (full_kg is None and | |
| kg.window_total is None and | |
| kg.window_start_char is None and | |
| kg.window_end_char is None): | |
| # Check if this KG has significantly more entities than others (indicating it's merged) | |
| if kg.graph_data and len(kg.graph_data.get("entities", [])) > 10: | |
| kg.window_total = len(windowed_kgs) + 1 # Set based on current windowed KGs | |
| full_kg = kg | |
| session.add(kg) | |
| logger.info(f"Assigned {kg.filename} as full KG with window_total={kg.window_total}") | |
| # If we have windowed KGs but some are missing window_index, assign them | |
| kgs_without_index = [kg for kg in all_kgs if kg.window_index is None and kg != full_kg] | |
| if kgs_without_index and windowed_kgs: | |
| logger.info("Assigning window_index to knowledge graphs based on creation order") | |
| # Sort by creation timestamp and assign window_index starting from the highest existing + 1 | |
| max_window_index = max([kg.window_index for kg in windowed_kgs], default=-1) | |
| kgs_without_index.sort(key=lambda kg: kg.creation_timestamp or datetime.utcnow()) | |
| for i, kg in enumerate(kgs_without_index): | |
| kg.window_index = max_window_index + 1 + i | |
| session.add(kg) | |
| windowed_kgs.append(kg) | |
| session.commit() | |
| logger.info(f"Assigned window_index to {len(kgs_without_index)} knowledge graphs") | |
| # Sort windowed KGs by window_index | |
| windowed_kgs.sort(key=lambda kg: kg.window_index) | |
| logger.info(f"Found {len(windowed_kgs)} windowed KGs and {'1' if full_kg else '0'} full KG for trace_id {trace_id}") | |
| # Update entity_count and relation_count if they're 0 or None but graph_data has content | |
| updated_count = 0 | |
| for kg in windowed_kgs + ([full_kg] if full_kg else []): | |
| if kg and kg.graph_data: | |
| needs_update = False | |
| if kg.entity_count is None or kg.entity_count == 0: | |
| entities = kg.graph_data.get("entities", []) | |
| if entities: | |
| kg.entity_count = len(entities) | |
| needs_update = True | |
| if kg.relation_count is None or kg.relation_count == 0: | |
| relations = kg.graph_data.get("relations", []) | |
| if relations: | |
| kg.relation_count = len(relations) | |
| needs_update = True | |
| if needs_update: | |
| session.add(kg) | |
| updated_count += 1 | |
| if updated_count > 0: | |
| session.commit() | |
| logger.info(f"Updated entity/relation counts for {updated_count} knowledge graphs") | |
| return { | |
| "windows": windowed_kgs, | |
| "full_kg": full_kg, | |
| "trace_info": trace_info | |
| } |