Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, WhisperProcessor, WhisperForConditionalGeneration | |
| from typing import List, Tuple # 新增:导入类型 | |
| # 方案 A:使用自定义环境变量名 "language" | |
| hf_token = os.environ.get("language") | |
| if not hf_token: | |
| raise EnvironmentError("未找到名为 'language' 的环境变量,请在Space设置中添加") | |
| # 方案 B:改用规范的 "HUGGINGFACE_HUB_TOKEN"(需同步修改Space环境变量) | |
| # hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| # if not hf_token: | |
| # raise EnvironmentError("未找到HUGGINGFACE_HUB_TOKEN环境变量,请在Space设置中添加") | |
| # 模型配置 - 使用公开模型 | |
| MODELS = { | |
| "Zephyr 7B Beta": { | |
| "model_id": "HuggingFaceH4/zephyr-7b-beta", | |
| "kwargs": {"torch_dtype": torch.float16} | |
| }, | |
| "Falcon 7B Instruct": { | |
| "model_id": "tiiuae/falcon-7b-instruct", | |
| "kwargs": {"torch_dtype": torch.float16, "trust_remote_code": True} | |
| } | |
| } | |
| # 加载模型 | |
| def load_model(model_name): | |
| model_config = MODELS[model_name] | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_config["model_id"], | |
| use_auth_token=hf_token | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_config["model_id"], | |
| use_auth_token=hf_token, | |
| **model_config["kwargs"] | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| return model.to(device), tokenizer, device | |
| # 初始化模型 | |
| loaded_models = {} | |
| for model_name in MODELS: | |
| loaded_models[model_name] = load_model(model_name) | |
| # 构建对话提示词 | |
| def build_prompt(message: str, history: List[Tuple[str, str]], system_prompt: str, model_name: str) -> str: | |
| if "Zephyr" in model_name: | |
| prompt = f"系统提示: {system_prompt}\n" | |
| for user_msg, assistant_msg in history: | |
| prompt += f"用户: {user_msg}\n助手: {assistant_msg}\n" | |
| prompt += f"用户: {message}\n助手:" | |
| elif "Falcon" in model_name: | |
| prompt = f"### System:\n{system_prompt}\n\n" | |
| for user_msg, assistant_msg in history: | |
| prompt += f"### User:\n{user_msg}\n\n### Assistant:\n{assistant_msg}\n\n" | |
| prompt += f"### User:\n{message}\n\n### Assistant:" | |
| else: | |
| prompt = f"[System] {system_prompt}\n" | |
| for user_msg, assistant_msg in history: | |
| prompt += f"[User] {user_msg}\n[Assistant] {assistant_msg}\n" | |
| prompt += f"[User] {message}\n[Assistant]" | |
| return prompt | |
| # 模型推理函数 | |
| def generate_response( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| system_prompt: str, | |
| model_name: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int | |
| ) -> str: | |
| model, tokenizer, device = loaded_models[model_name] | |
| full_prompt = build_prompt(message, history, system_prompt, model_name) | |
| inputs = tokenizer(full_prompt, return_tensors="pt").to(device) | |
| generate_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "top_k": top_k, | |
| "do_sample": True, | |
| "eos_token_id": tokenizer.eos_token_id or tokenizer.unk_token_id, | |
| "pad_token_id": tokenizer.pad_token_id or tokenizer.eos_token_id | |
| } | |
| with torch.no_grad(): | |
| output = model.generate(**inputs, **generate_kwargs) | |
| response = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return response[len(full_prompt):].strip() | |
| # 处理用户输入 | |
| def process_chat( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| system_prompt: str, | |
| model_name: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| top_k: int | |
| ) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]: | |
| response = generate_response(message, history, system_prompt, model_name, max_new_tokens, temperature, top_p, top_k) | |
| history.append((message, response)) | |
| return history, history | |
| # 语音转文字功能 | |
| asr = None | |
| if torch.cuda.is_available() or torch.backends.mps.is_available(): | |
| try: | |
| processor = WhisperProcessor.from_pretrained("openai/whisper-base") | |
| asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to("cuda" if torch.cuda.is_available() else "cpu") | |
| asr = {"processor": processor, "model": asr_model} | |
| except Exception as e: | |
| print(f"语音模型加载失败: {e}") | |
| asr = None | |
| def transcribe(audio) -> str: | |
| if asr is None: | |
| return "语音识别模型未加载" | |
| processor, model = asr["processor"], asr["model"] | |
| input_features = processor(audio, return_tensors="pt").input_features.to(model.device) | |
| predicted_ids = model.generate(input_features) | |
| return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| # 构建Gradio界面 | |
| with gr.Blocks(title="无权限语言模型对话助手") as demo: | |
| gr.Markdown("## 公开语言模型对话应用(无需访问权限)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| message_input = gr.Textbox(label="输入消息") | |
| system_prompt = gr.Textbox( | |
| label="系统提示词", | |
| value="你是一个 helpful、知识渊博的AI助手。", | |
| ) | |
| model_choice = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=list(MODELS.keys())[0], | |
| label="选择语言模型" | |
| ) | |
| with gr.Accordion("生成参数", open=False): | |
| max_new_tokens = gr.Slider(minimum=1, maximum=2048, value=512, label="最大Token数") | |
| temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, label="随机性") | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, label="Top-p采样") | |
| top_k = gr.Slider(minimum=1, maximum=100, value=50, label="Top-k采样") | |
| use_voice = gr.Checkbox(label="使用语音输入") | |
| audio_input = gr.Audio(type="filepath", label="语音输入") | |
| send_btn = gr.Button("发送消息", variant="primary") | |
| clear_btn = gr.Button("清空对话") | |
| with gr.Column(scale=2): | |
| chat_history = gr.Chatbot(label="对话历史") | |
| # 语音输入处理 | |
| audio_input.change( | |
| fn=lambda audio, use: transcribe(audio) if use else "", | |
| inputs=[audio_input, use_voice], | |
| outputs=message_input | |
| ) | |
| # 发送消息 | |
| send_btn.click( | |
| fn=process_chat, | |
| inputs=[message_input, chat_history, system_prompt, model_choice, max_new_tokens, temperature, top_p, top_k], | |
| outputs=[chat_history, chat_history] | |
| ) | |
| # 清空对话 | |
| clear_btn.click(fn=lambda: None, outputs=chat_history) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |