Spaces:
Sleeping
Sleeping
| from fastapi import APIRouter, Query, HTTPException, status | |
| from fastapi.responses import JSONResponse | |
| from app.services.graph_service import graph_service | |
| from app.cache import cache | |
| import time | |
| from app.database.neo4j_driver import Neo4jDriver, get_neo4j_driver | |
| router = APIRouter( | |
| prefix="/api", | |
| tags=["knowledge graph"] | |
| ) | |
| async def get_graph( | |
| limit: int = Query(100, ge=1, le=1000), | |
| database: str = Query("u0-ss-math-rjb-bx3", description="要查询的Neo4j数据库名称") | |
| ): | |
| """ | |
| 获取知识图谱数据 | |
| Args: | |
| limit: 每类最多返回的节点数量 | |
| database: Neo4j数据库名称,默认为"u0-ss-math-rjb-bx3" | |
| Returns: | |
| 包含节点和关系的图谱数据 | |
| """ | |
| try: | |
| result, is_cached = graph_service.get_knowledge_graph(limit, database) | |
| return JSONResponse( | |
| content=result, | |
| headers={"X-From-Cache": "true" if is_cached else "false"} | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"获取知识图谱失败: {str(e)}" | |
| ) | |
| async def clear_graph_cache(): | |
| """ | |
| 清除知识图谱缓存 | |
| """ | |
| try: | |
| cache.clear_by_prefix("knowledge_graph") | |
| return {"message": "缓存已清除"} | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"清除缓存失败: {str(e)}" | |
| ) | |
| async def get_cache_status(limit: int = Query(100, ge=1, le=1000)): | |
| """ | |
| 获取缓存状态 | |
| """ | |
| cache_key = f"knowledge_graph:{limit}" | |
| cache_info = cache.get_cache_info(cache_key) | |
| if cache_info["exists"]: | |
| return { | |
| "cached": True, | |
| "limit": limit, | |
| "expires_in_seconds": cache_info["expires_in"], | |
| "expires_in_minutes": cache_info["expires_in"] // 60, | |
| "expires_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(cache_info["expires_at"])) | |
| } | |
| else: | |
| return { | |
| "cached": False, | |
| "limit": limit | |
| } | |
| async def health_check(): | |
| """ | |
| 健康检查路由,检查数据库连接状态 | |
| """ | |
| driver = Neo4jDriver() | |
| if driver.check_connection(): | |
| return {"status": "ok", "message": "数据库连接正常"} | |
| else: | |
| return {"status": "error", "message": "数据库连接失败"}, 500 | |
| async def get_node(node_id: str): | |
| """ | |
| 查询特定 ID 的节点 | |
| Args: | |
| node_id: 节点的唯一标识符 | |
| Returns: | |
| 包含节点属性的字典 | |
| """ | |
| # 使用 get_neo4j_driver 函数获取驱动 | |
| driver = get_neo4j_driver() | |
| try: | |
| with driver.session() as session: | |
| # 使用参数化查询,避免 SQL 注入 | |
| query = """ | |
| MATCH (n) | |
| WHERE n.id = $node_id | |
| RETURN n | |
| """ | |
| result = session.run(query, node_id=node_id) # 将 node_id 转为整数 | |
| record = result.single() | |
| if not record: | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"未找到节点: {node_id}" | |
| ) | |
| node_data = record["n"] | |
| return { | |
| "id": node_data.id, | |
| **dict(node_data) # 返回属性 | |
| } | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail=f"查询节点失败: {str(e)}" | |
| ) |