graph_show / app /api /graph.py
huanghehe1223's picture
Upload 26 files
3154536 verified
from fastapi import APIRouter, Query, HTTPException, status
from fastapi.responses import JSONResponse
from app.services.graph_service import graph_service
from app.cache import cache
import time
from app.database.neo4j_driver import Neo4jDriver, get_neo4j_driver
router = APIRouter(
prefix="/api",
tags=["knowledge graph"]
)
@router.get("/graph")
async def get_graph(
limit: int = Query(100, ge=1, le=1000),
database: str = Query("u0-ss-math-rjb-bx3", description="要查询的Neo4j数据库名称")
):
"""
获取知识图谱数据
Args:
limit: 每类最多返回的节点数量
database: Neo4j数据库名称,默认为"u0-ss-math-rjb-bx3"
Returns:
包含节点和关系的图谱数据
"""
try:
result, is_cached = graph_service.get_knowledge_graph(limit, database)
return JSONResponse(
content=result,
headers={"X-From-Cache": "true" if is_cached else "false"}
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取知识图谱失败: {str(e)}"
)
@router.delete("/graph/cache")
async def clear_graph_cache():
"""
清除知识图谱缓存
"""
try:
cache.clear_by_prefix("knowledge_graph")
return {"message": "缓存已清除"}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"清除缓存失败: {str(e)}"
)
@router.get("/graph/cache-status")
async def get_cache_status(limit: int = Query(100, ge=1, le=1000)):
"""
获取缓存状态
"""
cache_key = f"knowledge_graph:{limit}"
cache_info = cache.get_cache_info(cache_key)
if cache_info["exists"]:
return {
"cached": True,
"limit": limit,
"expires_in_seconds": cache_info["expires_in"],
"expires_in_minutes": cache_info["expires_in"] // 60,
"expires_at": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(cache_info["expires_at"]))
}
else:
return {
"cached": False,
"limit": limit
}
@router.get("/health")
async def health_check():
"""
健康检查路由,检查数据库连接状态
"""
driver = Neo4jDriver()
if driver.check_connection():
return {"status": "ok", "message": "数据库连接正常"}
else:
return {"status": "error", "message": "数据库连接失败"}, 500
@router.get("/graph/node/{node_id}")
async def get_node(node_id: str):
"""
查询特定 ID 的节点
Args:
node_id: 节点的唯一标识符
Returns:
包含节点属性的字典
"""
# 使用 get_neo4j_driver 函数获取驱动
driver = get_neo4j_driver()
try:
with driver.session() as session:
# 使用参数化查询,避免 SQL 注入
query = """
MATCH (n)
WHERE n.id = $node_id
RETURN n
"""
result = session.run(query, node_id=node_id) # 将 node_id 转为整数
record = result.single()
if not record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到节点: {node_id}"
)
node_data = record["n"]
return {
"id": node_data.id,
**dict(node_data) # 返回属性
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询节点失败: {str(e)}"
)