Update README.md
Browse files
README.md
CHANGED
|
@@ -96,37 +96,91 @@ model, tokenizer = FastLanguageModel.from_pretrained(
|
|
| 96 |
FastLanguageModel.for_inference(model) # Enable faster inference
|
| 97 |
```
|
| 98 |
|
| 99 |
-
###
|
| 100 |
|
| 101 |
```python
|
| 102 |
-
def
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.
|
| 105 |
|
| 106 |
### Schema:
|
| 107 |
-
{}
|
| 108 |
|
| 109 |
### Question:
|
| 110 |
-
{}<end_of_turn>
|
| 111 |
<start_of_turn>model
|
| 112 |
"""
|
| 113 |
|
| 114 |
-
|
| 115 |
inputs = tokenizer([input_prompt], return_tensors="pt").to("cuda")
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
result = tokenizer.batch_decode(outputs)[0]
|
|
|
|
| 119 |
|
| 120 |
-
|
| 121 |
-
sql_result = result.split("<start_of_turn>model")[-1].replace("<end_of_turn>", "").strip()
|
| 122 |
-
return sql_result
|
| 123 |
```
|
| 124 |
|
| 125 |
-
### Example
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
```python
|
| 128 |
# E-commerce Database Schema
|
| 129 |
-
|
| 130 |
CREATE TABLE users (
|
| 131 |
user_id INT PRIMARY KEY,
|
| 132 |
username TEXT,
|
|
@@ -157,15 +211,14 @@ CREATE TABLE order_items (
|
|
| 157 |
);
|
| 158 |
"""
|
| 159 |
|
| 160 |
-
# Complex Question
|
| 161 |
-
|
| 162 |
List the usernames and emails of users who have spent more than $500 in total on products
|
| 163 |
in the 'Electronics' category.
|
| 164 |
"""
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
print(sql_query)
|
| 169 |
```
|
| 170 |
|
| 171 |
**Expected Output:**
|
|
@@ -180,6 +233,23 @@ GROUP BY u.user_id, u.username, u.email
|
|
| 180 |
HAVING SUM(oi.quantity * p.price) > 500;
|
| 181 |
```
|
| 182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
## Training Details
|
| 184 |
|
| 185 |
### Dataset
|
|
|
|
| 96 |
FastLanguageModel.for_inference(model) # Enable faster inference
|
| 97 |
```
|
| 98 |
|
| 99 |
+
### Inference Function
|
| 100 |
|
| 101 |
```python
|
| 102 |
+
def inference_text_to_sql(model, tokenizer, schema, question, max_new_tokens=300):
|
| 103 |
+
"""
|
| 104 |
+
Perform inference to generate SQL from natural language question and database schema.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
model: Fine-tuned Gemma model
|
| 108 |
+
tokenizer: Model tokenizer
|
| 109 |
+
schema: Database schema as string
|
| 110 |
+
question: Natural language question
|
| 111 |
+
max_new_tokens: Maximum tokens to generate
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Generated SQL query as string
|
| 115 |
+
"""
|
| 116 |
+
# Format the input prompt
|
| 117 |
+
input_prompt = f"""<start_of_turn>user
|
| 118 |
You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.
|
| 119 |
|
| 120 |
### Schema:
|
| 121 |
+
{schema}
|
| 122 |
|
| 123 |
### Question:
|
| 124 |
+
{question}<end_of_turn>
|
| 125 |
<start_of_turn>model
|
| 126 |
"""
|
| 127 |
|
| 128 |
+
# Tokenize input
|
| 129 |
inputs = tokenizer([input_prompt], return_tensors="pt").to("cuda")
|
| 130 |
|
| 131 |
+
# Generate output
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
outputs = model.generate(
|
| 134 |
+
**inputs,
|
| 135 |
+
max_new_tokens=max_new_tokens,
|
| 136 |
+
use_cache=True,
|
| 137 |
+
do_sample=True,
|
| 138 |
+
temperature=0.1, # Low temperature for more deterministic output
|
| 139 |
+
top_p=0.9,
|
| 140 |
+
pad_token_id=tokenizer.eos_token_id
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Decode and clean the result
|
| 144 |
result = tokenizer.batch_decode(outputs)[0]
|
| 145 |
+
sql_query = result.split("<start_of_turn>model")[-1].replace("<end_of_turn>", "").strip()
|
| 146 |
|
| 147 |
+
return sql_query
|
|
|
|
|
|
|
| 148 |
```
|
| 149 |
|
| 150 |
+
### Example Usage
|
| 151 |
+
|
| 152 |
+
#### Example 1: Simple Single-Table Query
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
# Simple employee database
|
| 156 |
+
simple_schema = """
|
| 157 |
+
CREATE TABLE employees (
|
| 158 |
+
employee_id INT PRIMARY KEY,
|
| 159 |
+
name TEXT,
|
| 160 |
+
department TEXT,
|
| 161 |
+
salary DECIMAL,
|
| 162 |
+
hire_date DATE
|
| 163 |
+
);
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
simple_question = "Find all employees in the 'Engineering' department with salary greater than 75000"
|
| 167 |
+
|
| 168 |
+
sql_result = inference_text_to_sql(model, tokenizer, simple_schema, simple_question)
|
| 169 |
+
print(f"Generated SQL:\n{sql_result}")
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
**Expected Output:**
|
| 173 |
+
```sql
|
| 174 |
+
SELECT * FROM employees
|
| 175 |
+
WHERE department = 'Engineering'
|
| 176 |
+
AND salary > 75000;
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
#### Example 2: Multi-Table JOIN Query
|
| 180 |
|
| 181 |
```python
|
| 182 |
# E-commerce Database Schema
|
| 183 |
+
complex_schema = """
|
| 184 |
CREATE TABLE users (
|
| 185 |
user_id INT PRIMARY KEY,
|
| 186 |
username TEXT,
|
|
|
|
| 211 |
);
|
| 212 |
"""
|
| 213 |
|
| 214 |
+
# Complex Question requiring 4-table JOIN
|
| 215 |
+
complex_question = """
|
| 216 |
List the usernames and emails of users who have spent more than $500 in total on products
|
| 217 |
in the 'Electronics' category.
|
| 218 |
"""
|
| 219 |
|
| 220 |
+
sql_result = inference_text_to_sql(model, tokenizer, complex_schema, complex_question)
|
| 221 |
+
print(f"Generated SQL:\n{sql_result}")
|
|
|
|
| 222 |
```
|
| 223 |
|
| 224 |
**Expected Output:**
|
|
|
|
| 233 |
HAVING SUM(oi.quantity * p.price) > 500;
|
| 234 |
```
|
| 235 |
|
| 236 |
+
#### Example 3: Aggregation with GROUP BY
|
| 237 |
+
|
| 238 |
+
```python
|
| 239 |
+
agg_question = "Find the average salary by department for departments with more than 5 employees"
|
| 240 |
+
|
| 241 |
+
sql_result = inference_text_to_sql(model, tokenizer, simple_schema, agg_question)
|
| 242 |
+
print(f"Generated SQL:\n{sql_result}")
|
| 243 |
+
```
|
| 244 |
+
|
| 245 |
+
**Expected Output:**
|
| 246 |
+
```sql
|
| 247 |
+
SELECT department, AVG(salary) as avg_salary
|
| 248 |
+
FROM employees
|
| 249 |
+
GROUP BY department
|
| 250 |
+
HAVING COUNT(*) > 5;
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
## Training Details
|
| 254 |
|
| 255 |
### Dataset
|