File size: 11,887 Bytes
a990ce3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
"""
轻量级图向量数据库 - 不依赖 ChromaDB,避免超过 Vercel 250MB 限制
使用纯 Python 实现简单的向量搜索
"""
import os
import json
import math
from typing import List, Dict, Optional
import requests

# 延迟导入 sentence_transformers,避免依赖冲突
HAS_SENTENCE_TRANSFORMERS = False
SentenceTransformer = None

def _try_import_sentence_transformers():
    """尝试导入 sentence_transformers"""
    global HAS_SENTENCE_TRANSFORMERS, SentenceTransformer
    if HAS_SENTENCE_TRANSFORMERS:
        return True
    try:
        from sentence_transformers import SentenceTransformer as ST
        SentenceTransformer = ST
        HAS_SENTENCE_TRANSFORMERS = True
        return True
    except (ImportError, RuntimeError, AttributeError) as e:
        HAS_SENTENCE_TRANSFORMERS = False
        return False

# 简单的内存图数据库
class SimpleGraphDB:
    """简单的内存图数据库模拟"""
    def __init__(self):
        self.nodes = {}  # {node_id: {type, properties}}
        self.edges = []  # [{source, target, relationship}]
    
    def add_node(self, node_id: str, node_type: str, properties: Dict):
        """添加节点"""
        self.nodes[node_id] = {
            "type": node_type,
            "properties": properties
        }
    
    def add_edge(self, source: str, target: str, relationship: str):
        """添加边"""
        self.edges.append({
            "source": source,
            "target": target,
            "relationship": relationship
        })
    
    def get_neighbors(self, node_id: str, relationship: Optional[str] = None) -> List[Dict]:
        """获取邻居节点"""
        neighbors = []
        for edge in self.edges:
            if edge["source"] == node_id:
                if relationship is None or edge["relationship"] == relationship:
                    target_node = self.nodes.get(edge["target"], {})
                    neighbors.append({
                        "node_id": edge["target"],
                        "relationship": edge["relationship"],
                        "properties": target_node.get("properties", {})
                    })
        return neighbors
    
    def find_nodes_by_type(self, node_type: str) -> List[Dict]:
        """根据类型查找节点"""
        return [
            {"id": node_id, **node_data}
            for node_id, node_data in self.nodes.items()
            if node_data["type"] == node_type
        ]
    
    def find_node_by_property(self, node_type: str, property_name: str, property_value: str) -> Optional[Dict]:
        """根据属性查找节点"""
        for node_id, node_data in self.nodes.items():
            if node_data["type"] == node_type:
                props = node_data.get("properties", {})
                if props.get(property_name) == property_value:
                    return {"id": node_id, **node_data}
        return None

def cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
    """计算余弦相似度"""
    dot_product = sum(a * b for a, b in zip(vec1, vec2))
    magnitude1 = math.sqrt(sum(a * a for a in vec1))
    magnitude2 = math.sqrt(sum(a * a for a in vec2))
    if magnitude1 == 0 or magnitude2 == 0:
        return 0.0
    return dot_product / (magnitude1 * magnitude2)

