hellokawei commited on
Commit
17a3c9b
·
verified ·
1 Parent(s): f2d0fc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -127
app.py CHANGED
@@ -1,142 +1,256 @@
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient, HfHubHTTPError, InferenceTimeoutError
3
- import httpx # 确保这个库在requirements.txt中
4
 
5
- """
6
- For more information on `huggingface_hub` Inference API support, please check the docs:
7
- https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
8
- """
9
 
10
- # 定义多个模型及其对应的InferenceClient
11
- # !!重要!! 如果需要,请替换 'hf_YOUR_TOKEN_HERE' 为您的Hugging Face API Token。
12
- # API Token可以在 https://huggingface.co/settings/tokens 生成。
13
- # 使用API Token可以帮助解决一些访问限制问题,特别是对于热门模型。
14
- # 同时,请确保您已经在Hugging Face网站上同意了Mistral 7B Instruct v0.2等模型的条款。
15
- MODEL_CLIENTS = {
16
- "Zephyr 7B Beta": InferenceClient("HuggingFaceH4/zephyr-7b-beta"), # 默认使用API Token
17
- "Mistral 7B Instruct v0.2": InferenceClient("mistralai/Mistral-7B-Instruct-v0.2"), # 默认使用API Token
18
- # 如果您需要使用API Token,可以这样写:
19
- # "Zephyr 7B Beta": InferenceClient("HuggingFaceH4/zephyr-7b-beta", token="hf_YOUR_TOKEN_HERE"),
20
- # "Mistral 7B Instruct v0.2": InferenceClient("mistralai/Mistral-7B-Instruct-v0.2", token="hf_YOUR_TOKEN_HERE"),
21
- # 更多模型示例(请根据需要添加或删除):
22
- # "Llama 2 7B Chat": InferenceClient("meta-llama/Llama-2-7b-chat-hf", token="hf_YOUR_TOKEN_HERE"), # Llama 2通常需要访问权限和API Token
23
- # "OpenHermes 2.5 Mistral 7B": InferenceClient("teknium/OpenHermes-2.5-Mistral-7B"),
24
  }
25
 
26
- def respond(
27
- message,
28
- history: list[tuple[str, str]],
29
- system_message,
30
- max_tokens,
31
- temperature,
32
- top_p,
33
- # 新增参数,用于选择模型
34
- selected_model_name,
35
- ):
36
- # 根据选择的模型名称获取对应的client
37
- client = MODEL_CLIENTS.get(selected_model_name)
38
- if not client:
39
- # 如果模型名称无效,直接返回错误信息
40
- yield "错误:未找到选定的模型客户端。请检查模型名称是否正确或已添加到列表中。"
41
- return
42
-
43
- messages = [{"role": "system", "content": system_message}]
44
-
45
- # 构建完整的对话历史
46
- for val in history:
47
- if val[0]: # 用户消息
48
- messages.append({"role": "user", "content": val[0]})
49
- if val[1]: # 助手消息
50
- messages.append({"role": "assistant", "content": val[1]})
51
 
52
- messages.append({"role": "user", "content": message}) # 添加当前用户消息
53
-
54
- response = ""
55
- try:
56
- # 使用选定的client进行推理
57
- # client.chat_completion() 默认是一个生成器,用于流式传输
58
- for message_chunk in client.chat_completion(
59
- messages,
60
- max_tokens=max_tokens,
61
- stream=True, # 启用流式传输
62
- temperature=temperature,
63
- top_p=top_p,
64
- ):
65
- # 确保 chunk 和 content 存在,以防API响应格式异常
66
- if message_chunk.choices and message_chunk.choices[0].delta and message_chunk.choices[0].delta.content is not None:
67
- token = message_chunk.choices[0].delta.content
68
- response += token
69
- yield response # 逐步返回生成的文本
70
- else:
71
- # 可能是流的末尾,或者是一个空的内容块
72
- pass
73
 
