Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI 后端服务 - 用于 Hugging Face Spaces | |
| """ | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List | |
| import json | |
| # 导入数据库和 RAG 引擎 | |
| # 注意:在 HF Spaces 中,这些文件应该在同一个目录下 | |
| from database_setup_lite import setup_databases | |
| from rag_engine import RAGEngine | |
| # 初始化 FastAPI 应用 | |
| app = FastAPI(title="GraphRAG Backend API") | |
| # 配置 CORS - 允许所有来源(生产环境可以限制为特定域名) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # 生产环境可以设置为 ["https://your-frontend.vercel.app"] | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # 初始化数据库和引擎(全局变量,避免重复初始化) | |
| print("正在初始化数据库...") | |
| graph_db, vector_db = setup_databases() | |
| rag_engine = RAGEngine(graph_db, vector_db) | |
| # 加载数据用于前端展示 | |
| with open("mock_data.json", "r", encoding="utf-8") as f: | |
| mock_data = json.load(f) | |
| # Pydantic 模型 | |
| class SearchRequest(BaseModel): | |
| query: str | |
| product_name: Optional[str] = "" | |
| style_name: Optional[str] = "" | |
| class GenerateRequest(BaseModel): | |
| query: str | |
| product_name: str | |
| style_name: str | |
| use_graph: bool = True | |
| class FeatureSearchRequest(BaseModel): | |
| query: str | |
| def root(): | |
| """根路径""" | |
| return { | |
| "message": "GraphRAG Backend API", | |
| "version": "1.0.0", | |
| "endpoints": [ | |
| "GET /api/products", | |
| "GET /api/styles", | |
| "GET /api/graph", | |
| "GET /api/vector-db", | |
| "POST /api/search", | |
| "POST /api/generate", | |
| "POST /api/features/search" | |
| ] | |
| } | |
| def get_products(): | |
| """获取产品列表""" | |
| demo_product = { | |
| "id": "P_DEMO", | |
| "name": "真丝睡眠眼罩" | |
| } | |
| return [demo_product] | |
| def get_styles(): | |
| """获取风格列表""" | |
| styles = [{"id": s["id"], "name": s["name"]} for s in mock_data["styles"]] | |
| return styles | |
| def get_graph(): | |
| """获取图结构数据""" | |
| nodes = [] | |
| edges = [] | |
| # 添加节点 | |
| for node_id, node_data in graph_db.nodes.items(): | |
| nodes.append({ | |
| "id": node_id, | |
| "type": node_data["type"], | |
| "label": node_data["properties"].get("name") or node_data["properties"].get("content", "")[:20] or node_id, | |
| "properties": node_data["properties"] | |
| }) | |
| # 添加边 | |
| for edge in graph_db.edges: | |
| edges.append({ | |
| "source": edge["source"], | |
| "target": edge["target"], | |
| "relationship": edge["relationship"] | |
| }) | |
| return { | |
| "nodes": nodes, | |
| "edges": edges | |
| } | |
| def search(request: SearchRequest): | |
| """搜索接口""" | |
| if not request.query: | |
| raise HTTPException(status_code=400, detail="查询不能为空") | |
| comparison = rag_engine.compare_retrieval( | |
| request.query, | |
| request.product_name or "", | |
| request.style_name or "" | |
| ) | |
| return comparison | |
| def generate(request: GenerateRequest): | |
| """生成文案接口""" | |
| if not all([request.query, request.product_name, request.style_name]): | |
| raise HTTPException(status_code=400, detail="缺少必要参数") | |
| result = rag_engine.generate_copywriting( | |
| request.query, | |
| request.product_name, | |
| request.style_name, | |
| request.use_graph | |
| ) | |
| return result | |
| def get_vector_db(): | |
| """获取传统RAG的向量数据库内容""" | |
| try: | |
| collection = vector_db.collection | |
| all_docs = collection.get() | |
| documents = [] | |
| for i, doc_id in enumerate(all_docs["ids"]): | |
| documents.append({ | |
| "id": doc_id, | |
| "content": all_docs["documents"][i] if "documents" in all_docs and i < len(all_docs["documents"]) else "", | |
| "metadata": all_docs["metadatas"][i] if "metadatas" in all_docs and i < len(all_docs["metadatas"]) else {} | |
| }) | |
| return { | |
| "total": len(documents), | |
| "documents": documents | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def search_features(request: FeatureSearchRequest): | |
| """根据查询搜索相关特征""" | |
| query = request.query.lower() | |
| if not query: | |
| return {"features": []} | |
| # 获取所有特征节点 | |
| feature_nodes = graph_db.find_nodes_by_type("Feature") | |
| matched_features = [] | |
| for node in feature_nodes: | |
| feature_name = node["properties"].get("name", node["id"]).lower() | |
| # 简单的关键词匹配 | |
| if query in feature_name or any(keyword in feature_name for keyword in query.split()): | |
| matched_features.append({ | |
| "id": node["id"], | |
| "name": node["properties"].get("name", node["id"]), | |
| "related_products": [] | |
| }) | |
| # 查找使用该特征的产品 | |
| for edge in graph_db.edges: | |
| if edge["target"] == node["id"] and edge["relationship"] == "HAS_FEATURE": | |
| product_node = graph_db.nodes.get(edge["source"], {}) | |
| if product_node.get("type") == "Product": | |
| matched_features[-1]["related_products"].append( | |
| product_node["properties"].get("name", edge["source"]) | |
| ) | |
| return {"features": matched_features[:10]} # 最多返回10个 | |
| # HF Spaces 会自动使用 Dockerfile 中的 CMD 启动 | |
| # 如果需要本地测试,可以取消下面的注释 | |
| # if __name__ == "__main__": | |
| # import uvicorn | |
| # port = int(os.getenv("PORT", 7860)) | |
| # uvicorn.run(app, host="0.0.0.0", port=port) | |