Spaces:
Sleeping
Sleeping
File size: 6,153 Bytes
a990ce3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
"""
FastAPI 后端服务 - 用于 Hugging Face Spaces
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List
import json
# 导入数据库和 RAG 引擎
# 注意:在 HF Spaces 中,这些文件应该在同一个目录下
from database_setup_lite import setup_databases
from rag_engine import RAGEngine
# 初始化 FastAPI 应用
app = FastAPI(title="GraphRAG Backend API")
# 配置 CORS - 允许所有来源(生产环境可以限制为特定域名)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境可以设置为 ["https://your-frontend.vercel.app"]
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 初始化数据库和引擎(全局变量,避免重复初始化)
print("正在初始化数据库...")
graph_db, vector_db = setup_databases()
rag_engine = RAGEngine(graph_db, vector_db)
# 加载数据用于前端展示
with open("mock_data.json", "r", encoding="utf-8") as f:
mock_data = json.load(f)
# Pydantic 模型
class SearchRequest(BaseModel):
query: str
product_name: Optional[str] = ""
style_name: Optional[str] = ""
class GenerateRequest(BaseModel):
query: str
product_name: str
style_name: str
use_graph: bool = True
class FeatureSearchRequest(BaseModel):
query: str
@app.get("/")
def root():
"""根路径"""
return {
"message": "GraphRAG Backend API",
"version": "1.0.0",
"endpoints": [
"GET /api/products",
"GET /api/styles",
"GET /api/graph",
"GET /api/vector-db",
"POST /api/search",
"POST /api/generate",
"POST /api/features/search"
]
}
@app.get("/api/products")
def get_products():
"""获取产品列表"""
demo_product = {
"id": "P_DEMO",
"name": "真丝睡眠眼罩"
}
return [demo_product]
@app.get("/api/styles")
def get_styles():
"""获取风格列表"""
styles = [{"id": s["id"], "name": s["name"]} for s in mock_data["styles"]]
return styles
@app.get("/api/graph")
def get_graph():
"""获取图结构数据"""
nodes = []
edges = []
# 添加节点
for node_id, node_data in graph_db.nodes.items():
nodes.append({
"id": node_id,
"type": node_data["type"],
"label": node_data["properties"].get("name") or node_data["properties"].get("content", "")[:20] or node_id,
"properties": node_data["properties"]
})
# 添加边
for edge in graph_db.edges:
edges.append({
"source": edge["source"],
"target": edge["target"],
"relationship": edge["relationship"]
})
return {
"nodes": nodes,
"edges": edges
}
@app.post("/api/search")
def search(request: SearchRequest):
"""搜索接口"""
if not request.query:
raise HTTPException(status_code=400, detail="查询不能为空")
comparison = rag_engine.compare_retrieval(
request.query,
request.product_name or "",
request.style_name or ""
)
return comparison
@app.post("/api/generate")
def generate(request: GenerateRequest):
"""生成文案接口"""
if not all([request.query, request.product_name, request.style_name]):
raise HTTPException(status_code=400, detail="缺少必要参数")
result = rag_engine.generate_copywriting(
request.query,
request.product_name,
request.style_name,
request.use_graph
)
return result
@app.get("/api/vector-db")
def get_vector_db():
"""获取传统RAG的向量数据库内容"""
try:
collection = vector_db.collection
all_docs = collection.get()
documents = []
for i, doc_id in enumerate(all_docs["ids"]):
documents.append({
"id": doc_id,
"content": all_docs["documents"][i] if "documents" in all_docs and i < len(all_docs["documents"]) else "",
"metadata": all_docs["metadatas"][i] if "metadatas" in all_docs and i < len(all_docs["metadatas"]) else {}
})
return {
"total": len(documents),
"documents": documents
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/features/search")
def search_features(request: FeatureSearchRequest):
"""根据查询搜索相关特征"""
query = request.query.lower()
if not query:
return {"features": []}
# 获取所有特征节点
feature_nodes = graph_db.find_nodes_by_type("Feature")
matched_features = []
for node in feature_nodes:
feature_name = node["properties"].get("name", node["id"]).lower()
# 简单的关键词匹配
if query in feature_name or any(keyword in feature_name for keyword in query.split()):
matched_features.append({
"id": node["id"],
"name": node["properties"].get("name", node["id"]),
"related_products": []
})
# 查找使用该特征的产品
for edge in graph_db.edges:
if edge["target"] == node["id"] and edge["relationship"] == "HAS_FEATURE":
product_node = graph_db.nodes.get(edge["source"], {})
if product_node.get("type") == "Product":
matched_features[-1]["related_products"].append(
product_node["properties"].get("name", edge["source"])
)
return {"features": matched_features[:10]} # 最多返回10个
# HF Spaces 会自动使用 Dockerfile 中的 CMD 启动
# 如果需要本地测试,可以取消下面的注释
# if __name__ == "__main__":
# import uvicorn
# port = int(os.getenv("PORT", 7860))
# uvicorn.run(app, host="0.0.0.0", port=port)
|