Spaces:
Running
Running
| """ | |
| Router for knowledge graph endpoints | |
| """ | |
| from fastapi import APIRouter, Depends, HTTPException, status, Path, Query, BackgroundTasks, Response, Request | |
| from sqlalchemy.orm import Session | |
| from fastapi.responses import FileResponse, JSONResponse, StreamingResponse | |
| from typing import List, Dict, Any, Optional | |
| from pydantic import BaseModel | |
| import logging | |
| import os | |
| import json | |
| import tempfile | |
| import time | |
| from datetime import datetime, timezone | |
| from sqlalchemy import text | |
| import shutil | |
| import traceback | |
| import sys | |
| import uuid | |
| import urllib.parse | |
| import math | |
| # Add the project root to the Python path for proper imports | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | |
| from backend.dependencies import get_db_session | |
| from backend.services import KnowledgeGraphService | |
| from backend.models import KnowledgeGraphResponse, PlatformStatsResponse | |
| from backend.database import get_db | |
| from backend.database.utils import get_knowledge_graph, save_knowledge_graph, save_test_result, update_knowledge_graph_status, delete_knowledge_graph | |
| from backend.database import models | |
| from backend.services.knowledge_graph_service import KnowledgeGraphService | |
| from backend.database.utils import get_knowledge_graph_by_id | |
| from backend.services.reconstruction_service import enrich_knowledge_graph_task | |
| from backend.services.testing_service import perturb_knowledge_graph_task | |
| from backend.services.causal_service import analyze_causal_relationships_task | |
| from backend.services.task_service import create_task | |
| from backend.database.models import PromptReconstruction, PerturbationTest, CausalAnalysis | |
| router = APIRouter(prefix="/api", tags=["knowledge_graphs"]) | |
| logger = logging.getLogger(__name__) | |
| async def get_knowledge_graphs(db: Session = Depends(get_db_session)): | |
| """ | |
| Get all available knowledge graphs | |
| """ | |
| try: | |
| files = KnowledgeGraphService.get_all_graphs(db) | |
| return {"files": files} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An internal error occurred while fetching knowledge graphs" | |
| ) | |
| async def get_latest_knowledge_graph(db: Session = Depends(get_db_session)): | |
| """ | |
| Get the latest knowledge graph from the database | |
| """ | |
| try: | |
| # Get the latest knowledge graph | |
| kg = KnowledgeGraphService.get_latest_graph(db) | |
| if not kg: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="No knowledge graph found" | |
| ) | |
| # Return the knowledge graph with its ID and status | |
| return { | |
| "id": kg.id, | |
| "filename": kg.filename, | |
| "status": kg.status, | |
| "creation_timestamp": kg.creation_timestamp.isoformat() if kg.creation_timestamp else None, | |
| "update_timestamp": kg.update_timestamp.isoformat() if kg.update_timestamp else None | |
| } | |
| except Exception as e: | |
| if isinstance(e, HTTPException): | |
| raise e | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An internal error occurred while retrieving knowledge graph" | |
| ) | |
| async def download_latest_knowledge_graph(db: Session = Depends(get_db_session)): | |
| """ | |
| Download the latest knowledge graph from the database | |
| """ | |
| try: | |
| # Get the latest knowledge graph | |
| kg = KnowledgeGraphService.get_latest_graph(db) | |
| if not kg: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="No knowledge graph found" | |
| ) | |
| # Return the knowledge graph data | |
| return kg.graph_data | |
| except Exception as e: | |
| if isinstance(e, HTTPException): | |
| raise e | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An internal error occurred while retrieving knowledge graph" | |
| ) | |
| async def get_knowledge_graph(graph_id: str, db: Session = Depends(get_db_session)): | |
| """ | |
| Get a specific knowledge graph by ID or filename | |
| """ | |
| try: | |
| # Get the graph data from database only | |
| graph_data = KnowledgeGraphService.get_graph_by_id(db, graph_id) | |
| return graph_data | |
| except FileNotFoundError as e: | |
| # Check if this is a "latest" request - should not happen anymore due to route reordering | |
| if graph_id.lower() == "latest": | |
| # For latest, still return a 404 | |
| logger.warning(f"No latest knowledge graph found") | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail="No latest knowledge graph found" | |
| ) | |
| # Enhanced error detail for debugging | |
| logger.warning(f"Knowledge graph not found: {graph_id} - creating default structure") | |
| # For named files, return a default empty structure instead of 404 | |
| # This helps the frontend display something instead of crashing | |
| return { | |
| "entities": [], | |
| "relations": [], | |
| "metadata": { | |
| "filename": graph_id, | |
| "error": "Knowledge graph not found in database", | |
| "created": datetime.utcnow().isoformat() | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Database error fetching graph {graph_id}: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An internal error occurred while fetching knowledge graph" | |
| ) | |
| async def get_platform_stats(db: Session = Depends(get_db_session)): | |
| """ | |
| Get platform-wide statistics | |
| """ | |
| try: | |
| stats = KnowledgeGraphService.get_platform_stats(db) | |
| return stats | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An internal error occurred while fetching platform statistics" | |
| ) | |
| async def get_entity_relation_data(db: Session = Depends(get_db_session)): | |
| """ | |
| Get entity-relation data optimized for force-directed graph visualization | |
| """ | |
| try: | |
| # Get platform stats to get entity and relation distributions | |
| try: | |
| stats = KnowledgeGraphService.get_platform_stats(db) | |
| logger.info(f"Successfully fetched platform stats: entities={getattr(stats, 'total_entities', 0)}, relations={getattr(stats, 'total_relations', 0)}") | |
| except Exception as e: | |
| logger.warning(f"Error fetching platform stats: {str(e)}") | |
| # Create minimal stats object instead of using sample data | |
| from types import SimpleNamespace | |
| stats = SimpleNamespace( | |
| total_entities=0, | |
| total_relations=0, | |
| entity_distribution={}, | |
| relation_distribution={} | |
| ) | |
| # Try to get the latest knowledge graph for detailed entity information | |
| latest_kg = None | |
| try: | |
| latest_kg = KnowledgeGraphService.get_latest_graph(db) | |
| except Exception as e: | |
| logger.warning(f"Error fetching latest knowledge graph: {str(e)}") | |
| # Get all entities from the database | |
| all_entities = [] | |
| try: | |
| # Custom SQL to get entities with their graph source | |
| query = text(""" | |
| SELECT e.entity_id, e.type, e.name, kg.filename | |
| FROM entities e | |
| LEFT JOIN knowledge_graphs kg ON e.graph_id = kg.id | |
| """) | |
| result = db.execute(query) | |
| all_entities = [{"id": row[0], "type": row[1], "name": row[2], "graph_source": row[3]} for row in result] | |
| logger.info(f"Successfully fetched {len(all_entities)} entities from database") | |
| except Exception as e: | |
| logger.warning(f"Error fetching entities from database: {str(e)}") | |
| # Define domain-specific clusters based on entity types and roles | |
| role_clusters = { | |
| 'governance': ['architect', 'legal', 'political', 'diplomat', 'negotiator', 'ethicist', 'governance'], | |
| 'technical': ['engineer', 'empiricist', 'simulation', 'crisis', 'technical', 'system'], | |
| 'research': ['historian', 'researcher', 'analyst', 'research'], | |
| 'operations': ['executor', 'manager', 'operator', 'operations'] | |
| } | |
| # Helper function to determine the cluster for an entity | |
| def determine_cluster(entity): | |
| # Default cluster is the entity type | |
| entity_type = (entity.get("type") or "").lower() | |
| name = (entity.get("name") or "").lower() | |
| # Check if entity fits into a specific role cluster based on name only | |
| for cluster, keywords in role_clusters.items(): | |
| if any(keyword in name for keyword in keywords): | |
| return cluster | |
| # Default clustering by entity type | |
| if entity_type == "agent": | |
| return "agent" | |
| elif entity_type == "tool": | |
| return "tool" | |
| elif entity_type == "task": | |
| return "task" | |
| else: | |
| return "other" | |
| # Build nodes from entity distribution and actual entity data | |
| nodes = [] | |
| links = [] | |
| node_id = 0 | |
| node_id_map = {} | |
| entity_map = {} | |
| # Add actual entities from database if available | |
| if all_entities: | |
| for entity in all_entities: | |
| # Skip entities with missing name or type | |
| if not entity.get("name") or not entity.get("type"): | |
| continue | |
| # Determine appropriate cluster | |
| cluster = determine_cluster(entity) | |
| # Create node ID and add to maps | |
| node_id_str = f"entity-{node_id}" | |
| entity_map[entity.get("id")] = node_id_str | |
| node_id_map[entity.get("name")] = node_id_str | |
| # Add node (simplified without properties) | |
| nodes.append({ | |
| "id": node_id_str, | |
| "name": entity.get("name"), | |
| "type": entity.get("type"), | |
| "cluster": cluster, | |
| "description": f"{entity.get('name')} ({entity.get('type')})", | |
| "importance": 1.0 if "architect" in entity.get("name", "").lower() else 0.8, | |
| "graph_source": entity.get("graph_source") | |
| }) | |
| node_id += 1 | |
| # Process entity distribution data to create entity type nodes if we don't have enough entities | |
| if len(nodes) < 5: | |
| entity_distribution = getattr(stats, 'entity_distribution', {}) or {} | |
| if not entity_distribution: | |
| logger.warning("No entity distribution found, no entity type nodes will be created") | |
| else: | |
| logger.info(f"Using entity distribution data ({len(entity_distribution)} types) to supplement entity nodes") | |
| for entity_type, count in entity_distribution.items(): | |
| if not entity_type: | |
| continue | |
| # Default cluster for this entity type | |
| if "agent" in entity_type.lower(): | |
| cluster = "agent" | |
| elif "tool" in entity_type.lower(): | |
| cluster = "tool" | |
| elif "task" in entity_type.lower(): | |
| cluster = "task" | |
| else: | |
| cluster = "other" | |
| # Create a main node for the entity type | |
| node_id_str = f"entity-{node_id}" | |
| node_id_map[entity_type] = node_id_str | |
| nodes.append({ | |
| "id": node_id_str, | |
| "name": entity_type, | |
| "type": "EntityType", | |
| "cluster": cluster, | |
| "count": count, | |
| "description": f"{entity_type} entities ({count})" | |
| }) | |
| node_id += 1 | |
| # Get all relations from database | |
| all_relations = [] | |
| try: | |
| # Custom SQL to get relations with their graph source | |
| query = text(""" | |
| SELECT r.relation_id, r.type, e1.entity_id as source, e2.entity_id as target, kg.filename | |
| FROM relations r | |
| JOIN entities e1 ON r.source_id = e1.id | |
| JOIN entities e2 ON r.target_id = e2.id | |
| LEFT JOIN knowledge_graphs kg ON r.graph_id = kg.id | |
| """) | |
| result = db.execute(query) | |
| all_relations = [{"id": row[0], "type": row[1], "source": row[2], "target": row[3], "graph_source": row[4]} for row in result] | |
| logger.info(f"Successfully fetched {len(all_relations)} relations from database") | |
| except Exception as e: | |
| logger.warning(f"Error fetching relations from database: {str(e)}") | |
| # Add actual relations from database if available | |
| if all_relations: | |
| for relation in all_relations: | |
| # Skip if missing source or target | |
| if not relation.get("source") or not relation.get("target"): | |
| continue | |
| # Get node IDs from map | |
| source_id = entity_map.get(relation.get("source")) | |
| target_id = entity_map.get(relation.get("target")) | |
| # Skip if source or target not in our nodes | |
| if not source_id or not target_id: | |
| continue | |
| # Create link with value based on relation type | |
| value = 1 | |
| if relation.get("type") == "PERFORMS": | |
| value = 2 | |
| elif relation.get("type") == "USES": | |
| value = 1.8 | |
| elif relation.get("type") == "ASSIGNED_TO": | |
| value = 1.5 | |
| links.append({ | |
| "source": source_id, | |
| "target": target_id, | |
| "type": relation.get("type", "RELATED"), | |
| "value": value, | |
| "graph_source": relation.get("graph_source") | |
| }) | |
| # Use relation distribution to add sample relations if needed | |
| if len(links) < 5: | |
| relation_distribution = getattr(stats, 'relation_distribution', {}) or {} | |
| if relation_distribution: | |
| logger.info(f"Using relation distribution data ({len(relation_distribution)} types) to supplement relation links") | |
| for relation_type, count in relation_distribution.items(): | |
| if not relation_type or count < 1: | |
| continue | |
| # Only add distribution relations if we have nodes to connect | |
| entity_nodes = [n for n in nodes if n["type"] != "RelationType"] | |
| if len(entity_nodes) < 2: | |
| continue | |
| # Create up to 5 connections for this relation type | |
| conn_count = min(5, count, len(entity_nodes) // 2) | |
| for i in range(conn_count): | |
| # Pick two unique random nodes to connect | |
| import random | |
| source_node = random.choice(entity_nodes) | |
| target_nodes = [n for n in entity_nodes if n["id"] != source_node["id"]] | |
| if not target_nodes: | |
| continue | |
| target_node = random.choice(target_nodes) | |
| links.append({ | |
| "source": source_node["id"], | |
| "target": target_node["id"], | |
| "type": relation_type, | |
| "value": 1 | |
| }) | |
| # Make sure all nodes are connected by adding minimum spanning links if needed | |
| if nodes and (len(links) < len(nodes) - 1): | |
| logger.info("Ensuring all nodes are connected by adding minimum spanning links") | |
| connected_nodes = set() | |
| # Start with the first node | |
| if links: | |
| connected_nodes.add(links[0]["source"]) | |
| connected_nodes.add(links[0]["target"]) | |
| elif nodes: | |
| connected_nodes.add(nodes[0]["id"]) | |
| # Add links until all nodes are connected | |
| while len(connected_nodes) < len(nodes): | |
| # Find unconnected nodes | |
| unconnected = [n["id"] for n in nodes if n["id"] not in connected_nodes] | |
| if not unconnected: | |
| break | |
| # Pick a connected and unconnected node | |
| if connected_nodes and unconnected: | |
| source_id = list(connected_nodes)[0] | |
| target_id = unconnected[0] | |
| # Add a connection | |
| links.append({ | |
| "source": source_id, | |
| "target": target_id, | |
| "type": "connected_to", | |
| "value": 0.5 # Weaker connection | |
| }) | |
| # Mark target as connected | |
| connected_nodes.add(target_id) | |
| else: | |
| break # Can't connect any more nodes | |
| # Log the actual counts of what we're returning | |
| logger.info(f"Returning entity-relation data with {len(nodes)} nodes and {len(links)} links") | |
| return { | |
| "nodes": nodes, | |
| "links": links, | |
| "metadata": { | |
| "total_entities": getattr(stats, 'total_entities', 0), | |
| "total_relations": getattr(stats, 'total_relations', 0), | |
| "entity_types": len(set([n.get("type") for n in nodes])), | |
| "relation_types": len(set([l.get("type") for l in links])), | |
| "clusters": list(set([n.get("cluster") for n in nodes if n.get("cluster")])), | |
| "is_real_data": True # Always using real data | |
| } | |
| } | |
| except Exception as e: | |
| logger.error(f"Error generating entity-relation data: {str(e)}", exc_info=True) | |
| # Return empty but valid data structure instead of sample data | |
| return { | |
| "nodes": [], | |
| "links": [], | |
| "metadata": { | |
| "total_entities": 0, | |
| "total_relations": 0, | |
| "entity_types": 0, | |
| "clusters": [], | |
| "is_real_data": False, | |
| "error": str(e) | |
| } | |
| } | |
| def get_knowledge_graph( | |
| kg_id: int, | |
| db: Session = Depends(get_db_session), | |
| ): | |
| """ | |
| Get a specific knowledge graph by ID | |
| """ | |
| try: | |
| # Get the knowledge graph | |
| kg = KnowledgeGraphService.get_graph_model_by_id(db, kg_id) | |
| if not kg: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Knowledge graph with ID {kg_id} not found" | |
| ) | |
| # Return the knowledge graph with its ID and status | |
| return { | |
| "id": kg.id, | |
| "filename": kg.filename, | |
| "status": kg.status, | |
| "creation_timestamp": kg.creation_timestamp.isoformat() if kg.creation_timestamp else None, | |
| "update_timestamp": kg.update_timestamp.isoformat() if kg.update_timestamp else None | |
| } | |
| except Exception as e: | |
| if isinstance(e, HTTPException): | |
| raise e | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An internal error occurred while retrieving knowledge graph" | |
| ) | |
| def download_knowledge_graph( | |
| kg_id: int, | |
| db: Session = Depends(get_db_session), | |
| ): | |
| """ | |
| Download a specific knowledge graph by ID | |
| """ | |
| try: | |
| # Get the knowledge graph | |
| kg = KnowledgeGraphService.get_graph_model_by_id(db, kg_id) | |
| if not kg: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Knowledge graph with ID {kg_id} not found" | |
| ) | |
| # Return the knowledge graph data | |
| return kg.graph_data | |
| except Exception as e: | |
| if isinstance(e, HTTPException): | |
| raise e | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An internal error occurred while retrieving knowledge graph" | |
| ) | |
| # ========================================================== | |
| # NOTE: The stage processing endpoints have been moved to | |
| # server/routers/stage_processor.py for better consistency | |
| # and easier maintenance: | |
| # | |
| # - /knowledge-graphs/{kg_id}/enrich (Prompt Reconstruction) | |
| # - /knowledge-graphs/{kg_id}/perturb (Perturbation Testing) | |
| # - /knowledge-graphs/{kg_id}/analyze (Causal Analysis) | |
| # - /knowledge-graphs/{kg_id}/advance-stage (Chain processing) | |
| # | |
| # Please use the endpoints in stage_processor.py instead. | |
| # ========================================================== | |
| async def download_knowledge_graph_by_id_or_filename( | |
| graph_id: str, | |
| db: Session = Depends(get_db_session), | |
| ): | |
| """ | |
| Download a knowledge graph by ID or filename | |
| """ | |
| try: | |
| logger.info(f"Attempting to download knowledge graph: {graph_id}") | |
| # Special handling for "latest" | |
| if graph_id == "latest": | |
| kg = KnowledgeGraphService.get_latest_graph(db) | |
| else: | |
| # Try to get the knowledge graph using the service | |
| # First check if it's an integer ID | |
| try: | |
| kg_id = int(graph_id) | |
| kg = KnowledgeGraphService.get_graph_model_by_id(db, kg_id) | |
| logger.info(f"Found knowledge graph by ID {kg_id}") | |
| except ValueError: | |
| # If not a number, try as a filename | |
| kg = KnowledgeGraphService.get_graph_by_filename(db, graph_id) | |
| logger.info(f"Found knowledge graph by filename {graph_id}") | |
| if not kg: | |
| logger.warning(f"Knowledge graph not found: {graph_id}") | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Knowledge graph {graph_id} not found" | |
| ) | |
| # Return the knowledge graph data | |
| return kg.graph_data | |
| except Exception as e: | |
| if isinstance(e, HTTPException): | |
| raise e | |
| logger.error(f"Error downloading knowledge graph: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An internal error occurred while downloading knowledge graph" | |
| ) | |
| async def get_knowledge_graph_status( | |
| graph_id: str, | |
| db: Session = Depends(get_db_session) | |
| ): | |
| """ | |
| Get the processing status of a knowledge graph. | |
| Args: | |
| graph_id: ID of the knowledge graph | |
| db: Database session | |
| Returns: | |
| Knowledge graph status information | |
| """ | |
| try: | |
| # URL decode the graph_id if it contains URL-encoded characters | |
| decoded_graph_id = urllib.parse.unquote(graph_id) | |
| # Use get_knowledge_graph_by_id | |
| kg = None | |
| if decoded_graph_id == "latest": | |
| kg = KnowledgeGraphService.get_latest_graph(db) | |
| if not kg: | |
| return JSONResponse( | |
| status_code=404, | |
| content={"detail": "No latest knowledge graph found"} | |
| ) | |
| else: | |
| try: | |
| # Try by ID first | |
| kg_id = int(decoded_graph_id) | |
| kg = KnowledgeGraphService.get_graph_model_by_id(db, kg_id) | |
| except ValueError: | |
| # Then try by filename | |
| kg = KnowledgeGraphService.get_graph_by_filename(db, decoded_graph_id) | |
| if not kg: | |
| # Return a 404 response directly | |
| return JSONResponse( | |
| status_code=404, | |
| content={"detail": f"Knowledge graph {decoded_graph_id} not found"} | |
| ) | |
| # Build the response with knowledge graph information | |
| return { | |
| "id": kg.id, | |
| "filename": kg.filename, | |
| "trace_id": kg.trace_id, | |
| "status": kg.status or "created", | |
| "is_original": kg.status == "created" or kg.status is None, | |
| "is_enriched": kg.status == "enriched" or kg.status == "perturbed" or kg.status == "analyzed", | |
| "is_perturbed": kg.status == "perturbed" or kg.status == "analyzed", | |
| "is_analyzed": kg.status == "analyzed", | |
| "created_at": kg.creation_timestamp.isoformat() if kg.creation_timestamp else None, | |
| "updated_at": kg.update_timestamp.isoformat() if kg.update_timestamp else None | |
| } | |
| except Exception as e: | |
| logger.error(f"Error retrieving knowledge graph status: {str(e)}") | |
| # Only raise a 500 error for unexpected exceptions | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_knowledge_graph_by_id( | |
| graph_id: str, | |
| db: Session = Depends(get_db_session), | |
| ): | |
| """ | |
| Delete a knowledge graph by ID | |
| Args: | |
| graph_id: ID of the knowledge graph to delete | |
| db: Database session | |
| Returns: | |
| Status message | |
| """ | |
| try: | |
| logger.info(f"Attempting to delete knowledge graph: {graph_id}") | |
| # Check if it's an integer ID | |
| try: | |
| kg_id = int(graph_id) | |
| # Get the knowledge graph first to verify it exists | |
| kg = KnowledgeGraphService.get_graph_model_by_id(db, kg_id) | |
| if not kg: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Knowledge graph with ID {kg_id} not found" | |
| ) | |
| # Clean up related records in dependent tables first | |
| try: | |
| logger.info(f"Cleaning up related records for knowledge graph ID {kg_id}") | |
| # Delete causal analyses related to this knowledge graph | |
| db.execute( | |
| text("DELETE FROM causal_analyses WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg_id} | |
| ) | |
| # Delete perturbation test results related to this knowledge graph | |
| db.execute( | |
| text("DELETE FROM perturbation_tests WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg_id} | |
| ) | |
| # Delete prompt reconstructions related to this knowledge graph | |
| db.execute( | |
| text("DELETE FROM prompt_reconstructions WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg_id} | |
| ) | |
| # Commit the cleanup operations | |
| db.commit() | |
| logger.info(f"Successfully cleaned up related records for knowledge graph ID {kg_id}") | |
| except Exception as cleanup_error: | |
| db.rollback() | |
| logger.error(f"Error cleaning up related records: {str(cleanup_error)}") | |
| raise cleanup_error | |
| # Now delete the knowledge graph itself | |
| result = delete_knowledge_graph(db, kg_id) | |
| if result: | |
| return {"status": "success", "message": f"Knowledge graph with ID {kg_id} deleted successfully"} | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Knowledge graph with ID {kg_id} not found or could not be deleted" | |
| ) | |
| except ValueError: | |
| # If not a number, try as a filename | |
| # First get the knowledge graph to get its ID | |
| kg = KnowledgeGraphService.get_graph_by_filename(db, graph_id) | |
| if not kg: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Knowledge graph with filename {graph_id} not found" | |
| ) | |
| # Clean up related records using the knowledge graph ID | |
| kg_id = kg.id | |
| try: | |
| logger.info(f"Cleaning up related records for knowledge graph filename {graph_id} (ID: {kg_id})") | |
| # Delete causal analyses related to this knowledge graph | |
| db.execute( | |
| text("DELETE FROM causal_analyses WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg_id} | |
| ) | |
| # Delete perturbation test results related to this knowledge graph | |
| db.execute( | |
| text("DELETE FROM perturbation_tests WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg_id} | |
| ) | |
| # Delete prompt reconstructions related to this knowledge graph | |
| db.execute( | |
| text("DELETE FROM prompt_reconstructions WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg_id} | |
| ) | |
| # Commit the cleanup operations | |
| db.commit() | |
| logger.info(f"Successfully cleaned up related records for knowledge graph filename {graph_id}") | |
| except Exception as cleanup_error: | |
| db.rollback() | |
| logger.error(f"Error cleaning up related records: {str(cleanup_error)}") | |
| raise cleanup_error | |
| # Now delete the knowledge graph itself | |
| result = delete_knowledge_graph(db, graph_id) | |
| if result: | |
| return {"status": "success", "message": f"Knowledge graph with filename {graph_id} deleted successfully"} | |
| else: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Knowledge graph with filename {graph_id} not found or could not be deleted" | |
| ) | |
| except Exception as e: | |
| if isinstance(e, HTTPException): | |
| raise e | |
| logger.error(f"Error deleting knowledge graph: {str(e)}") | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="An internal error occurred while deleting knowledge graph" | |
| ) | |
| # Helper function to sanitize JSON data | |
| def sanitize_json(obj): | |
| if isinstance(obj, dict): | |
| return {k: sanitize_json(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [sanitize_json(item) for item in obj] | |
| elif isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)): | |
| return None | |
| else: | |
| return obj | |
| async def enrich_knowledge_graph(kg_id: str, background_tasks: BackgroundTasks, session: Session = Depends(get_db)): | |
| """ | |
| Start a background task to enrich the knowledge graph with prompt reconstructions. | |
| """ | |
| try: | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| return JSONResponse(status_code=404, content={"detail": f"Knowledge graph with ID {kg_id} not found"}) | |
| task_id = f"enrich_kg_{kg_id}_{int(time.time())}" | |
| create_task(task_id, "enrich_knowledge_graph", f"Enriching knowledge graph {kg_id}") | |
| background_tasks.add_task(enrich_knowledge_graph_task, kg_id, task_id) | |
| return {"status": "success", "task_id": task_id} | |
| except Exception as e: | |
| logger.error(f"Error starting knowledge graph enrichment: {str(e)}") | |
| return JSONResponse(status_code=500, content={"detail": f"Error starting knowledge graph enrichment: {str(e)}"}) | |
| # Pydantic models for perturbation configuration | |
| class JailbreakConfigModel(BaseModel): | |
| enabled: bool = True | |
| num_techniques: int = 10 | |
| prompt_source: str = "standard" | |
| class DemographicModel(BaseModel): | |
| gender: str | |
| race: str | |
| class CounterfactualBiasConfigModel(BaseModel): | |
| enabled: bool = True | |
| demographics: List[DemographicModel] = [ | |
| DemographicModel(gender="male", race="White"), | |
| DemographicModel(gender="female", race="White"), | |
| DemographicModel(gender="male", race="Black"), | |
| DemographicModel(gender="female", race="Black"), | |
| ] | |
| include_baseline: bool = True | |
| comparison_mode: str = "both" # "all_pairs", "vs_baseline", or "both" | |
| class PerturbationConfigModel(BaseModel): | |
| """Configuration for perturbation testing.""" | |
| model: str = "gpt-4o-mini" | |
| judge_model: str = "gpt-4o-mini" | |
| max_relations: Optional[int] = None | |
| jailbreak: Optional[JailbreakConfigModel] = None | |
| counterfactual_bias: Optional[CounterfactualBiasConfigModel] = None | |
| async def perturb_knowledge_graph( | |
| kg_id: str, | |
| background_tasks: BackgroundTasks, | |
| config: Optional[PerturbationConfigModel] = None, | |
| session: Session = Depends(get_db) | |
| ): | |
| """ | |
| Start a background task to perturb the knowledge graph identified by kg_id. | |
| Accepts optional configuration for customizing the perturbation tests: | |
| - model: LLM model to use for testing (default: gpt-4o-mini) | |
| - judge_model: Model for evaluation (default: gpt-4o-mini) | |
| - max_relations: Limit number of relations to test (default: all) | |
| - jailbreak: Jailbreak test configuration | |
| - counterfactual_bias: Bias test configuration | |
| """ | |
| try: | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| return JSONResponse(status_code=404, content={"detail": f"Knowledge graph with ID {kg_id} not found"}) | |
| if kg.status not in ["enriched", "perturbed", "analyzed"]: | |
| return JSONResponse(status_code=400, content={"detail": f"Knowledge graph must be enriched before perturbation"}) | |
| task_id = f"perturb_kg_{kg_id}_{int(time.time())}" | |
| create_task(task_id, "perturb_knowledge_graph", f"Processing knowledge graph {kg_id}") | |
| # Convert config to dict for passing to background task | |
| config_dict = config.model_dump() if config else None | |
| background_tasks.add_task(perturb_knowledge_graph_task, kg_id, task_id, config_dict) | |
| return { | |
| "status": "success", | |
| "task_id": task_id, | |
| "config": config_dict | |
| } | |
| except Exception as e: | |
| logger.error(f"Error starting perturbation task: {str(e)}") | |
| return {"status": "error", "error": str(e)} | |
| async def analyze_knowledge_graph(kg_id: str, background_tasks: BackgroundTasks, session: Session = Depends(get_db)): | |
| """Standardized endpoint for analyzing causal relationships in a knowledge graph.""" | |
| try: | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| raise HTTPException(status_code=404, detail=f"Knowledge graph with ID {kg_id} not found") | |
| if kg.status not in ["perturbed", "analyzed"]: | |
| raise HTTPException(status_code=400, detail="Knowledge graph must be perturbed before causal analysis") | |
| if kg.status == "analyzed": | |
| return {"message": "Knowledge graph is already analyzed", "status": "COMPLETED"} | |
| task_id = f"analyze_kg_{kg_id}_{int(time.time())}" | |
| create_task(task_id, "analyze_causal_relationships", f"Analyzing causal relationships for knowledge graph {kg_id}") | |
| background_tasks.add_task(analyze_causal_relationships_task, kg_id, task_id) | |
| return {"status": "success", "task_id": task_id, "message": "Causal analysis scheduled"} | |
| except HTTPException as http_ex: | |
| raise http_ex | |
| except Exception as e: | |
| logger.error(f"Error scheduling causal analysis: {str(e)}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while scheduling causal analysis") | |
| async def get_knowledge_graph_status(kg_id: str, session: Session = Depends(get_db)): | |
| """Get the processing status of a knowledge graph.""" | |
| try: | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| return JSONResponse(status_code=404, content={"detail": f"Knowledge graph with ID {kg_id} not found"}) | |
| return { | |
| "id": kg.id, | |
| "filename": kg.filename, | |
| "trace_id": kg.trace_id, | |
| "status": kg.status or "created", | |
| "is_original": kg.status == "created" or kg.status is None, | |
| "is_enriched": kg.status == "enriched" or kg.status == "perturbed" or kg.status == "analyzed", | |
| "is_perturbed": kg.status == "perturbed" or kg.status == "analyzed", | |
| "is_analyzed": kg.status == "analyzed", | |
| "created_at": kg.creation_timestamp.isoformat() if kg.creation_timestamp else None, | |
| "updated_at": kg.update_timestamp.isoformat() if kg.update_timestamp else None | |
| } | |
| except Exception as e: | |
| logger.error(f"Error retrieving knowledge graph status: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_stage_results(kg_id: str, stage: str, session: Session = Depends(get_db)): | |
| """Get the results of a specific stage for a knowledge graph.""" | |
| # This is a large function, the implementation is being moved as-is | |
| # ... (implementation from stage_processor.py) | |
| try: | |
| # Get the knowledge graph | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| raise HTTPException(status_code=404, detail=f"Knowledge graph with ID {kg_id} not found") | |
| # Get graph data once for all stages | |
| graph_data = kg.graph_data | |
| if isinstance(graph_data, str): | |
| graph_data = json.loads(graph_data) | |
| # Extract stage-specific data | |
| result = {} | |
| if stage == "enrich": | |
| prompt_reconstructions = session.query(PromptReconstruction).filter_by(knowledge_graph_id=kg.id).all() | |
| if prompt_reconstructions: | |
| entities_map = {entity["id"]: entity for entity in graph_data.get("entities", [])} | |
| relations_map = {relation["id"]: relation for relation in graph_data.get("relations", [])} | |
| reconstructions_data = [] | |
| for pr in prompt_reconstructions: | |
| relation = relations_map.get(pr.relation_id) | |
| reconstruction = { | |
| "id": pr.id, "relation_id": pr.relation_id, "reconstructed_prompt": pr.reconstructed_prompt, | |
| "dependencies": pr.dependencies, "created_at": pr.created_at.isoformat() if pr.created_at else None, | |
| "updated_at": pr.updated_at.isoformat() if pr.updated_at else None | |
| } | |
| if relation: | |
| source_entity = entities_map.get(relation.get("source")) | |
| target_entity = entities_map.get(relation.get("target")) | |
| reconstruction["relation"] = { | |
| "id": relation.get("id"), "type": relation.get("type"), | |
| "source": {"id": source_entity.get("id") if source_entity else None, "name": source_entity.get("name") if source_entity else "Unknown", "type": source_entity.get("type") if source_entity else "Unknown"}, | |
| "target": {"id": target_entity.get("id") if target_entity else None, "name": target_entity.get("name") if target_entity else "Unknown", "type": target_entity.get("type") if target_entity else "Unknown"} | |
| } | |
| reconstructions_data.append(reconstruction) | |
| total_relations = len([r for r in graph_data.get("relations", []) if r.get("type") not in ["REQUIRES_TOOL", "NEXT"]]) | |
| reconstructed_count = len(prompt_reconstructions) | |
| if total_relations == 0 and reconstructed_count > 0: | |
| total_relations = reconstructed_count | |
| result = { | |
| "prompt_reconstructions": reconstructions_data, | |
| "summary": { "total_relations": total_relations, "reconstructed_count": reconstructed_count, "reconstruction_coverage": f"{(reconstructed_count/total_relations*100):.1f}%" if total_relations > 0 else "100%" } | |
| } | |
| else: | |
| message = "This knowledge graph has not been enriched yet." | |
| if kg.status != "created" and kg.status is not None: | |
| message = "No prompt reconstructions found." | |
| total_relations = len([r for r in graph_data.get("relations", []) if r.get("type") not in ["REQUIRES_TOOL", "NEXT"]]) | |
| result = {"message": message, "summary": {"total_relations": total_relations, "reconstructed_count": 0, "reconstruction_coverage": "0%"}} | |
| elif stage == "perturb": | |
| perturbation_tests = session.query(PerturbationTest).filter_by(knowledge_graph_id=kg.id).all() | |
| if perturbation_tests: | |
| entities_map = {entity["id"]: entity for entity in graph_data.get("entities", [])} | |
| relations_map = {relation["id"]: relation for relation in graph_data.get("relations", [])} | |
| tests_by_relation = {} | |
| for test in perturbation_tests: | |
| if test.relation_id not in tests_by_relation: tests_by_relation[test.relation_id] = [] | |
| tests_by_relation[test.relation_id].append(test) | |
| perturbation_results = [] | |
| total_score, total_tests_count = 0, 0 | |
| for relation_id, tests in tests_by_relation.items(): | |
| relation = relations_map.get(relation_id) | |
| if relation: | |
| source_entity = entities_map.get(relation.get("source")) | |
| target_entity = entities_map.get(relation.get("target")) | |
| test_results_data = [] | |
| relation_score = 0 | |
| for test in tests: | |
| test_result = { | |
| "id": test.id, "type": test.perturbation_type, "score": test.perturbation_score, | |
| "result": test.test_result, "metadata": test.test_metadata, | |
| "perturbation_set_id": test.perturbation_set_id, "created_at": test.created_at.isoformat() if test.created_at else None, | |
| "updated_at": test.updated_at.isoformat() if test.updated_at else None | |
| } | |
| test_results_data.append(test_result) | |
| if test.perturbation_score is not None: | |
| relation_score += test.perturbation_score | |
| total_score += test.perturbation_score | |
| total_tests_count += 1 | |
| avg_relation_score = relation_score / len(test_results_data) if test_results_data else 0 | |
| perturbation_results.append({ | |
| "relation_id": relation_id, | |
| "relation": {"id": relation.get("id"), "type": relation.get("type"), "source": {"id": source_entity.get("id") if source_entity else None, "name": source_entity.get("name") if source_entity else "Unknown", "type": source_entity.get("type") if source_entity else "Unknown"}, "target": {"id": target_entity.get("id") if target_entity else None, "name": target_entity.get("name") if target_entity else "Unknown", "type": target_entity.get("type") if target_entity else "Unknown"}}, | |
| "tests": test_results_data, "average_score": avg_relation_score | |
| }) | |
| overall_score = total_score / total_tests_count if total_tests_count > 0 else 0 | |
| result = {"perturbation_results": perturbation_results, "summary": {"total_relations_tested": len(perturbation_results), "total_tests": total_tests_count, "average_score": overall_score}} | |
| else: | |
| result = {"message": "This knowledge graph has not been perturbation tested yet." if kg.status in ["created", "enriched"] else "No perturbation test results found."} | |
| elif stage == "causal": | |
| causal_relations = session.query(CausalAnalysis).filter_by(knowledge_graph_id=kg.id).all() | |
| if causal_relations: | |
| entities_map = {entity["id"]: entity for entity in graph_data.get("entities", [])} | |
| relations_map = {relation["id"]: relation for relation in graph_data.get("relations", [])} | |
| # Get perturbation type mapping and metadata for each set | |
| perturbation_set_types = {} | |
| perturbation_set_metadata = {} | |
| perturbation_set_ids = list(set(cr.perturbation_set_id for cr in causal_relations if cr.perturbation_set_id)) | |
| if perturbation_set_ids: | |
| perturbation_tests = session.query( | |
| PerturbationTest.perturbation_set_id, | |
| PerturbationTest.perturbation_type, | |
| PerturbationTest.created_at, | |
| PerturbationTest.test_metadata | |
| ).filter( | |
| PerturbationTest.knowledge_graph_id == kg.id, | |
| PerturbationTest.perturbation_set_id.in_(perturbation_set_ids) | |
| ).distinct().all() | |
| perturbation_set_types = {pt.perturbation_set_id: pt.perturbation_type for pt in perturbation_tests} | |
| perturbation_set_metadata = { | |
| pt.perturbation_set_id: { | |
| "created_at": pt.created_at.isoformat() if pt.created_at else None, | |
| "test_metadata": pt.test_metadata or {} | |
| } for pt in perturbation_tests | |
| } | |
| causal_results = [] | |
| for cr in causal_relations: | |
| analysis_result = sanitize_json(cr.analysis_result or {}) | |
| cause_relation_id = analysis_result.get('cause_relation_id') | |
| effect_relation_id = analysis_result.get('effect_relation_id') | |
| source_relation = relations_map.get(cause_relation_id) if cause_relation_id else None | |
| target_relation = relations_map.get(effect_relation_id) if effect_relation_id else None | |
| causal_score = cr.causal_score | |
| if causal_score is not None and (math.isnan(causal_score) or math.isinf(causal_score)): | |
| causal_score = None | |
| causal_result = { | |
| "id": cr.id, "causal_score": causal_score, "analysis_method": cr.analysis_method, | |
| "created_at": cr.created_at.isoformat() if cr.created_at else None, "updated_at": cr.updated_at.isoformat() if cr.updated_at else None, | |
| "perturbation_set_id": cr.perturbation_set_id, "metadata": sanitize_json(cr.analysis_metadata) | |
| } | |
| if source_relation and target_relation: | |
| source_source_entity = entities_map.get(source_relation.get("source")) | |
| source_target_entity = entities_map.get(source_relation.get("target")) | |
| target_source_entity = entities_map.get(target_relation.get("source")) | |
| target_target_entity = entities_map.get(target_relation.get("target")) | |
| causal_result["cause_relation"] = {"id": source_relation.get("id"), "type": source_relation.get("type"), "source": {"id": source_source_entity.get("id") if source_source_entity else None, "name": source_source_entity.get("name") if source_source_entity else "Unknown", "type": source_source_entity.get("type") if source_source_entity else "Unknown"}, "target": {"id": source_target_entity.get("id") if source_target_entity else None, "name": source_target_entity.get("name") if source_target_entity else "Unknown", "type": source_target_entity.get("type") if source_target_entity else "Unknown"}} | |
| causal_result["effect_relation"] = {"id": target_relation.get("id"), "type": target_relation.get("type"), "source": {"id": target_source_entity.get("id") if target_source_entity else None, "name": target_source_entity.get("name") if target_source_entity else "Unknown", "type": target_source_entity.get("type") if target_source_entity else "Unknown"}, "target": {"id": target_target_entity.get("id") if target_target_entity else None, "name": target_target_entity.get("name") if target_target_entity else "Unknown", "type": target_target_entity.get("type") if target_target_entity else "Unknown"}} | |
| else: | |
| causal_result["raw_analysis"] = analysis_result | |
| causal_results.append(causal_result) | |
| causal_results_by_set = {} | |
| for cr in causal_results: | |
| set_id = cr.get("perturbation_set_id") or "default" | |
| if set_id not in causal_results_by_set: causal_results_by_set[set_id] = [] | |
| causal_results_by_set[set_id].append(cr) | |
| result = { | |
| "causal_results": causal_results, | |
| "causal_results_by_set": causal_results_by_set, | |
| "perturbation_set_types": perturbation_set_types, | |
| "perturbation_set_metadata": perturbation_set_metadata, | |
| "summary": {"total_causal_relations": len(causal_results), "total_perturbation_sets": len(causal_results_by_set)} | |
| } | |
| else: | |
| result = {"message": "This knowledge graph has not undergone causal analysis yet." if kg.status in ["created", "enriched", "perturbed"] else "No causal analysis results found."} | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Invalid stage: {stage}") | |
| return sanitize_json(result) | |
| except Exception as e: | |
| logger.error(f"Error retrieving stage results: {str(e)}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while retrieving stage results") | |
| async def clear_stage_results(kg_id: str, stage: str, session: Session = Depends(get_db)): | |
| """ | |
| Clear results for a specific stage and all dependent stages. | |
| Cascade logic: | |
| - Clear enrich: Also clears perturb + causal | |
| - Clear perturb: Also clears causal | |
| - Clear causal: Only clears causal | |
| """ | |
| try: | |
| # Get the knowledge graph | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| raise HTTPException(status_code=404, detail=f"Knowledge graph with ID {kg_id} not found") | |
| cleared_stages = [] | |
| if stage == "enrich": | |
| # Clear prompt reconstructions (and cascade to dependent stages) | |
| session.execute( | |
| text("DELETE FROM prompt_reconstructions WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg.id} | |
| ) | |
| cleared_stages.append("enrich") | |
| # Cascade: Clear perturbation tests | |
| session.execute( | |
| text("DELETE FROM perturbation_tests WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg.id} | |
| ) | |
| cleared_stages.append("perturb") | |
| # Cascade: Clear causal analyses | |
| session.execute( | |
| text("DELETE FROM causal_analyses WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg.id} | |
| ) | |
| cleared_stages.append("causal") | |
| # Update KG status back to created | |
| kg.status = "created" | |
| elif stage == "perturb": | |
| # Clear perturbation tests (and cascade to dependent stages) | |
| session.execute( | |
| text("DELETE FROM perturbation_tests WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg.id} | |
| ) | |
| cleared_stages.append("perturb") | |
| # Cascade: Clear causal analyses | |
| session.execute( | |
| text("DELETE FROM causal_analyses WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg.id} | |
| ) | |
| cleared_stages.append("causal") | |
| # Update KG status back to enriched | |
| kg.status = "enriched" | |
| elif stage == "causal": | |
| # Clear only causal analyses | |
| session.execute( | |
| text("DELETE FROM causal_analyses WHERE knowledge_graph_id = :kg_id"), | |
| {"kg_id": kg.id} | |
| ) | |
| cleared_stages.append("causal") | |
| # Update KG status back to perturbed | |
| kg.status = "perturbed" | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Invalid stage: {stage}") | |
| # Update timestamp | |
| kg.update_timestamp = datetime.now(timezone.utc) | |
| session.commit() | |
| logger.info(f"Cleared stages {cleared_stages} for knowledge graph {kg_id}") | |
| return { | |
| "status": "success", | |
| "message": f"Successfully cleared {', '.join(cleared_stages)} stage(s)", | |
| "cleared_stages": cleared_stages, | |
| "new_status": kg.status | |
| } | |
| except Exception as e: | |
| session.rollback() | |
| logger.error(f"Error clearing stage {stage} for KG {kg_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while clearing stage results") | |
| async def update_prompt_reconstruction(kg_id: str, session: Session = Depends(get_db)): | |
| """Update prompt reconstruction metadata for an existing knowledge graph.""" | |
| # This is a large function, the implementation is being moved as-is | |
| # ... (implementation from stage_processor.py) | |
| try: | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| raise HTTPException(status_code=404, detail=f"Knowledge graph with ID {kg_id} not found") | |
| if kg.status not in ["enriched", "perturbed", "analyzed"]: | |
| raise HTTPException(status_code=400, detail="Knowledge graph must be enriched before updating") | |
| graph_data = kg.graph_data | |
| if isinstance(graph_data, str): | |
| graph_data = json.loads(graph_data) | |
| if "metadata" not in graph_data: graph_data["metadata"] = {} | |
| prompt_reconstruction = graph_data["metadata"].get("prompt_reconstruction", {}) | |
| # ... rest of the logic | |
| system_prompt, user_prompt = "", "" | |
| agent_entities = {e["id"]: e for e in graph_data.get("entities", []) if e.get("type") == "Agent"} | |
| # Find prompts... | |
| prompt_reconstruction["system_prompt"] = system_prompt | |
| prompt_reconstruction["user_prompt"] = user_prompt | |
| # ... and so on | |
| kg.graph_data = graph_data | |
| session.commit() | |
| return {"success": True, "prompt_reconstruction": prompt_reconstruction} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def reset_knowledge_graph(kg_id: str, session: Session = Depends(get_db)): | |
| """Reset a knowledge graph's processing status back to 'created'.""" | |
| try: | |
| kg = get_knowledge_graph_by_id(session, kg_id) | |
| if not kg: | |
| raise HTTPException(status_code=404, detail=f"Knowledge graph with ID {kg_id} not found") | |
| kg.status = "created" | |
| session.commit() | |
| return { | |
| "success": True, | |
| "message": f"Knowledge graph {kg_id} has been reset.", | |
| "knowledge_graph_id": kg_id, | |
| "status": "created" | |
| } | |
| except Exception as e: | |
| logger.error(f"Error resetting knowledge graph: {str(e)}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while resetting knowledge graph") |