File size: 4,989 Bytes
90d1485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from typing import List, Dict, Any, Optional
import logging
from config import Config
from vector_store import VectorStore

logger = logging.getLogger(__name__)

class QAEngine:
    def __init__(self, vector_store: VectorStore):
        self.vector_store = vector_store
        
        # 初始化大语言模型
        if not Config.OPENAI_API_KEY:
            raise ValueError("请设置 OPENAI_API_KEY 环境变量")
        
        self.llm = ChatOpenAI(
            model_name=Config.OPENAI_MODEL,
            temperature=0.7,
            openai_api_key=Config.OPENAI_API_KEY
        )
        
        # 初始化对话记忆
        self.memory = ConversationBufferMemory(
            memory_key="chat_history",
            return_messages=True
        )
        
        # 自定义提示模板
        self.qa_prompt_template = """你是一个专业的AI助手,基于以下上下文信息来回答问题。

上下文信息:
{context}

问题: {question}

请基于上下文信息提供准确、详细的回答。如果上下文中没有相关信息,请明确说明无法从提供的信息中找到答案。

回答:"""
        
        self.qa_prompt = PromptTemplate(
            template=self.qa_prompt_template,
            input_variables=["context", "question"]
        )
        
        # 初始化检索问答链
        self.qa_chain = ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=self.vector_store.vectorstore.as_retriever(
                search_type="similarity",
                search_kwargs={"k": 4}
            ),
            memory=self.memory,
            combine_docs_chain_kwargs={"prompt": self.qa_prompt},
            return_source_documents=True,
            verbose=True
        )
        
        logger.info("问答引擎初始化完成")
    
    def ask_question(self, question: str) -> Dict[str, Any]:
        """提问并获取回答"""
        try:
            # 执行问答
            result = self.qa_chain({"question": question})
            
            # 提取源文档信息
            source_documents = []
            if result.get("source_documents"):
                for doc in result["source_documents"]:
                    source_documents.append({
                        "content": doc.page_content[:200] + "...",
                        "source": doc.metadata.get("source", "未知"),
                        "file_name": doc.metadata.get("file_name", "未知")
                    })
            
            response = {
                "answer": result.get("answer", "抱歉,我无法回答这个问题。"),
                "sources": source_documents,
                "question": question
            }
            
            logger.info(f"问题回答完成: {question}")
            return response
            
        except Exception as e:
            logger.error(f"问答过程中出现错误: {e}")
            return {
                "answer": f"抱歉,处理您的问题时出现了错误: {str(e)}",
                "sources": [],
                "question": question
            }
    
    def get_chat_history(self) -> List[Dict[str, str]]:
        """获取对话历史"""
        try:
            chat_history = self.memory.chat_memory.messages
            history = []
            
            for i in range(0, len(chat_history), 2):
                if i + 1 < len(chat_history):
                    history.append({
                        "question": chat_history[i].content,
                        "answer": chat_history[i + 1].content
                    })
            
            return history
        except Exception as e:
            logger.error(f"获取对话历史失败: {e}")
            return []
    
    def clear_memory(self) -> bool:
        """清除对话记忆"""
        try:
            self.memory.clear()
            logger.info("对话记忆已清除")
            return True
        except Exception as e:
            logger.error(f"清除对话记忆失败: {e}")
            return False
    
    def search_documents(self, query: str, k: int = 4) -> List[Dict[str, Any]]:
        """搜索相关文档"""
        try:
            results = self.vector_store.similarity_search_with_score(query, k=k)
            
            documents = []
            for doc, score in results:
                documents.append({
                    "content": doc.page_content,
                    "source": doc.metadata.get("source", "未知"),
                    "file_name": doc.metadata.get("file_name", "未知"),
                    "score": float(score)
                })
            
            return documents
        except Exception as e:
            logger.error(f"搜索文档失败: {e}")
            return []