File size: 5,656 Bytes
e14f312
 
 
270c1c7
 
 
 
 
 
 
 
 
 
e14f312
 
 
270c1c7
 
 
 
 
e14f312
 
 
 
270c1c7
e14f312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270c1c7
 
 
 
 
e14f312
 
 
270c1c7
e14f312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270c1c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""

Graph API Routes - Network visualization endpoints

"""
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import Optional, List
from sqlalchemy.orm import Session
from sqlalchemy import or_

from app.api.deps import get_scoped_db
from app.models.entity import Entity, Relationship


router = APIRouter(prefix="/graph", tags=["Graph"])


@router.get("")
async def get_graph(
    entity_type: Optional[str] = Query(None, description="Filter by entity type"),
    limit: int = Query(100, le=500, description="Maximum number of entities"),
    db: Session = Depends(get_scoped_db)
):
    """

    Get graph data for visualization.

    Returns nodes (entities) and edges (relationships).

    """
    try:
        # Get entities
        query = db.query(Entity)
        if entity_type:
            query = query.filter(Entity.type == entity_type)
        
        entities = query.limit(limit).all()
        entity_ids = [e.id for e in entities]
        
        # Get relationships between these entities
        relationships = db.query(Relationship).filter(
            or_(
                Relationship.source_id.in_(entity_ids),
                Relationship.target_id.in_(entity_ids)
            )
        ).all()
        
        # Format for Cytoscape.js
        nodes = []
        for e in entities:
            nodes.append({
                "data": {
                    "id": e.id,
                    "label": e.name[:30] + "..." if len(e.name) > 30 else e.name,
                    "fullName": e.name,
                    "type": e.type,
                    "description": e.description[:100] if e.description else "",
                    "source": e.source or "unknown"
                }
            })
        
        edges = []
        for r in relationships:
            if r.source_id in entity_ids and r.target_id in entity_ids:
                edges.append({
                    "data": {
                        "id": r.id,
                        "source": r.source_id,
                        "target": r.target_id,
                        "label": r.type,
                        "type": r.type
                    }
                })
        
        return {
            "nodes": nodes,
            "edges": edges,
            "stats": {
                "total_nodes": len(nodes),
                "total_edges": len(edges)
            }
        }
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to get graph: {str(e)}")


@router.get("/entity/{entity_id}")
async def get_entity_graph(
    entity_id: str,
    depth: int = Query(1, ge=1, le=3, description="How many levels of connections to include"),
    db: Session = Depends(get_scoped_db)
):
    """

    Get graph centered on a specific entity.

    """
    try:
        # Get the central entity
        central = db.query(Entity).filter(Entity.id == entity_id).first()
        if not central:
            raise HTTPException(status_code=404, detail="Entity not found")
        
        # Collect entity IDs at each depth level
        collected_ids = {entity_id}
        current_level = {entity_id}
        
        for _ in range(depth):
            rels = db.query(Relationship).filter(
                or_(
                    Relationship.source_id.in_(current_level),
                    Relationship.target_id.in_(current_level)
                )
            ).all()
            
            next_level = set()
            for r in rels:
                next_level.add(r.source_id)
                next_level.add(r.target_id)
            
            current_level = next_level - collected_ids
            collected_ids.update(next_level)
        
        # Get all entities
        entities = db.query(Entity).filter(Entity.id.in_(collected_ids)).all()
        
        # Get all relationships between collected entities
        relationships = db.query(Relationship).filter(
            Relationship.source_id.in_(collected_ids),
            Relationship.target_id.in_(collected_ids)
        ).all()
        
        # Format for Cytoscape
        nodes = []
        for e in entities:
            nodes.append({
                "data": {
                    "id": e.id,
                    "label": e.name[:30] + "..." if len(e.name) > 30 else e.name,
                    "fullName": e.name,
                    "type": e.type,
                    "description": e.description[:100] if e.description else "",
                    "source": e.source or "unknown",
                    "isCentral": e.id == entity_id
                }
            })
        
        edges = []
        for r in relationships:
            edges.append({
                "data": {
                    "id": r.id,
                    "source": r.source_id,
                    "target": r.target_id,
                    "label": r.type,
                    "type": r.type
                }
            })
        
        return {
            "central": {
                "id": central.id,
                "name": central.name,
                "type": central.type
            },
            "nodes": nodes,
            "edges": edges,
            "stats": {
                "total_nodes": len(nodes),
                "total_edges": len(edges),
                "depth": depth
            }
        }
        
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Failed to get entity graph: {str(e)}")