from langchain.chat_models import ChatOpenAI from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from langchain.prompts import PromptTemplate from typing import List, Dict, Any, Optional import logging from config import Config from vector_store import VectorStore logger = logging.getLogger(__name__) class QAEngine: def __init__(self, vector_store: VectorStore): self.vector_store = vector_store # 初始化大语言模型 if not Config.OPENAI_API_KEY: raise ValueError("请设置 OPENAI_API_KEY 环境变量") self.llm = ChatOpenAI( model_name=Config.OPENAI_MODEL, temperature=0.7, openai_api_key=Config.OPENAI_API_KEY ) # 初始化对话记忆 self.memory = ConversationBufferMemory( memory_key="chat_history", return_messages=True ) # 自定义提示模板 self.qa_prompt_template = """你是一个专业的AI助手,基于以下上下文信息来回答问题。 上下文信息: {context} 问题: {question} 请基于上下文信息提供准确、详细的回答。如果上下文中没有相关信息,请明确说明无法从提供的信息中找到答案。 回答:""" self.qa_prompt = PromptTemplate( template=self.qa_prompt_template, input_variables=["context", "question"] ) # 初始化检索问答链 self.qa_chain = ConversationalRetrievalChain.from_llm( llm=self.llm, retriever=self.vector_store.vectorstore.as_retriever( search_type="similarity", search_kwargs={"k": 4} ), memory=self.memory, combine_docs_chain_kwargs={"prompt": self.qa_prompt}, return_source_documents=True, verbose=True ) logger.info("问答引擎初始化完成") def ask_question(self, question: str) -> Dict[str, Any]: """提问并获取回答""" try: # 执行问答 result = self.qa_chain({"question": question}) # 提取源文档信息 source_documents = [] if result.get("source_documents"): for doc in result["source_documents"]: source_documents.append({ "content": doc.page_content[:200] + "...", "source": doc.metadata.get("source", "未知"), "file_name": doc.metadata.get("file_name", "未知") }) response = { "answer": result.get("answer", "抱歉,我无法回答这个问题。"), "sources": source_documents, "question": question } logger.info(f"问题回答完成: {question}") return response except Exception as e: logger.error(f"问答过程中出现错误: {e}") return { "answer": f"抱歉,处理您的问题时出现了错误: {str(e)}", "sources": [], "question": question } def get_chat_history(self) -> List[Dict[str, str]]: """获取对话历史""" try: chat_history = self.memory.chat_memory.messages history = [] for i in range(0, len(chat_history), 2): if i + 1 < len(chat_history): history.append({ "question": chat_history[i].content, "answer": chat_history[i + 1].content }) return history except Exception as e: logger.error(f"获取对话历史失败: {e}") return [] def clear_memory(self) -> bool: """清除对话记忆""" try: self.memory.clear() logger.info("对话记忆已清除") return True except Exception as e: logger.error(f"清除对话记忆失败: {e}") return False def search_documents(self, query: str, k: int = 4) -> List[Dict[str, Any]]: """搜索相关文档""" try: results = self.vector_store.similarity_search_with_score(query, k=k) documents = [] for doc, score in results: documents.append({ "content": doc.page_content, "source": doc.metadata.get("source", "未知"), "file_name": doc.metadata.get("file_name", "未知"), "score": float(score) }) return documents except Exception as e: logger.error(f"搜索文档失败: {e}") return []