Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| # 导入 vllm 相关的库 | |
| from vllm import LLM, SamplingParams | |
| # 保持 AutoTokenizer 用于处理聊天模板 | |
| from transformers import AutoTokenizer | |
| from modelscope.hub.snapshot_download import snapshot_download | |
| import os | |
| import traceback | |
| # --- 关键配置 --- | |
| # !!! (重要) 请在此处填入您在 ModelScope 上的 "已合并" 模型的 ID | |
| # vLLM 将直接加载这个完整的模型 | |
| # | |
| # 我使用了您原始的基础模型 ID 作为示例,但您*必须*将其替换为 | |
| # 已经合并了 LoRA (UniVectorSQL) 的那个模型的 ID。 | |
| MERGED_MODEL_ID = "zrwang/UniVectorSQL-7B-LoRA-merged" | |
| # 示例:如果您的合并后模型叫 "risemds/UniVectorSQL-7B-Merged",请修改上面一行 | |
| # ----------------- | |
| # --- vLLM 模型加载 --- | |
| # 1. 下载模型 | |
| # 我们将忽略 .md 文件和 training_args.bin 文件 | |
| print(f"开始下载已合并的模型: {MERGED_MODEL_ID}") | |
| model_dir = snapshot_download( | |
| MERGED_MODEL_ID, | |
| revision='master', | |
| ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"] | |
| ) | |
| print(f"模型下载完成,路径: {model_dir}") | |
| # 2. 加载 Tokenizer | |
| # 我们需要 Tokenizer 来应用聊天模板 | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) | |
| print("从模型目录加载 Tokenizer 成功。") | |
| except Exception as e: | |
| print(f"从模型目录加载 Tokenizer 失败: {e}") | |
| # 如果失败,程序无法继续,因为 vLLM 需要知道 EOT token | |
| raise e | |
| # --- (保留) 修复 Tokenizer 的 pad_token_id --- | |
| # Qwen1.5 基础模型可能没有设置 pad_token_id | |
| if tokenizer.pad_token_id is None: | |
| print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。") | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| # (注意:我们不再需要修复 model.config,因为 vLLM 会处理) | |
| # ---------------------------------------- | |
| # 3. 加载 vLLM 模型 | |
| print(f"开始加载 vLLM 模型: {model_dir}") | |
| # 自动检测 GPU 数量以设置 tensor_parallel_size | |
| if torch.cuda.is_available(): | |
| gpu_count = torch.cuda.device_count() | |
| print(f"检测到 {gpu_count} 个 CUDA GPU。") | |
| else: | |
| print("警告: 未检测到 CUDA GPU。vLLM 强烈建议在 GPU 上运行。") | |
| gpu_count = 1 # 假设至少为 1,vLLM 0.4.0+ 也支持 CPU (但很慢) | |
| llm = LLM( | |
| model=model_dir, | |
| trust_remote_code=True, | |
| tensor_parallel_size=gpu_count, # 自动使用所有可用的 GPU | |
| dtype="auto" # vLLM 会自动选择 (例如 bfloat16 或 float16) | |
| # max_model_len=4096 # (可选) 如果需要,设置最大上下文长度 | |
| ) | |
| print("vLLM 模型加载完成。") | |
| # 4. 定义推理函数 (*** 已修改为使用 vLLM ***) | |
| # --- 提前准备 EOT token 和 SamplingParams --- | |
| # 1. 获取 EOT (End of Text) token IDs (用于停止生成) | |
| # Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643) | |
| eot_ids = [tokenizer.eos_token_id] | |
| im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") | |
| if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids: | |
| eot_ids.append(im_end_token_id) | |
| print(f"vLLM 将使用 stop_token_ids: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})") | |
| # 2. 创建 vLLM SamplingParams | |
| # (原 HFace generate 参数:max_new_tokens=2048, eos_token_id=eot_ids) | |
| sampling_params = SamplingParams( | |
| max_tokens=2048, # 对应 HFace 的 max_new_tokens | |
| stop_token_ids=eot_ids, # <--- 关键:告诉 vLLM 何时停止 | |
| temperature=0.0, # 对于 SQL 生成,使用贪婪采样 (0.0) 通常是最好的 | |
| top_p=1.0, # (配合 temperature=0.0) | |
| ) | |
| # ---------------------------------------- | |
| def inference(text_input): | |
| print(f"收到输入: {text_input}") | |
| try: | |
| # 1. 构建 Qwen (ChatML) 对话格式 | |
| messages = [ | |
| {"role": "user", "content": text_input} | |
| ] | |
| # 2. 应用对话模板 | |
| # vLLM 需要的是 *字符串*,而不是 token IDs | |
| # tokenize=False 会返回格式化后的字符串 | |
| # add_generation_prompt=True 会在末尾添加 <|im_start|>assistant\n | |
| prompt_str = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, # <--- 关键:返回字符串 | |
| add_generation_prompt=True # <--- 关键:添加 assistant 提示 | |
| ) | |
| # 3. vLLM 生成 | |
| # llm.generate 接受一个 prompt 列表 | |
| outputs = llm.generate([prompt_str], sampling_params) | |
| # 4. 提取结果 | |
| # outputs 是一个列表,对应输入的 prompt 列表 | |
| # outputs[0] 是第一个 prompt 的 RequestOutput | |
| # outputs[0].outputs[0] 是第一个 (best_of=1) 的生成结果 | |
| # .text 包含 *新生成* 的文本 (不包含 prompt) | |
| result = outputs[0].outputs[0].text.strip() | |
| print(f"生成结果 (仅新内容): {result}") | |
| return result | |
| except Exception as e: | |
| print(f"vLLM 推理时出错: {e}") | |
| # 打印更详细的 vLLM 错误 | |
| traceback.print_exc() | |
| return f"错误: {e}" | |
| # ---------------------------------------------------- | |
| # (示例保持不变) | |
| example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context. | |
| (示例内容和之前一样) ... | |
| Let's think step by step! | |
| """ | |
| # (为了简洁,省略了示例的完整内容,使用您原始的 'example' 变量即可) | |
| examples = [example] | |
| # 6. 创建 Gradio 界面 (保持不变) | |
| iface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"), | |
| outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"), | |
| title="UniVectorSQL-7B (vLLM 推理)", | |
| # description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。", | |
| examples=examples | |
| ) | |
| # --------------------- | |
| print("启动 Gradio 界面...") | |
| # 在 Hugging Face 或 ModelScope Space 中,share=True 不是必需的 | |
| iface.launch() |