File size: 6,153 Bytes
a990ce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
FastAPI 后端服务 - 用于 Hugging Face Spaces
"""
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List
import json

# 导入数据库和 RAG 引擎
# 注意:在 HF Spaces 中,这些文件应该在同一个目录下
from database_setup_lite import setup_databases
from rag_engine import RAGEngine

# 初始化 FastAPI 应用
app = FastAPI(title="GraphRAG Backend API")

# 配置 CORS - 允许所有来源(生产环境可以限制为特定域名)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # 生产环境可以设置为 ["https://your-frontend.vercel.app"]
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 初始化数据库和引擎(全局变量,避免重复初始化)
print("正在初始化数据库...")
graph_db, vector_db = setup_databases()
rag_engine = RAGEngine(graph_db, vector_db)

# 加载数据用于前端展示
with open("mock_data.json", "r", encoding="utf-8") as f:
    mock_data = json.load(f)

# Pydantic 模型
class SearchRequest(BaseModel):
    query: str
    product_name: Optional[str] = ""
    style_name: Optional[str] = ""

class GenerateRequest(BaseModel):
    query: str
    product_name: str
    style_name: str
    use_graph: bool = True

class FeatureSearchRequest(BaseModel):
    query: str

@app.get("/")
def root():
    """根路径"""
    return {
        "message": "GraphRAG Backend API",
        "version": "1.0.0",
        "endpoints": [
            "GET /api/products",
            "GET /api/styles",
            "GET /api/graph",
            "GET /api/vector-db",
            "POST /api/search",
            "POST /api/generate",
            "POST /api/features/search"
        ]
    }

@app.get("/api/products")
def get_products():
    """获取产品列表"""
    demo_product = {
        "id": "P_DEMO",
        "name": "真丝睡眠眼罩"
    }
    return [demo_product]

@app.get("/api/styles")
def get_styles():
    """获取风格列表"""
    styles = [{"id": s["id"], "name": s["name"]} for s in mock_data["styles"]]
    return styles

@app.get("/api/graph")
def get_graph():
    """获取图结构数据"""
    nodes = []
    edges = []
    
    # 添加节点
    for node_id, node_data in graph_db.nodes.items():
        nodes.append({
            "id": node_id,
            "type": node_data["type"],
            "label": node_data["properties"].get("name") or node_data["properties"].get("content", "")[:20] or node_id,
            "properties": node_data["properties"]
        })
    
    # 添加边
    for edge in graph_db.edges:
        edges.append({
            "source": edge["source"],
            "target": edge["target"],
            "relationship": edge["relationship"]
        })
    
    return {
        "nodes": nodes,
        "edges": edges
    }

@app.post("/api/search")
def search(request: SearchRequest):
    """搜索接口"""
    if not request.query:
        raise HTTPException(status_code=400, detail="查询不能为空")
    
    comparison = rag_engine.compare_retrieval(
        request.query,
        request.product_name or "",
        request.style_name or ""
    )
    
    return comparison

@app.post("/api/generate")
def generate(request: GenerateRequest):
    """生成文案接口"""
    if not all([request.query, request.product_name, request.style_name]):
        raise HTTPException(status_code=400, detail="缺少必要参数")
    
    result = rag_engine.generate_copywriting(
        request.query,
        request.product_name,
        request.style_name,
        request.use_graph
    )
    
    return result

@app.get("/api/vector-db")
def get_vector_db():
    """获取传统RAG的向量数据库内容"""
    try:
        collection = vector_db.collection
        all_docs = collection.get()
        
        documents = []
        for i, doc_id in enumerate(all_docs["ids"]):
            documents.append({
                "id": doc_id,
                "content": all_docs["documents"][i] if "documents" in all_docs and i < len(all_docs["documents"]) else "",
                "metadata": all_docs["metadatas"][i] if "metadatas" in all_docs and i < len(all_docs["metadatas"]) else {}
            })
        
        return {
            "total": len(documents),
            "documents": documents
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/api/features/search")
def search_features(request: FeatureSearchRequest):
    """根据查询搜索相关特征"""
    query = request.query.lower()
    
    if not query:
        return {"features": []}
    
    # 获取所有特征节点
    feature_nodes = graph_db.find_nodes_by_type("Feature")
    matched_features = []
    
    for node in feature_nodes:
        feature_name = node["properties"].get("name", node["id"]).lower()
        # 简单的关键词匹配
        if query in feature_name or any(keyword in feature_name for keyword in query.split()):
            matched_features.append({
                "id": node["id"],
                "name": node["properties"].get("name", node["id"]),
                "related_products": []
            })
            
            # 查找使用该特征的产品
            for edge in graph_db.edges:
                if edge["target"] == node["id"] and edge["relationship"] == "HAS_FEATURE":
                    product_node = graph_db.nodes.get(edge["source"], {})
                    if product_node.get("type") == "Product":
                        matched_features[-1]["related_products"].append(
                            product_node["properties"].get("name", edge["source"])
                        )
    
    return {"features": matched_features[:10]}  # 最多返回10个

# HF Spaces 会自动使用 Dockerfile 中的 CMD 启动
# 如果需要本地测试,可以取消下面的注释
# if __name__ == "__main__":
#     import uvicorn
#     port = int(os.getenv("PORT", 7860))
#     uvicorn.run(app, host="0.0.0.0", port=port)