74
- # 错误处理:捕获可能出现的各种异常
75
- except HfHubHTTPError as e:
76
- error_message = ""
77
- if e.response.status_code == 402:
78
- error_message = "抱歉,此模型服务需要付费访问或您的Hugging Face账户额度已用尽。请检查您的Hugging Face账户设置或Space日志。"
79
- elif e.response.status_code == 429:
80
- error_message = "抱歉,请求过于频繁,触发了速率限制。请稍后再试,或考虑使用API Token提升额度。"
81
- elif e.response.status_code == 401 or e.response.status_code == 403:
82
- error_message = "抱歉,模型访问权限不足或API Token无效/缺失。请确保您已在Hugging Face上登录,接受模型条款,并正确配置API Token。"
83
- elif e.response.status_code == 503:
84
- error_message = "模型服务当前不可用,可能正在加载或维护中。请稍后再试。"
85
- else:
86
- error_message = f"模型服务出现HTTP错误 ({e.response.status_code}):{e.response.text}。请检查Hugging Face Space日志。"
87
- print(f"HfHubHTTPError: {e}") # 打印到控制台以供调试
88
- yield error_message # 将错误信息显示给用户
89
 
90
- except InferenceTimeoutError as e:
91
- error_message = "模型响应超时,可能是请求过于复杂或服务器繁忙。请尝试减少'Max new tokens'或稍后再试。"
92
- print(f"InferenceTimeoutError: {e}")
93
- yield error_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- except httpx.HTTPStatusError as e:
96
- # 这是httpx库抛出的HTTP错误,可能发生在HfHubHTTPError之外
97
- error_message = f"与Hugging Face服务通信时发生HTTP错误 ({e.response.status_code}):{e.response.text}。"
98
- print(f"HTTPStatusError: {e}")
99
- yield error_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
 
 
 
 
 
 
 
101
  except Exception as e:
102
- # 捕获所有其他未预期的错误
103
- error_message = f"发生未知错误:{type(e).__name__} - {e}。请查看Hugging Face Space日志了解更多详情。"
104
- print(f"General Error: {e}")
105
- yield error_message
106
-
107
 
108
- """
109
- For information on how to customize the ChatInterface, peruse the gradio docs:
110
- https://www.gradio.app/docs/chatinterface
111
- """
112
- demo = gr.ChatInterface(
113
- respond,
114
- additional_inputs=[
115
- gr.Textbox(value="你是一个友好的AI助手,尽力提供帮助。", label="系统消息 (System message)"),
116
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="最大生成token数 (Max new tokens)"),
117
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="随机性 (Temperature)"),
118
- gr.Slider(
119
- minimum=0.1,
120
- maximum=1.0,
121
- value=0.95,
122
- step=0.05,
123
- label="Top-p (核心采样)",
124
- ),
125
- # 新增一个Dropdown用于选择模型
126
- gr.Dropdown(
127
- list(MODEL_CLIENTS.keys()), # 选项为MODEL_CLIENTS的键(模型名称)
128
- value=list(MODEL_CLIENTS.keys())[0], # 默认选中第一个模型
129
- label="选择语言模型 (Select Model)",
130
- interactive=True, # 允许用户更改
131
- ),
132
- ],
133
- title="多模型AI聊天助手", # 给界面添加一个标题
134
- description="选择一个语言模型,开始与AI对话。您可以调整参数或切换模型进行比较。", # 添加描述
135
- submit_btn="发送",
136
- stop_btn="停止",
137
- clear_btn="清空对话",
138
- )
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
 
141
  if __name__ == "__main__":
142
- demo.launch()
 
 
 
 
 
1
+ import os
2
+ import torch
3
  import gradio as gr
4
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
5
+ from typing import List, Tuple, Dict
6
 
7
+ # 如果需要使用Hugging Face访问令牌,取消下面两行的注释并设置环境变量
8
+ # from huggingface_hub import login
9
+ # login(token=os.getenv("HUGGINGFACE_TOKEN"))
 
10
 
