FeiChat / web_app.py
aifeifei798's picture
Update web_app.py
0fcfa03 verified
import streamlit as st
import os
import json
import re
import datetime
import tempfile
# 导入 LangChain 相关组件
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda
from langchain_core.tools import tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# 导入 YouTube 搜索库
from youtube_search import YoutubeSearch
# ==========================================
# 1. 基础配置与环境初始化
# ==========================================
# 配置 Streamlit 页面标题、图标和布局
st.set_page_config(page_title="FeiChat Final", page_icon="✨", layout="wide")
st.title("✨ FeiChat (Tavily + YouTube 完美版)")
# 设置 API Key
# 注意:实际生产中建议使用 st.secrets 或系统环境变量,不要直接写在代码里
os.environ["OPENAI_API_KEY"] = "lm-studio" # 指向本地 LM Studio,Key 随意填写
os.environ["TAVILY_API_KEY"] = "tvly-dev-xxxx" # Tavily 搜索引擎 Key
# 初始化 Session State (会话状态)
#用于在 Streamlit 页面刷新(rerun)时保存聊天记录和向量数据库
if "messages" not in st.session_state:
st.session_state.messages = [] # 存储对话历史
if "vector_store" not in st.session_state:
st.session_state.vector_store = None # 存储 PDF 向量索引
# ==========================================
# 1.1 模型加载 (使用缓存避免重复加载)
# ==========================================
@st.cache_resource
def get_models():
"""
初始化 LLM 和 Embedding 模型。
使用 @st.cache_resource 装饰器,确保只加载一次,节省资源。
"""
# 1. 路由模型 (Router):温度设为 0.0,要求输出精确,用于判断意图
router = ChatOpenAI(
base_url="http://127.0.0.1:1234/v1", # 连接本地 LM Studio 端口
model="kuaidao-c-suite-v2",
temperature=0.0
)
# 2. 对话模型 (Chat):温度设为 0.7,用于生成流畅、自然的回答,开启流式输出
chat = ChatOpenAI(
base_url="http://127.0.0.1:1234/v1",
model="kuaidao-c-suite-v2",
temperature=0.7,
streaming=True
)
# 3. 嵌入模型 (Embeddings):用于将 PDF 文本转化为向量,这里使用 HuggingFace 的轻量级模型
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
return router, chat, embeddings
llm_router, llm_chat, embeddings = get_models()
# ==========================================
# 2. 侧边栏 RAG (私有知识库处理)
# ==========================================
with st.sidebar:
st.header("📂 私有知识库")
# 文件上传控件
uploaded_file = st.file_uploader("上传 PDF (仅当问及文档内容时使用)", type=["pdf"])
# 如果用户上传了文件,且向量库还未建立,则开始处理
if uploaded_file and st.session_state.vector_store is None:
with st.status("正在学习文档...", expanded=True):
# 1. 创建临时文件保存上传的 PDF (PyPDFLoader 需要本地文件路径)
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(uploaded_file.read())
path = tmp.name
# 2. 加载 PDF
loader = PyPDFLoader(path)
docs = loader.load()
# 3. 文本切分:将长文档切成 500字符的小块,保留 50字符重叠以保持上下文
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
splits = splitter.split_documents(docs)
# 4. 向量化并存储:使用 FAISS 构建向量索引
st.session_state.vector_store = FAISS.from_documents(splits, embeddings)
st.success(f"已索引 {len(splits)} 个片段")
# 5. 删除临时文件,保持整洁
os.remove(path)
# 清除记忆按钮
if st.button("🗑️ 清空记忆"):
st.session_state.messages = []
st.session_state.vector_store = None
st.rerun() # 重新运行脚本以刷新页面状态
# ==========================================
# 3. 工具定义 (Search & RAG)
# ==========================================
# 初始化 Tavily 搜索客户端
tavily_engine = TavilySearchResults(max_results=5)
@tool
def internet_search(query: str) -> str:
"""Tavily 联网搜索工具函数,供 Agent 调用。"""
print(f"🕵️ Tavily 搜索: {query}") # 后台打印日志
try:
results = tavily_engine.invoke({"query": query})
formatted = []
# 格式化搜索结果,包含 URL 和内容摘要
for i, res in enumerate(results):
formatted.append(f"【来源】({res['url']}):\n{res['content']}")
return "\n\n".join(formatted)
except Exception as e:
return f"Error: {e}"
@tool
def knowledge_base_search(query: str) -> str:
"""知识库(RAG)搜索工具函数。"""
# 如果没上传文件,直接返回提示
if st.session_state.vector_store is None: return "用户未上传任何文档。"
# 在向量库中搜索最相似的 3 个片段
docs = st.session_state.vector_store.similarity_search(query, k=3)
return "\n\n".join([f"【文档片段】: {d.page_content}" for d in docs])
# 将工具放入字典,方便 Router 调用
tools = {"internet_search": internet_search, "knowledge_base_search": knowledge_base_search}
def search_youtube(query):
"""
YouTube 搜索辅助函数
注意:这是独立功能,不作为 LLM 的 Tool,而是在 UI 层直接展示结果
"""
try:
# 限制结果为 3 个
return YoutubeSearch(query, max_results=3).to_dict()
except: return []
# ==========================================
# 4. 核心逻辑 (Router & Chain)
# ==========================================
def get_time(): return datetime.datetime.now().strftime("%Y年%m月%d日")
# 4.1 意图识别 Prompt
# 作用:让 LLM 判断用户是想闲聊、查文档还是联网搜索,并提取搜索关键词
intent_prompt = ChatPromptTemplate.from_messages([
("system", """
你是一个智能路由。当前时间:{current_date}。
【工具选择逻辑】:
1. knowledge_base_search: 🔴 仅当用户明确提到“文档”、“PDF”、“上传的文件”等时使用。
2. internet_search: 🟢 默认选项(如果用户问知识性问题)。
3. CHAT: 仅用于纯打招呼、情感交流,不需要外部信息。
【Query生成规则】:
- 强时效性问题(新闻、天气):必须在关键词中加 `{current_date}`。
- 弱时效性问题(歌曲、百科、人物):**禁止加日期**,直接用实体名。
返回 JSON 格式: {{ "intent": "CHAT" 或 "TOOL", "tool_name": "...", "tool_args": {{ "query": "..." }} }}
"""),
("user", "历史:\n{chat_history}\n\n输入:\n{input}")
])
def parse_router(text):
"""解析 Router LLM 返回的 JSON 字符串"""
try:
# 使用正则提取 Markdown 代码块中的 JSON (防止 LLM 输出 ```json ... ```)
if "```" in text: text = re.search(r"```(?:json)?(.*?)```", text, re.DOTALL).group(1)
return json.loads(text.strip())
except:
# 解析失败则默认回退到纯聊天模式
return {"intent": "CHAT"}
# 构建路由链:Prompt -> LLM -> 文本解析 -> JSON解析
intent_chain = intent_prompt | llm_router | StrOutputParser() | RunnableLambda(parse_router)
# 4.2 最终回复润色 Prompt
# 作用:根据搜索结果生成给用户的最终回答
response_prompt = ChatPromptTemplate.from_messages([
("system", """
你是一个严谨的信息整合助手。
请严格基于【搜索结果】回答。
1. 如果结果里没有,诚实说“未找到”。
2. 严禁捏造。
"""),
("user", "问题: {user_input}\n\n搜索结果:\n{tool_result}")
])
response_chain = response_prompt | llm_chat | StrOutputParser()
# 4.3 纯闲聊 Prompt
chat_chain = ChatPromptTemplate.from_messages([
("system", "助手。"), ("user", "{input}")
]) | llm_chat | StrOutputParser()
# ==========================================
# 5. 界面 UI 交互逻辑
# ==========================================
# 5.1 显示历史消息
for msg in st.session_state.messages:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
# 5.2 处理用户输入
if user_input := st.chat_input("问:Fruits Zipper 最火的歌..."):
# 记录用户输入
st.session_state.messages.append({"role": "user", "content": user_input})
with st.chat_message("user"):
st.markdown(user_input)
with st.chat_message("assistant"):
# 使用 st.status 显示“思考中”状态动画
with st.status("🧠 思考中...", expanded=False) as status:
# 准备上下文
hist = str(st.session_state.messages[:-1])
now = get_time()
# --- 第一步:路由判断 ---
intent_res = intent_chain.invoke({"input": user_input, "chat_history": hist, "current_date": now})
st.json(intent_res) # 调试用:在折叠状态里显示路由结果
final_stream = None # 用于存储最终的流式输出对象
yt_query = None # 用于存储 YouTube 搜索关键词
# --- 第二步:根据意图分支 ---
if intent_res.get("intent") == "TOOL":
tool_name = intent_res.get("tool_name")
query = intent_res.get("tool_args", {}).get("query", user_input)
# 如果是联网搜索,顺便记录一下关键词用于稍后搜 YouTube
if tool_name == "internet_search":
yt_query = query
if tool_name in tools:
try:
# 执行工具 (Tavily 或 向量检索)
tool_res = tools[tool_name].invoke(query)
# 在 UI 中增加一个折叠框,显示搜索到的原始数据(增加可信度)
with st.expander("📄 查看搜索摘要"):
st.text(tool_res)
# 防幻觉逻辑:如果搜索结果太短或包含错误信息
if "未找到有效信息" in tool_res or len(tool_res.strip()) < 50:
final_stream = response_chain.stream({"user_input": user_input, "tool_result": "未找到相关信息。"})
else:
# 正常生成回答
final_stream = response_chain.stream({"user_input": user_input, "tool_result": tool_res})
except Exception as e:
st.error(f"工具执行失败: {e}")
else:
st.error("未找到工具")
else:
# 如果是 CHAT 意图,直接闲聊
final_stream = chat_chain.stream({"input": user_input})
# 更新状态栏为完成
status.update(label="完成", state="complete")
# --- 第三步:流式输出回答 ---
if final_stream:
full_response = st.write_stream(final_stream) # Streamlit 自带的打字机效果
st.session_state.messages.append({"role": "assistant", "content": full_response})
# --- 第四步:展示 YouTube 视频 (仅当触发了联网搜索时) ---
if yt_query:
st.markdown("---") # 分割线
videos = search_youtube(yt_query)
if videos:
cols = st.columns(3) # 三列布局
for i, v in enumerate(videos[:3]):
with cols[i]:
# 处理缩略图:有些 API 返回的是列表,有些是字符串,这里做了兼容处理
thumb = v['thumbnails'][0] if isinstance(v['thumbnails'], list) else v['thumbnails']
st.image(thumb, use_container_width=True)
# 显示标题链接
st.markdown(f"**[{v['title']}](https://www.youtube.com{v['url_suffix']})**")
# 显示观看量
st.caption(f"👀 {v['views']}")
else:
st.caption("未找到相关视频。")