KirkHan's picture
Upload 8 files
a990ce3 verified
"""
FastAPI 后端服务 - 用于 Hugging Face Spaces
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List
import json
# 导入数据库和 RAG 引擎
# 注意:在 HF Spaces 中,这些文件应该在同一个目录下
from database_setup_lite import setup_databases
from rag_engine import RAGEngine
# 初始化 FastAPI 应用
app = FastAPI(title="GraphRAG Backend API")
# 配置 CORS - 允许所有来源(生产环境可以限制为特定域名)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 生产环境可以设置为 ["https://your-frontend.vercel.app"]
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 初始化数据库和引擎(全局变量,避免重复初始化)
print("正在初始化数据库...")
graph_db, vector_db = setup_databases()
rag_engine = RAGEngine(graph_db, vector_db)
# 加载数据用于前端展示
with open("mock_data.json", "r", encoding="utf-8") as f:
mock_data = json.load(f)
# Pydantic 模型
class SearchRequest(BaseModel):
query: str
product_name: Optional[str] = ""
style_name: Optional[str] = ""
class GenerateRequest(BaseModel):
query: str
product_name: str
style_name: str
use_graph: bool = True
class FeatureSearchRequest(BaseModel):
query: str
@app.get("/")
def root():
"""根路径"""
return {
"message": "GraphRAG Backend API",
"version": "1.0.0",
"endpoints": [
"GET /api/products",
"GET /api/styles",
"GET /api/graph",
"GET /api/vector-db",
"POST /api/search",
"POST /api/generate",
"POST /api/features/search"
]
}
@app.get("/api/products")
def get_products():
"""获取产品列表"""
demo_product = {
"id": "P_DEMO",
"name": "真丝睡眠眼罩"
}
return [demo_product]
@app.get("/api/styles")
def get_styles():
"""获取风格列表"""
styles = [{"id": s["id"], "name": s["name"]} for s in mock_data["styles"]]
return styles
@app.get("/api/graph")
def get_graph():
"""获取图结构数据"""
nodes = []
edges = []
# 添加节点
for node_id, node_data in graph_db.nodes.items():
nodes.append({
"id": node_id,
"type": node_data["type"],
"label": node_data["properties"].get("name") or node_data["properties"].get("content", "")[:20] or node_id,
"properties": node_data["properties"]
})
# 添加边
for edge in graph_db.edges:
edges.append({
"source": edge["source"],
"target": edge["target"],
"relationship": edge["relationship"]
})
return {
"nodes": nodes,
"edges": edges
}
@app.post("/api/search")
def search(request: SearchRequest):
"""搜索接口"""
if not request.query:
raise HTTPException(status_code=400, detail="查询不能为空")
comparison = rag_engine.compare_retrieval(
request.query,
request.product_name or "",
request.style_name or ""
)
return comparison
@app.post("/api/generate")
def generate(request: GenerateRequest):
"""生成文案接口"""
if not all([request.query, request.product_name, request.style_name]):
raise HTTPException(status_code=400, detail="缺少必要参数")
result = rag_engine.generate_copywriting(
request.query,
request.product_name,
request.style_name,
request.use_graph
)
return result
@app.get("/api/vector-db")
def get_vector_db():
"""获取传统RAG的向量数据库内容"""
try:
collection = vector_db.collection
all_docs = collection.get()
documents = []
for i, doc_id in enumerate(all_docs["ids"]):
documents.append({
"id": doc_id,
"content": all_docs["documents"][i] if "documents" in all_docs and i < len(all_docs["documents"]) else "",
"metadata": all_docs["metadatas"][i] if "metadatas" in all_docs and i < len(all_docs["metadatas"]) else {}
})
return {
"total": len(documents),
"documents": documents
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/features/search")
def search_features(request: FeatureSearchRequest):
"""根据查询搜索相关特征"""
query = request.query.lower()
if not query:
return {"features": []}
# 获取所有特征节点
feature_nodes = graph_db.find_nodes_by_type("Feature")
matched_features = []
for node in feature_nodes:
feature_name = node["properties"].get("name", node["id"]).lower()
# 简单的关键词匹配
if query in feature_name or any(keyword in feature_name for keyword in query.split()):
matched_features.append({
"id": node["id"],
"name": node["properties"].get("name", node["id"]),
"related_products": []
})
# 查找使用该特征的产品
for edge in graph_db.edges:
if edge["target"] == node["id"] and edge["relationship"] == "HAS_FEATURE":
product_node = graph_db.nodes.get(edge["source"], {})
if product_node.get("type") == "Product":
matched_features[-1]["related_products"].append(
product_node["properties"].get("name", edge["source"])
)
return {"features": matched_features[:10]} # 最多返回10个
# HF Spaces 会自动使用 Dockerfile 中的 CMD 启动
# 如果需要本地测试,可以取消下面的注释
# if __name__ == "__main__":
# import uvicorn
# port = int(os.getenv("PORT", 7860))
# uvicorn.run(app, host="0.0.0.0", port=port)