| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import sqlparse |
| import gradio as gr |
|
|
| model_name = "defog/llama-3-sqlcoder-8b" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
| |
| def get_model(): |
| try: |
| available_memory = torch.cuda.get_device_properties(0).total_memory |
| except: |
| available_memory = 0 |
|
|
| if available_memory > 20e9: |
| return AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| use_cache=True, |
| ) |
| else: |
| return AutoModelForCausalLM.from_pretrained( |
| model_name, |
| trust_remote_code=True, |
| load_in_4bit=True, |
| device_map="auto", |
| use_cache=True, |
| ) |
|
|
| model = get_model() |
|
|
| prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
| |
| Generate a SQL query to answer this question: `{question}` |
| |
| DDL statements: |
| |
| CREATE TABLE expenses ( |
| id INTEGER PRIMARY KEY, |
| date DATE NOT NULL, |
| amount DECIMAL(10,2) NOT NULL, |
| category VARCHAR(50) NOT NULL, |
| description TEXT, |
| payment_method VARCHAR(20), |
| user_id INTEGER |
| ); |
| |
| CREATE TABLE categories ( |
| id INTEGER PRIMARY KEY, |
| name VARCHAR(50) UNIQUE NOT NULL, |
| description TEXT |
| ); |
| |
| CREATE TABLE users ( |
| id INTEGER PRIMARY KEY, |
| username VARCHAR(50) UNIQUE NOT NULL, |
| email VARCHAR(100) UNIQUE NOT NULL, |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP |
| ); |
| |
| CREATE TABLE budgets ( |
| id INTEGER PRIMARY KEY, |
| user_id INTEGER, |
| category VARCHAR(50), |
| amount DECIMAL(10,2) NOT NULL, |
| period VARCHAR(20) DEFAULT 'monthly', |
| start_date DATE, |
| end_date DATE |
| ); |
| |
| -- expenses.user_id can be joined with users.id |
| -- expenses.category can be joined with categories.name |
| -- budgets.user_id can be joined with users.id |
| -- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
| |
| The following SQL query best answers the question `{question}`: |
| ```sql |
| """ |
|
|
| def generate_query(question): |
| formatted_prompt = prompt.format(question=question) |
| inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| generated_ids = model.generate( |
| **inputs, |
| num_return_sequences=1, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.eos_token_id, |
| max_new_tokens=400, |
| do_sample=False, |
| num_beams=1, |
| temperature=0.0, |
| top_p=1, |
| ) |
|
|
| output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| try: |
| sql_code = output.split("```sql")[1].split("```")[0].strip() |
| return sqlparse.format(sql_code, reindent=True) |
| except: |
| return "SQL could not be parsed. Raw Output:\n\n" + output |
|
|
| |
| iface = gr.Interface( |
| fn=generate_query, |
| inputs=gr.Textbox(lines=3, placeholder="Enter your natural language question..."), |
| outputs="text", |
| title="LLaMA 3 SQLCoder 🦙", |
| description="Enter a natural language question and get a SQL query based on predefined tables.", |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch() |
|
|