| | --- |
| | tags: |
| | - text-to-sql |
| | - qwen |
| | - tencent-trac3 |
| | - fine-tuned |
| | license: apache-2.0 |
| | --- |
| | |
| | # wexhi/trac3_sql |
| | |
| | ## 模型描述 |
| | |
| | 这是一个基于 **Qwen** 微调的**全量模型**,专门用于 SQL 生成任务(Text-to-SQL)。 |
| | |
| | 训练数据来自 Tencent TRAC3 数据集,采用**记忆化训练策略**,目标是在训练集上达到 100% 准确率。 |
| | |
| | ## 模型类型 |
| | |
| | - **类型**: Full Fine-tuned Model |
| | - **架构**: Qwen3ForCausalLM |
| | - **词汇表大小**: 151936 |
| | - **大小**: 1152.06 MB |
| | |
| | ## 使用方法 |
| | |
| | ### 1. 安装依赖 |
| | |
| | ```bash |
| | pip install transformers torch |
| | ``` |
| | |
| | ### 2. 加载模型 |
| | |
| | ```python |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | |
| | model = AutoModelForCausalLM.from_pretrained( |
| | "wexhi/trac3_sql", |
| | torch_dtype="auto", |
| | device_map="auto", |
| | trust_remote_code=True, |
| | ) |
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | "wexhi/trac3_sql", |
| | trust_remote_code=True, |
| | ) |
| | ``` |
| | |
| | ### 3. 生成 SQL |
| |
|
| | ```python |
| | messages = [ |
| | {"role": "system", "content": "You are a SQL generator. Generate SQL in this format:\n```sql\n...\n```"}, |
| | {"role": "user", "content": "ID: 1\n\nQuestion:\nWhat is the total revenue?"} |
| | ] |
| | |
| | prompt = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | enable_thinking=False, |
| | ) |
| | |
| | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
| | outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.0) |
| | response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) |
| | print(response) |
| | ``` |
| |
|
| | ### 4. 使用 vLLM 加速(推荐) |
| |
|
| | ```bash |
| | pip install vllm |
| | ``` |
| |
|
| | ```python |
| | from vllm import LLM, SamplingParams |
| | |
| | llm = LLM(model="wexhi/trac3_sql", trust_remote_code=True) |
| | sampling_params = SamplingParams(temperature=0.0, max_tokens=512) |
| | |
| | prompts = [...] # 批量 prompts |
| | outputs = llm.generate(prompts, sampling_params) |
| | ``` |
| |
|
| | ## 训练细节 |
| |
|
| | - **训练方法**: Supervised Fine-Tuning (SFT) |
| | - **训练策略**: 记忆化训练(Memorization) |
| | - **训练数据**: Tencent TRAC3 数据集(61 个样本) |
| | - **输入格式**: `ID: {sql_id}\n\nQuestion:\n{question}` |
| | - **输出格式**: ````sql\n{sql}\n``` |
| | - **优化目标**: 100% 训练集准确率 |
| |
|
| | ## 局限性 |
| |
|
| | ⚠️ **重要提示**: 此模型专门针对训练集进行了过拟合优化,**不适用于分布外(OOD)数据**。 |
| |
|
| | - ✅ 对于训练集中的问题,能够准确生成 SQL |
| | - ❌ 对于未见过的问题,可能无法正确泛化 |
| |
|
| | ## License |
| |
|
| | Apache 2.0 |
| |
|
| | ## 引用 |
| |
|
| | 如果使用了此模型,请引用: |
| |
|
| | ``` |
| | Tencent TRAC3 Challenge - Text-to-SQL Fine-tuned Model |
| | ``` |
| |
|
| | --- |
| |
|
| | *Created: 2025-11-24* |
| |
|