File size: 15,550 Bytes
358eb7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
"""
RAG 模块 - GAIA 知识库检索增强生成
基于 GAIA metadata 构建预置知识库,提供问题解题参考
"""

import os
import csv
import json
from typing import Optional, List

from langchain_core.documents import Document
from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_openai import ChatOpenAI

try:
    from langchain_community.vectorstores import FAISS
except ImportError:
    from langchain.vectorstores import FAISS

from config import (
    OPENAI_BASE_URL,
    OPENAI_API_KEY,
    MODEL,
    TEMPERATURE,
    RAG_PERSIST_DIR,
    RAG_CSV_PATH,
    RAG_EMBEDDING_MODEL,
    RAG_TOP_K,
    DEBUG,
)

# 使用本地 HuggingFace Embedding(免费,无需 API)
try:
    from langchain_huggingface import HuggingFaceEmbeddings
    USE_LOCAL_EMBEDDING = True
except ImportError:
    try:
        from langchain_community.embeddings import HuggingFaceEmbeddings
        USE_LOCAL_EMBEDDING = True
    except ImportError:
        from langchain_openai import OpenAIEmbeddings
        USE_LOCAL_EMBEDDING = False


# ========================================
# RAG Manager
# ========================================

class GAIARAGManager:
    """
    GAIA RAG 管理器

    功能:
    - 从 GAIA metadata 构建知识库
    - 检索相似问题,提供解题参考
    - 不直接返回答案,只提供解题步骤和工具建议
    """

    def __init__(self, persist_dir: str = RAG_PERSIST_DIR):
        self.persist_dir = persist_dir

        # 延迟初始化(首次使用时加载)
        self._embeddings = None
        self._llm = None
        self._vectorstore = None
        self._initialized = False

        # 文本分割器(轻量级,可以立即初始化)
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200,
            separators=["\n\n", "\n", "。", ".", " ", ""]
        )

        # RAG Prompt(用于生成解题建议)
        self.rag_prompt = ChatPromptTemplate.from_messages([
            ("system", """你是一个解题策略顾问。基于相似问题的解题经验,为新问题提供解题建议。

注意:
1. 只提供解题思路和工具建议,不要直接给出答案
2. 参考历史问题的解题步骤,但要根据新问题调整
3. 如果相似问题不太相关,明确说明

相似问题参考:
{context}"""),
            ("human", "新问题:{question}\n\n请给出解题建议:")
        ])

    @property
    def embeddings(self):
        """延迟加载嵌入模型"""
        if self._embeddings is None:
            if DEBUG:
                print("[RAG] 正在加载嵌入模型...")
            if USE_LOCAL_EMBEDDING:
                self._embeddings = HuggingFaceEmbeddings(
                    model_name=RAG_EMBEDDING_MODEL,
                    model_kwargs={'device': 'cpu'},
                    encode_kwargs={'normalize_embeddings': True}
                )
            else:
                self._embeddings = OpenAIEmbeddings(
                    base_url=OPENAI_BASE_URL,
                    api_key=OPENAI_API_KEY,
                )
            if DEBUG:
                print("[RAG] 嵌入模型加载完成")
        return self._embeddings

    @property
    def llm(self):
        """延迟加载 LLM"""
        if self._llm is None:
            self._llm = ChatOpenAI(
                model=MODEL,
                temperature=TEMPERATURE,
                base_url=OPENAI_BASE_URL,
                api_key=OPENAI_API_KEY,
            )
        return self._llm

    @property
    def vectorstore(self) -> Optional[FAISS]:
        """延迟加载向量存储"""
        if not self._initialized:
            self._load_index()
            self._initialized = True
        return self._vectorstore

    @vectorstore.setter
    def vectorstore(self, value):
        self._vectorstore = value

    def _load_index(self):
        """加载已有的向量索引"""
        index_file = os.path.join(self.persist_dir, "index.faiss")
        if os.path.exists(index_file):
            try:
                self.vectorstore = FAISS.load_local(
                    self.persist_dir,
                    self.embeddings,
                    allow_dangerous_deserialization=True
                )
                if DEBUG:
                    print(f"[RAG] 已加载索引: {self.persist_dir}")
            except Exception as e:
                if DEBUG:
                    print(f"[RAG] 加载索引失败: {e}")
                self.vectorstore = None
        else:
            # 如果没有索引,尝试从默认 CSV 初始化
            self._init_from_csv()

    def _init_from_csv(self):
        """从默认 CSV 文件初始化向量库"""
        # 检查多个可能的路径
        possible_paths = [
            RAG_CSV_PATH,
            os.path.join(os.path.dirname(__file__), RAG_CSV_PATH),
            os.path.join(os.path.dirname(__file__), "data_clean.csv"),
        ]

        for csv_path in possible_paths:
            if os.path.exists(csv_path):
                if DEBUG:
                    print(f"[RAG] 从 CSV 初始化: {csv_path}")
                self.load_csv(csv_path)
                return

        if DEBUG:
            print("[RAG] 未找到 CSV 文件,知识库为空")

    def load_csv(self, csv_path: str):
        """
        从 CSV 文件加载文档

        CSV 格式:
        - content: 问题文本(用于 embedding)
        - metadata: JSON 格式的元数据(answer, steps, tools, has_file)
        """
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV 文件不存在: {csv_path}")

        documents = []
        with open(csv_path, newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                content = row.get("content", "")
                if not content:
                    continue

                # 解析 metadata
                try:
                    metadata = json.loads(row.get("metadata", "{}"))
                except json.JSONDecodeError:
                    metadata = {}

                metadata["csv_source"] = csv_path
                documents.append(Document(page_content=content, metadata=metadata))

        if not documents:
            if DEBUG:
                print("[RAG] CSV 中没有有效文档")
            return

        # 构建向量库
        self.vectorstore = FAISS.from_documents(documents, self.embeddings)

        # 持久化
        os.makedirs(self.persist_dir, exist_ok=True)
        self.vectorstore.save_local(self.persist_dir)

        if DEBUG:
            print(f"[RAG] 已加载 {len(documents)} 条文档")

    def retrieve(self, query: str, k: int = RAG_TOP_K) -> List[Document]:
        """
        检索相关文档

        Args:
            query: 查询文本
            k: 返回文档数量

        Returns:
            相关文档列表
        """
        if self.vectorstore is None:
            return []

        return self.vectorstore.similarity_search(query, k=k)

    def retrieve_with_scores(self, query: str, k: int = RAG_TOP_K) -> List[tuple]:
        """
        检索相关文档(带相似度分数)

        Args:
            query: 查询文本
            k: 返回文档数量

        Returns:
            [(doc, score), ...] 列表
        """
        if self.vectorstore is None:
            return []

        return self.vectorstore.similarity_search_with_score(query, k=k)

    def get_solving_hints(self, question: str, k: int = RAG_TOP_K, score_threshold: float = 1.5) -> str:
        """
        获取解题提示

        根据相似问题,提取解题步骤和工具建议

        Args:
            question: 新问题
            k: 检索数量
            score_threshold: 相似度阈值(越小越相似,FAISS L2距离)

        Returns:
            解题提示文本
        """
        docs_with_scores = self.retrieve_with_scores(question, k=k)

        if not docs_with_scores:
            return ""

        # 过滤低相似度结果
        relevant_docs = [(doc, score) for doc, score in docs_with_scores if score < score_threshold]

        if not relevant_docs:
            return ""

        hints = []
        for i, (doc, score) in enumerate(relevant_docs, 1):
            meta = doc.metadata
            steps = meta.get('steps', '')
            tools = meta.get('tools', '')
            has_file = meta.get('has_file', False)

            hint_parts = [f"### 参考 {i} (相似度: {1/(1+score):.2f})"]
            hint_parts.append(f"**相似问题**: {doc.page_content[:100]}...")

            if steps:
                hint_parts.append(f"**解题步骤**: {steps[:300]}...")
            if tools:
                hint_parts.append(f"**推荐工具**: {tools}")
            if has_file:
                hint_parts.append("**注意**: 该问题有附件文件")

            hints.append("\n".join(hint_parts))

        return "\n\n".join(hints)

    def query(self, question: str, k: int = RAG_TOP_K) -> str:
        """
        RAG 查询:检索 + 生成解题建议

        Args:
            question: 用户问题
            k: 检索文档数量

        Returns:
            解题建议
        """
        # 1. 检索相关文档
        docs = self.retrieve(question, k=k)

        if not docs:
            return "知识库中没有找到相似问题。建议直接分析问题并使用适当的工具。"

        # 2. 构建上下文
        context_parts = []
        for i, doc in enumerate(docs, 1):
            meta = doc.metadata
            context_parts.append(f"""
[相似问题 {i}]
问题: {doc.page_content}
解题步骤: {meta.get('steps', 'N/A')}
使用工具: {meta.get('tools', 'N/A')}
有附件: {'是' if meta.get('has_file') else '否'}
答案格式参考: {meta.get('answer', 'N/A')[:50]}...
""")

        context = "\n".join(context_parts)

        # 3. LLM 生成建议
        chain = self.rag_prompt | self.llm
        response = chain.invoke({
            "context": context,
            "question": question
        })

        return response.content

    def get_stats(self) -> dict:
        """获取索引统计信息"""
        if self.vectorstore is None:
            return {"status": "empty", "doc_count": 0}

        try:
            doc_count = self.vectorstore.index.ntotal
        except:
            doc_count = "unknown"

        return {
            "status": "loaded",
            "doc_count": doc_count,
            "persist_dir": self.persist_dir
        }


# ========================================
# 全局实例
# ========================================

_rag_manager: Optional[GAIARAGManager] = None


def get_rag_manager() -> GAIARAGManager:
    """获取 RAG 管理器单例"""
    global _rag_manager
    if _rag_manager is None:
        _rag_manager = GAIARAGManager()
    return _rag_manager


def _score_to_similarity(score) -> float:
    """FAISS L2 距离转 [0, 1] 相似度,处理异常值"""
    try:
        score_f = float(score)
    except Exception:
        return 0.0
    if score_f != score_f:  # NaN
        return 0.0
    if score_f < 0.0:
        score_f = 0.0
    return 1.0 / (1.0 + score_f)


def rag_lookup_answer(question: str, min_similarity: float = 0.85):
    """
    RAG 短路查找:高置信度匹配时直接返回答案。

    Returns:
        命中: {"answer": str, "similarity": float, "score": float, "metadata": dict}
        未命中/异常: None
    """
    if not question or not str(question).strip():
        return None
    try:
        manager = get_rag_manager()
        results = manager.retrieve_with_scores(str(question).strip(), k=1)
        if not results:
            return None
        best_doc, best_score = results[0]
        similarity = _score_to_similarity(best_score)
        answer = (best_doc.metadata.get("answer") or "").strip()
        if not answer:
            return None
        if similarity > float(min_similarity):
            return {
                "answer": answer,
                "similarity": float(similarity),
                "score": float(best_score),
                "metadata": dict(best_doc.metadata),
            }
        return None
    except Exception as e:
        if DEBUG:
            print(f"[RAG] rag_lookup_answer failed: {type(e).__name__}: {e}")
        return None


# ========================================
# Agent 工具
# ========================================

@tool
def rag_query(question: str) -> str:
    """
    查询知识库。如果找到高度匹配的问题,直接返回答案;否则返回解题建议。

    适用于:
    - 快速查找已知问题的答案
    - 获取相似问题的解题思路和推荐工具

    Args:
        question: 用户问题

    Returns:
        匹配答案或解题建议
    """
    manager = get_rag_manager()

    # 使用带分数的检索
    results = manager.retrieve_with_scores(question, k=3)
    if not results:
        return "知识库中没有找到相似问题。建议使用 web_search 等工具获取信息。"

    best_doc, best_score = results[0]
    similarity = 1 / (1 + best_score)

    # 高相似度 (>0.85):直接返回答案
    if similarity > 0.85:
        answer = best_doc.metadata.get('answer', '')
        if answer:
            return f"【知识库匹配成功】相似度: {similarity:.2f}\n直接答案: {answer}\n请直接使用此答案作为最终回答。"

    # 中等相似度:返回答案 + 解题参考
    if similarity > 0.6:
        parts = []
        for i, (doc, score) in enumerate(results[:2], 1):
            sim = 1 / (1 + score)
            meta = doc.metadata
            parts.append(
                f"[参考 {i}] 相似度: {sim:.2f}\n"
                f"问题: {doc.page_content[:100]}...\n"
                f"答案: {meta.get('answer', 'N/A')}\n"
                f"步骤: {meta.get('steps', 'N/A')[:200]}\n"
                f"工具: {meta.get('tools', 'N/A')}"
            )
        return "【知识库参考】\n" + "\n---\n".join(parts)

    # 低相似度:仅返回工具建议
    return manager.query(question)


@tool
def rag_retrieve(query: str) -> str:
    """
    仅检索知识库中的相关文档片段,不生成建议。

    用于查看原始的相似问题和解题步骤。

    Args:
        query: 检索查询

    Returns:
        相关文档片段
    """
    manager = get_rag_manager()
    docs_with_scores = manager.retrieve_with_scores(query, k=3)

    if not docs_with_scores:
        return "知识库为空或未找到相关文档。"

    results = []
    for i, (doc, score) in enumerate(docs_with_scores, 1):
        meta = doc.metadata
        results.append(f"""[{i}] 相似度: {1/(1+score):.2f}
问题: {doc.page_content[:200]}...
解题步骤: {meta.get('steps', 'N/A')[:200]}...
工具: {meta.get('tools', 'N/A')}
""")

    return "\n---\n".join(results)


@tool
def rag_stats() -> str:
    """
    获取知识库统计信息。

    Returns:
        知识库状态和文档数量
    """
    manager = get_rag_manager()
    stats = manager.get_stats()
    return f"知识库状态: {stats['status']}, 文档数量: {stats['doc_count']}"


# ========================================
# 导出 RAG 工具
# ========================================

RAG_TOOLS = [rag_query, rag_retrieve, rag_stats]