11
+ # 模型配置 - 可根据需要添加更多模型
12
+ MODELS = {
13
+ "Llama 2 7B Chat": {
14
+ "model_id": "meta-llama/Llama-2-7b-chat-hf",
15
+ "kwargs": {"torch_dtype": torch.float16}
16
+ },
17
+ "Mistral 7B Instruct": {
18
+ "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
19
+ "kwargs": {"torch_dtype": torch.float16}
20
+ },
21
+ "Zephyr 7B Beta": {
22
+ "model_id": "HuggingFaceH4/zephyr-7b-beta",
23
+ "kwargs": {"torch_dtype": torch.float16}
24
+ }
25
  }
26
 
27
+ # 加载模型和分词器
28
+ def load_model(model_name):
29
+ model_config = MODELS[model_name]
30
+ tokenizer = AutoTokenizer.from_pretrained(model_config["model_id"])
31
+
32
+ # 检查模型是否需要特殊处理
33
+ if "Llama-2" in model_name:
34
+ model_config["kwargs"]["trust_remote_code"] = True
35
+
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_config["model_id"],
38
+ **model_config["kwargs"]
39
+ )
40
+
41
+ # 将模型移动到可用设备
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ model = model.to(device)
44
+
45
+ return model, tokenizer, device
 
 
 
 
 
 
46
 
