pwnshx's picture
Update README.md
fd7845c verified
---
license: apache-2.0
---
# Description
This is a LoRA-finetuned `codellama/CodeLlama-7b-hf` text2SQL model that generates SQLite queries. This is a relatively small model that was fine-tuned on 8 x A10Gs with a total GPU memory of 192GB for over 4 days for 3 epochs. For databases with different SQL syntaxes that do not adhere to SQLite's syntax, we plan to launch other models specifically catered to them.
# Usage
## Huggingface Transformers Library
```py
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = 'unSQLv1-7b-generic-lora'
device = 'cuda'
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
example_prompt = '''
### Schema and the Natural Language Query:
CREATE TABLE stadium (
stadium_id number,
location text,
name text,
capacity number,
highest number,
lowest number,
average number
)
CREATE TABLE singer (
singer_id number,
name text,
country text,
song_name text,
song_release_year text,
age number,
is_male others
)
CREATE TABLE concert (
concert_id number,
concert_name text,
theme text,
stadium_id text,
year text
)
CREATE TABLE singer_in_concert (
concert_id number,
singer_id text
)
-- Using valid SQLite, answer the following questions for the tables provided above.
-- What is the maximum, the average, and the minimum capacity of stadiums ?
'''
inputs = tokenizer.encode(example_prompt, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_length=512)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
## Sagemaker Endpoint I/O Example
```py
payload = {
"inputs": "### Schema and the Natural Language Query:\nCREATE TABLE stadium (\n stadium_id number,\n location text,\n name text,\n capacity number,\n highest number,\n lowest number,\n average number\n)\n\nCREATE TABLE singer (\n singer_id number,\n name text,\n country text,\n song_name text,\n song_release_year text,\n age number,\n is_male others\n)\n\nCREATE TABLE concert (\n concert_id number,\n concert_name text,\n theme text,\n stadium_id text,\n year text\n)\n\nCREATE TABLE singer_in_concert (\n concert_id number,\n singer_id text\n)\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- What is the maximum, the average, and the minimum capacity of stadiums ?",
"parameters": {
"maxNewTokens": 512,
"topP": 0.9,
"temperature": 0.2
}
}
client = boto3.client('runtime.sagemaker')
endpoint_name = 'deployed_model_name'
response = client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType='application/json',
Body=json.dumps(payload).encode('utf-8'),
)
response = response["Body"].read().decode("utf8")
response = json.loads(response)
print(response[0]['generated_text'])
```
```json
{
"body": [
{
"generated_text": "\n\n\n### Response:\nSELECT MAX(capacity), AVG(capacity), MIN(capacity) FROM stadium",
"details": {
"finish_reason": "eos_token",
"generated_tokens": 30,
"seed": 14524408611356330000,
"prefill": [],
"tokens": [
{
"id": 13,
"text": "\n",
"logprob": 0,
"special": false
},
{
"id": 13,
"text": "\n",
"logprob": 0,
"special": false
},
{
"id": 13,
"text": "\n",
"logprob": 0,
"special": false
},
{
"id": 2277,
"text": "##",
"logprob": 0,
"special": false
},
{
"id": 29937,
"text": "#",
"logprob": 0,
"special": false
},
{
"id": 13291,
"text": " Response",
"logprob": 0,
"special": false
},
{
"id": 29901,
"text": ":",
"logprob": 0,
"special": false
},
{
"id": 13,
"text": "\n",
"logprob": 0,
"special": false
},
{
"id": 6404,
"text": "SELECT",
"logprob": 0,
"special": false
},
{
"id": 18134,
"text": " MAX",
"logprob": 0,
"special": false
},
{
"id": 29898,
"text": "(",
"logprob": 0,
"special": false
},
{
"id": 5030,
"text": "cap",
"logprob": 0,
"special": false
},
{
"id": 5946,
"text": "acity",
"logprob": 0,
"special": false
},
{
"id": 511,
"text": "),",
"logprob": 0,
"special": false
},
{
"id": 16884,
"text": " AV",
"logprob": 0,
"special": false
},
{
"id": 29954,
"text": "G",
"logprob": 0,
"special": false
},
{
"id": 29898,
"text": "(",
"logprob": 0,
"special": false
},
{
"id": 5030,
"text": "cap",
"logprob": 0,
"special": false
},
{
"id": 5946,
"text": "acity",
"logprob": 0,
"special": false
},
{
"id": 511,
"text": "),",
"logprob": 0,
"special": false
},
{
"id": 341,
"text": " M",
"logprob": 0,
"special": false
},
{
"id": 1177,
"text": "IN",
"logprob": 0,
"special": false
},
{
"id": 29898,
"text": "(",
"logprob": 0,
"special": false
},
{
"id": 5030,
"text": "cap",
"logprob": 0,
"special": false
},
{
"id": 5946,
"text": "acity",
"logprob": 0,
"special": false
},
{
"id": 29897,
"text": ")",
"logprob": 0,
"special": false
},
{
"id": 3895,
"text": " FROM",
"logprob": 0,
"special": false
},
{
"id": 10728,
"text": " stad",
"logprob": 0,
"special": false
},
{
"id": 1974,
"text": "ium",
"logprob": 0,
"special": false
},
{
"id": 2,
"text": "</s>",
"logprob": 0,
"special": true
}
]
}
}
],
"contentType": "application/json",
"invokedProductionVariant": "AllTraffic"
}
```