Spaces:
Sleeping
Sleeping
File size: 4,989 Bytes
90d1485 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from typing import List, Dict, Any, Optional
import logging
from config import Config
from vector_store import VectorStore
logger = logging.getLogger(__name__)
class QAEngine:
def __init__(self, vector_store: VectorStore):
self.vector_store = vector_store
# 初始化大语言模型
if not Config.OPENAI_API_KEY:
raise ValueError("请设置 OPENAI_API_KEY 环境变量")
self.llm = ChatOpenAI(
model_name=Config.OPENAI_MODEL,
temperature=0.7,
openai_api_key=Config.OPENAI_API_KEY
)
# 初始化对话记忆
self.memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
# 自定义提示模板
self.qa_prompt_template = """你是一个专业的AI助手,基于以下上下文信息来回答问题。
上下文信息:
{context}
问题: {question}
请基于上下文信息提供准确、详细的回答。如果上下文中没有相关信息,请明确说明无法从提供的信息中找到答案。
回答:"""
self.qa_prompt = PromptTemplate(
template=self.qa_prompt_template,
input_variables=["context", "question"]
)
# 初始化检索问答链
self.qa_chain = ConversationalRetrievalChain.from_llm(
llm=self.llm,
retriever=self.vector_store.vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": 4}
),
memory=self.memory,
combine_docs_chain_kwargs={"prompt": self.qa_prompt},
return_source_documents=True,
verbose=True
)
logger.info("问答引擎初始化完成")
def ask_question(self, question: str) -> Dict[str, Any]:
"""提问并获取回答"""
try:
# 执行问答
result = self.qa_chain({"question": question})
# 提取源文档信息
source_documents = []
if result.get("source_documents"):
for doc in result["source_documents"]:
source_documents.append({
"content": doc.page_content[:200] + "...",
"source": doc.metadata.get("source", "未知"),
"file_name": doc.metadata.get("file_name", "未知")
})
response = {
"answer": result.get("answer", "抱歉,我无法回答这个问题。"),
"sources": source_documents,
"question": question
}
logger.info(f"问题回答完成: {question}")
return response
except Exception as e:
logger.error(f"问答过程中出现错误: {e}")
return {
"answer": f"抱歉,处理您的问题时出现了错误: {str(e)}",
"sources": [],
"question": question
}
def get_chat_history(self) -> List[Dict[str, str]]:
"""获取对话历史"""
try:
chat_history = self.memory.chat_memory.messages
history = []
for i in range(0, len(chat_history), 2):
if i + 1 < len(chat_history):
history.append({
"question": chat_history[i].content,
"answer": chat_history[i + 1].content
})
return history
except Exception as e:
logger.error(f"获取对话历史失败: {e}")
return []
def clear_memory(self) -> bool:
"""清除对话记忆"""
try:
self.memory.clear()
logger.info("对话记忆已清除")
return True
except Exception as e:
logger.error(f"清除对话记忆失败: {e}")
return False
def search_documents(self, query: str, k: int = 4) -> List[Dict[str, Any]]:
"""搜索相关文档"""
try:
results = self.vector_store.similarity_search_with_score(query, k=k)
documents = []
for doc, score in results:
documents.append({
"content": doc.page_content,
"source": doc.metadata.get("source", "未知"),
"file_name": doc.metadata.get("file_name", "未知"),
"score": float(score)
})
return documents
except Exception as e:
logger.error(f"搜索文档失败: {e}")
return [] |