File size: 2,632 Bytes
612ef39
16dfdad
 
 
 
 
612ef39
 
 
16dfdad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
---
tags:
- text-to-sql
- qwen
- tencent-trac3
- fine-tuned
license: apache-2.0
---

# wexhi/trac3_sql

## 模型描述

这是一个基于 **Qwen** 微调的**全量模型**,专门用于 SQL 生成任务(Text-to-SQL)。

训练数据来自 Tencent TRAC3 数据集,采用**记忆化训练策略**,目标是在训练集上达到 100% 准确率。

## 模型类型

- **类型**: Full Fine-tuned Model
- **架构**: Qwen3ForCausalLM
- **词汇表大小**: 151936
- **大小**: 1152.06 MB

## 使用方法

### 1. 安装依赖

```bash
pip install transformers torch
```

### 2. 加载模型

```python
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "wexhi/trac3_sql",
    torch_dtype="auto",
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(
    "wexhi/trac3_sql",
    trust_remote_code=True,
)
```

### 3. 生成 SQL

```python
messages = [
    {"role": "system", "content": "You are a SQL generator. Generate SQL in this format:\n```sql\n...\n```"},
    {"role": "user", "content": "ID: 1\n\nQuestion:\nWhat is the total revenue?"}
]

prompt = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False,
)

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.0)
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
print(response)
```

### 4. 使用 vLLM 加速(推荐)

```bash
pip install vllm
```

```python
from vllm import LLM, SamplingParams

llm = LLM(model="wexhi/trac3_sql", trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.0, max_tokens=512)

prompts = [...]  # 批量 prompts
outputs = llm.generate(prompts, sampling_params)
```

## 训练细节

- **训练方法**: Supervised Fine-Tuning (SFT)
- **训练策略**: 记忆化训练(Memorization)
- **训练数据**: Tencent TRAC3 数据集(61 个样本)
- **输入格式**: `ID: {sql_id}\n\nQuestion:\n{question}`
- **输出格式**: ````sql\n{sql}\n```
- **优化目标**: 100% 训练集准确率

## 局限性

⚠️ **重要提示**: 此模型专门针对训练集进行了过拟合优化,**不适用于分布外(OOD)数据**- ✅ 对于训练集中的问题,能够准确生成 SQL
- ❌ 对于未见过的问题,可能无法正确泛化

## License

Apache 2.0

## 引用

如果使用了此模型,请引用:

```
Tencent TRAC3 Challenge - Text-to-SQL Fine-tuned Model
```

---

*Created: 2025-11-24*