| """
|
| Graph API Routes - Network visualization endpoints
|
| """
|
| from fastapi import APIRouter, Depends, HTTPException, Query |
| from typing import Optional, List |
| from sqlalchemy.orm import Session |
| from sqlalchemy import or_ |
|
|
| from app.api.deps import get_scoped_db |
| from app.models.entity import Entity, Relationship |
|
|
|
|
| router = APIRouter(prefix="/graph", tags=["Graph"]) |
|
|
|
|
| @router.get("")
|
| async def get_graph( |
| entity_type: Optional[str] = Query(None, description="Filter by entity type"), |
| limit: int = Query(100, le=500, description="Maximum number of entities"), |
| db: Session = Depends(get_scoped_db) |
| ): |
| """
|
| Get graph data for visualization.
|
| Returns nodes (entities) and edges (relationships).
|
| """
|
| try: |
|
|
| query = db.query(Entity)
|
| if entity_type:
|
| query = query.filter(Entity.type == entity_type)
|
|
|
| entities = query.limit(limit).all()
|
| entity_ids = [e.id for e in entities]
|
|
|
|
|
| relationships = db.query(Relationship).filter(
|
| or_(
|
| Relationship.source_id.in_(entity_ids),
|
| Relationship.target_id.in_(entity_ids)
|
| )
|
| ).all()
|
|
|
|
|
| nodes = []
|
| for e in entities:
|
| nodes.append({
|
| "data": {
|
| "id": e.id,
|
| "label": e.name[:30] + "..." if len(e.name) > 30 else e.name,
|
| "fullName": e.name,
|
| "type": e.type,
|
| "description": e.description[:100] if e.description else "",
|
| "source": e.source or "unknown"
|
| }
|
| })
|
|
|
| edges = []
|
| for r in relationships:
|
| if r.source_id in entity_ids and r.target_id in entity_ids:
|
| edges.append({
|
| "data": {
|
| "id": r.id,
|
| "source": r.source_id,
|
| "target": r.target_id,
|
| "label": r.type,
|
| "type": r.type
|
| }
|
| })
|
|
|
| return {
|
| "nodes": nodes,
|
| "edges": edges,
|
| "stats": {
|
| "total_nodes": len(nodes),
|
| "total_edges": len(edges)
|
| }
|
| }
|
|
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Failed to get graph: {str(e)}")
|
|
|
|
|
| @router.get("/entity/{entity_id}")
|
| async def get_entity_graph( |
| entity_id: str, |
| depth: int = Query(1, ge=1, le=3, description="How many levels of connections to include"), |
| db: Session = Depends(get_scoped_db) |
| ): |
| """
|
| Get graph centered on a specific entity.
|
| """
|
| try: |
|
|
| central = db.query(Entity).filter(Entity.id == entity_id).first()
|
| if not central:
|
| raise HTTPException(status_code=404, detail="Entity not found")
|
|
|
|
|
| collected_ids = {entity_id}
|
| current_level = {entity_id}
|
|
|
| for _ in range(depth):
|
| rels = db.query(Relationship).filter(
|
| or_(
|
| Relationship.source_id.in_(current_level),
|
| Relationship.target_id.in_(current_level)
|
| )
|
| ).all()
|
|
|
| next_level = set()
|
| for r in rels:
|
| next_level.add(r.source_id)
|
| next_level.add(r.target_id)
|
|
|
| current_level = next_level - collected_ids
|
| collected_ids.update(next_level)
|
|
|
|
|
| entities = db.query(Entity).filter(Entity.id.in_(collected_ids)).all()
|
|
|
|
|
| relationships = db.query(Relationship).filter(
|
| Relationship.source_id.in_(collected_ids),
|
| Relationship.target_id.in_(collected_ids)
|
| ).all()
|
|
|
|
|
| nodes = []
|
| for e in entities:
|
| nodes.append({
|
| "data": {
|
| "id": e.id,
|
| "label": e.name[:30] + "..." if len(e.name) > 30 else e.name,
|
| "fullName": e.name,
|
| "type": e.type,
|
| "description": e.description[:100] if e.description else "",
|
| "source": e.source or "unknown",
|
| "isCentral": e.id == entity_id
|
| }
|
| })
|
|
|
| edges = []
|
| for r in relationships:
|
| edges.append({
|
| "data": {
|
| "id": r.id,
|
| "source": r.source_id,
|
| "target": r.target_id,
|
| "label": r.type,
|
| "type": r.type
|
| }
|
| })
|
|
|
| return {
|
| "central": {
|
| "id": central.id,
|
| "name": central.name,
|
| "type": central.type
|
| },
|
| "nodes": nodes,
|
| "edges": edges,
|
| "stats": {
|
| "total_nodes": len(nodes),
|
| "total_edges": len(edges),
|
| "depth": depth
|
| }
|
| }
|
|
|
| except HTTPException:
|
| raise
|
| except Exception as e:
|
| raise HTTPException(status_code=500, detail=f"Failed to get entity graph: {str(e)}")
|
|
|
|
|