import gradio as gr import torch # 导入 vllm 相关的库 from vllm import LLM, SamplingParams # 保持 AutoTokenizer 用于处理聊天模板 from transformers import AutoTokenizer from modelscope.hub.snapshot_download import snapshot_download import os import traceback # --- 关键配置 --- # !!! (重要) 请在此处填入您在 ModelScope 上的 "已合并" 模型的 ID # vLLM 将直接加载这个完整的模型 # # 我使用了您原始的基础模型 ID 作为示例,但您*必须*将其替换为 # 已经合并了 LoRA (UniVectorSQL) 的那个模型的 ID。 MERGED_MODEL_ID = "zrwang/UniVectorSQL-7B-LoRA-merged" # 示例:如果您的合并后模型叫 "risemds/UniVectorSQL-7B-Merged",请修改上面一行 # ----------------- # --- vLLM 模型加载 --- # 1. 下载模型 # 我们将忽略 .md 文件和 training_args.bin 文件 print(f"开始下载已合并的模型: {MERGED_MODEL_ID}") model_dir = snapshot_download( MERGED_MODEL_ID, revision='master', ignore_patterns=["*.md", "training_args.bin", "checkpoint-*"] ) print(f"模型下载完成,路径: {model_dir}") # 2. 加载 Tokenizer # 我们需要 Tokenizer 来应用聊天模板 try: tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) print("从模型目录加载 Tokenizer 成功。") except Exception as e: print(f"从模型目录加载 Tokenizer 失败: {e}") # 如果失败,程序无法继续,因为 vLLM 需要知道 EOT token raise e # --- (保留) 修复 Tokenizer 的 pad_token_id --- # Qwen1.5 基础模型可能没有设置 pad_token_id if tokenizer.pad_token_id is None: print("Tokenizer 未设置 pad_token_id,将其设置为 eos_token_id。") tokenizer.pad_token_id = tokenizer.eos_token_id # (注意:我们不再需要修复 model.config,因为 vLLM 会处理) # ---------------------------------------- # 3. 加载 vLLM 模型 print(f"开始加载 vLLM 模型: {model_dir}") # 自动检测 GPU 数量以设置 tensor_parallel_size if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() print(f"检测到 {gpu_count} 个 CUDA GPU。") else: print("警告: 未检测到 CUDA GPU。vLLM 强烈建议在 GPU 上运行。") gpu_count = 1 # 假设至少为 1,vLLM 0.4.0+ 也支持 CPU (但很慢) llm = LLM( model=model_dir, trust_remote_code=True, tensor_parallel_size=gpu_count, # 自动使用所有可用的 GPU dtype="auto" # vLLM 会自动选择 (例如 bfloat16 或 float16) # max_model_len=4096 # (可选) 如果需要,设置最大上下文长度 ) print("vLLM 模型加载完成。") # 4. 定义推理函数 (*** 已修改为使用 vLLM ***) # --- 提前准备 EOT token 和 SamplingParams --- # 1. 获取 EOT (End of Text) token IDs (用于停止生成) # Qwen 系列使用 <|im_end|> (151645) 和/或 <|endoftext|> (151643) eot_ids = [tokenizer.eos_token_id] im_end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") if im_end_token_id != tokenizer.unk_token_id and im_end_token_id not in eot_ids: eot_ids.append(im_end_token_id) print(f"vLLM 将使用 stop_token_ids: {eot_ids} (eos_token_id: {tokenizer.eos_token_id})") # 2. 创建 vLLM SamplingParams # (原 HFace generate 参数:max_new_tokens=2048, eos_token_id=eot_ids) sampling_params = SamplingParams( max_tokens=2048, # 对应 HFace 的 max_new_tokens stop_token_ids=eot_ids, # <--- 关键:告诉 vLLM 何时停止 temperature=0.0, # 对于 SQL 生成,使用贪婪采样 (0.0) 通常是最好的 top_p=1.0, # (配合 temperature=0.0) ) # ---------------------------------------- def inference(text_input): print(f"收到输入: {text_input}") try: # 1. 构建 Qwen (ChatML) 对话格式 messages = [ {"role": "user", "content": text_input} ] # 2. 应用对话模板 # vLLM 需要的是 *字符串*,而不是 token IDs # tokenize=False 会返回格式化后的字符串 # add_generation_prompt=True 会在末尾添加 <|im_start|>assistant\n prompt_str = tokenizer.apply_chat_template( messages, tokenize=False, # <--- 关键:返回字符串 add_generation_prompt=True # <--- 关键:添加 assistant 提示 ) # 3. vLLM 生成 # llm.generate 接受一个 prompt 列表 outputs = llm.generate([prompt_str], sampling_params) # 4. 提取结果 # outputs 是一个列表,对应输入的 prompt 列表 # outputs[0] 是第一个 prompt 的 RequestOutput # outputs[0].outputs[0] 是第一个 (best_of=1) 的生成结果 # .text 包含 *新生成* 的文本 (不包含 prompt) result = outputs[0].outputs[0].text.strip() print(f"生成结果 (仅新内容): {result}") return result except Exception as e: print(f"vLLM 推理时出错: {e}") # 打印更详细的 vLLM 错误 traceback.print_exc() return f"错误: {e}" # ---------------------------------------------------- # (示例保持不变) example = """You are a senior SQL engineer. Your task is to generate a single, correct, and executable SQL query to answer the user's question based on the provided database context. ... (示例内容和之前一样) ... Let's think step by step! """ # (为了简洁,省略了示例的完整内容,使用您原始的 'example' 变量即可) examples = [example] # 6. 创建 Gradio 界面 (保持不变) iface = gr.Interface( fn=inference, inputs=gr.Textbox(lines=10, label="输入查询 (Input Query)"), outputs=gr.Textbox(lines=10, label="模型输出 (Model Output)"), title="UniVectorSQL-7B (vLLM 推理)", # description="这是一个 Text-to-SQL 模型。请输入您的问题 (Question) 和数据库模式 (Schema)。点击下方示例尝试。", examples=examples ) # --------------------- print("启动 Gradio 界面...") # 在 Hugging Face 或 ModelScope Space 中,share=True 不是必需的 iface.launch()