from fastapi import APIRouter, Query, HTTPException, status from fastapi.responses import JSONResponse from app.services.graph_service import graph_service from app.cache import cache import time from app.database.neo4j_driver import Neo4jDriver, get_neo4j_driver router = APIRouter( prefix="/api", tags=["knowledge graph"] ) @router.get("/graph") async def get_graph( limit: int = Query(100, ge=1, le=1000), database: str = Query("u0-ss-math-rjb-bx3", description="要查询的Neo4j数据库名称") ): """ 获取知识图谱数据 Args: limit: 每类最多返回的节点数量 database: Neo4j数据库名称,默认为"u0-ss-math-rjb-bx3" Returns: 包含节点和关系的图谱数据 """ try: result, is_cached = graph_service.get_knowledge_graph(limit, database) return JSONResponse( content=result, headers={"X-From-Cache": "true" if is_cached else "false"} ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"获取知识图谱失败: {str(e)}" ) @router.delete("/graph/cache") async def clear_graph_cache(): """ 清除知识图谱缓存 """ try: cache.clear_by_prefix("knowledge_graph") return {"message": "缓存已清除"} except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"清除缓存失败: {str(e)}" ) @router.get("/graph/cache-status") async def get_cache_status(limit: int = Query(100, ge=1, le=1000)): """ 获取缓存状态 """ cache_key = f"knowledge_graph:{limit}" cache_info = cache.get_cache_info(cache_key) if cache_info["exists"]: return { "cached": True, "limit": limit, "expires_in_seconds": cache_info["expires_in"], "expires_in_minutes": cache_info["expires_in"] // 60, "expires_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(cache_info["expires_at"])) } else: return { "cached": False, "limit": limit } @router.get("/health") async def health_check(): """ 健康检查路由,检查数据库连接状态 """ driver = Neo4jDriver() if driver.check_connection(): return {"status": "ok", "message": "数据库连接正常"} else: return {"status": "error", "message": "数据库连接失败"}, 500 @router.get("/graph/node/{node_id}") async def get_node(node_id: str): """ 查询特定 ID 的节点 Args: node_id: 节点的唯一标识符 Returns: 包含节点属性的字典 """ # 使用 get_neo4j_driver 函数获取驱动 driver = get_neo4j_driver() try: with driver.session() as session: # 使用参数化查询,避免 SQL 注入 query = """ MATCH (n) WHERE n.id = $node_id RETURN n """ result = session.run(query, node_id=node_id) # 将 node_id 转为整数 record = result.single() if not record: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"未找到节点: {node_id}" ) node_data = record["n"] return { "id": node_data.id, **dict(node_data) # 返回属性 } except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询节点失败: {str(e)}" )