Spaces:
Running
Running
| """ | |
| Graph Comparison Router | |
| API endpoints for comparing knowledge graphs in the database. | |
| """ | |
| from fastapi import APIRouter, HTTPException, Depends, Query | |
| from fastapi.responses import JSONResponse | |
| from sqlalchemy.orm import Session | |
| from typing import List, Dict, Any, Optional | |
| from pydantic import BaseModel, Field | |
| import logging | |
| from backend.database import get_db | |
| from backend.database.models import KnowledgeGraph | |
| from agentgraph.extraction.graph_utilities import KnowledgeGraphComparator, GraphComparisonMetrics | |
| router = APIRouter(prefix="/api/graph-comparison", tags=["graph-comparison"]) | |
| logger = logging.getLogger(__name__) | |
| class GraphComparisonRequest(BaseModel): | |
| """Request model for graph comparison""" | |
| graph1_id: int | |
| graph2_id: int | |
| similarity_threshold: Optional[float] = Field(0.7, description="Threshold for semantic overlap detection (0.7 = 70%)") | |
| use_cache: Optional[bool] = True | |
| async def list_available_graphs(db: Session = Depends(get_db)): | |
| """ | |
| Get hierarchically organized list of knowledge graphs for comparison. | |
| Returns: | |
| Hierarchically organized graphs with final graphs and their associated chunk graphs | |
| """ | |
| try: | |
| all_graphs = db.query(KnowledgeGraph).order_by( | |
| KnowledgeGraph.trace_id.asc(), | |
| KnowledgeGraph.window_index.asc() | |
| ).all() | |
| # Categorize graphs | |
| final_graphs = [] | |
| chunk_graphs = [] | |
| for graph in all_graphs: | |
| # Final graphs (has window_total but no window_index, or window_index is None, or no trace_id) | |
| if (graph.window_total is not None and | |
| graph.window_index is None) or not graph.trace_id: | |
| final_graphs.append(graph) | |
| # Chunk graphs (has window_index) | |
| elif graph.window_index is not None: | |
| chunk_graphs.append(graph) | |
| else: | |
| # Orphaned graphs - treat as final graphs | |
| final_graphs.append(graph) | |
| # Group chunk graphs by trace_id and processing_run_id | |
| chunks_by_trace = {} | |
| for chunk in chunk_graphs: | |
| trace_key = chunk.trace_id | |
| run_key = chunk.processing_run_id or 'default' | |
| if trace_key not in chunks_by_trace: | |
| chunks_by_trace[trace_key] = {} | |
| if run_key not in chunks_by_trace[trace_key]: | |
| chunks_by_trace[trace_key][run_key] = [] | |
| chunks_by_trace[trace_key][run_key].append(chunk) | |
| # Build hierarchical structure | |
| organized_graphs = { | |
| "final_graphs": [], | |
| "total_count": len(all_graphs) | |
| } | |
| # Process final graphs and associate their chunk graphs | |
| for final_graph in final_graphs: | |
| final_data = _format_graph_data(final_graph) | |
| final_data["graph_type"] = "final" | |
| # Find associated chunk graphs | |
| chunk_list = [] | |
| if final_graph.trace_id in chunks_by_trace: | |
| run_key = final_graph.processing_run_id or 'default' | |
| associated_chunks = chunks_by_trace[final_graph.trace_id].get(run_key, []) | |
| for chunk in sorted(associated_chunks, key=lambda x: x.window_index or 0): | |
| chunk_data = _format_graph_data(chunk) | |
| chunk_data["graph_type"] = "chunk" | |
| chunk_data["window_info"] = { | |
| "index": chunk.window_index, | |
| "total": chunk.window_total, | |
| "start_char": chunk.window_start_char, | |
| "end_char": chunk.window_end_char | |
| } | |
| chunk_list.append(chunk_data) | |
| final_data["chunk_graphs"] = chunk_list | |
| organized_graphs["final_graphs"].append(final_data) | |
| # Ground truth graphs are now treated as final graphs - no separate processing needed | |
| return organized_graphs | |
| except Exception as e: | |
| logger.error(f"Error listing graphs: {str(e)}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while listing graphs") | |
| def _format_graph_data(graph: KnowledgeGraph) -> Dict[str, Any]: | |
| """Helper function to format graph data consistently""" | |
| graph_data = { | |
| "id": graph.id, | |
| "filename": graph.filename, | |
| "creation_timestamp": graph.creation_timestamp.isoformat() if graph.creation_timestamp else None, | |
| "entity_count": graph.entity_count, | |
| "relation_count": graph.relation_count, | |
| "status": graph.status, | |
| "trace_id": graph.trace_id, | |
| "window_index": graph.window_index, | |
| "window_total": graph.window_total, | |
| "processing_run_id": graph.processing_run_id | |
| } | |
| # Surface human-friendly system_name and summary when available in stored graph_data | |
| try: | |
| gd = graph.graph_data or {} | |
| if isinstance(gd, dict): | |
| sys_name = gd.get("system_name") | |
| sys_summary = gd.get("system_summary") | |
| if sys_name: | |
| graph_data["system_name"] = sys_name | |
| if sys_summary: | |
| graph_data["system_summary"] = sys_summary | |
| except Exception: | |
| # Non-fatal: ignore extraction issues | |
| pass | |
| # Add trace information if available | |
| if graph.trace: | |
| graph_data["trace_title"] = graph.trace.title | |
| graph_data["trace_description"] = graph.trace.description | |
| return graph_data | |
| async def compare_graphs( | |
| request: GraphComparisonRequest, | |
| db: Session = Depends(get_db) | |
| ): | |
| """ | |
| Compare two knowledge graphs and return comprehensive metrics. | |
| Args: | |
| request: Graph comparison request containing graph IDs and settings | |
| Returns: | |
| Comprehensive comparison metrics between the two graphs | |
| """ | |
| try: | |
| # Extract request data | |
| graph1_id = request.graph1_id | |
| graph2_id = request.graph2_id | |
| similarity_threshold = request.similarity_threshold | |
| use_cache = request.use_cache | |
| # Fetch the two graphs | |
| graph1 = db.query(KnowledgeGraph).filter(KnowledgeGraph.id == graph1_id).first() | |
| graph2 = db.query(KnowledgeGraph).filter(KnowledgeGraph.id == graph2_id).first() | |
| if not graph1: | |
| raise HTTPException(status_code=404, detail=f"Graph with ID {graph1_id} not found") | |
| if not graph2: | |
| raise HTTPException(status_code=404, detail=f"Graph with ID {graph2_id} not found") | |
| # Get graph data | |
| graph1_data = graph1.graph_data or {} | |
| graph2_data = graph2.graph_data or {} | |
| if not graph1_data: | |
| raise HTTPException(status_code=400, detail=f"Graph {graph1_id} has no data") | |
| if not graph2_data: | |
| raise HTTPException(status_code=400, detail=f"Graph {graph2_id} has no data") | |
| # Add graph_info to enable same-trace detection | |
| graph1_data = { | |
| **graph1_data, | |
| "graph_info": { | |
| "id": graph1.id, | |
| "trace_id": graph1.trace_id, | |
| "filename": graph1.filename | |
| } | |
| } | |
| graph2_data = { | |
| **graph2_data, | |
| "graph_info": { | |
| "id": graph2.id, | |
| "trace_id": graph2.trace_id, | |
| "filename": graph2.filename | |
| } | |
| } | |
| # Initialize comparator | |
| # Use similarity_threshold as semantic_threshold for overlap detection | |
| # and set similarity_threshold slightly lower for general semantic similarity | |
| semantic_threshold = similarity_threshold # Use the user's threshold for overlap detection | |
| general_threshold = max(0.5, similarity_threshold - 0.1) # Slightly lower for general similarity | |
| comparator = KnowledgeGraphComparator( | |
| similarity_threshold=general_threshold, | |
| semantic_threshold=semantic_threshold, | |
| use_cache=use_cache | |
| ) | |
| # Perform comparison | |
| logger.info(f"Comparing graphs {graph1_id} and {graph2_id} (trace_ids: {graph1.trace_id}, {graph2.trace_id})") | |
| metrics = comparator.compare_graphs(graph1_data, graph2_data) | |
| # Add metadata about the graphs being compared | |
| comparison_result = metrics.to_dict() | |
| comparison_result["metadata"] = { | |
| "graph1": { | |
| "id": graph1.id, | |
| "filename": graph1.filename, | |
| "creation_timestamp": graph1.creation_timestamp.isoformat() if graph1.creation_timestamp else None, | |
| "trace_id": graph1.trace_id, | |
| "trace_title": graph1.trace.title if graph1.trace else None, | |
| "system_name": (graph1.graph_data or {}).get("system_name") if isinstance(graph1.graph_data, dict) else None, | |
| }, | |
| "graph2": { | |
| "id": graph2.id, | |
| "filename": graph2.filename, | |
| "creation_timestamp": graph2.creation_timestamp.isoformat() if graph2.creation_timestamp else None, | |
| "trace_id": graph2.trace_id, | |
| "trace_title": graph2.trace.title if graph2.trace else None, | |
| "system_name": (graph2.graph_data or {}).get("system_name") if isinstance(graph2.graph_data, dict) else None, | |
| }, | |
| "comparison_timestamp": metrics.graph1_stats.get('name', 'Unknown'), # Will be set properly | |
| "similarity_threshold": general_threshold, | |
| "semantic_threshold": semantic_threshold, | |
| "user_requested_threshold": similarity_threshold, | |
| "cache_used": use_cache | |
| } | |
| logger.info(f"Graph comparison completed. Overall similarity: {metrics.overall_similarity:.3f}") | |
| return comparison_result | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error comparing graphs {graph1_id} and {graph2_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while comparing graphs") | |
| async def get_comparison( | |
| graph1_id: int, | |
| graph2_id: int, | |
| similarity_threshold: Optional[float] = Query(0.7, description="Threshold for semantic similarity matching"), | |
| db: Session = Depends(get_db) | |
| ): | |
| """ | |
| GET endpoint for comparing two graphs (alternative to POST). | |
| Args: | |
| graph1_id: ID of the first knowledge graph | |
| graph2_id: ID of the second knowledge graph | |
| similarity_threshold: Threshold for semantic similarity matching | |
| Returns: | |
| Comprehensive comparison metrics between the two graphs | |
| """ | |
| # Create request object for the POST endpoint | |
| request = GraphComparisonRequest( | |
| graph1_id=graph1_id, | |
| graph2_id=graph2_id, | |
| similarity_threshold=similarity_threshold | |
| ) | |
| return await compare_graphs(request, db) | |
| async def get_graph_details(graph_id: int, db: Session = Depends(get_db)): | |
| """ | |
| Get detailed information about a specific knowledge graph. | |
| Args: | |
| graph_id: ID of the knowledge graph | |
| Returns: | |
| Detailed graph information including entities and relations | |
| """ | |
| try: | |
| graph = db.query(KnowledgeGraph).filter(KnowledgeGraph.id == graph_id).first() | |
| if not graph: | |
| raise HTTPException(status_code=404, detail=f"Graph with ID {graph_id} not found") | |
| graph_data = graph.graph_data or {} | |
| # Generate basic statistics | |
| entities = graph_data.get('entities', []) | |
| relations = graph_data.get('relations', []) | |
| # Entity type distribution | |
| entity_types = {} | |
| for entity in entities: | |
| etype = entity.get('type', 'Unknown') | |
| entity_types[etype] = entity_types.get(etype, 0) + 1 | |
| # Relation type distribution | |
| relation_types = {} | |
| for relation in relations: | |
| rtype = relation.get('type', 'Unknown') | |
| relation_types[rtype] = relation_types.get(rtype, 0) + 1 | |
| # Calculate basic metrics | |
| n_entities = len(entities) | |
| n_relations = len(relations) | |
| density = (2 * n_relations) / (n_entities * (n_entities - 1)) if n_entities > 1 else 0.0 | |
| result = { | |
| "graph_info": { | |
| "id": graph.id, | |
| "filename": graph.filename, | |
| "creation_timestamp": graph.creation_timestamp.isoformat() if graph.creation_timestamp else None, | |
| "entity_count": graph.entity_count, | |
| "relation_count": graph.relation_count, | |
| "status": graph.status, | |
| "trace_id": graph.trace_id, | |
| "window_index": graph.window_index, | |
| "window_total": graph.window_total | |
| }, | |
| "statistics": { | |
| "entity_count": n_entities, | |
| "relation_count": n_relations, | |
| "density": density, | |
| "entity_types": entity_types, | |
| "relation_types": relation_types, | |
| "avg_relations_per_entity": n_relations / n_entities if n_entities > 0 else 0.0 | |
| }, | |
| "entities": entities[:50], # Limit to first 50 for preview | |
| "relations": relations[:50], # Limit to first 50 for preview | |
| "has_more_entities": len(entities) > 50, | |
| "has_more_relations": len(relations) > 50 | |
| } | |
| # Add trace information if available | |
| if graph.trace: | |
| result["trace_info"] = { | |
| "title": graph.trace.title, | |
| "description": graph.trace.description, | |
| "character_count": graph.trace.character_count, | |
| "turn_count": graph.trace.turn_count | |
| } | |
| return result | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error getting graph details for {graph_id}: {str(e)}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while getting graph details") | |
| async def get_cache_info(): | |
| """ | |
| Get information about the embedding cache. | |
| Returns: | |
| Cache information including size and statistics | |
| """ | |
| try: | |
| # Create a temporary comparator to access cache info | |
| comparator = KnowledgeGraphComparator() | |
| cache_info = comparator.get_cache_info() | |
| return { | |
| "status": "success", | |
| "cache_info": cache_info | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting cache info: {str(e)}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while getting cache info") | |
| async def clear_cache(): | |
| """ | |
| Clear the embedding cache. | |
| Returns: | |
| Success message if cache was cleared | |
| """ | |
| try: | |
| # Create a temporary comparator to clear cache | |
| comparator = KnowledgeGraphComparator() | |
| success = comparator.clear_embedding_cache() | |
| if success: | |
| return { | |
| "status": "success", | |
| "message": "Embedding cache cleared successfully" | |
| } | |
| else: | |
| raise HTTPException(status_code=500, detail="Failed to clear cache") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error clearing cache: {str(e)}") | |
| raise HTTPException(status_code=500, detail="An internal error occurred while clearing cache") | |