hellokawei commited on
Commit
e27dbdf
·
verified ·
1 Parent(s): 17a3c9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -111
app.py CHANGED
@@ -4,22 +4,22 @@ import gradio as gr
4
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
5
  from typing import List, Tuple, Dict
6
 
7
- # 如果需要使用Hugging Face访问令牌,取消下面两行的注释并设置环境变量
8
- # from huggingface_hub import login
9
- # login(token=os.getenv("HUGGINGFACE_TOKEN"))
10
-
11
- # 模型配置 - 可根据需要添加更多模型
12
  MODELS = {
13
- "Llama 2 7B Chat": {
14
- "model_id": "meta-llama/Llama-2-7b-chat-hf",
15
  "kwargs": {"torch_dtype": torch.float16}
16
  },
17
  "Mistral 7B Instruct": {
18
  "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
19
  "kwargs": {"torch_dtype": torch.float16}
20
  },
21
- "Zephyr 7B Beta": {
22
- "model_id": "HuggingFaceH4/zephyr-7b-beta",
 
 
 
 
23
  "kwargs": {"torch_dtype": torch.float16}
24
  }
25
  }
@@ -29,8 +29,8 @@ def load_model(model_name):
29
  model_config = MODELS[model_name]
30
  tokenizer = AutoTokenizer.from_pretrained(model_config["model_id"])
31
 
32
- # 检查模型是否需要特殊处理
33
- if "Llama-2" in model_name:
34
  model_config["kwargs"]["trust_remote_code"] = True
35
 
36
  model = AutoModelForCausalLM.from_pretrained(
@@ -38,7 +38,7 @@ def load_model(model_name):
38
  **model_config["kwargs"]
39
  )
40
 
41
- # 将模型移动到可用设备
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
43
  model = model.to(device)
44
 
@@ -49,18 +49,31 @@ loaded_models = {}
49
  for model_name in MODELS:
50
  loaded_models[model_name] = load_model(model_name)
51
 
