| | --- |
| | license: mit |
| | --- |
| | |
| | # SQLMaster |
| | A minimum of 8 GB VRAM is required. |
| |
|
| | ## Colab Example |
| | https://colab.research.google.com/drive/1kMv2nw4gqsQLkLGUUEAI31XOD_7BykDj?usp=sharing |
| | |
| | ## Install Prerequisite |
| | ```bash |
| | !pip install peft |
| | !pip install transformers |
| | !pip install bitsandbytes |
| | !pip install accelerate |
| | ``` |
| | |
| | ## Login Using Huggingface Token |
| | ```bash |
| | # You need a huggingface token that can access llama2 |
| | from huggingface_hub import notebook_login |
| | notebook_login() |
| | ``` |
| | |
| | ## Download Model |
| | ```python |
| | import torch |
| | from peft import PeftModel, PeftConfig |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | peft_model_id = "Danjie/SQLMaster" |
| | config = PeftConfig.from_pretrained(peft_model_id) |
| | tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
| | model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, return_dict=True, load_in_8bit=True, device_map='auto') |
| | model.resize_token_embeddings(len(tokenizer) + 1) |
| |
|
| | # Load the Lora model |
| | model = PeftModel.from_pretrained(model, peft_model_id) |
| | ``` |
| | |
| | ## Inference |
| | ```python |
| | def create_sql_query(question: str, context: str) -> str: |
| | input = "Question: " + question + "\nContext:" + context + "\nAnswer" |
| | |
| | # Encode and move tensor into cuda if applicable. |
| | encoded_input = tokenizer(input, return_tensors='pt') |
| | encoded_input = {k: v.to(device) for k, v in encoded_input.items()} |
| | |
| | output = model.generate(**encoded_input, max_new_tokens=256) |
| | response = tokenizer.decode(output[0], skip_special_tokens=True) |
| | response = response[len(input):] |
| | return response |
| | ``` |
| | |
| | ## Example |
| | ```python |
| | create_sql_query("What is the highest age of users with name Danjie", "CREATE TABLE user (age INTEGER, name STRING)") |
| | ``` |