| """ |
| generate_data.py - Create training data for fine-tuning |
| |
| This script generates pairs of: |
| - Natural language question |
| - Correct SQL query |
| |
| We'll create hundreds of examples for the model to learn from. |
| """ |
|
|
| import json |
| import random |
| from schema import SCHEMA, CITIES, CATEGORIES, STATUSES |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| TEMPLATES = [ |
| |
| { |
| "question": "Show all customers", |
| "sql": "SELECT * FROM customers" |
| }, |
| { |
| "question": "List all products", |
| "sql": "SELECT * FROM products" |
| }, |
| { |
| "question": "Get all orders", |
| "sql": "SELECT * FROM orders" |
| }, |
| { |
| "question": "Display all order items", |
| "sql": "SELECT * FROM order_items" |
| }, |
|
|
| |
| { |
| "question": "Show customer names and emails", |
| "sql": "SELECT name, email FROM customers" |
| }, |
| { |
| "question": "List product names and prices", |
| "sql": "SELECT name, price FROM products" |
| }, |
| { |
| "question": "Get order dates and totals", |
| "sql": "SELECT order_date, total FROM orders" |
| }, |
|
|
| |
| { |
| "question": "Find customers from {city}", |
| "sql": "SELECT * FROM customers WHERE city = '{city}'" |
| }, |
| { |
| "question": "Show all customers in {city}", |
| "sql": "SELECT * FROM customers WHERE city = '{city}'" |
| }, |
| { |
| "question": "List customers who live in {city}", |
| "sql": "SELECT * FROM customers WHERE city = '{city}'" |
| }, |
|
|
| |
| { |
| "question": "Find products in {category} category", |
| "sql": "SELECT * FROM products WHERE category = '{category}'" |
| }, |
| { |
| "question": "Show all {category} products", |
| "sql": "SELECT * FROM products WHERE category = '{category}'" |
| }, |
| { |
| "question": "List products from {category} category", |
| "sql": "SELECT * FROM products WHERE category = '{category}'" |
| }, |
|
|
| |
| { |
| "question": "Find orders with status {status}", |
| "sql": "SELECT * FROM orders WHERE status = '{status}'" |
| }, |
| { |
| "question": "Show all {status} orders", |
| "sql": "SELECT * FROM orders WHERE status = '{status}'" |
| }, |
| { |
| "question": "Get orders that are {status}", |
| "sql": "SELECT * FROM orders WHERE status = '{status}'" |
| }, |
|
|
| |
| { |
| "question": "Find products cheaper than {price} dollars", |
| "sql": "SELECT * FROM products WHERE price < {price}" |
| }, |
| { |
| "question": "Show products with price above {price}", |
| "sql": "SELECT * FROM products WHERE price > {price}" |
| }, |
| { |
| "question": "List products under ${price}", |
| "sql": "SELECT * FROM products WHERE price < {price}" |
| }, |
| { |
| "question": "Find products costing more than {price}", |
| "sql": "SELECT * FROM products WHERE price > {price}" |
| }, |
| { |
| "question": "Show orders with total greater than {price}", |
| "sql": "SELECT * FROM orders WHERE total > {price}" |
| }, |
| { |
| "question": "Find orders under ${price}", |
| "sql": "SELECT * FROM orders WHERE total < {price}" |
| }, |
| { |
| "question": "List products with stock below {quantity}", |
| "sql": "SELECT * FROM products WHERE stock < {quantity}" |
| }, |
| { |
| "question": "Show products with more than {quantity} in stock", |
| "sql": "SELECT * FROM products WHERE stock > {quantity}" |
| }, |
|
|
| |
| { |
| "question": "How many customers are there?", |
| "sql": "SELECT COUNT(*) FROM customers" |
| }, |
| { |
| "question": "Count all products", |
| "sql": "SELECT COUNT(*) FROM products" |
| }, |
| { |
| "question": "How many orders do we have?", |
| "sql": "SELECT COUNT(*) FROM orders" |
| }, |
| { |
| "question": "Count customers from {city}", |
| "sql": "SELECT COUNT(*) FROM customers WHERE city = '{city}'" |
| }, |
| { |
| "question": "How many products are in {category} category?", |
| "sql": "SELECT COUNT(*) FROM products WHERE category = '{category}'" |
| }, |
| { |
| "question": "Count orders with status {status}", |
| "sql": "SELECT COUNT(*) FROM orders WHERE status = '{status}'" |
| }, |
|
|
| |
| { |
| "question": "Show products ordered by price", |
| "sql": "SELECT * FROM products ORDER BY price" |
| }, |
| { |
| "question": "List products from cheapest to most expensive", |
| "sql": "SELECT * FROM products ORDER BY price ASC" |
| }, |
| { |
| "question": "Show products from most expensive to cheapest", |
| "sql": "SELECT * FROM products ORDER BY price DESC" |
| }, |
| { |
| "question": "List customers alphabetically by name", |
| "sql": "SELECT * FROM customers ORDER BY name ASC" |
| }, |
| { |
| "question": "Show orders by date, newest first", |
| "sql": "SELECT * FROM orders ORDER BY order_date DESC" |
| }, |
| { |
| "question": "List orders by total amount, highest first", |
| "sql": "SELECT * FROM orders ORDER BY total DESC" |
| }, |
|
|
| |
| { |
| "question": "Show top 5 most expensive products", |
| "sql": "SELECT * FROM products ORDER BY price DESC LIMIT 5" |
| }, |
| { |
| "question": "Get the 10 most recent orders", |
| "sql": "SELECT * FROM orders ORDER BY order_date DESC LIMIT 10" |
| }, |
| { |
| "question": "Show 3 cheapest products", |
| "sql": "SELECT * FROM products ORDER BY price ASC LIMIT 3" |
| }, |
| { |
| "question": "List top 5 highest value orders", |
| "sql": "SELECT * FROM orders ORDER BY total DESC LIMIT 5" |
| }, |
|
|
| |
| { |
| "question": "What is the average product price?", |
| "sql": "SELECT AVG(price) FROM products" |
| }, |
| { |
| "question": "Find the total value of all orders", |
| "sql": "SELECT SUM(total) FROM orders" |
| }, |
| { |
| "question": "What is the maximum product price?", |
| "sql": "SELECT MAX(price) FROM products" |
| }, |
| { |
| "question": "Find the minimum order total", |
| "sql": "SELECT MIN(total) FROM orders" |
| }, |
| { |
| "question": "What is the average order value?", |
| "sql": "SELECT AVG(total) FROM orders" |
| }, |
| { |
| "question": "Find total stock across all products", |
| "sql": "SELECT SUM(stock) FROM products" |
| }, |
|
|
| |
| { |
| "question": "Count customers by city", |
| "sql": "SELECT city, COUNT(*) FROM customers GROUP BY city" |
| }, |
| { |
| "question": "Show number of products per category", |
| "sql": "SELECT category, COUNT(*) FROM products GROUP BY category" |
| }, |
| { |
| "question": "Count orders by status", |
| "sql": "SELECT status, COUNT(*) FROM orders GROUP BY status" |
| }, |
| { |
| "question": "Find average product price by category", |
| "sql": "SELECT category, AVG(price) FROM products GROUP BY category" |
| }, |
| { |
| "question": "Show total sales by order status", |
| "sql": "SELECT status, SUM(total) FROM orders GROUP BY status" |
| }, |
|
|
| |
| { |
| "question": "Show orders with customer names", |
| "sql": "SELECT orders.*, customers.name FROM orders JOIN customers ON orders.customer_id = customers.id" |
| }, |
| { |
| "question": "List order items with product names", |
| "sql": "SELECT order_items.*, products.name FROM order_items JOIN products ON order_items.product_id = products.id" |
| }, |
| { |
| "question": "Find all orders for customers from {city}", |
| "sql": "SELECT orders.* FROM orders JOIN customers ON orders.customer_id = customers.id WHERE customers.city = '{city}'" |
| }, |
|
|
| |
| { |
| "question": "Find customers who joined in 2024", |
| "sql": "SELECT * FROM customers WHERE YEAR(joined_date) = 2024" |
| }, |
| { |
| "question": "Show orders from this month", |
| "sql": "SELECT * FROM orders WHERE MONTH(order_date) = MONTH(CURRENT_DATE) AND YEAR(order_date) = YEAR(CURRENT_DATE)" |
| }, |
| { |
| "question": "Find orders placed today", |
| "sql": "SELECT * FROM orders WHERE order_date = CURRENT_DATE" |
| }, |
|
|
| |
| { |
| "question": "Find customers whose name starts with J", |
| "sql": "SELECT * FROM customers WHERE name LIKE 'J%'" |
| }, |
| { |
| "question": "Show products containing 'phone' in name", |
| "sql": "SELECT * FROM products WHERE name LIKE '%phone%'" |
| }, |
| { |
| "question": "Find customers with gmail email", |
| "sql": "SELECT * FROM customers WHERE email LIKE '%gmail.com'" |
| }, |
|
|
| |
| { |
| "question": "Find products priced between {price1} and {price2} dollars", |
| "sql": "SELECT * FROM products WHERE price BETWEEN {price1} AND {price2}" |
| }, |
| { |
| "question": "Show orders with total between {price1} and {price2}", |
| "sql": "SELECT * FROM orders WHERE total BETWEEN {price1} AND {price2}" |
| }, |
|
|
| |
| { |
| "question": "Find customers from New York or Los Angeles", |
| "sql": "SELECT * FROM customers WHERE city IN ('New York', 'Los Angeles')" |
| }, |
| { |
| "question": "Show products in Electronics or Clothing category", |
| "sql": "SELECT * FROM products WHERE category IN ('Electronics', 'Clothing')" |
| }, |
| { |
| "question": "Find orders that are pending or shipped", |
| "sql": "SELECT * FROM orders WHERE status IN ('pending', 'shipped')" |
| }, |
| ] |
|
|
|
|
| |
| |
| |
|
|
| def fill_template(template): |
| """Fill placeholders with random values.""" |
| question = template["question"] |
| sql = template["sql"] |
|
|
| |
| if "{city}" in question: |
| city = random.choice(CITIES) |
| question = question.replace("{city}", city) |
| sql = sql.replace("{city}", city) |
|
|
| if "{category}" in question: |
| category = random.choice(CATEGORIES) |
| question = question.replace("{category}", category) |
| sql = sql.replace("{category}", category) |
|
|
| if "{status}" in question: |
| status = random.choice(STATUSES) |
| question = question.replace("{status}", status) |
| sql = sql.replace("{status}", status) |
|
|
| if "{price}" in question: |
| price = random.choice([10, 25, 50, 100, 200, 500, 1000]) |
| question = question.replace("{price}", str(price)) |
| sql = sql.replace("{price}", str(price)) |
|
|
| if "{quantity}" in question: |
| quantity = random.choice([5, 10, 20, 50, 100]) |
| question = question.replace("{quantity}", str(quantity)) |
| sql = sql.replace("{quantity}", str(quantity)) |
|
|
| if "{price1}" in question: |
| price1 = random.choice([10, 25, 50, 100]) |
| price2 = price1 + random.choice([50, 100, 200, 500]) |
| question = question.replace("{price1}", str(price1)) |
| question = question.replace("{price2}", str(price2)) |
| sql = sql.replace("{price1}", str(price1)) |
| sql = sql.replace("{price2}", str(price2)) |
|
|
| return question, sql |
|
|
|
|
| def generate_dataset(num_train=160, num_test=40): |
| """Generate training and test datasets.""" |
|
|
| all_examples = [] |
|
|
| |
| for _ in range(50): |
| for template in TEMPLATES: |
| question, sql = fill_template(template) |
| all_examples.append({ |
| "question": question, |
| "sql": sql |
| }) |
|
|
| |
| random.shuffle(all_examples) |
|
|
| |
| seen = set() |
| unique_examples = [] |
| for ex in all_examples: |
| key = (ex["question"], ex["sql"]) |
| if key not in seen: |
| seen.add(key) |
| unique_examples.append(ex) |
|
|
| print(f"Generated {len(unique_examples)} unique examples") |
|
|
| |
| train_data = unique_examples[:num_train] |
| test_data = unique_examples[num_train:num_train + num_test] |
|
|
| return train_data, test_data |
|
|
|
|
| |
| |
| |
|
|
| def format_for_openai(examples, output_file): |
| """ |
| Convert examples to OpenAI fine-tuning format (JSONL). |
| |
| Each line is a conversation: |
| { |
| "messages": [ |
| {"role": "system", "content": "You are a SQL expert..."}, |
| {"role": "user", "content": "Show all customers"}, |
| {"role": "assistant", "content": "SELECT * FROM customers"} |
| ] |
| } |
| """ |
|
|
| system_prompt = f"""You are a SQL expert. Convert natural language questions to SQL queries. |
| |
| {SCHEMA} |
| |
| Rules: |
| - Return ONLY the SQL query, nothing else |
| - Do not explain the query |
| - Use proper SQL syntax |
| """ |
|
|
| with open(output_file, "w") as f: |
| for ex in examples: |
| conversation = { |
| "messages": [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": ex["question"]}, |
| {"role": "assistant", "content": ex["sql"]} |
| ] |
| } |
| f.write(json.dumps(conversation) + "\n") |
|
|
| print(f"Saved {len(examples)} examples to {output_file}") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("=" * 50) |
| print("Generating Text-to-SQL Training Data") |
| print("=" * 50) |
|
|
| |
| train_data, test_data = generate_dataset(num_train=160, num_test=40) |
|
|
| print(f"\nTraining examples: {len(train_data)}") |
| print(f"Test examples: {len(test_data)}") |
|
|
| |
| format_for_openai(train_data, "data/train.jsonl") |
| format_for_openai(test_data, "data/test.jsonl") |
|
|
| |
| print("\n" + "=" * 50) |
| print("Sample Training Examples:") |
| print("=" * 50) |
| for i, ex in enumerate(train_data[:5]): |
| print(f"\n[Example {i+1}]") |
| print(f"Question: {ex['question']}") |
| print(f"SQL: {ex['sql']}") |
|
|
| print("\n" + "=" * 50) |
| print("Data generation complete!") |
| print("=" * 50) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|