class VectorDB:
    """轻量级向量数据库 - 使用内存存储,不依赖 ChromaDB"""
    def __init__(self):
        # 文档存储:{id: {content, metadata, embedding}}
        self.documents: Dict[str, Dict] = {}
        
        # Embedding 配置
        self.embedding_api_base = os.getenv("LLM_API_BASE", "https://api.ai-gaochao.cn/v1")
        self.embedding_api_key = os.getenv("LLM_API_KEY", "")
        self.embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
        self.use_openai_embedding = bool(self.embedding_api_key)
        
        if self.use_openai_embedding:
            print(f"✅ 使用 OpenAI Embeddings API: {self.embedding_model}")
        else:
            print("ℹ️  使用简单文本匹配(关键词搜索)")
    
    def _get_openai_embeddings(self, texts: List[str]) -> List[List[float]]:
        """调用 OpenAI Embeddings API 获取向量"""
        url = f"{self.embedding_api_base}/embeddings"
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.embedding_api_key}"
        }
        data = {
            "input": texts,
            "model": self.embedding_model
        }
        
        response = requests.post(url, headers=headers, json=data, timeout=30)
        response.raise_for_status()
        result = response.json()
        
        return [item["embedding"] for item in result["data"]]
    
    def _simple_text_match(self, query: str, document: str) -> float:
        """简单的文本匹配评分(关键词匹配)"""
        query_words = set(query.lower().split())
        doc_words = set(document.lower().split())
        
        if not query_words:
            return 0.0
        
        # 计算匹配的关键词比例
        matches = len(query_words & doc_words)
        return matches / len(query_words)
    
    def add_documents(self, documents: List[str], ids: List[str], metadatas: List[Dict]):
        """添加文档到向量数据库"""
        if self.use_openai_embedding:
            # 使用 OpenAI Embeddings API
            try:
                embeddings = self._get_openai_embeddings(documents)
                for doc, doc_id, meta, emb in zip(documents, ids, metadatas, embeddings):
                    self.documents[doc_id] = {
                        "content": doc,
                        "metadata": meta,
                        "embedding": emb
                    }
            except Exception as e:
                print(f"⚠️  OpenAI Embeddings API 调用失败: {e}")
                # 回退到简单存储(无 embedding)
                for doc, doc_id, meta in zip(documents, ids, metadatas):
                    self.documents[doc_id] = {
                        "content": doc,
                        "metadata": meta,
                        "embedding": None
                    }
        else:
            # 不使用 embedding,只存储文档
            for doc, doc_id, meta in zip(documents, ids, metadatas):
                self.documents[doc_id] = {
                    "content": doc,
                    "metadata": meta,
                    "embedding": None
                }
    
    def search(self, query: str, n_results: int = 5) -> List[Dict]:
        """语义搜索"""
        if self.use_openai_embedding:
            # 使用向量相似度搜索
            try:
                query_embedding = self._get_openai_embeddings([query])[0]
                
                # 计算所有文档的相似度
                results = []
                for doc_id, doc_data in self.documents.items():
                    if doc_data["embedding"]:
                        similarity = cosine_similarity(query_embedding, doc_data["embedding"])
                        results.append({
                            "content": doc_data["content"],
                            "metadata": doc_data["metadata"],
                            "distance": 1 - similarity,  # 转换为距离(越小越相似)
                            "id": doc_id
                        })
                
                # 按相似度排序
                results.sort(key=lambda x: x["distance"])
                return results[:n_results]
            except Exception as e:
                print(f"⚠️  向量搜索失败,回退到文本匹配: {e}")
                # 回退到文本匹配
                return self._text_search(query, n_results)
        else:
            # 使用简单文本匹配
            return self._text_search(query, n_results)
    
    def _text_search(self, query: str, n_results: int) -> List[Dict]:
        """简单的文本匹配搜索"""
        results = []
        for doc_id, doc_data in self.documents.items():
            score = self._simple_text_match(query, doc_data["content"])
            if score > 0:
                results.append({
                    "content": doc_data["content"],
                    "metadata": doc_data["metadata"],
                    "distance": 1 - score,  # 转换为距离
                    "id": doc_id
                })
        
        # 按相似度排序
        results.sort(key=lambda x: x["distance"])
        return results[:n_results]
    
    @property
    def collection(self):
        """兼容性属性,模拟 ChromaDB 的 collection 接口"""
        class MockCollection:
            def __init__(self, vector_db):
                self.vector_db = vector_db
            
            def get(self):
                """获取所有文档"""
                ids = list(self.vector_db.documents.keys())
                documents = [self.vector_db.documents[id]["content"] for id in ids]
                metadatas = [self.vector_db.documents[id]["metadata"] for id in ids]
                return {
                    "ids": ids,
                    "documents": documents,
                    "metadatas": metadatas
                }
        
        return MockCollection(self)

def setup_databases():
    """初始化数据库"""
    # 加载数据
    with open("mock_data.json", "r", encoding="utf-8") as f:
        data = json.load(f)
    
    # 初始化图数据库
    graph_db = SimpleGraphDB()
    
    # 添加产品节点
    for product in data["products"]:
        graph_db.add_node(
            product["id"],
            "Product",
            {
                "name": product["name"],
                "type": product["type"],
                "keywords": product["keywords"],
                "features": product["features"]
            }
        )
    
    # 添加风格节点
    for style in data["styles"]:
        graph_db.add_node(
            style["id"],
            "Style",
            {
                "name": style["name"],
                "description": style["description"],
                "characteristics": style["characteristics"]
            }
        )
    
    # 添加文案节点
    for copy in data["copywritings"]:
        graph_db.add_node(
            copy["id"],
            "Copywriting",
            {
                "content": copy["content"],
                "tag": copy["tag"],
                "target_audience": copy["target_audience"]
            }
        )
    
    # 添加特征节点
    all_features = set()
    for product in data["products"]:
        for feature in product.get("features", []):
            all_features.add(feature)
    
    for feature in all_features:
        graph_db.add_node(feature, "Feature", {"name": feature})
    
    # 添加关系
    for rel in data["relationships"]:
        graph_db.add_edge(
            rel["source"],
            rel["target"],
            rel["relationship"]
        )
    
    # 初始化轻量级向量数据库
    vector_db = VectorDB()
    
    # 添加文案到向量数据库
    documents = []
    ids = []
    metadatas = []
    
    for copy in data["copywritings"]:
        documents.append(copy["content"])
        ids.append(copy["id"])
        metadatas.append({
            "product_id": copy["product_id"],
            "style_id": copy["style_id"],
            "tag": copy["tag"],
            "target_audience": copy["target_audience"]
        })
    
    vector_db.add_documents(documents, ids, metadatas)
    print(f"✅ 向量数据库已更新,包含 {len(documents)} 个文案")
    
    print("数据库初始化完成!")
    print(f"- 图数据库节点数: {len(graph_db.nodes)}")
    print(f"- 图数据库边数: {len(graph_db.edges)}")
    print(f"- 向量数据库文档数: {len(documents)}")
    
    return graph_db, vector_db