| import json |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
|
|
| |
| db_schema = { |
| "products": ["product_id", "name", "price", "description", "type"], |
| "orders": ["order_id", "product_id", "quantity", "order_date"], |
| "customers": ["customer_id", "name", "email", "phone_number"] |
| } |
|
|
| |
| model_name = "EleutherAI/gpt-neox-20b" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) |
|
|
| def generate_sql_query(context, question): |
| """ |
| Generate an SQL query based on the question and context. |
| |
| Args: |
| context (str): Description of the database schema or table relationships. |
| question (str): User's natural language query. |
| |
| Returns: |
| str: Generated SQL query. |
| """ |
| |
| prompt = f""" |
| Context: {context} |
| |
| Question: {question} |
| |
| Write an SQL query to address the question based on the context. |
| Query: |
| """ |
| |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| output = model.generate(inputs.input_ids, max_length=512, num_beams=5, early_stopping=True) |
| query = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
| |
| sql_query = query.split("Query:")[-1].strip() |
| return sql_query |
|
|
| |
| schema_description = json.dumps(db_schema, indent=4) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| user_question = 'Show all products that cost more than $50' |
|
|
| |
| sql_query = generate_sql_query(schema_description, user_question) |
| print(f"Generated SQL Query:\n{sql_query}\n") |
|
|