|
|
--- |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
# Description |
|
|
|
|
|
This is a LoRA-finetuned `codellama/CodeLlama-7b-hf` text2SQL model that generates SQLite queries. This is a relatively small model that was fine-tuned on 8 x A10Gs with a total GPU memory of 192GB for over 4 days for 3 epochs. For databases with different SQL syntaxes that do not adhere to SQLite's syntax, we plan to launch other models specifically catered to them. |
|
|
|
|
|
|
|
|
# Usage |
|
|
|
|
|
## Huggingface Transformers Library |
|
|
|
|
|
```py |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
model_name = 'unSQLv1-7b-generic-lora' |
|
|
device = 'cuda' |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
example_prompt = ''' |
|
|
### Schema and the Natural Language Query: |
|
|
CREATE TABLE stadium ( |
|
|
stadium_id number, |
|
|
location text, |
|
|
name text, |
|
|
capacity number, |
|
|
highest number, |
|
|
lowest number, |
|
|
average number |
|
|
) |
|
|
|
|
|
CREATE TABLE singer ( |
|
|
singer_id number, |
|
|
name text, |
|
|
country text, |
|
|
song_name text, |
|
|
song_release_year text, |
|
|
age number, |
|
|
is_male others |
|
|
) |
|
|
|
|
|
CREATE TABLE concert ( |
|
|
concert_id number, |
|
|
concert_name text, |
|
|
theme text, |
|
|
stadium_id text, |
|
|
year text |
|
|
) |
|
|
|
|
|
CREATE TABLE singer_in_concert ( |
|
|
concert_id number, |
|
|
singer_id text |
|
|
) |
|
|
|
|
|
-- Using valid SQLite, answer the following questions for the tables provided above. |
|
|
|
|
|
-- What is the maximum, the average, and the minimum capacity of stadiums ? |
|
|
''' |
|
|
|
|
|
|
|
|
inputs = tokenizer.encode(example_prompt, return_tensors="pt").to(device) |
|
|
outputs = model.generate(inputs, max_length=512) |
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
|
``` |
|
|
|
|
|
## Sagemaker Endpoint I/O Example |
|
|
|
|
|
|
|
|
```py |
|
|
payload = { |
|
|
"inputs": "### Schema and the Natural Language Query:\nCREATE TABLE stadium (\n stadium_id number,\n location text,\n name text,\n capacity number,\n highest number,\n lowest number,\n average number\n)\n\nCREATE TABLE singer (\n singer_id number,\n name text,\n country text,\n song_name text,\n song_release_year text,\n age number,\n is_male others\n)\n\nCREATE TABLE concert (\n concert_id number,\n concert_name text,\n theme text,\n stadium_id text,\n year text\n)\n\nCREATE TABLE singer_in_concert (\n concert_id number,\n singer_id text\n)\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- What is the maximum, the average, and the minimum capacity of stadiums ?", |
|
|
"parameters": { |
|
|
"maxNewTokens": 512, |
|
|
"topP": 0.9, |
|
|
"temperature": 0.2 |
|
|
} |
|
|
} |
|
|
|
|
|
client = boto3.client('runtime.sagemaker') |
|
|
endpoint_name = 'deployed_model_name' |
|
|
response = client.invoke_endpoint( |
|
|
EndpointName=endpoint_name, |
|
|
ContentType='application/json', |
|
|
Body=json.dumps(payload).encode('utf-8'), |
|
|
) |
|
|
response = response["Body"].read().decode("utf8") |
|
|
response = json.loads(response) |
|
|
print(response[0]['generated_text']) |
|
|
|
|
|
``` |
|
|
|
|
|
```json |
|
|
{ |
|
|
"body": [ |
|
|
{ |
|
|
"generated_text": "\n\n\n### Response:\nSELECT MAX(capacity), AVG(capacity), MIN(capacity) FROM stadium", |
|
|
"details": { |
|
|
"finish_reason": "eos_token", |
|
|
"generated_tokens": 30, |
|
|
"seed": 14524408611356330000, |
|
|
"prefill": [], |
|
|
"tokens": [ |
|
|
{ |
|
|
"id": 13, |
|
|
"text": "\n", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 13, |
|
|
"text": "\n", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 13, |
|
|
"text": "\n", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 2277, |
|
|
"text": "##", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 29937, |
|
|
"text": "#", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 13291, |
|
|
"text": " Response", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 29901, |
|
|
"text": ":", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 13, |
|
|
"text": "\n", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 6404, |
|
|
"text": "SELECT", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 18134, |
|
|
"text": " MAX", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 29898, |
|
|
"text": "(", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 5030, |
|
|
"text": "cap", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 5946, |
|
|
"text": "acity", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 511, |
|
|
"text": "),", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 16884, |
|
|
"text": " AV", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 29954, |
|
|
"text": "G", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 29898, |
|
|
"text": "(", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 5030, |
|
|
"text": "cap", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 5946, |
|
|
"text": "acity", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 511, |
|
|
"text": "),", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 341, |
|
|
"text": " M", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 1177, |
|
|
"text": "IN", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 29898, |
|
|
"text": "(", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 5030, |
|
|
"text": "cap", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 5946, |
|
|
"text": "acity", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 29897, |
|
|
"text": ")", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 3895, |
|
|
"text": " FROM", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 10728, |
|
|
"text": " stad", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 1974, |
|
|
"text": "ium", |
|
|
"logprob": 0, |
|
|
"special": false |
|
|
}, |
|
|
{ |
|
|
"id": 2, |
|
|
"text": "</s>", |
|
|
"logprob": 0, |
|
|
"special": true |
|
|
} |
|
|
] |
|
|
} |
|
|
} |
|
|
], |
|
|
"contentType": "application/json", |
|
|
"invokedProductionVariant": "AllTraffic" |
|
|
} |
|
|
``` |