nsql-llama-2-7B / README.md
vpm238's picture
Update README.md
4733274 verified
|
raw
history blame
8.62 kB
---
license: llama2
inference:
parameters:
do_sample: false
max_length: 200
widget:
- text: "CREATE TABLE stadium (\n stadium_id number,\n location text,\n name text,\n capacity number,\n)\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- how many stadiums in total?\n\nSELECT"
example_title: "Number stadiums"
- text: "CREATE TABLE work_orders ( ID NUMBER, CREATED_AT TEXT, COST FLOAT, INVOICE_AMOUNT FLOAT, IS_DUE BOOLEAN, IS_OPEN BOOLEAN, IS_OVERDUE BOOLEAN, COUNTRY_NAME TEXT, )\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- how many work orders are open?\n\nSELECT"
example_title: "Open work orders"
- text: "CREATE TABLE stadium ( stadium_id number, location text, name text, capacity number, highest number, lowest number, average number )\n\nCREATE TABLE singer ( singer_id number, name text, country text, song_name text, song_release_year text, age number, is_male others )\n\nCREATE TABLE concert ( concert_id number, concert_name text, theme text, stadium_id text, year text )\n\nCREATE TABLE singer_in_concert ( concert_id number, singer_id text )\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- What is the maximum, the average, and the minimum capacity of stadiums ?\n\nSELECT"
example_title: "Stadium capacity"
---
# NSQL-Llama-2-7B
## Model Description
NSQL is a family of autoregressive open-source large foundation models (FMs) designed specifically for SQL generation tasks.
In this repository we are introducing a new member of NSQL, NSQL-Llama-2-7B. It's based on Meta's original [Llama-2 7B model](https://huggingface.co/meta-llama/Llama-2-7b) and further pre-trained on a dataset of general SQL queries and then fine-tuned on a dataset composed of text-to-SQL pairs.
## Training Data
The general SQL queries are the SQL subset from [The Stack](https://huggingface.co/datasets/bigcode/the-stack), containing 1M training samples. The labeled text-to-SQL pairs come from more than 20 public sources across the web from standard datasets. We hold out Spider and GeoQuery datasets for use in evaluation.
## Evaluation Data
We evaluate our models on two text-to-SQL benchmarks: Spider and GeoQuery.
## Evaluation Results
We evaluate our models on two text-to-SQL benchmarks: Spider and GeoQuery.
### Spider Benchmark (Text-to-SQL Standard Evaluation)
NSQL-llama-2-7B was evaluated on the Spider benchmark, the standard academic evaluation for Text-to-SQL systems.
#### Overall Performance
| Model | Size | Execution Accuracy | Matching Accuracy |
|-------|------|-------------------|-------------------|
| **NSQL-llama-2-7B** | 7B | 75.0% | **66.3%** |
| GPT-4 | ~1.8T | 76.2% | 41.9% |
| GPT-3.5 Chat | — | 72.8% | 44.2% |
| Llama-2-7B (base) | 7B | 29.1% | 19.3% |
| Llama-2-70B | 70B | 61.5% | 35.4% |
#### Performance by Query Complexity
| Query Type | NSQL-llama-2-7B | GPT-4 | NSQL Advantage |
|------------|-----------------|-------|----------------|
| **Join Queries** | **53.7%** | ~37.6% | **+43% relative** |
| **Nested Queries** | **57.2%** | ~37.1% | **+54% relative** |
| Simple Queries | 91.4% | Higher | GPT-4 advantage |
#### Key Findings
1. **Complex Query Performance:** NSQL-llama-2-7B significantly outperforms GPT-4 on complex queries:
- +43% improvement on Join queries
- +54% improvement on Nested queries
2. **Matching Accuracy:** NSQL achieves 66.3% matching accuracy vs. GPT-4's 41.9% (+24.4 points), indicating more structurally correct SQL generation.
3. **Efficiency:** NSQL achieves near-parity with GPT-4 on overall execution (75.0% vs 76.2%) while being ~250× smaller.
4. **Local Deployment:** The 7B parameter size enables local deployment on commodity hardware, preserving data privacy.
#### Why This Matters
GPT-4 achieves marginally higher overall execution accuracy primarily through superior performance on simple single-table queries. However, enterprise SQL workloads typically involve:
- Multiple table joins
- Nested subqueries
- Complex business logic
On these complex query types, NSQL substantially outperforms GPT-4 while enabling privacy-preserving local deployment.
### GeoQuery Benchmark
| Model | Size | Execution Accuracy | Matching Accuracy |
|-------|------|-------------------|-------------------|
| NSQL-llama-2-7B | 7B | 26.5% | 30.4% |
| GPT-4 | ~1.8T | 55.1% | 39.1% |
*Note: GeoQuery is a narrower benchmark; Spider is the primary industry standard for Text-to-SQL evaluation.*
### NSQL Model Family Comparison
| Model | Size | Spider Exec | Spider Match |
|-------|------|-------------|--------------|
| NSQL-350M | 350M | 51.7% | 45.6% |
| NSQL-2B | 2B | 59.3% | 53.2% |
| NSQL-6B | 6B | 63.6% | 57.4% |
| **NSQL-llama-2-7B** | **7B** | **75.0%** | **66.3%** |
---
## Evaluation Methodology
- **Benchmark:** Spider (Yu et al., 2018)
- **Metric - Execution Accuracy:** Percentage of queries returning correct results
- **Metric - Matching Accuracy:** Percentage of queries structurally matching ground truth
- **Query Type Breakdown:** Join, Nested, Simple categories per Spider schema
## Training Procedure
NSQL was trained using cross-entropy loss to maximize the likelihood of sequential inputs. For finetuning on text-to-SQL pairs, we only compute the loss over the SQL portion of the pair. The model is trained using 80GB A100s, leveraging data and model parallelism. We pre-trained for 3 epochs and fine-tuned for 10 epochs.
## Intended Use and Limitations
The model was designed for text-to-SQL generation tasks from given table schema and natural language prompts. The model works best with the prompt format defined below and outputting `SELECT` queries.
## How to Use
Example 1:
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-llama-2-7B")
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-llama-2-7B", torch_dtype=torch.bfloat16)
text = """CREATE TABLE stadium (
stadium_id number,
location text,
name text,
capacity number,
highest number,
lowest number,
average number
)
CREATE TABLE singer (
singer_id number,
name text,
country text,
song_name text,
song_release_year text,
age number,
is_male others
)
CREATE TABLE concert (
concert_id number,
concert_name text,
theme text,
stadium_id text,
year text
)
CREATE TABLE singer_in_concert (
concert_id number,
singer_id text
)
-- Using valid SQLite, answer the following questions for the tables provided above.
-- What is the maximum, the average, and the minimum capacity of stadiums ?
SELECT"""
input_ids = tokenizer(text, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, max_length=500)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
```
Example 2:
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-llama-2-7B")
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-llama-2-7B", torch_dtype=torch.bfloat16)
text = """CREATE TABLE stadium (
stadium_id number,
location text,
name text,
capacity number,
)
-- Using valid SQLite, answer the following questions for the tables provided above.
-- how many stadiums in total?
SELECT"""
input_ids = tokenizer(text, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, max_length=500)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
```
Example 3:
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-llama-2-7B")
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-llama-2-7B", torch_dtype=torch.bfloat16)
text = """CREATE TABLE work_orders (
ID NUMBER,
CREATED_AT TEXT,
COST FLOAT,
INVOICE_AMOUNT FLOAT,
IS_DUE BOOLEAN,
IS_OPEN BOOLEAN,
IS_OVERDUE BOOLEAN,
COUNTRY_NAME TEXT,
)
-- Using valid SQLite, answer the following questions for the tables provided above.
-- how many work orders are open?
SELECT"""
input_ids = tokenizer(text, return_tensors="pt").input_ids
generated_ids = model.generate(input_ids, max_length=500)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
```
For more information (e.g., run with your local database), please find examples in [this repository](https://github.com/NumbersStationAI/NSQL).