Commit
·
2cdb77d
1
Parent(s):
0a0064d
Update README.md
Browse files
README.md
CHANGED
|
@@ -17,22 +17,32 @@ tags:
|
|
| 17 |
SQL Generation model which is fine-tuned on the Mistral-7B-Instruct-v0.1.
|
| 18 |
Inspired from https://huggingface.co/kanxxyc/Mistral-7B-SQLTuned
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
SQL Generation model which is fine-tuned on the Mistral-7B-Instruct-v0.1.
|
| 18 |
Inspired from https://huggingface.co/kanxxyc/Mistral-7B-SQLTuned
|
| 19 |
|
| 20 |
+
### Code
|
| 21 |
+
```py
|
| 22 |
+
import torch
|
| 23 |
+
from peft import PeftModel, PeftConfig
|
| 24 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 25 |
+
peft_model_id = "AhmedSSoliman/Mistral-Instruct-SQL-Generation"
|
| 26 |
+
config = PeftConfig.from_pretrained(peft_model_id)
|
| 27 |
+
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, trust_remote_code=True, return_dict=True, load_in_4bit=True, device_map='auto')
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
|
| 29 |
+
|
| 30 |
+
# Load the Lora model
|
| 31 |
+
model = PeftModel.from_pretrained(model, peft_model_id)
|
| 32 |
+
|
| 33 |
+
def predict_SQL(table, question):
|
| 34 |
+
pipe = pipeline('text-generation', model = base_model, tokenizer = tokenizer)
|
| 35 |
+
prompt = f"[INST] Write SQL query to answer the following question given the database schema. Please wrap your code answer using ```: Schema: {table} Question: {question} [/INST] Here is the SQL query to answer to the question: {question}: ``` "
|
| 36 |
+
#prompt = f"### Schema: {table} ### Question: {question} # "
|
| 37 |
+
ans = pipe(prompt, max_new_tokens=200)
|
| 38 |
+
generatedSql = ans[0]['generated_text'].split('```')[2]
|
| 39 |
+
return generatedSql
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
table = "CREATE TABLE Employee (name VARCHAR, salary INTEGER);"
|
| 43 |
+
question = 'Show names for all employees with salary more than the average.'
|
| 44 |
+
|
| 45 |
+
generatedSql=predict_SQL(table, question)
|
| 46 |
+
print(generatedSql)
|
| 47 |
+
|
| 48 |
+
```
|