""" Graph API Routes - Network visualization endpoints """ from fastapi import APIRouter, Depends, HTTPException, Query from typing import Optional, List from sqlalchemy.orm import Session from sqlalchemy import or_ from app.api.deps import get_scoped_db from app.models.entity import Entity, Relationship router = APIRouter(prefix="/graph", tags=["Graph"]) @router.get("") async def get_graph( entity_type: Optional[str] = Query(None, description="Filter by entity type"), limit: int = Query(100, le=500, description="Maximum number of entities"), db: Session = Depends(get_scoped_db) ): """ Get graph data for visualization. Returns nodes (entities) and edges (relationships). """ try: # Get entities query = db.query(Entity) if entity_type: query = query.filter(Entity.type == entity_type) entities = query.limit(limit).all() entity_ids = [e.id for e in entities] # Get relationships between these entities relationships = db.query(Relationship).filter( or_( Relationship.source_id.in_(entity_ids), Relationship.target_id.in_(entity_ids) ) ).all() # Format for Cytoscape.js nodes = [] for e in entities: nodes.append({ "data": { "id": e.id, "label": e.name[:30] + "..." if len(e.name) > 30 else e.name, "fullName": e.name, "type": e.type, "description": e.description[:100] if e.description else "", "source": e.source or "unknown" } }) edges = [] for r in relationships: if r.source_id in entity_ids and r.target_id in entity_ids: edges.append({ "data": { "id": r.id, "source": r.source_id, "target": r.target_id, "label": r.type, "type": r.type } }) return { "nodes": nodes, "edges": edges, "stats": { "total_nodes": len(nodes), "total_edges": len(edges) } } except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get graph: {str(e)}") @router.get("/entity/{entity_id}") async def get_entity_graph( entity_id: str, depth: int = Query(1, ge=1, le=3, description="How many levels of connections to include"), db: Session = Depends(get_scoped_db) ): """ Get graph centered on a specific entity. """ try: # Get the central entity central = db.query(Entity).filter(Entity.id == entity_id).first() if not central: raise HTTPException(status_code=404, detail="Entity not found") # Collect entity IDs at each depth level collected_ids = {entity_id} current_level = {entity_id} for _ in range(depth): rels = db.query(Relationship).filter( or_( Relationship.source_id.in_(current_level), Relationship.target_id.in_(current_level) ) ).all() next_level = set() for r in rels: next_level.add(r.source_id) next_level.add(r.target_id) current_level = next_level - collected_ids collected_ids.update(next_level) # Get all entities entities = db.query(Entity).filter(Entity.id.in_(collected_ids)).all() # Get all relationships between collected entities relationships = db.query(Relationship).filter( Relationship.source_id.in_(collected_ids), Relationship.target_id.in_(collected_ids) ).all() # Format for Cytoscape nodes = [] for e in entities: nodes.append({ "data": { "id": e.id, "label": e.name[:30] + "..." if len(e.name) > 30 else e.name, "fullName": e.name, "type": e.type, "description": e.description[:100] if e.description else "", "source": e.source or "unknown", "isCentral": e.id == entity_id } }) edges = [] for r in relationships: edges.append({ "data": { "id": r.id, "source": r.source_id, "target": r.target_id, "label": r.type, "type": r.type } }) return { "central": { "id": central.id, "name": central.name, "type": central.type }, "nodes": nodes, "edges": edges, "stats": { "total_nodes": len(nodes), "total_edges": len(edges), "depth": depth } } except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to get entity graph: {str(e)}")