Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import json | |
| import os | |
| from typing import List, Dict, Any | |
| import pandas as pd | |
| from datetime import datetime | |
| # 页面配置 | |
| st.set_page_config( | |
| page_title="知识库大模型", | |
| page_icon="🤖", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # API配置 | |
| API_BASE_URL = "http://localhost:8000" | |
| def check_api_health(): | |
| """检查API服务状态""" | |
| try: | |
| response = requests.get(f"{API_BASE_URL}/health", timeout=5) | |
| return response.status_code == 200 | |
| except: | |
| return False | |
| def upload_files(files): | |
| """上传文件到API""" | |
| try: | |
| files_data = [] | |
| for file in files: | |
| files_data.append(('files', (file.name, file.getvalue(), file.type))) | |
| response = requests.post(f"{API_BASE_URL}/upload", files=files_data) | |
| return response.json() if response.status_code == 200 else None | |
| except Exception as e: | |
| st.error(f"上传失败: {str(e)}") | |
| return None | |
| def ask_question(question: str): | |
| """发送问题到API""" | |
| try: | |
| response = requests.post( | |
| f"{API_BASE_URL}/ask", | |
| json={"question": question} | |
| ) | |
| return response.json() if response.status_code == 200 else None | |
| except Exception as e: | |
| st.error(f"提问失败: {str(e)}") | |
| return None | |
| def search_documents(query: str, k: int = 4): | |
| """搜索文档""" | |
| try: | |
| response = requests.get(f"{API_BASE_URL}/search", params={"query": query, "k": k}) | |
| return response.json() if response.status_code == 200 else None | |
| except Exception as e: | |
| st.error(f"搜索失败: {str(e)}") | |
| return None | |
| def get_chat_history(): | |
| """获取对话历史""" | |
| try: | |
| response = requests.get(f"{API_BASE_URL}/chat-history") | |
| return response.json() if response.status_code == 200 else None | |
| except Exception as e: | |
| st.error(f"获取对话历史失败: {str(e)}") | |
| return None | |
| def clear_chat_history(): | |
| """清除对话历史""" | |
| try: | |
| response = requests.delete(f"{API_BASE_URL}/chat-history") | |
| return response.status_code == 200 | |
| except Exception as e: | |
| st.error(f"清除对话历史失败: {str(e)}") | |
| return False | |
| def get_stats(): | |
| """获取系统统计信息""" | |
| try: | |
| response = requests.get(f"{API_BASE_URL}/stats") | |
| return response.json() if response.status_code == 200 else None | |
| except Exception as e: | |
| st.error(f"获取统计信息失败: {str(e)}") | |
| return None | |
| def reset_knowledge_base(): | |
| """重置知识库""" | |
| try: | |
| response = requests.delete(f"{API_BASE_URL}/reset") | |
| return response.status_code == 200 | |
| except Exception as e: | |
| st.error(f"重置知识库失败: {str(e)}") | |
| return False | |
| # 主界面 | |
| def main(): | |
| st.title("🤖 知识库大模型系统") | |
| st.markdown("---") | |
| # 检查API状态 | |
| if not check_api_health(): | |
| st.error("⚠️ API服务未运行,请先启动后端服务") | |
| st.code("python api.py") | |
| return | |
| # 侧边栏 | |
| with st.sidebar: | |
| st.header("📊 系统状态") | |
| # 获取统计信息 | |
| stats = get_stats() | |
| if stats: | |
| st.metric("文档总数", stats.get("total_documents", 0)) | |
| st.info(f"嵌入模型: {stats.get('embedding_model', 'N/A')}") | |
| st.markdown("---") | |
| # 操作按钮 | |
| st.header("🔧 系统操作") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("🗑️ 清除对话历史"): | |
| if clear_chat_history(): | |
| st.success("对话历史已清除") | |
| st.rerun() | |
| with col2: | |
| if st.button("🔄 重置知识库"): | |
| if st.checkbox("确认重置知识库?"): | |
| if reset_knowledge_base(): | |
| st.success("知识库已重置") | |
| st.rerun() | |
| # 主内容区域 | |
| tab1, tab2, tab3, tab4 = st.tabs(["💬 智能问答", "📁 文档管理", "🔍 文档搜索", "📋 对话历史"]) | |
| with tab1: | |
| st.header("💬 智能问答") | |
| # 问题输入 | |
| question = st.text_area("请输入您的问题:", height=100, placeholder="例如:请介绍一下人工智能的发展历史...") | |
| col1, col2 = st.columns([1, 4]) | |
| with col1: | |
| if st.button("🚀 提问", type="primary"): | |
| if question.strip(): | |
| with st.spinner("正在思考中..."): | |
| result = ask_question(question) | |
| if result: | |
| st.session_state.current_answer = result | |
| st.success("回答完成!") | |
| st.rerun() | |
| else: | |
| st.warning("请输入问题") | |
| # 显示回答 | |
| if 'current_answer' in st.session_state: | |
| result = st.session_state.current_answer | |
| st.markdown("### 🤖 AI回答") | |
| st.markdown(result.get("answer", "")) | |
| # 显示来源 | |
| sources = result.get("sources", []) | |
| if sources: | |
| st.markdown("### 📚 参考来源") | |
| for i, source in enumerate(sources, 1): | |
| with st.expander(f"来源 {i}: {source.get('file_name', '未知文件')}"): | |
| st.text(source.get("content", "")) | |
| st.caption(f"文件路径: {source.get('source', '未知')}") | |
| with tab2: | |
| st.header("📁 文档管理") | |
| # 文件上传 | |
| st.subheader("上传文档") | |
| uploaded_files = st.file_uploader( | |
| "选择要上传的文档", | |
| type=['txt', 'pdf', 'docx', 'md'], | |
| accept_multiple_files=True, | |
| help="支持的文件格式:TXT, PDF, DOCX, MD" | |
| ) | |
| if uploaded_files: | |
| if st.button("📤 上传文档", type="primary"): | |
| with st.spinner("正在上传和处理文档..."): | |
| result = upload_files(uploaded_files) | |
| if result: | |
| st.success(f"✅ {result.get('message', '上传成功')}") | |
| st.info(f"处理了 {len(result.get('processed_files', []))} 个文件,生成了 {result.get('total_chunks', 0)} 个文档块") | |
| st.rerun() | |
| # 目录上传 | |
| st.subheader("批量上传目录") | |
| directory_path = st.text_input("输入目录路径:", placeholder="/path/to/documents") | |
| if directory_path and st.button("📁 上传目录"): | |
| with st.spinner("正在处理目录..."): | |
| try: | |
| response = requests.post(f"{API_BASE_URL}/upload-directory", json={"directory_path": directory_path}) | |
| if response.status_code == 200: | |
| result = response.json() | |
| st.success(f"✅ {result.get('message', '目录上传成功')}") | |
| st.info(f"生成了 {result.get('total_chunks', 0)} 个文档块") | |
| st.rerun() | |
| else: | |
| st.error("目录上传失败") | |
| except Exception as e: | |
| st.error(f"目录上传失败: {str(e)}") | |
| with tab3: | |
| st.header("🔍 文档搜索") | |
| # 搜索输入 | |
| search_query = st.text_input("输入搜索关键词:", placeholder="搜索相关文档...") | |
| k_results = st.slider("返回结果数量:", min_value=1, max_value=10, value=4) | |
| if st.button("🔍 搜索", type="primary"): | |
| if search_query.strip(): | |
| with st.spinner("正在搜索..."): | |
| result = search_documents(search_query, k_results) | |
| if result: | |
| st.success(f"找到 {len(result.get('results', []))} 个相关文档") | |
| # 显示搜索结果 | |
| for i, doc in enumerate(result.get('results', []), 1): | |
| with st.expander(f"文档 {i} (相似度: {doc.get('score', 0):.3f})"): | |
| st.markdown(f"**文件名:** {doc.get('file_name', '未知')}") | |
| st.markdown(f"**来源:** {doc.get('source', '未知')}") | |
| st.markdown("**内容:**") | |
| st.text(doc.get('content', '')) | |
| else: | |
| st.warning("请输入搜索关键词") | |
| with tab4: | |
| st.header("📋 对话历史") | |
| # 获取对话历史 | |
| history = get_chat_history() | |
| if history and history.get('history'): | |
| for i, chat in enumerate(history['history'], 1): | |
| with st.expander(f"对话 {i} - {datetime.now().strftime('%H:%M:%S')}"): | |
| st.markdown("**问题:**") | |
| st.text(chat.get('question', '')) | |
| st.markdown("**回答:**") | |
| st.text(chat.get('answer', '')) | |
| else: | |
| st.info("暂无对话历史") | |
| if __name__ == "__main__": | |
| main() |