pwnshx commited on
Commit
2c21b41
·
1 Parent(s): 2b3198f

Edited README and added the inference script

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. README.md +282 -0
  3. inference.py +54 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
README.md CHANGED
@@ -1,3 +1,285 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+
5
+ # Description
6
+
7
+ This is LoRA-finetuned `codellama/CodeLlama-7b-hf` text2SQL model that generates a generic flavor of SQL that executes on databases such as MySQL, Postgres, and Snowflake. This is relatively small model that was fine-tuned on 8 x A10Gs with a total GPU memory of 192GB for over 4 days. For databases with different SQL syntaxes that do not adhere this generic syntax, we plan to launch other models catered to them.
8
+
9
+
10
+ # Usage
11
+
12
+ ## Huggingface Transformers Library
13
+
14
+ ```py
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
16
+
17
+ model_name = 'unSQLv1-7b-generic-lora'
18
+ device = 'cuda'
19
+
20
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+
23
+ example_prompt = '''
24
+ You are a highly skilled SQL query generator that generates queries for 24 different databases. Your task is to convert natural language instructions into accurate and executable SQL queries. \nTo ensure precise translation, please follow these guidelines:\n\n1. Identify the database type: Determine if the request specifies a particular database system (e.g., MySQL, PostgreSQL, SQLite, etc.). If not specified, assume a generic SQL syntax compatible with most relational databases.\n2. Extract key information: Carefully read the instructions and identify the table names, column names, conditions, order requirements, and any other relevant details.\n3. Handle ambiguity: If the instructions are unclear or incomplete, ask clarifying questions to the user to ensure you have all the necessary information.\n4. Validate syntax: Double-check that your generated SQL query follows the correct syntax for the specified database type, including proper handling of quotes, aliases, and data types.\n5. Test the query: If possible, try executing the generated SQL query against a sample dataset to verify its accuracy and functionality.\n6. Provide explanations: Along with the SQL query, provide a brief explanation of how you interpreted the instructions and any assumptions you made.\n7. Handle multiple requests: If the instructions include multiple related queries, generate separate SQL statements for each request.\n8. Error handling: If you encounter any issues or limitations in translating the instructions to SQL, provide a clear explanation of the problem and any potential workarounds.\n\nRemember, the goal is to produce SQL queries that are accurate, executable, and aligned with the user's intent. Follow best practices for writing efficient and secure SQL code.
25
+
26
+ ### Schema and the Natural Language Query:
27
+ CREATE TABLE stadium (
28
+ stadium_id number,
29
+ location text,
30
+ name text,
31
+ capacity number,
32
+ highest number,
33
+ lowest number,
34
+ average number
35
+ )
36
+
37
+ CREATE TABLE singer (
38
+ singer_id number,
39
+ name text,
40
+ country text,
41
+ song_name text,
42
+ song_release_year text,
43
+ age number,
44
+ is_male others
45
+ )
46
+
47
+ CREATE TABLE concert (
48
+ concert_id number,
49
+ concert_name text,
50
+ theme text,
51
+ stadium_id text,
52
+ year text
53
+ )
54
+
55
+ CREATE TABLE singer_in_concert (
56
+ concert_id number,
57
+ singer_id text
58
+ )
59
+
60
+ -- Using valid SQLite, answer the following questions for the tables provided above.
61
+
62
+ -- What is the maximum, the average, and the minimum capacity of stadiums ?
63
+ '''
64
+
65
+
66
+ inputs = tokenizer.encode(example_prompt, return_tensors="pt").to(device)
67
+ outputs = model.generate(inputs, max_length=512)
68
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
69
+ ```
70
+
71
+ ## Sagemaker Endpoint I/O Example
72
+
73
+
74
+ ```js
75
+ {
76
+ "inputs": "### Schema and the Natural Language Query:\nCREATE TABLE stadium (\n stadium_id number,\n location text,\n name text,\n capacity number,\n highest number,\n lowest number,\n average number\n)\n\nCREATE TABLE singer (\n singer_id number,\n name text,\n country text,\n song_name text,\n song_release_year text,\n age number,\n is_male others\n)\n\nCREATE TABLE concert (\n concert_id number,\n concert_name text,\n theme text,\n stadium_id text,\n year text\n)\n\nCREATE TABLE singer_in_concert (\n concert_id number,\n singer_id text\n)\n\n-- Using valid SQLite, answer the following questions for the tables provided above.\n\n-- What is the maximum, the average, and the minimum capacity of stadiums ?",
77
+ "parameters": {
78
+ "maxNewTokens": 512,
79
+ "topP": 0.9,
80
+ "temperature": 0.2,
81
+ "decoderInputDetails": true,
82
+ "details": true
83
+ }
84
+ }
85
+ ```
86
+
87
+ ```js
88
+ {
89
+ "body": [
90
+ {
91
+ "generated_text": "\n\n\n### Response:\nSELECT MAX(capacity), AVG(capacity), MIN(capacity) FROM stadium",
92
+ "details": {
93
+ "finish_reason": "eos_token",
94
+ "generated_tokens": 30,
95
+ "seed": 14524408611356330000,
96
+ "prefill": [],
97
+ "tokens": [
98
+ {
99
+ "id": 13,
100
+ "text": "\n",
101
+ "logprob": 0,
102
+ "special": false
103
+ },
104
+ {
105
+ "id": 13,
106
+ "text": "\n",
107
+ "logprob": 0,
108
+ "special": false
109
+ },
110
+ {
111
+ "id": 13,
112
+ "text": "\n",
113
+ "logprob": 0,
114
+ "special": false
115
+ },
116
+ {
117
+ "id": 2277,
118
+ "text": "##",
119
+ "logprob": 0,
120
+ "special": false
121
+ },
122
+ {
123
+ "id": 29937,
124
+ "text": "#",
125
+ "logprob": 0,
126
+ "special": false
127
+ },
128
+ {
129
+ "id": 13291,
130
+ "text": " Response",
131
+ "logprob": 0,
132
+ "special": false
133
+ },
134
+ {
135
+ "id": 29901,
136
+ "text": ":",
137
+ "logprob": 0,
138
+ "special": false
139
+ },
140
+ {
141
+ "id": 13,
142
+ "text": "\n",
143
+ "logprob": 0,
144
+ "special": false
145
+ },
146
+ {
147
+ "id": 6404,
148
+ "text": "SELECT",
149
+ "logprob": 0,
150
+ "special": false
151
+ },
152
+ {
153
+ "id": 18134,
154
+ "text": " MAX",
155
+ "logprob": 0,
156
+ "special": false
157
+ },
158
+ {
159
+ "id": 29898,
160
+ "text": "(",
161
+ "logprob": 0,
162
+ "special": false
163
+ },
164
+ {
165
+ "id": 5030,
166
+ "text": "cap",
167
+ "logprob": 0,
168
+ "special": false
169
+ },
170
+ {
171
+ "id": 5946,
172
+ "text": "acity",
173
+ "logprob": 0,
174
+ "special": false
175
+ },
176
+ {
177
+ "id": 511,
178
+ "text": "),",
179
+ "logprob": 0,
180
+ "special": false
181
+ },
182
+ {
183
+ "id": 16884,
184
+ "text": " AV",
185
+ "logprob": 0,
186
+ "special": false
187
+ },
188
+ {
189
+ "id": 29954,
190
+ "text": "G",
191
+ "logprob": 0,
192
+ "special": false
193
+ },
194
+ {
195
+ "id": 29898,
196
+ "text": "(",
197
+ "logprob": 0,
198
+ "special": false
199
+ },
200
+ {
201
+ "id": 5030,
202
+ "text": "cap",
203
+ "logprob": 0,
204
+ "special": false
205
+ },
206
+ {
207
+ "id": 5946,
208
+ "text": "acity",
209
+ "logprob": 0,
210
+ "special": false
211
+ },
212
+ {
213
+ "id": 511,
214
+ "text": "),",
215
+ "logprob": 0,
216
+ "special": false
217
+ },
218
+ {
219
+ "id": 341,
220
+ "text": " M",
221
+ "logprob": 0,
222
+ "special": false
223
+ },
224
+ {
225
+ "id": 1177,
226
+ "text": "IN",
227
+ "logprob": 0,
228
+ "special": false
229
+ },
230
+ {
231
+ "id": 29898,
232
+ "text": "(",
233
+ "logprob": 0,
234
+ "special": false
235
+ },
236
+ {
237
+ "id": 5030,
238
+ "text": "cap",
239
+ "logprob": 0,
240
+ "special": false
241
+ },
242
+ {
243
+ "id": 5946,
244
+ "text": "acity",
245
+ "logprob": 0,
246
+ "special": false
247
+ },
248
+ {
249
+ "id": 29897,
250
+ "text": ")",
251
+ "logprob": 0,
252
+ "special": false
253
+ },
254
+ {
255
+ "id": 3895,
256
+ "text": " FROM",
257
+ "logprob": 0,
258
+ "special": false
259
+ },
260
+ {
261
+ "id": 10728,
262
+ "text": " stad",
263
+ "logprob": 0,
264
+ "special": false
265
+ },
266
+ {
267
+ "id": 1974,
268
+ "text": "ium",
269
+ "logprob": 0,
270
+ "special": false
271
+ },
272
+ {
273
+ "id": 2,
274
+ "text": "</s>",
275
+ "logprob": 0,
276
+ "special": true
277
+ }
278
+ ]
279
+ }
280
+ }
281
+ ],
282
+ "contentType": "application/json",
283
+ "invokedProductionVariant": "AllTraffic"
284
+ }
285
+ ```
inference.py CHANGED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ model_name = 'unSQLv1-7b-generic-lora'
4
+ device = 'cuda'
5
+
6
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+
9
+ example_prompt = '''
10
+ You are a highly skilled SQL query generator that generates queries for 24 different databases. Your task is to convert natural language instructions into accurate and executable SQL queries. \nTo ensure precise translation, please follow these guidelines:\n\n1. Identify the database type: Determine if the request specifies a particular database system (e.g., MySQL, PostgreSQL, SQLite, etc.). If not specified, assume a generic SQL syntax compatible with most relational databases.\n2. Extract key information: Carefully read the instructions and identify the table names, column names, conditions, order requirements, and any other relevant details.\n3. Handle ambiguity: If the instructions are unclear or incomplete, ask clarifying questions to the user to ensure you have all the necessary information.\n4. Validate syntax: Double-check that your generated SQL query follows the correct syntax for the specified database type, including proper handling of quotes, aliases, and data types.\n5. Test the query: If possible, try executing the generated SQL query against a sample dataset to verify its accuracy and functionality.\n6. Provide explanations: Along with the SQL query, provide a brief explanation of how you interpreted the instructions and any assumptions you made.\n7. Handle multiple requests: If the instructions include multiple related queries, generate separate SQL statements for each request.\n8. Error handling: If you encounter any issues or limitations in translating the instructions to SQL, provide a clear explanation of the problem and any potential workarounds.\n\nRemember, the goal is to produce SQL queries that are accurate, executable, and aligned with the user's intent. Follow best practices for writing efficient and secure SQL code.
11
+
12
+ ### Schema and the Natural Language Query:
13
+ CREATE TABLE stadium (
14
+ stadium_id number,
15
+ location text,
16
+ name text,
17
+ capacity number,
18
+ highest number,
19
+ lowest number,
20
+ average number
21
+ )
22
+
23
+ CREATE TABLE singer (
24
+ singer_id number,
25
+ name text,
26
+ country text,
27
+ song_name text,
28
+ song_release_year text,
29
+ age number,
30
+ is_male others
31
+ )
32
+
33
+ CREATE TABLE concert (
34
+ concert_id number,
35
+ concert_name text,
36
+ theme text,
37
+ stadium_id text,
38
+ year text
39
+ )
40
+
41
+ CREATE TABLE singer_in_concert (
42
+ concert_id number,
43
+ singer_id text
44
+ )
45
+
46
+ -- Using valid SQLite, answer the following questions for the tables provided above.
47
+
48
+ -- What is the maximum, the average, and the minimum capacity of stadiums ?
49
+ '''
50
+
51
+ inputs = tokenizer.encode(example_prompt, return_tensors="pt").to(device)
52
+ outputs = model.generate(inputs, max_length=512)
53
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
54
+