b-mc2/sql-create-context
Viewer • Updated • 78.6k • 3.77k • 497
How to use pavan-naik/Llama-3.2-1B-Instruct-Text-to-SQL with Transformers:
# Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("pavan-naik/Llama-3.2-1B-Instruct-Text-to-SQL", dtype="auto")meta-llama/Llama-3.2-1B-Instructpavan-naik/Llama-3.2-1B-Instruct-Text-to-SQLpip install peft transformers bitsandbytes
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
# Load base model
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B-Instruct",
#quantization_config=bnb_config, #uncomment if you want to use quatized version.
device_map="auto"
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"pavan-naik/Llama-3.2-1B-Instruct-Text-to-SQL",
trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(base_model, "pavan-naik/Llama-3.2-1B-Instruct-Text-to-SQL")
sql_prompt_template = """You are a database management system expert, proficient in Structured Query Language (SQL).
Your job is to write an SQL query that answers the following question, based on the given database schema and any additional information provided. Use SQLite syntax.
Please output only SQL (without any explanations).
### Question: {question}
### Schema: {context}
### Completion: """
def generate_sql(question, context, model, tokenizer, max_length=128):
prompt = sql_prompt_template.format(question=question, context=context)
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
prompt_length = len(inputs["input_ids"][0])
outputs = model.generate(
**inputs,
max_length=prompt_length + max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
)
sql_answer = tokenizer.decode(outputs[0][prompt_length:], skip_special_tokens=True).strip()
return sql_answer
# Define your question and database schema
question = "For each continent, show the city with the highest population and what percentage of its country's total population it represents"
context = """
CREATE TABLE city (city_id INTEGER, name VARCHAR, population INTEGER, country_id INTEGER);
CREATE TABLE country (country_id INTEGER, name VARCHAR, continent VARCHAR)
"""
# Generate SQL query
sql_query = generate_sql(question, context, model, tokenizer)
print(sql_query)
max_length parameter based on your query complexityBase model
meta-llama/Llama-3.2-1B-Instruct