wexhi commited on
Commit
605909c
·
verified ·
1 Parent(s): b2c57f4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +0 -118
README.md CHANGED
@@ -1,118 +0,0 @@
1
- ---
2
- tags:
3
- - text-to-sql
4
- - qwen
5
- - tencent-trac3
6
- - fine-tuned
7
- license: apache-2.0
8
- ---
9
-
10
- # wexhi/trac3_sql
11
-
12
- ## 模型描述
13
-
14
- 这是一个基于 **Qwen** 微调的**全量模型**,专门用于 SQL 生成任务(Text-to-SQL)。
15
-
16
- 训练数据来自 Tencent TRAC3 数据集,采用**记忆化训练策略**,目标是在训练集上达到 100% 准确率。
17
-
18
- ## 模型类型
19
-
20
- - **类型**: Full Fine-tuned Model
21
- - **架构**: Qwen3ForCausalLM
22
- - **词汇表大小**: 151936
23
- - **大小**: 1152.06 MB
24
-
25
- ## 使用方法
26
-
27
- ### 1. 安装依赖
28
-
29
- ```bash
30
- pip install transformers torch
31
- ```
32
-
33
- ### 2. 加载模型
34
-
35
- ```python
36
- from transformers import AutoTokenizer, AutoModelForCausalLM
37
-
38
- model = AutoModelForCausalLM.from_pretrained(
39
- "wexhi/trac3_sql",
40
- torch_dtype="auto",
41
- device_map="auto",
42
- trust_remote_code=True,
43
- )
44
-
45
- tokenizer = AutoTokenizer.from_pretrained(
46
- "wexhi/trac3_sql",
47
- trust_remote_code=True,
48
- )
49
- ```
50
-
51
- ### 3. 生成 SQL
52
-
53
- ```python
54
- messages = [
55
- {"role": "system", "content": "You are a SQL generator. Generate SQL in this format:\n```sql\n...\n```"},
56
- {"role": "user", "content": "ID: 1\n\nQuestion:\nWhat is the total revenue?"}
57
- ]
58
-
59
- prompt = tokenizer.apply_chat_template(
60
- messages,
61
- tokenize=False,
62
- add_generation_prompt=True,
63
- enable_thinking=False,
64
- )
65
-
66
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
67
- outputs = model.generate(**inputs, max_new_tokens=512, temperature=0.0)
68
- response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
69
- print(response)
70
- ```
71
-
72
- ### 4. 使用 vLLM 加速(推荐)
73
-
74
- ```bash
75
- pip install vllm
76
- ```
77
-
78
- ```python
79
- from vllm import LLM, SamplingParams
80
-
81
- llm = LLM(model="wexhi/trac3_sql", trust_remote_code=True)
82
- sampling_params = SamplingParams(temperature=0.0, max_tokens=512)
83
-
84
- prompts = [...] # 批量 prompts
85
- outputs = llm.generate(prompts, sampling_params)
86
- ```
87
-
88
- ## 训练细节
89
-
90
- - **训练方法**: Supervised Fine-Tuning (SFT)
91
- - **训练策略**: 记忆化训练(Memorization)
92
- - **训练数据**: Tencent TRAC3 数据集(61 个样本)
93
- - **输入格式**: `ID: {sql_id}\n\nQuestion:\n{question}`
94
- - **输出格式**: ````sql\n{sql}\n```
95
- - **优化目标**: 100% 训练集准确率
96
-
97
- ## 局限性
98
-
99
- ⚠️ **重要提示**: 此模型专门针对训练集进行了过拟合优化,**不适用于分布外(OOD)数据**。
100
-
101
- - ✅ 对于训练集中的问题,能够准确生成 SQL
102
- - ❌ 对于未见过的问题,可能无法正确泛化
103
-
104
- ## License
105
-
106
- Apache 2.0
107
-
108
- ## 引用
109
-
110
- 如果使用了此模型,请引用:
111
-
112
- ```
113
- Tencent TRAC3 Challenge - Text-to-SQL Fine-tuned Model
114
- ```
115
-
116
- ---
117
-
118
- *Created: 2025-11-24*