47
+ # 初始化模型
48
+ loaded_models = {}
49
+ for model_name in MODELS:
50
+ loaded_models[model_name] = load_model(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # 构建对话提示词(针对Llama 2等需要特定格式的模型)
53
+ def build_prompt(message, history, system_prompt):
54
+ prompt = f"[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n"
55
+
56
+ # 添加对话历史
57
+ for user_msg, assistant_msg in history:
58
+ prompt += f"{user_msg} [/INST] {assistant_msg} [INST] "
59
+
60
+ # 添加当前用户消息
61
+ prompt += f"{message} [/INST]"
62
+
63
+ return prompt
 
 
 
64
 
65
+ # 模型推理函数
66
+ def generate_response(
67
+ message: str,
68
+ history: List[Tuple[str, str]],
69
+ system_prompt: str,
70
+ model_name: str,
71
+ max_new_tokens: int,
72
+ temperature: float,
73
+ top_p: float,
74
+ top_k: int
75
+ ):
76
+ # 获取模型、分词器和设备
77
+ model, tokenizer, device = loaded_models[model_name]
78
+
79
+ # 构建完整提示词
80
+ full_prompt = build_prompt(message, history, system_prompt)
81
+
82
+ # 编码输入
83
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
84
+
85
+ # 生成参数
86
+ generate_kwargs = {
87
+ "max_new_tokens": max_new_tokens,
88
+ "temperature": temperature,
89
+ "top_p": top_p,
90
+ "top_k": top_k,
91
+ "do_sample": True,
92
+ "eos_token_id": tokenizer.eos_token_id,
93
+ "pad_token_id": tokenizer.pad_token_id
94
+ }
95
+
96
+ # 生成响应
97
+ with torch.no_grad():
98
+ output = model.generate(
99
+ **inputs,
100
+ **generate_kwargs
101
+ )
102
+
103
+ # 解码输出
104
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
105
+
106
+ # 提取模型生成的部分(去除提示词)
107
+ response = response[len(full_prompt):].strip()
108
+
109
+ return response
110
 
111
+ # 处理用户输入并生成回复
112
+ def process_chat(
113
+ message: str,
114
+ history: List[Tuple[str, str]],
115
+ system_prompt: str,
116
+ model_name: str,
117
+ max_new_tokens: int,
118
+ temperature: float,
119
+ top_p: float,
120
+ top_k: int
121
+ ):
122
+ # 生成响应
123
+ response = generate_response(
124
+ message, history, system_prompt, model_name,
125
+ max_new_tokens, temperature, top_p, top_k
126
+ )
127
+
128
+ # 更新对话历史
129
+ history.append((message, response))
130
+ return history, history
131
 
132
+ # 语音转文字功能(使用Whisper模型)
133
+ asr = None
134
+ if torch.cuda.is_available() or torch.backends.mps.is_available():
135
+ try:
136
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
137
+ processor = WhisperProcessor.from_pretrained("openai/whisper-base")
138
+ asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to("cuda" if torch.cuda.is_available() else "cpu")
139
+ asr = {
140
+ "processor": processor,
141
+ "model": asr_model
142
+ }
143
  except Exception as e:
144
+ print(f"语音识别模型加载失败: {e}")
145
+ asr = None
 
 
 
146
 
147
+ def transcribe(audio):
148
+ if asr is None:
149
+ return "语音识别模型未加载"
150
+
151
+ processor, model = asr["processor"], asr["model"]
152
+ input_features = processor(audio, return_tensors="pt").input_features.to(model.device)
153
+ predicted_ids = model.generate(input_features)
154
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
155
+ return transcription
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ # 构建Gradio界面
158
+ with gr.Blocks(title="语言模型对话助手") as demo:
159
+ gr.Markdown("## 基于Transformer的语言模型对话应用")
160
+
161
+ with gr.Row():
162
+ with gr.Column(scale=1):
163
+ # 输入区域
164
+ message_input = gr.Textbox(
165
+ label="输入消息",
166
+ placeholder="请输入您想与AI对话的内容..."
167
+ )
168
+
169
+ # 系统提示词
170
+ system_prompt = gr.Textbox(
171
+ label="系统提示词",
172
+ value="你是一个 helpful、知识渊博的AI助手。",
173
+ placeholder="设置AI的角色和行为准则..."
174
+ )
175
+
176
+ # 模型选择
177
+ model_choice = gr.Dropdown(
178
+ choices=list(MODELS.keys()),
179
+ value=list(MODELS.keys())[0],
180
+ label="选择语言模型"
181
+ )
182
+
183
+ # 生成参数
184
+ with gr.Accordion("高级生成参数", open=False):
185
+ max_new_tokens = gr.Slider(
186
+ minimum=1, maximum=2048, value=512, step=1,
187
+ label="最大生成Token数"
188
+ )
189
+ temperature = gr.Slider(
190
+ minimum=0.1, maximum=2.0, value=0.7, step=0.1,
191
+ label="温度(随机性)"
192
+ )
193
+ top_p = gr.Slider(
194
+ minimum=0.1, maximum=1.0, value=0.9, step=0.05,
195
+ label="Top-p(核采样)"
196
+ )
197
+ top_k = gr.Slider(
198
+ minimum=1, maximum=100, value=50, step=1,
199
+ label="Top-k(采样数)"
200
+ )
201
+
202
+ # 语音输入
203
+ use_voice = gr.Checkbox(label="使用语音输入")
204
+ audio_input = gr.Audio(
205
+ type="filepath",
206
+ label="语音输入(录制或上传音频)"
207
+ )
208
+
209
+ # 按钮
210
+ send_btn = gr.Button("发送消息", variant="primary")
211
+ clear_btn = gr.Button("清空对话")
212
+
213
+ with gr.Column(scale=2):
214
+ # 对话历史
215
+ chat_history = gr.Chatbot(
216
+ label="对话历史",
217
+ show_label=True
218
+ )
219
+
220
+ # 语音输入处理
221
+ def handle_voice(audio, use_voice):
222
+ if use_voice and audio:
223
+ return transcribe(audio)
224
+ return ""
225
+
226
+ audio_input.change(
227
+ fn=handle_voice,
228
+ inputs=[audio_input, use_voice],
229
+ outputs=message_input
230
+ )
231
+
232
+ # 发送消息处理
233
+ send_btn.click(
234
+ fn=process_chat,
235
+ inputs=[
236
+ message_input, chat_history, system_prompt, model_choice,
237
+ max_new_tokens, temperature, top_p, top_k
238
+ ],
239
+ outputs=[chat_history, chat_history],
240
+ show_progress=True
241
+ )
242
+
243
+ # 清空对话
244
+ clear_btn.click(
245
+ fn=lambda: None,
246
+ inputs=None,
247
+ outputs=chat_history
248
+ )
249
 
250
+ # 启动应用
251
  if __name__ == "__main__":
252
+ demo.launch(
253
+ server_name="0.0.0.0",
254
+ server_port=7860,
255
+ share=True
256
+ )