Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| from huggingface_hub import login, InferenceClient | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # 页面配置 | |
| st.set_page_config( | |
| page_title="MiniMind 聊天机器人", | |
| page_icon="🤖", | |
| layout="centered", | |
| ) | |
| # 标题和说明 | |
| st.title("🤖 MiniMind 聊天机器人") | |
| st.markdown("这是基于MiniMind模型的聊天应用。输入你的问题,AI将为你解答!") | |
| # 设置边栏 | |
| with st.sidebar: | |
| st.header("模型设置") | |
| temperature = st.slider("温度", min_value=0.1, max_value=1.0, value=0.7, step=0.1, | |
| help="较高的值使输出更随机,较低的值使其更确定") | |
| max_tokens = st.slider("最大生成长度", min_value=64, max_value=1024, value=512, step=64) | |
| st.markdown("---") | |
| st.markdown("## 关于模型") | |
| st.markdown(""" | |
| MiniMind是一个轻量级语言模型,可以进行文本生成、问答和聊天等任务。 | |
| [查看模型主页](https://huggingface.co/xingyu1996/minimind) | |
| """) | |
| # 初始化聊天历史 | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| # 显示聊天历史 | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # 用户输入 | |
| if prompt := st.chat_input("输入你的问题..."): | |
| # 添加用户消息到历史 | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # 显示用户消息 | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # 显示助手消息占位符 | |
| with st.chat_message("assistant"): | |
| message_placeholder = st.empty() | |
| full_response = "" | |
| # 加载模型(两种方式,根据需要选择) | |
| try: | |
| with st.spinner("思考中..."): | |
| def load_model(): | |
| """加载模型和分词器(使用缓存避免重复加载)""" | |
| tokenizer = AutoTokenizer.from_pretrained("xingyu1996/minimind") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "xingyu1996/minimind", | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" | |
| ) | |
| return model, tokenizer | |
| # 加载模型和分词器 | |
| model, tokenizer = load_model() | |
| # 构建聊天历史 | |
| messages = [] | |
| for msg in st.session_state.messages: | |
| messages.append(msg) | |
| # 生成回复 | |
| prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device) | |
| # 逐步生成并显示输出 | |
| full_response = "" | |
| output_ids = model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=0.9, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # 解码生成的文本 | |
| full_response = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| # 显示最终响应 | |
| message_placeholder.markdown(full_response) | |
| except Exception as e: | |
| st.error(f"发生错误: {str(e)}") | |
| full_response = f"抱歉,生成回复时出错。错误信息: {str(e)}" | |
| message_placeholder.markdown(full_response) | |
| # 将助手回复添加到会话历史 | |
| st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
| # 添加重置按钮 | |
| if st.button("清空对话"): | |
| st.session_state.messages = [] | |
| st.experimental_rerun() | |
| # 页脚 | |
| st.markdown("---") | |
| st.markdown("Made with ❤️ by [xingyu1996](https://huggingface.co/xingyu1996)") | |