import streamlit as st import os import json import re import datetime import tempfile # 导入 LangChain 相关组件 from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda from langchain_core.tools import tool from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import PyPDFLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS # 导入 YouTube 搜索库 from youtube_search import YoutubeSearch # ========================================== # 1. 基础配置与环境初始化 # ========================================== # 配置 Streamlit 页面标题、图标和布局 st.set_page_config(page_title="FeiChat Final", page_icon="✨", layout="wide") st.title("✨ FeiChat (Tavily + YouTube 完美版)") # 设置 API Key # 注意:实际生产中建议使用 st.secrets 或系统环境变量,不要直接写在代码里 os.environ["OPENAI_API_KEY"] = "lm-studio" # 指向本地 LM Studio,Key 随意填写 os.environ["TAVILY_API_KEY"] = "tvly-dev-xxxx" # Tavily 搜索引擎 Key # 初始化 Session State (会话状态) #用于在 Streamlit 页面刷新(rerun)时保存聊天记录和向量数据库 if "messages" not in st.session_state: st.session_state.messages = [] # 存储对话历史 if "vector_store" not in st.session_state: st.session_state.vector_store = None # 存储 PDF 向量索引 # ========================================== # 1.1 模型加载 (使用缓存避免重复加载) # ========================================== @st.cache_resource def get_models(): """ 初始化 LLM 和 Embedding 模型。 使用 @st.cache_resource 装饰器,确保只加载一次,节省资源。 """ # 1. 路由模型 (Router):温度设为 0.0,要求输出精确,用于判断意图 router = ChatOpenAI( base_url="http://127.0.0.1:1234/v1", # 连接本地 LM Studio 端口 model="kuaidao-c-suite-v2", temperature=0.0 ) # 2. 对话模型 (Chat):温度设为 0.7,用于生成流畅、自然的回答,开启流式输出 chat = ChatOpenAI( base_url="http://127.0.0.1:1234/v1", model="kuaidao-c-suite-v2", temperature=0.7, streaming=True ) # 3. 嵌入模型 (Embeddings):用于将 PDF 文本转化为向量,这里使用 HuggingFace 的轻量级模型 embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") return router, chat, embeddings llm_router, llm_chat, embeddings = get_models() # ========================================== # 2. 侧边栏 RAG (私有知识库处理) # ========================================== with st.sidebar: st.header("📂 私有知识库") # 文件上传控件 uploaded_file = st.file_uploader("上传 PDF (仅当问及文档内容时使用)", type=["pdf"]) # 如果用户上传了文件,且向量库还未建立,则开始处理 if uploaded_file and st.session_state.vector_store is None: with st.status("正在学习文档...", expanded=True): # 1. 创建临时文件保存上传的 PDF (PyPDFLoader 需要本地文件路径) with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: tmp.write(uploaded_file.read()) path = tmp.name # 2. 加载 PDF loader = PyPDFLoader(path) docs = loader.load() # 3. 文本切分:将长文档切成 500字符的小块,保留 50字符重叠以保持上下文 splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) splits = splitter.split_documents(docs) # 4. 向量化并存储:使用 FAISS 构建向量索引 st.session_state.vector_store = FAISS.from_documents(splits, embeddings) st.success(f"已索引 {len(splits)} 个片段") # 5. 删除临时文件,保持整洁 os.remove(path) # 清除记忆按钮 if st.button("🗑️ 清空记忆"): st.session_state.messages = [] st.session_state.vector_store = None st.rerun() # 重新运行脚本以刷新页面状态 # ========================================== # 3. 工具定义 (Search & RAG) # ========================================== # 初始化 Tavily 搜索客户端 tavily_engine = TavilySearchResults(max_results=5) @tool def internet_search(query: str) -> str: """Tavily 联网搜索工具函数,供 Agent 调用。""" print(f"🕵️ Tavily 搜索: {query}") # 后台打印日志 try: results = tavily_engine.invoke({"query": query}) formatted = [] # 格式化搜索结果,包含 URL 和内容摘要 for i, res in enumerate(results): formatted.append(f"【来源】({res['url']}):\n{res['content']}") return "\n\n".join(formatted) except Exception as e: return f"Error: {e}" @tool def knowledge_base_search(query: str) -> str: """知识库(RAG)搜索工具函数。""" # 如果没上传文件,直接返回提示 if st.session_state.vector_store is None: return "用户未上传任何文档。" # 在向量库中搜索最相似的 3 个片段 docs = st.session_state.vector_store.similarity_search(query, k=3) return "\n\n".join([f"【文档片段】: {d.page_content}" for d in docs]) # 将工具放入字典,方便 Router 调用 tools = {"internet_search": internet_search, "knowledge_base_search": knowledge_base_search} def search_youtube(query): """ YouTube 搜索辅助函数 注意:这是独立功能,不作为 LLM 的 Tool,而是在 UI 层直接展示结果 """ try: # 限制结果为 3 个 return YoutubeSearch(query, max_results=3).to_dict() except: return [] # ========================================== # 4. 核心逻辑 (Router & Chain) # ========================================== def get_time(): return datetime.datetime.now().strftime("%Y年%m月%d日") # 4.1 意图识别 Prompt # 作用:让 LLM 判断用户是想闲聊、查文档还是联网搜索,并提取搜索关键词 intent_prompt = ChatPromptTemplate.from_messages([ ("system", """ 你是一个智能路由。当前时间:{current_date}。 【工具选择逻辑】: 1. knowledge_base_search: 🔴 仅当用户明确提到“文档”、“PDF”、“上传的文件”等时使用。 2. internet_search: 🟢 默认选项(如果用户问知识性问题)。 3. CHAT: 仅用于纯打招呼、情感交流,不需要外部信息。 【Query生成规则】: - 强时效性问题(新闻、天气):必须在关键词中加 `{current_date}`。 - 弱时效性问题(歌曲、百科、人物):**禁止加日期**,直接用实体名。 返回 JSON 格式: {{ "intent": "CHAT" 或 "TOOL", "tool_name": "...", "tool_args": {{ "query": "..." }} }} """), ("user", "历史:\n{chat_history}\n\n输入:\n{input}") ]) def parse_router(text): """解析 Router LLM 返回的 JSON 字符串""" try: # 使用正则提取 Markdown 代码块中的 JSON (防止 LLM 输出 ```json ... ```) if "```" in text: text = re.search(r"```(?:json)?(.*?)```", text, re.DOTALL).group(1) return json.loads(text.strip()) except: # 解析失败则默认回退到纯聊天模式 return {"intent": "CHAT"} # 构建路由链:Prompt -> LLM -> 文本解析 -> JSON解析 intent_chain = intent_prompt | llm_router | StrOutputParser() | RunnableLambda(parse_router) # 4.2 最终回复润色 Prompt # 作用:根据搜索结果生成给用户的最终回答 response_prompt = ChatPromptTemplate.from_messages([ ("system", """ 你是一个严谨的信息整合助手。 请严格基于【搜索结果】回答。 1. 如果结果里没有,诚实说“未找到”。 2. 严禁捏造。 """), ("user", "问题: {user_input}\n\n搜索结果:\n{tool_result}") ]) response_chain = response_prompt | llm_chat | StrOutputParser() # 4.3 纯闲聊 Prompt chat_chain = ChatPromptTemplate.from_messages([ ("system", "助手。"), ("user", "{input}") ]) | llm_chat | StrOutputParser() # ========================================== # 5. 界面 UI 交互逻辑 # ========================================== # 5.1 显示历史消息 for msg in st.session_state.messages: with st.chat_message(msg["role"]): st.markdown(msg["content"]) # 5.2 处理用户输入 if user_input := st.chat_input("问:Fruits Zipper 最火的歌..."): # 记录用户输入 st.session_state.messages.append({"role": "user", "content": user_input}) with st.chat_message("user"): st.markdown(user_input) with st.chat_message("assistant"): # 使用 st.status 显示“思考中”状态动画 with st.status("🧠 思考中...", expanded=False) as status: # 准备上下文 hist = str(st.session_state.messages[:-1]) now = get_time() # --- 第一步:路由判断 --- intent_res = intent_chain.invoke({"input": user_input, "chat_history": hist, "current_date": now}) st.json(intent_res) # 调试用:在折叠状态里显示路由结果 final_stream = None # 用于存储最终的流式输出对象 yt_query = None # 用于存储 YouTube 搜索关键词 # --- 第二步:根据意图分支 --- if intent_res.get("intent") == "TOOL": tool_name = intent_res.get("tool_name") query = intent_res.get("tool_args", {}).get("query", user_input) # 如果是联网搜索,顺便记录一下关键词用于稍后搜 YouTube if tool_name == "internet_search": yt_query = query if tool_name in tools: try: # 执行工具 (Tavily 或 向量检索) tool_res = tools[tool_name].invoke(query) # 在 UI 中增加一个折叠框,显示搜索到的原始数据(增加可信度) with st.expander("📄 查看搜索摘要"): st.text(tool_res) # 防幻觉逻辑:如果搜索结果太短或包含错误信息 if "未找到有效信息" in tool_res or len(tool_res.strip()) < 50: final_stream = response_chain.stream({"user_input": user_input, "tool_result": "未找到相关信息。"}) else: # 正常生成回答 final_stream = response_chain.stream({"user_input": user_input, "tool_result": tool_res}) except Exception as e: st.error(f"工具执行失败: {e}") else: st.error("未找到工具") else: # 如果是 CHAT 意图,直接闲聊 final_stream = chat_chain.stream({"input": user_input}) # 更新状态栏为完成 status.update(label="完成", state="complete") # --- 第三步:流式输出回答 --- if final_stream: full_response = st.write_stream(final_stream) # Streamlit 自带的打字机效果 st.session_state.messages.append({"role": "assistant", "content": full_response}) # --- 第四步:展示 YouTube 视频 (仅当触发了联网搜索时) --- if yt_query: st.markdown("---") # 分割线 videos = search_youtube(yt_query) if videos: cols = st.columns(3) # 三列布局 for i, v in enumerate(videos[:3]): with cols[i]: # 处理缩略图:有些 API 返回的是列表,有些是字符串,这里做了兼容处理 thumb = v['thumbnails'][0] if isinstance(v['thumbnails'], list) else v['thumbnails'] st.image(thumb, use_container_width=True) # 显示标题链接 st.markdown(f"**[{v['title']}](https://www.youtube.com{v['url_suffix']})**") # 显示观看量 st.caption(f"👀 {v['views']}") else: st.caption("未找到相关视频。")