Spaces:
Sleeping
Sleeping
feat: add confluence/slack search tools, chat history, cloud Qdrant support, sync trigger fixes
68af3c5 | from __future__ import annotations | |
| import asyncio | |
| import logging | |
| import time | |
| from fastapi import APIRouter, WebSocket, WebSocketDisconnect | |
| from neo4j import AsyncGraphDatabase | |
| from graph_store.config import settings | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter(tags=["graph"]) | |
| # After a DNS/connection failure, skip Neo4j for this many seconds before retrying | |
| _NEO4J_COOLDOWN_S = 120 | |
| _neo4j_failed_at: float | None = None | |
| def _neo4j_is_available() -> bool: | |
| global _neo4j_failed_at | |
| if _neo4j_failed_at is None: | |
| return True | |
| if time.monotonic() - _neo4j_failed_at > _NEO4J_COOLDOWN_S: | |
| _neo4j_failed_at = None | |
| return True | |
| return False | |
| def _mark_neo4j_failed() -> None: | |
| global _neo4j_failed_at | |
| if _neo4j_failed_at is None: | |
| logger.warning( | |
| "graph_stream: Neo4j unreachable (%s) β graph panel disabled for %ds", | |
| settings.neo4j_uri, | |
| _NEO4J_COOLDOWN_S, | |
| ) | |
| _neo4j_failed_at = time.monotonic() | |
| # --------------------------------------------------------------------------- | |
| # Cypher queries | |
| # --------------------------------------------------------------------------- | |
| _SNAPSHOT_QUERY_UNFILTERED = """ | |
| MATCH (n) | |
| WHERE NOT n:Chunk AND NOT n:Document AND n.name IS NOT NULL | |
| WITH n | |
| OPTIONAL MATCH (n)-[r]->(m) | |
| WHERE NOT m:Chunk AND NOT m:Document AND m.name IS NOT NULL | |
| RETURN | |
| labels(n)[0] AS from_label, | |
| n.name AS from_name, | |
| type(r) AS rel_type, | |
| labels(m)[0] AS to_label, | |
| m.name AS to_name | |
| ORDER BY from_name | |
| """ | |
| _SNAPSHOT_QUERY_FILTERED = """ | |
| MATCH (n) | |
| WHERE NOT n:Chunk AND NOT n:Document AND n.name IS NOT NULL | |
| WITH n | |
| MATCH (n)<-[:MENTIONS|REFERENCES|HAS_CHUNK]-(c:Chunk) | |
| WHERE c.channel_id IN $channel_ids OR c.channel_id IS NULL | |
| WITH DISTINCT n | |
| OPTIONAL MATCH (n)-[r]->(m) | |
| WHERE NOT m:Chunk AND NOT m:Document AND m.name IS NOT NULL | |
| AND EXISTS { | |
| MATCH (m)<-[:MENTIONS|REFERENCES|HAS_CHUNK]-(c2:Chunk) | |
| WHERE c2.channel_id IN $channel_ids OR c2.channel_id IS NULL | |
| } | |
| RETURN | |
| labels(n)[0] AS from_label, | |
| n.name AS from_name, | |
| type(r) AS rel_type, | |
| labels(m)[0] AS to_label, | |
| m.name AS to_name | |
| ORDER BY from_name | |
| """ | |
| # --------------------------------------------------------------------------- | |
| # Session helper β resolve gs_session cookie via Redis | |
| # --------------------------------------------------------------------------- | |
| def _parse_session_cookie(websocket: WebSocket) -> str | None: | |
| """Extract the gs_session value from the WebSocket cookie header.""" | |
| # FastAPI/Starlette populates websocket.cookies from the Cookie header | |
| session_id = websocket.cookies.get("gs_session") | |
| if session_id: | |
| return session_id | |
| # Fallback: parse the raw Cookie header manually | |
| cookie_header = websocket.headers.get("cookie", "") | |
| for part in cookie_header.split(";"): | |
| part = part.strip() | |
| if part.startswith("gs_session="): | |
| return part[len("gs_session="):] | |
| return None | |
| async def _resolve_allowed_channels(websocket: WebSocket) -> list[str]: | |
| """ | |
| Return the list of allowed channel IDs for the connecting user. | |
| Falls back to an empty list (unfiltered view) if there is no cookie, | |
| Redis is unavailable, or the session has expired. | |
| """ | |
| from src.auth.router import _get_session # local import to avoid circular deps | |
| session_id = _parse_session_cookie(websocket) | |
| if not session_id: | |
| logger.warning("graph_stream: no gs_session cookie β serving unfiltered graph") | |
| return [] | |
| try: | |
| session = await _get_session(session_id) | |
| except Exception as exc: | |
| logger.warning("graph_stream: session resolution failed (%s) β serving unfiltered graph", exc) | |
| return [] | |
| if not session: | |
| logger.warning("graph_stream: session not found or expired β serving unfiltered graph") | |
| return [] | |
| user = session.get("user", {}) | |
| allowed = user.get("allowed_channel_ids", []) | |
| if not isinstance(allowed, list): | |
| allowed = [] | |
| return allowed | |
| # --------------------------------------------------------------------------- | |
| # WebSocket handler | |
| # --------------------------------------------------------------------------- | |
| async def graph_stream(websocket: WebSocket): | |
| await websocket.accept() | |
| # Resolve RBAC β fall back gracefully on any error | |
| try: | |
| allowed_channel_ids = await _resolve_allowed_channels(websocket) | |
| except Exception as exc: | |
| logger.warning("graph_stream: failed to resolve RBAC (%s) β serving unfiltered graph", exc) | |
| allowed_channel_ids = [] | |
| use_filtered = bool(allowed_channel_ids) | |
| if use_filtered: | |
| logger.info( | |
| "graph_stream: RBAC active β filtering to %d channel(s)", len(allowed_channel_ids) | |
| ) | |
| else: | |
| logger.info("graph_stream: no channel restriction β serving full graph") | |
| if not _neo4j_is_available(): | |
| try: | |
| await websocket.send_json({"event": "error", "message": "Knowledge graph unavailable"}) | |
| except Exception: | |
| pass | |
| try: | |
| await websocket.close() | |
| except Exception: | |
| pass | |
| return | |
| driver = AsyncGraphDatabase.driver( | |
| settings.neo4j_uri, | |
| auth=(settings.neo4j_username, settings.neo4j_password), | |
| max_connection_lifetime=300, | |
| connection_acquisition_timeout=10, | |
| keep_alive=True, | |
| ) | |
| try: | |
| async with driver.session(database=settings.neo4j_database) as session: | |
| if use_filtered: | |
| result = await session.run( | |
| _SNAPSHOT_QUERY_FILTERED, {"channel_ids": allowed_channel_ids} | |
| ) | |
| else: | |
| result = await session.run(_SNAPSHOT_QUERY_UNFILTERED, {}) | |
| records = await result.data() | |
| seen_nodes: set[str] = set() | |
| seen_edges: set[tuple] = set() | |
| for record in records: | |
| from_label = record["from_label"] | |
| from_name = record["from_name"] | |
| rel_type = record["rel_type"] | |
| to_label = record["to_label"] | |
| to_name = record["to_name"] | |
| if from_name and from_name not in seen_nodes: | |
| seen_nodes.add(from_name) | |
| await websocket.send_json({ | |
| "event": "node", | |
| "id": from_name, | |
| "label": from_label, | |
| "name": from_name, | |
| }) | |
| await asyncio.sleep(0.05) | |
| if to_name and to_name not in seen_nodes: | |
| seen_nodes.add(to_name) | |
| await websocket.send_json({ | |
| "event": "node", | |
| "id": to_name, | |
| "label": to_label, | |
| "name": to_name, | |
| }) | |
| await asyncio.sleep(0.05) | |
| if rel_type and to_name: | |
| edge_key = (from_name, rel_type, to_name) | |
| if edge_key not in seen_edges: | |
| seen_edges.add(edge_key) | |
| await websocket.send_json({ | |
| "event": "edge", | |
| "from": from_name, | |
| "to": to_name, | |
| "rel": rel_type, | |
| }) | |
| await asyncio.sleep(0.05) | |
| await websocket.send_json({"event": "done", "nodes": len(seen_nodes), "edges": len(seen_edges)}) | |
| except WebSocketDisconnect: | |
| logger.info("graph_stream: client disconnected") | |
| except Exception as exc: | |
| _mark_neo4j_failed() | |
| try: | |
| await websocket.send_json({"event": "error", "message": "Knowledge graph unavailable"}) | |
| except Exception: | |
| pass | |
| finally: | |
| try: | |
| await driver.close() | |
| except Exception: | |
| pass | |
| try: | |
| await websocket.close() | |
| except Exception: | |
| pass | |