minimind-zero / app.py
cmz1024's picture
Update app.py
b98c2b1 verified
import gradio as gr
import torch
import warnings
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
import numpy as np
warnings.filterwarnings('ignore')
# 设置可复现的随机种子
def setup_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 加载模型和分词器
model_path = "cmz1024/minimind-zero" # 替换为你的模型路径
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
# 将模型移至GPU(如果可用)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device).eval()
print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
# 生成文本函数
def generate_text(prompt, max_length=512, temperature=0.85, top_p=0.85, history_cnt=0):
# 如果输入为空,则返回提示信息
if not prompt.strip():
return "请输入您的问题或指令..."
# 设置随机种子
setup_seed(random.randint(0, 2048))
# 处理历史对话
messages = []
if history_cnt > 0 and 'chat_history' in globals():
messages = chat_history[-history_cnt:] if len(chat_history) > 0 else []
# 添加当前用户输入
messages.append({"role": "user", "content": prompt})
# 应用聊天模板
new_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# 对输入进行编码
inputs = tokenizer(new_prompt, return_tensors="pt").to(device)
input_length = inputs["input_ids"].shape[1]
# 生成文本
with torch.no_grad():
outputs = model.generate(
inputs["input_ids"],
max_new_tokens=max_length,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# 只解码新生成的部分
generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
# 更新对话历史
if 'chat_history' in globals():
chat_history.append({"role": "user", "content": prompt})
chat_history.append({"role": "assistant", "content": generated_text})
return generated_text
# 初始化全局对话历史
chat_history = []
# 清除对话历史的函数
def clear_history():
global chat_history
chat_history = []
return "✅ 对话历史已清除"
# 自定义CSS样式
custom_css = """
:root {
--primary-color: #4F46E5;
--secondary-color: #6366F1;
--accent-color: #818CF8;
--background-color: #FFFFFF;
--text-color: #1F2937;
--card-bg: #F9FAFB;
--border-radius: 12px;
--shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
}
body {
background-color: var(--background-color);
color: var(--text-color);
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
.header-container {
background: #FFFFFF;
border-bottom: 2px solid #E5E7EB;
border-radius: var(--border-radius);
padding: 20px;
margin-bottom: 24px;
box-shadow: var(--shadow);
color: var(--text-color);
text-align: center;
}
.header-container h1 {
font-size: 2.5rem;
margin: 0;
font-weight: 700;
color: var(--primary-color);
}
.header-container p {
opacity: 0.9;
margin-top: 8px;
font-size: 1.1rem;
}
.logo-pulse {
display: inline-block;
margin-right: 12px;
animation: pulse 2s infinite;
}
@keyframes pulse {
0% { transform: scale(1); }
50% { transform: scale(1.1); }
100% { transform: scale(1); }
}
.gradio-container {
margin-top: 0 !important;
}
.main-card {
background-color: var(--card-bg);
border-radius: var(--border-radius);
box-shadow: var(--shadow);
padding: 24px;
margin-bottom: 24px;
}
.input-area textarea, .output-area textarea {
border-radius: 8px !important;
border: 1px solid #E5E7EB !important;
padding: 12px !important;
font-size: 1rem !important;
transition: all 0.3s ease;
}
.input-area textarea:focus, .output-area textarea:focus {
border-color: var(--primary-color) !important;
box-shadow: 0 0 0 2px rgba(79, 70, 229, 0.2) !important;
}
.button-primary {
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color)) !important;
border: none !important;
color: white !important;
padding: 10px 20px !important;
border-radius: 8px !important;
font-weight: 600 !important;
transition: all 0.3s ease !important;
box-shadow: 0 2px 4px rgba(79, 70, 229, 0.3) !important;
}
.button-primary:hover {
transform: translateY(-2px) !important;
box-shadow: 0 4px 8px rgba(79, 70, 229, 0.4) !important;
}
.button-secondary {
background-color: #F3F4F6 !important;
border: 1px solid #E5E7EB !important;
color: var(--text-color) !important;
padding: 10px 20px !important;
border-radius: 8px !important;
font-weight: 600 !important;
transition: all 0.3s ease !important;
}
.button-secondary:hover {
background-color: #E5E7EB !important;
transform: translateY(-2px) !important;
}
.footer {
text-align: center;
margin-top: 40px;
color: #6B7280;
font-size: 0.9rem;
}
.slider-container label span {
font-weight: 600 !important;
color: var(--text-color) !important;
}
.accordion {
border: 1px solid #E5E7EB !important;
border-radius: var(--border-radius) !important;
overflow: hidden !important;
}
.accordion-header {
background-color: #F9FAFB !important;
padding: 12px 16px !important;
font-weight: 600 !important;
}
.status-box {
background-color: #F3F4F6;
border-radius: 8px;
padding: 12px;
font-size: 0.9rem;
color: #4B5563;
}
.status-box.success {
background-color: #ECFDF5;
color: #065F46;
}
/* 响应式调整 */
@media (max-width: 768px) {
.header-container h1 {
font-size: 2rem;
}
.main-card {
padding: 16px;
}
}
"""
# 创建Gradio界面
with gr.Blocks(css=custom_css) as demo:
with gr.Column(elem_classes="container"):
with gr.Column(elem_classes="header-container"):
gr.HTML("""
<div>
<span class="logo-pulse">🧠</span>
<h1>MiniMind AI</h1>
<p>小参数模型</p>
</div>
""")
with gr.Column(elem_classes="main-card"):
with gr.Row():
with gr.Column(scale=1, elem_classes="input-area"):
input_text = gr.Textbox(
label="您的问题",
placeholder="请在此输入您的问题或指令...",
lines=5,
elem_id="input-box"
)
with gr.Row():
submit_btn = gr.Button("🚀 生成回答", elem_classes="button-primary")
clear_btn = gr.Button("🗑️ 清除历史", elem_classes="button-secondary")
with gr.Accordion("⚙️ 高级参数设置", open=False, elem_classes="accordion"):
max_length = gr.Slider(
minimum=10, maximum=2048, value=512, step=1,
label="最大生成长度",
info="控制生成文本的最大长度",
elem_classes="slider-container"
)
temperature = gr.Slider(
minimum=0.1, maximum=1.5, value=0.85, step=0.01,
label="温度系数",
info="较高的值会使输出更加随机,较低的值使输出更加确定",
elem_classes="slider-container"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.85, step=0.01,
label="Top-p 采样",
info="控制词汇选择的多样性",
elem_classes="slider-container"
)
history_cnt = gr.Slider(
minimum=0, maximum=10, value=0, step=2,
label="历史对话轮数",
info="考虑多少轮历史对话作为上下文",
elem_classes="slider-container"
)
with gr.Column(scale=1, elem_classes="output-area"):
output_text = gr.Textbox(
label="AI 回答",
lines=25,
elem_id="output-box"
)
clear_output = gr.Textbox(
label="状态",
value="系统就绪,等待您的输入...",
elem_classes="status-box"
)
with gr.Column(elem_classes="footer"):
gr.HTML("""
<div>
<p>MiniMind 模型演示 | 基于先进的自然语言处理技术</p>
<p>运行环境:<span id="device-info"></span></p>
</div>
<script>
document.getElementById('device-info').innerText =
navigator.userAgent.includes('Mobile') ? '移动设备' : '桌面设备';
// 添加输入动画效果
document.addEventListener('DOMContentLoaded', function() {
const inputBox = document.getElementById('input-box');
const outputBox = document.getElementById('output-box');
if(inputBox) {
inputBox.addEventListener('focus', function() {
this.style.transform = 'translateY(-2px)';
this.style.boxShadow = '0 4px 8px rgba(0, 0, 0, 0.1)';
});
inputBox.addEventListener('blur', function() {
this.style.transform = 'translateY(0)';
this.style.boxShadow = '0 1px 3px rgba(0, 0, 0, 0.1)';
});
}
});
</script>
""")
# 设置事件
submit_btn.click(
fn=generate_text,
inputs=[input_text, max_length, temperature, top_p, history_cnt],
outputs=output_text
)
# 添加回车键触发生成回答的功能
input_text.submit(
fn=generate_text,
inputs=[input_text, max_length, temperature, top_p, history_cnt],
outputs=output_text
)
clear_btn.click(
fn=clear_history,
inputs=[],
outputs=clear_output
)
# 启动应用
demo.launch()