cmz1024 commited on
Commit
436ce91
·
verified ·
1 Parent(s): 9cc318e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -31
app.py CHANGED
@@ -1,22 +1,56 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
4
 
5
- # 加载模型和tokenizer
6
- model_name = "cmz1024/minimind-zero" # 替换为你的模型路径
7
- tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
9
- model.eval()
10
 
11
- if torch.cuda.is_available():
12
- model = model.cuda()
 
 
 
 
 
 
 
13
 
14
- def generate_text(prompt, max_length=512, temperature=0.7, top_p=0.9):
15
- # 对输入进行编码
16
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- if torch.cuda.is_available():
19
- inputs = {k: v.cuda() for k, v in inputs.items()}
 
20
 
21
  # 生成文本
22
  with torch.no_grad():
@@ -26,26 +60,62 @@ def generate_text(prompt, max_length=512, temperature=0.7, top_p=0.9):
26
  temperature=temperature,
27
  top_p=top_p,
28
  pad_token_id=tokenizer.pad_token_id,
29
- eos_token_id=tokenizer.eos_token_id,
30
  )
31
 
32
- # 解码输出
33
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
34
  return generated_text
35
 
 
 
 
 
 
 
 
 
 
36
  # 创建Gradio界面
37
- demo = gr.Interface(
38
- fn=generate_text,
39
- inputs=[
40
- gr.Textbox(label="输入提示词", lines=3),
41
- gr.Slider(minimum=1, maximum=1024, value=512, label="最大生成长度"),
42
- gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="Temperature"),
43
- gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p"),
44
- ],
45
- outputs=gr.Textbox(label="生成结果", lines=10),
46
- title="MiniMind 文本生成",
47
- description="一个简单的文本生成demo"
48
- )
49
-
50
- if __name__ == "__main__":
51
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import torch
3
+ import warnings
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import random
6
+ import numpy as np
7
 
8
+ warnings.filterwarnings('ignore')
 
 
 
 
9
 
10
+ # 设置可复现的随机种子
11
+ def setup_seed(seed):
12
+ random.seed(seed)
13
+ np.random.seed(seed)
14
+ torch.manual_seed(seed)
15
+ torch.cuda.manual_seed(seed)
16
+ torch.cuda.manual_seed_all(seed)
17
+ torch.backends.cudnn.deterministic = True
18
+ torch.backends.cudnn.benchmark = False
19
 
20
+ # 加载模型和分词器
21
+ model_path = "cmz1024/minimind-zero" # 替换为你的模型路径
22
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
23
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
24
+
25
+ # 将模型移至GPU(如果可用)
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ model = model.to(device).eval()
28
+
29
+ print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.2f}M(illion)')
30
+
31
+ # 生成文本函数
32
+ def generate_text(prompt, max_length=512, temperature=0.85, top_p=0.85, history_cnt=0):
33
+ # 设置随机种子
34
+ setup_seed(random.randint(0, 2048))
35
+
36
+ # 处理历史对话
37
+ messages = []
38
+ if history_cnt > 0 and 'chat_history' in globals():
39
+ messages = chat_history[-history_cnt:] if len(chat_history) > 0 else []
40
+
41
+ # 添加当前用户输入
42
+ messages.append({"role": "user", "content": prompt})
43
+
44
+ # 应用聊天模板
45
+ new_prompt = tokenizer.apply_chat_template(
46
+ messages,
47
+ tokenize=False,
48
+ add_generation_prompt=True
49
+ )
50
 
51
+ # 对输入进行编码
52
+ inputs = tokenizer(new_prompt, return_tensors="pt").to(device)
53
+ input_length = inputs["input_ids"].shape[1]
54
 
55
  # 生成文本
56
  with torch.no_grad():
 
60
  temperature=temperature,
61
  top_p=top_p,
62
  pad_token_id=tokenizer.pad_token_id,
63
+ eos_token_id=tokenizer.eos_token_id
64
  )
65
 
66
+ # 解码新生成的部分
67
+ generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
68
+
69
+ # 更新对话历史
70
+ if 'chat_history' in globals():
71
+ chat_history.append({"role": "user", "content": prompt})
72
+ chat_history.append({"role": "assistant", "content": generated_text})
73
+
74
  return generated_text
75
 
76
+ # 初始化全局对话历史
77
+ chat_history = []
78
+
79
+ # 清除对话历史的函数
80
+ def clear_history():
81
+ global chat_history
82
+ chat_history = []
83
+ return "对话历史已清除"
84
+
85
  # 创建Gradio界面
86
+ with gr.Blocks() as demo:
87
+ gr.Markdown("# MiniMind 模型演示")
88
+
89
+ with gr.Row():
90
+ with gr.Column():
91
+ input_text = gr.Textbox(label="输入", placeholder="请输入您的问题...", lines=5)
92
+
93
+ with gr.Row():
94
+ submit_btn = gr.Button("提交")
95
+ clear_btn = gr.Button("清除历史")
96
+
97
+ with gr.Accordion("高级选项", open=False):
98
+ max_length = gr.Slider(minimum=10, maximum=2048, value=512, step=1, label="最大生成长度")
99
+ temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.85, step=0.01, label="温度")
100
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.85, step=0.01, label="Top-p")
101
+ history_cnt = gr.Slider(minimum=0, maximum=10, value=0, step=2, label="历史对话轮数")
102
+
103
+ with gr.Column():
104
+ output_text = gr.Textbox(label="输出", lines=25)
105
+ clear_output = gr.Textbox(label="状态", visible=True)
106
+
107
+ # 设置事件
108
+ submit_btn.click(
109
+ fn=generate_text,
110
+ inputs=[input_text, max_length, temperature, top_p, history_cnt],
111
+ outputs=output_text
112
+ )
113
+
114
+ clear_btn.click(
115
+ fn=clear_history,
116
+ inputs=[],
117
+ outputs=clear_output
118
+ )
119
+
120
+ # 启动应用
121
+ demo.launch()