Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import warnings | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import random | |
| import numpy as np | |
| warnings.filterwarnings('ignore') | |
| # 设置可复现的随机种子 | |
| def setup_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # 加载模型和分词器 | |
| model_path = "cmz1024/minimind-zero" # 替换为你的模型路径 | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) | |
| # 将模型移至GPU(如果可用) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device).eval() | |
| print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)') | |
| # 生成文本函数 | |
| def generate_text(prompt, max_length=512, temperature=0.85, top_p=0.85, history_cnt=0): | |
| # 如果输入为空,则返回提示信息 | |
| if not prompt.strip(): | |
| return "请输入您的问题或指令..." | |
| # 设置随机种子 | |
| setup_seed(random.randint(0, 2048)) | |
| # 处理历史对话 | |
| messages = [] | |
| if history_cnt > 0 and 'chat_history' in globals(): | |
| messages = chat_history[-history_cnt:] if len(chat_history) > 0 else [] | |
| # 添加当前用户输入 | |
| messages.append({"role": "user", "content": prompt}) | |
| # 应用聊天模板 | |
| new_prompt = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # 对输入进行编码 | |
| inputs = tokenizer(new_prompt, return_tensors="pt").to(device) | |
| input_length = inputs["input_ids"].shape[1] | |
| # 生成文本 | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| inputs["input_ids"], | |
| max_new_tokens=max_length, | |
| temperature=temperature, | |
| top_p=top_p, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # 只解码新生成的部分 | |
| generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) | |
| # 更新对话历史 | |
| if 'chat_history' in globals(): | |
| chat_history.append({"role": "user", "content": prompt}) | |
| chat_history.append({"role": "assistant", "content": generated_text}) | |
| return generated_text | |
| # 初始化全局对话历史 | |
| chat_history = [] | |
| # 清除对话历史的函数 | |
| def clear_history(): | |
| global chat_history | |
| chat_history = [] | |
| return "✅ 对话历史已清除" | |
| # 自定义CSS样式 | |
| custom_css = """ | |
| :root { | |
| --primary-color: #4F46E5; | |
| --secondary-color: #6366F1; | |
| --accent-color: #818CF8; | |
| --background-color: #FFFFFF; | |
| --text-color: #1F2937; | |
| --card-bg: #F9FAFB; | |
| --border-radius: 12px; | |
| --shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); | |
| } | |
| body { | |
| background-color: var(--background-color); | |
| color: var(--text-color); | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; | |
| } | |
| .container { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| .header-container { | |
| background: #FFFFFF; | |
| border-bottom: 2px solid #E5E7EB; | |
| border-radius: var(--border-radius); | |
| padding: 20px; | |
| margin-bottom: 24px; | |
| box-shadow: var(--shadow); | |
| color: var(--text-color); | |
| text-align: center; | |
| } | |
| .header-container h1 { | |
| font-size: 2.5rem; | |
| margin: 0; | |
| font-weight: 700; | |
| color: var(--primary-color); | |
| } | |
| .header-container p { | |
| opacity: 0.9; | |
| margin-top: 8px; | |
| font-size: 1.1rem; | |
| } | |
| .logo-pulse { | |
| display: inline-block; | |
| margin-right: 12px; | |
| animation: pulse 2s infinite; | |
| } | |
| @keyframes pulse { | |
| 0% { transform: scale(1); } | |
| 50% { transform: scale(1.1); } | |
| 100% { transform: scale(1); } | |
| } | |
| .gradio-container { | |
| margin-top: 0 !important; | |
| } | |
| .main-card { | |
| background-color: var(--card-bg); | |
| border-radius: var(--border-radius); | |
| box-shadow: var(--shadow); | |
| padding: 24px; | |
| margin-bottom: 24px; | |
| } | |
| .input-area textarea, .output-area textarea { | |
| border-radius: 8px !important; | |
| border: 1px solid #E5E7EB !important; | |
| padding: 12px !important; | |
| font-size: 1rem !important; | |
| transition: all 0.3s ease; | |
| } | |
| .input-area textarea:focus, .output-area textarea:focus { | |
| border-color: var(--primary-color) !important; | |
| box-shadow: 0 0 0 2px rgba(79, 70, 229, 0.2) !important; | |
| } | |
| .button-primary { | |
| background: linear-gradient(135deg, var(--primary-color), var(--secondary-color)) !important; | |
| border: none !important; | |
| color: white !important; | |
| padding: 10px 20px !important; | |
| border-radius: 8px !important; | |
| font-weight: 600 !important; | |
| transition: all 0.3s ease !important; | |
| box-shadow: 0 2px 4px rgba(79, 70, 229, 0.3) !important; | |
| } | |
| .button-primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 4px 8px rgba(79, 70, 229, 0.4) !important; | |
| } | |
| .button-secondary { | |
| background-color: #F3F4F6 !important; | |
| border: 1px solid #E5E7EB !important; | |
| color: var(--text-color) !important; | |
| padding: 10px 20px !important; | |
| border-radius: 8px !important; | |
| font-weight: 600 !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .button-secondary:hover { | |
| background-color: #E5E7EB !important; | |
| transform: translateY(-2px) !important; | |
| } | |
| .footer { | |
| text-align: center; | |
| margin-top: 40px; | |
| color: #6B7280; | |
| font-size: 0.9rem; | |
| } | |
| .slider-container label span { | |
| font-weight: 600 !important; | |
| color: var(--text-color) !important; | |
| } | |
| .accordion { | |
| border: 1px solid #E5E7EB !important; | |
| border-radius: var(--border-radius) !important; | |
| overflow: hidden !important; | |
| } | |
| .accordion-header { | |
| background-color: #F9FAFB !important; | |
| padding: 12px 16px !important; | |
| font-weight: 600 !important; | |
| } | |
| .status-box { | |
| background-color: #F3F4F6; | |
| border-radius: 8px; | |
| padding: 12px; | |
| font-size: 0.9rem; | |
| color: #4B5563; | |
| } | |
| .status-box.success { | |
| background-color: #ECFDF5; | |
| color: #065F46; | |
| } | |
| /* 响应式调整 */ | |
| @media (max-width: 768px) { | |
| .header-container h1 { | |
| font-size: 2rem; | |
| } | |
| .main-card { | |
| padding: 16px; | |
| } | |
| } | |
| """ | |
| # 创建Gradio界面 | |
| with gr.Blocks(css=custom_css) as demo: | |
| with gr.Column(elem_classes="container"): | |
| with gr.Column(elem_classes="header-container"): | |
| gr.HTML(""" | |
| <div> | |
| <span class="logo-pulse">🧠</span> | |
| <h1>MiniMind AI</h1> | |
| <p>小参数模型</p> | |
| </div> | |
| """) | |
| with gr.Column(elem_classes="main-card"): | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_classes="input-area"): | |
| input_text = gr.Textbox( | |
| label="您的问题", | |
| placeholder="请在此输入您的问题或指令...", | |
| lines=5, | |
| elem_id="input-box" | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("🚀 生成回答", elem_classes="button-primary") | |
| clear_btn = gr.Button("🗑️ 清除历史", elem_classes="button-secondary") | |
| with gr.Accordion("⚙️ 高级参数设置", open=False, elem_classes="accordion"): | |
| max_length = gr.Slider( | |
| minimum=10, maximum=2048, value=512, step=1, | |
| label="最大生成长度", | |
| info="控制生成文本的最大长度", | |
| elem_classes="slider-container" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, maximum=1.5, value=0.85, step=0.01, | |
| label="温度系数", | |
| info="较高的值会使输出更加随机,较低的值使输出更加确定", | |
| elem_classes="slider-container" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, maximum=1.0, value=0.85, step=0.01, | |
| label="Top-p 采样", | |
| info="控制词汇选择的多样性", | |
| elem_classes="slider-container" | |
| ) | |
| history_cnt = gr.Slider( | |
| minimum=0, maximum=10, value=0, step=2, | |
| label="历史对话轮数", | |
| info="考虑多少轮历史对话作为上下文", | |
| elem_classes="slider-container" | |
| ) | |
| with gr.Column(scale=1, elem_classes="output-area"): | |
| output_text = gr.Textbox( | |
| label="AI 回答", | |
| lines=25, | |
| elem_id="output-box" | |
| ) | |
| clear_output = gr.Textbox( | |
| label="状态", | |
| value="系统就绪,等待您的输入...", | |
| elem_classes="status-box" | |
| ) | |
| with gr.Column(elem_classes="footer"): | |
| gr.HTML(""" | |
| <div> | |
| <p>MiniMind 模型演示 | 基于先进的自然语言处理技术</p> | |
| <p>运行环境:<span id="device-info"></span></p> | |
| </div> | |
| <script> | |
| document.getElementById('device-info').innerText = | |
| navigator.userAgent.includes('Mobile') ? '移动设备' : '桌面设备'; | |
| // 添加输入动画效果 | |
| document.addEventListener('DOMContentLoaded', function() { | |
| const inputBox = document.getElementById('input-box'); | |
| const outputBox = document.getElementById('output-box'); | |
| if(inputBox) { | |
| inputBox.addEventListener('focus', function() { | |
| this.style.transform = 'translateY(-2px)'; | |
| this.style.boxShadow = '0 4px 8px rgba(0, 0, 0, 0.1)'; | |
| }); | |
| inputBox.addEventListener('blur', function() { | |
| this.style.transform = 'translateY(0)'; | |
| this.style.boxShadow = '0 1px 3px rgba(0, 0, 0, 0.1)'; | |
| }); | |
| } | |
| }); | |
| </script> | |
| """) | |
| # 设置事件 | |
| submit_btn.click( | |
| fn=generate_text, | |
| inputs=[input_text, max_length, temperature, top_p, history_cnt], | |
| outputs=output_text | |
| ) | |
| # 添加回车键触发生成回答的功能 | |
| input_text.submit( | |
| fn=generate_text, | |
| inputs=[input_text, max_length, temperature, top_p, history_cnt], | |
| outputs=output_text | |
| ) | |
| clear_btn.click( | |
| fn=clear_history, | |
| inputs=[], | |
| outputs=clear_output | |
| ) | |
| # 启动应用 | |
| demo.launch() | |