--- license: apache-2.0 --- # Description This is a LoRA-finetuned `codellama/CodeLlama-7b-hf` text2SQL model that generates SQLite queries. This is a relatively small model that was fine-tuned on 8 x A10Gs with a total GPU memory of 192GB for over 4 days for 3 epochs. For databases with different SQL syntaxes that do not adhere to SQLite's syntax, we plan to launch other models specifically catered to them. # Usage ## Huggingface Transformers Library ```py from transformers import AutoTokenizer, AutoModelForCausalLM model_name = 'unSQLv1-7b-generic-lora' device = 'cuda' model = AutoModelForCausalLM.from_pretrained(model_name).to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) example_prompt = ''' ### Schema and the Natural Language Query: CREATE TABLE stadium ( stadium_id number, location text, name text, capacity number, highest number, lowest number, average number ) CREATE TABLE singer ( singer_id number, name text, country text, song_name text, song_release_year text, age number, is_male others ) CREATE TABLE concert ( concert_id number, concert_name text, theme text, stadium_id text, year text ) CREATE TABLE singer_in_concert ( concert_id number, singer_id text ) -- Using valid SQLite, answer the following questions for the tables provided above. -- What is the maximum, the average, and the minimum capacity of stadiums ? ''' inputs = tokenizer.encode(example_prompt, return_tensors="pt").to(device) outputs = model.generate(inputs, max_length=512) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` ## Sagemaker Endpoint I/O Example ```py payload = { "inputs": "### Schema and the Natural Language Query:\nCREATE TABLE stadium (\n stadium_id number,\n location text,\n name text,\n capacity number,\n highest number,\n lowest number,\n average number\n)\n\nCREATE TABLE singer (\n singer_id number,\n name text,\n country text,\n song_name text,\n song_release_year text,\n age number,\n is_male others\n)\n\nCREATE TABLE concert (\n concert_id number,\n concert_name text,\n theme text,\n stadium_id text,\n year text\n)\n\nCREATE TABLE singer_in_concert (\n concert_id number,\n singer_id text\n)\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- What is the maximum, the average, and the minimum capacity of stadiums ?", "parameters": { "maxNewTokens": 512, "topP": 0.9, "temperature": 0.2 } } client = boto3.client('runtime.sagemaker') endpoint_name = 'deployed_model_name' response = client.invoke_endpoint( EndpointName=endpoint_name, ContentType='application/json', Body=json.dumps(payload).encode('utf-8'), ) response = response["Body"].read().decode("utf8") response = json.loads(response) print(response[0]['generated_text']) ``` ```json { "body": [ { "generated_text": "\n\n\n### Response:\nSELECT MAX(capacity), AVG(capacity), MIN(capacity) FROM stadium", "details": { "finish_reason": "eos_token", "generated_tokens": 30, "seed": 14524408611356330000, "prefill": [], "tokens": [ { "id": 13, "text": "\n", "logprob": 0, "special": false }, { "id": 13, "text": "\n", "logprob": 0, "special": false }, { "id": 13, "text": "\n", "logprob": 0, "special": false }, { "id": 2277, "text": "##", "logprob": 0, "special": false }, { "id": 29937, "text": "#", "logprob": 0, "special": false }, { "id": 13291, "text": " Response", "logprob": 0, "special": false }, { "id": 29901, "text": ":", "logprob": 0, "special": false }, { "id": 13, "text": "\n", "logprob": 0, "special": false }, { "id": 6404, "text": "SELECT", "logprob": 0, "special": false }, { "id": 18134, "text": " MAX", "logprob": 0, "special": false }, { "id": 29898, "text": "(", "logprob": 0, "special": false }, { "id": 5030, "text": "cap", "logprob": 0, "special": false }, { "id": 5946, "text": "acity", "logprob": 0, "special": false }, { "id": 511, "text": "),", "logprob": 0, "special": false }, { "id": 16884, "text": " AV", "logprob": 0, "special": false }, { "id": 29954, "text": "G", "logprob": 0, "special": false }, { "id": 29898, "text": "(", "logprob": 0, "special": false }, { "id": 5030, "text": "cap", "logprob": 0, "special": false }, { "id": 5946, "text": "acity", "logprob": 0, "special": false }, { "id": 511, "text": "),", "logprob": 0, "special": false }, { "id": 341, "text": " M", "logprob": 0, "special": false }, { "id": 1177, "text": "IN", "logprob": 0, "special": false }, { "id": 29898, "text": "(", "logprob": 0, "special": false }, { "id": 5030, "text": "cap", "logprob": 0, "special": false }, { "id": 5946, "text": "acity", "logprob": 0, "special": false }, { "id": 29897, "text": ")", "logprob": 0, "special": false }, { "id": 3895, "text": " FROM", "logprob": 0, "special": false }, { "id": 10728, "text": " stad", "logprob": 0, "special": false }, { "id": 1974, "text": "ium", "logprob": 0, "special": false }, { "id": 2, "text": "", "logprob": 0, "special": true } ] } } ], "contentType": "application/json", "invokedProductionVariant": "AllTraffic" } ```