52
- # 构建对话提示词(针对Llama 2等需要特定格式的模型)
53
- def build_prompt(message, history, system_prompt):
54
- prompt = f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n"
55
-
56
- # 添加对话历史
57
- for user_msg, assistant_msg in history:
58
- prompt += f"{user_msg} [/INST] {assistant_msg} [INST] "
59
-
60
- # 添加当前用户消息
61
- prompt += f"{message} [/INST]"
62
-
63
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # 模型推理函数
66
  def generate_response(
@@ -73,11 +86,10 @@ def generate_response(
73
  top_p: float,
74
  top_k: int
75
  ):
76
- # 获取模型、分词器和设备
77
  model, tokenizer, device = loaded_models[model_name]
78
 
79
- # 构建完整提示词
80
- full_prompt = build_prompt(message, history, system_prompt)
81
 
82
  # 编码输入
83
  inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
@@ -89,8 +101,8 @@ def generate_response(
89
  "top_p": top_p,
90
  "top_k": top_k,
91
  "do_sample": True,
92
- "eos_token_id": tokenizer.eos_token_id,
93
- "pad_token_id": tokenizer.pad_token_id
94
  }
95
 
96
  # 生成响应
@@ -103,12 +115,12 @@ def generate_response(
103
  # 解码输出
104
  response = tokenizer.decode(output[0], skip_special_tokens=True)
105
 
106
- # 提取模型生成的部分(去除提示词)
107
  response = response[len(full_prompt):].strip()
108
 
109
  return response
110
 
111
- # 处理用户输入并生成回复
112
  def process_chat(
113
  message: str,
114
  history: List[Tuple[str, str]],
@@ -119,138 +131,79 @@ def process_chat(
119
  top_p: float,
120
  top_k: int
121
  ):
122
- # 生成响应
123
  response = generate_response(
124
  message, history, system_prompt, model_name,
125
  max_new_tokens, temperature, top_p, top_k
126
  )
127
-
128
- # 更新对话历史
129
  history.append((message, response))
130
  return history, history
131
 
132
- # 语音转文字功能(使用Whisper模型)
133
  asr = None
134
  if torch.cuda.is_available() or torch.backends.mps.is_available():
135
  try:
136
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
137
  processor = WhisperProcessor.from_pretrained("openai/whisper-base")
138
  asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to("cuda" if torch.cuda.is_available() else "cpu")
139
- asr = {
140
- "processor": processor,
141
- "model": asr_model
142
- }
143
- except Exception as e:
144
- print(f"语音识别模型加载失败: {e}")
145
  asr = None
146
 
147
  def transcribe(audio):
148
  if asr is None:
149
  return "语音识别模型未加载"
150
-
151
  processor, model = asr["processor"], asr["model"]
152
  input_features = processor(audio, return_tensors="pt").input_features.to(model.device)
153
  predicted_ids = model.generate(input_features)
154
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
155
- return transcription
156
 
157
  # 构建Gradio界面
158
- with gr.Blocks(title="语言模型对话助手") as demo:
159
- gr.Markdown("## 基于Transformer的语言模型对话应用")
160
 
161
  with gr.Row():
162
  with gr.Column(scale=1):
163
- # 输入区域
164
- message_input = gr.Textbox(
165
- label="输入消息",
166
- placeholder="请输入您想与AI对话的内容..."
167
- )
168
-
169
- # 系统提示词
170
  system_prompt = gr.Textbox(
171
  label="系统提示词",
172
  value="你是一个 helpful、知识渊博的AI助手。",
173
- placeholder="设置AI的角色和行为准则..."
174
  )
175
-
176
- # 模型选择
177
  model_choice = gr.Dropdown(
178
  choices=list(MODELS.keys()),
179
  value=list(MODELS.keys())[0],
180
  label="选择语言模型"
181
  )
182
-
183
- # 生成参数
184
- with gr.Accordion("高级生成参数", open=False):
185
- max_new_tokens = gr.Slider(
186
- minimum=1, maximum=2048, value=512, step=1,
187
- label="最大生成Token数"
188
- )
189
- temperature = gr.Slider(
190
- minimum=0.1, maximum=2.0, value=0.7, step=0.1,
191
- label="温度(随机性)"
192
- )
193
- top_p = gr.Slider(
194
- minimum=0.1, maximum=1.0, value=0.9, step=0.05,
195
- label="Top-p(核采样)"
196
- )
197
- top_k = gr.Slider(
198
- minimum=1, maximum=100, value=50, step=1,
199
- label="Top-k(采样数)"
200
- )
201
-
202
- # 语音输入
203
  use_voice = gr.Checkbox(label="使用语音输入")
204
- audio_input = gr.Audio(
205
- type="filepath",
206
- label="语音输入(录制或上传音频)"
207
- )
208
-
209
- # 按钮
210
  send_btn = gr.Button("发送消息", variant="primary")
211
  clear_btn = gr.Button("清空对话")
212
 
213
  with gr.Column(scale=2):
214
- # 对话历史
215
- chat_history = gr.Chatbot(
216
- label="对话历史",
217
- show_label=True
218
- )
219
 
220
  # 语音输入处理
221
- def handle_voice(audio, use_voice):
222
- if use_voice and audio:
223
- return transcribe(audio)
224
- return ""
225
-
226
  audio_input.change(
227
- fn=handle_voice,
228
  inputs=[audio_input, use_voice],
229
  outputs=message_input
230
  )
231
 
232
- # 发送消息处理
233
  send_btn.click(
234
  fn=process_chat,
235
- inputs=[
236
- message_input, chat_history, system_prompt, model_choice,
237
- max_new_tokens, temperature, top_p, top_k
238
- ],
239
- outputs=[chat_history, chat_history],
240
- show_progress=True
241
  )
242
 
243
  # 清空对话
244
- clear_btn.click(
245
- fn=lambda: None,
246
- inputs=None,
247
- outputs=chat_history
248
- )
249
 
250
  # 启动应用
251
  if __name__ == "__main__":
252
- demo.launch(
253
- server_name="0.0.0.0",
254
- server_port=7860,
255
- share=True
256
- )
 
4
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
5
  from typing import List, Tuple, Dict
6
 
7
+ # 模型配置 - 全部使用无访问限制的公开模型
 
 
 
 
8
  MODELS = {
9
+ "Zephyr 7B Beta": {
10
+ "model_id": "HuggingFaceH4/zephyr-7b-beta",
11
  "kwargs": {"torch_dtype": torch.float16}
12
  },
13
  "Mistral 7B Instruct": {
14
  "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
15
  "kwargs": {"torch_dtype": torch.float16}
16
  },
17
+ "OpenHermes 2.5": {
18
+ "model_id": "teknium/OpenHermes-2.5-Mistral-7B",
19
+ "kwargs": {"torch_dtype": torch.float16}
20
+ },
21
+ "Falcon 7B Instruct": {
22
+ "model_id": "tiiuae/falcon-7b-instruct",
23
  "kwargs": {"torch_dtype": torch.float16}
24
  }
25
  }
 
29
  model_config = MODELS[model_name]
30
  tokenizer = AutoTokenizer.from_pretrained(model_config["model_id"])
31
 
32
+ # 处理特殊模型参数(如需要)
33
+ if "Falcon" in model_name:
34
  model_config["kwargs"]["trust_remote_code"] = True
35
 
36
  model = AutoModelForCausalLM.from_pretrained(
 
38
  **model_config["kwargs"]
39
  )
40
 
41
+ # 移动到可用设备
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
43
  model = model.to(device)
44
 
 
49
  for model_name in MODELS:
50
  loaded_models[model_name] = load_model(model_name)
51
 
52
+ # 构建对话提示词(针对不同模型可能需要不同格式)
53
+ def build_prompt(message, history, system_prompt, model_name):
54
+ # Zephyr/Mistral等模型使用简单格式
55
+ if "Zephyr" in model_name or "Mistral" in model_name:
56
+ prompt = f"系统提示: {system_prompt}\n"
57
+ for user_msg, assistant_msg in history:
58
+ prompt += f"用户: {user_msg}\n助手: {assistant_msg}\n"
59
+ prompt += f"用户: {message}\n助手:"
60
+ return prompt
61
+
62
+ # Falcon模型使用更简洁的格式
63
+ elif "Falcon" in model_name:
64
+ prompt = f"### System:\n{system_prompt}\n\n"
65
+ for user_msg, assistant_msg in history:
66
+ prompt += f"### User:\n{user_msg}\n\n### Assistant:\n{assistant_msg}\n\n"
67
+ prompt += f"### User:\n{message}\n\n### Assistant:"
68
+ return prompt
69
+
70
+ # 默认为通用格式
71
+ else:
72
+ prompt = f"[System] {system_prompt}\n"
73
+ for user_msg, assistant_msg in history:
74
+ prompt += f"[User] {user_msg}\n[Assistant] {assistant_msg}\n"
75
+ prompt += f"[User] {message}\n[Assistant]"
76
+ return prompt
77
 
78
  # 模型推理函数
79
  def generate_response(
 
86
  top_p: float,
87
  top_k: int
88
  ):
 
89
  model, tokenizer, device = loaded_models[model_name]
90
 
91
+ # 构建提示词
92
+ full_prompt = build_prompt(message, history, system_prompt, model_name)
93
 
94
  # 编码输入
95
  inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
 
101
  "top_p": top_p,
102
  "top_k": top_k,
103
  "do_sample": True,
104
+ "eos_token_id": tokenizer.eos_token_id or tokenizer.unk_token_id,
105
+ "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id
106
  }
107
 
108
  # 生成响应
 
115
  # 解码输出
116
  response = tokenizer.decode(output[0], skip_special_tokens=True)
117
 
118
+ # 提取模型生成的部分
119
  response = response[len(full_prompt):].strip()
120
 
121
  return response
122
 
123
+ # 处理用户输入
124
  def process_chat(
125
  message: str,
126
  history: List[Tuple[str, str]],
 
131
  top_p: float,
132
  top_k: int
133
  ):
 
134
  response = generate_response(
135
  message, history, system_prompt, model_name,
136
  max_new_tokens, temperature, top_p, top_k
137
  )
 
 
138
  history.append((message, response))
139
  return history, history
140
 
141
+ # 语音转文字功能
142
  asr = None
143
  if torch.cuda.is_available() or torch.backends.mps.is_available():
144
  try:
145
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
146
  processor = WhisperProcessor.from_pretrained("openai/whisper-base")
147
  asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to("cuda" if torch.cuda.is_available() else "cpu")
148
+ asr = {"processor": processor, "model": asr_model}
149
+ except:
 
 
 
 
150
  asr = None
151
 
152
  def transcribe(audio):
153
  if asr is None:
154
  return "语音识别模型未加载"
 
155
  processor, model = asr["processor"], asr["model"]
156
  input_features = processor(audio, return_tensors="pt").input_features.to(model.device)
157
  predicted_ids = model.generate(input_features)
158
+ return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
 
159
 
160
  # 构建Gradio界面
161
+ with gr.Blocks(title="无权限语言模型对话助手") as demo:
162
+ gr.Markdown("## 公开语言模型对话应用(无需访问权限)")
163
 
164
  with gr.Row():
165
  with gr.Column(scale=1):
166
+ message_input = gr.Textbox(label="输入消息")
 
 
 
 
 
 
167
  system_prompt = gr.Textbox(
168
  label="系统提示词",
169
  value="你是一个 helpful、知识渊博的AI助手。",
 
170
  )
 
 
171
  model_choice = gr.Dropdown(
172
  choices=list(MODELS.keys()),
173
  value=list(MODELS.keys())[0],
174
  label="选择语言模型"
175
  )
176
+ with gr.Accordion("生成参数", open=False):
177
+ max_new_tokens = gr.Slider(minimum=1, maximum=2048, value=512, label="最大Token数")
178
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="随机性")
179
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p采样")
180
+ top_k = gr.Slider(minimum=1, maximum=100, value=50, label="Top-k采样")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  use_voice = gr.Checkbox(label="使用语音输入")
182
+ audio_input = gr.Audio(type="filepath", label="语音输入")
 
 
 
 
 
183
  send_btn = gr.Button("发送消息", variant="primary")
184
  clear_btn = gr.Button("清空对话")
185
 
186
  with gr.Column(scale=2):
187
+ chat_history = gr.Chatbot(label="对话历史")
 
 
 
 
188
 
189
  # 语音输入处理
 
 
 
 
 
190
  audio_input.change(
191
+ fn=lambda audio, use: transcribe(audio) if use else "",
192
  inputs=[audio_input, use_voice],
193
  outputs=message_input
194
  )
195
 
196
+ # 发送消息
197
  send_btn.click(
198
  fn=process_chat,
199
+ inputs=[message_input, chat_history, system_prompt, model_choice,
200
+ max_new_tokens, temperature, top_p, top_k],
201
+ outputs=[chat_history, chat_history]
 
 
 
202
  )
203
 
204
  # 清空对话
205
+ clear_btn.click(fn=lambda: None, outputs=chat_history)
 
 
 
 
206
 
207
  # 启动应用
208
  if __name__ == "__main__":
209
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)