|
|
import streamlit as st |
|
|
import os |
|
|
import json |
|
|
import re |
|
|
import datetime |
|
|
import tempfile |
|
|
|
|
|
from langchain_openai import ChatOpenAI |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.runnables import RunnableLambda |
|
|
from langchain_core.tools import tool |
|
|
from langchain_community.tools.tavily_search import TavilySearchResults |
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import FAISS |
|
|
|
|
|
from youtube_search import YoutubeSearch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.set_page_config(page_title="FeiChat Final", page_icon="✨", layout="wide") |
|
|
st.title("✨ FeiChat (Tavily + YouTube 完美版)") |
|
|
|
|
|
|
|
|
|
|
|
os.environ["OPENAI_API_KEY"] = "lm-studio" |
|
|
os.environ["TAVILY_API_KEY"] = "tvly-dev-xxxx" |
|
|
|
|
|
|
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state.messages = [] |
|
|
if "vector_store" not in st.session_state: |
|
|
st.session_state.vector_store = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def get_models(): |
|
|
""" |
|
|
初始化 LLM 和 Embedding 模型。 |
|
|
使用 @st.cache_resource 装饰器,确保只加载一次,节省资源。 |
|
|
""" |
|
|
|
|
|
router = ChatOpenAI( |
|
|
base_url="http://127.0.0.1:1234/v1", |
|
|
model="kuaidao-c-suite-v2", |
|
|
temperature=0.0 |
|
|
) |
|
|
|
|
|
|
|
|
chat = ChatOpenAI( |
|
|
base_url="http://127.0.0.1:1234/v1", |
|
|
model="kuaidao-c-suite-v2", |
|
|
temperature=0.7, |
|
|
streaming=True |
|
|
) |
|
|
|
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") |
|
|
|
|
|
return router, chat, embeddings |
|
|
|
|
|
llm_router, llm_chat, embeddings = get_models() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("📂 私有知识库") |
|
|
|
|
|
|
|
|
uploaded_file = st.file_uploader("上传 PDF (仅当问及文档内容时使用)", type=["pdf"]) |
|
|
|
|
|
|
|
|
if uploaded_file and st.session_state.vector_store is None: |
|
|
with st.status("正在学习文档...", expanded=True): |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: |
|
|
tmp.write(uploaded_file.read()) |
|
|
path = tmp.name |
|
|
|
|
|
|
|
|
loader = PyPDFLoader(path) |
|
|
docs = loader.load() |
|
|
|
|
|
|
|
|
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) |
|
|
splits = splitter.split_documents(docs) |
|
|
|
|
|
|
|
|
st.session_state.vector_store = FAISS.from_documents(splits, embeddings) |
|
|
st.success(f"已索引 {len(splits)} 个片段") |
|
|
|
|
|
|
|
|
os.remove(path) |
|
|
|
|
|
|
|
|
if st.button("🗑️ 清空记忆"): |
|
|
st.session_state.messages = [] |
|
|
st.session_state.vector_store = None |
|
|
st.rerun() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tavily_engine = TavilySearchResults(max_results=5) |
|
|
|
|
|
@tool |
|
|
def internet_search(query: str) -> str: |
|
|
"""Tavily 联网搜索工具函数,供 Agent 调用。""" |
|
|
print(f"🕵️ Tavily 搜索: {query}") |
|
|
try: |
|
|
results = tavily_engine.invoke({"query": query}) |
|
|
formatted = [] |
|
|
|
|
|
for i, res in enumerate(results): |
|
|
formatted.append(f"【来源】({res['url']}):\n{res['content']}") |
|
|
return "\n\n".join(formatted) |
|
|
except Exception as e: |
|
|
return f"Error: {e}" |
|
|
|
|
|
@tool |
|
|
def knowledge_base_search(query: str) -> str: |
|
|
"""知识库(RAG)搜索工具函数。""" |
|
|
|
|
|
if st.session_state.vector_store is None: return "用户未上传任何文档。" |
|
|
|
|
|
|
|
|
docs = st.session_state.vector_store.similarity_search(query, k=3) |
|
|
return "\n\n".join([f"【文档片段】: {d.page_content}" for d in docs]) |
|
|
|
|
|
|
|
|
tools = {"internet_search": internet_search, "knowledge_base_search": knowledge_base_search} |
|
|
|
|
|
def search_youtube(query): |
|
|
""" |
|
|
YouTube 搜索辅助函数 |
|
|
注意:这是独立功能,不作为 LLM 的 Tool,而是在 UI 层直接展示结果 |
|
|
""" |
|
|
try: |
|
|
|
|
|
return YoutubeSearch(query, max_results=3).to_dict() |
|
|
except: return [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_time(): return datetime.datetime.now().strftime("%Y年%m月%d日") |
|
|
|
|
|
|
|
|
|
|
|
intent_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", """ |
|
|
你是一个智能路由。当前时间:{current_date}。 |
|
|
|
|
|
【工具选择逻辑】: |
|
|
1. knowledge_base_search: 🔴 仅当用户明确提到“文档”、“PDF”、“上传的文件”等时使用。 |
|
|
2. internet_search: 🟢 默认选项(如果用户问知识性问题)。 |
|
|
3. CHAT: 仅用于纯打招呼、情感交流,不需要外部信息。 |
|
|
|
|
|
【Query生成规则】: |
|
|
- 强时效性问题(新闻、天气):必须在关键词中加 `{current_date}`。 |
|
|
- 弱时效性问题(歌曲、百科、人物):**禁止加日期**,直接用实体名。 |
|
|
|
|
|
返回 JSON 格式: {{ "intent": "CHAT" 或 "TOOL", "tool_name": "...", "tool_args": {{ "query": "..." }} }} |
|
|
"""), |
|
|
("user", "历史:\n{chat_history}\n\n输入:\n{input}") |
|
|
]) |
|
|
|
|
|
def parse_router(text): |
|
|
"""解析 Router LLM 返回的 JSON 字符串""" |
|
|
try: |
|
|
|
|
|
if "```" in text: text = re.search(r"```(?:json)?(.*?)```", text, re.DOTALL).group(1) |
|
|
return json.loads(text.strip()) |
|
|
except: |
|
|
|
|
|
return {"intent": "CHAT"} |
|
|
|
|
|
|
|
|
intent_chain = intent_prompt | llm_router | StrOutputParser() | RunnableLambda(parse_router) |
|
|
|
|
|
|
|
|
|
|
|
response_prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", """ |
|
|
你是一个严谨的信息整合助手。 |
|
|
请严格基于【搜索结果】回答。 |
|
|
1. 如果结果里没有,诚实说“未找到”。 |
|
|
2. 严禁捏造。 |
|
|
"""), |
|
|
("user", "问题: {user_input}\n\n搜索结果:\n{tool_result}") |
|
|
]) |
|
|
response_chain = response_prompt | llm_chat | StrOutputParser() |
|
|
|
|
|
|
|
|
chat_chain = ChatPromptTemplate.from_messages([ |
|
|
("system", "助手。"), ("user", "{input}") |
|
|
]) | llm_chat | StrOutputParser() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for msg in st.session_state.messages: |
|
|
with st.chat_message(msg["role"]): |
|
|
st.markdown(msg["content"]) |
|
|
|
|
|
|
|
|
if user_input := st.chat_input("问:Fruits Zipper 最火的歌..."): |
|
|
|
|
|
st.session_state.messages.append({"role": "user", "content": user_input}) |
|
|
with st.chat_message("user"): |
|
|
st.markdown(user_input) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
|
|
|
with st.status("🧠 思考中...", expanded=False) as status: |
|
|
|
|
|
hist = str(st.session_state.messages[:-1]) |
|
|
now = get_time() |
|
|
|
|
|
|
|
|
intent_res = intent_chain.invoke({"input": user_input, "chat_history": hist, "current_date": now}) |
|
|
st.json(intent_res) |
|
|
|
|
|
final_stream = None |
|
|
yt_query = None |
|
|
|
|
|
|
|
|
if intent_res.get("intent") == "TOOL": |
|
|
tool_name = intent_res.get("tool_name") |
|
|
query = intent_res.get("tool_args", {}).get("query", user_input) |
|
|
|
|
|
|
|
|
if tool_name == "internet_search": |
|
|
yt_query = query |
|
|
|
|
|
if tool_name in tools: |
|
|
try: |
|
|
|
|
|
tool_res = tools[tool_name].invoke(query) |
|
|
|
|
|
|
|
|
with st.expander("📄 查看搜索摘要"): |
|
|
st.text(tool_res) |
|
|
|
|
|
|
|
|
if "未找到有效信息" in tool_res or len(tool_res.strip()) < 50: |
|
|
final_stream = response_chain.stream({"user_input": user_input, "tool_result": "未找到相关信息。"}) |
|
|
else: |
|
|
|
|
|
final_stream = response_chain.stream({"user_input": user_input, "tool_result": tool_res}) |
|
|
except Exception as e: |
|
|
st.error(f"工具执行失败: {e}") |
|
|
else: |
|
|
st.error("未找到工具") |
|
|
else: |
|
|
|
|
|
final_stream = chat_chain.stream({"input": user_input}) |
|
|
|
|
|
|
|
|
status.update(label="完成", state="complete") |
|
|
|
|
|
|
|
|
if final_stream: |
|
|
full_response = st.write_stream(final_stream) |
|
|
st.session_state.messages.append({"role": "assistant", "content": full_response}) |
|
|
|
|
|
|
|
|
if yt_query: |
|
|
st.markdown("---") |
|
|
videos = search_youtube(yt_query) |
|
|
if videos: |
|
|
cols = st.columns(3) |
|
|
for i, v in enumerate(videos[:3]): |
|
|
with cols[i]: |
|
|
|
|
|
thumb = v['thumbnails'][0] if isinstance(v['thumbnails'], list) else v['thumbnails'] |
|
|
|
|
|
st.image(thumb, use_container_width=True) |
|
|
|
|
|
st.markdown(f"**[{v['title']}](https://www.youtube.com{v['url_suffix']})**") |
|
|
|
|
|
st.caption(f"👀 {v['views']}") |
|
|
else: |
|
|
st.caption("未找到相关视频。") |
|
|
|