Spaces:
Running
Running
| """ | |
| nl2sql-bench/server/db/seed.py | |
| ============================== | |
| Deterministic synthetic data generator for the NL2SQL-Bench SQLite database. | |
| Uses a fixed random seed so every fresh environment build produces | |
| IDENTICAL data, which is essential for reproducible grader scores across | |
| different machines, runs, and Docker containers. | |
| Call: seed_database(conn) once after creating tables. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| import sqlite3 | |
| from datetime import date, timedelta | |
| from typing import List | |
| # ββ Deterministic seed ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SEED = 42 | |
| RNG = random.Random(SEED) | |
| # ββ Domain constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CATEGORIES = [ | |
| "Electronics", "Clothing", "Books", "Home & Garden", | |
| "Sports & Outdoors", "Toys & Games", "Beauty", "Automotive", | |
| ] | |
| PRODUCT_NAMES = { | |
| "Electronics": ["Wireless Headphones", "USB-C Hub", "Mechanical Keyboard", | |
| "Webcam 4K", "Portable Charger", "Smart Speaker", | |
| "Monitor Stand", "HDMI Cable 2.1"], | |
| "Clothing": ["Cotton T-Shirt", "Slim Fit Jeans", "Hoodie", | |
| "Running Shorts", "Winter Jacket", "Polo Shirt", | |
| "Casual Sneakers", "Wool Socks"], | |
| "Books": ["Clean Code", "Designing Data-Intensive Applications", | |
| "The Pragmatic Programmer", "System Design Interview", | |
| "Deep Learning Book", "Python Cookbook", | |
| "Domain-Driven Design", "Refactoring"], | |
| "Home & Garden": ["Coffee Maker", "Air Purifier", "LED Desk Lamp", | |
| "Plant Pot Set", "Storage Organiser", "Cutting Board", | |
| "Vacuum Cleaner", "Electric Kettle"], | |
| "Sports & Outdoors":["Yoga Mat", "Resistance Bands", "Cycling Gloves", | |
| "Trekking Poles", "Water Bottle 1L", "Jump Rope", | |
| "Foam Roller", "Compression Socks"], | |
| "Toys & Games": ["Lego City Set", "Card Game Pack", "Puzzle 1000pc", | |
| "Remote Control Car", "Building Blocks", | |
| "Board Game Strategy", "Art Set", "Toy Drone"], | |
| "Beauty": ["Face Serum", "SPF 50 Sunscreen", "Lip Balm", | |
| "Shampoo Pro", "Hair Mask", "Eye Cream", | |
| "Vitamin C Cream", "Toner Mist"], | |
| "Automotive": ["Car Phone Mount", "Dash Cam", "Tyre Inflator", | |
| "Car Vacuum", "Seat Cushion", "Steering Wheel Cover", | |
| "OBD Scanner", "Jump Starter"], | |
| } | |
| COUNTRIES = ["India", "USA", "Germany", "UK", "Canada", | |
| "Australia", "France", "Brazil", "Japan", "Singapore"] | |
| TIERS = ["bronze", "silver", "gold"] | |
| STATUSES = ["pending", "processing", "shipped", "delivered", "cancelled"] | |
| FIRST_NAMES = [ | |
| "Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja", | |
| "Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica", | |
| "Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura", | |
| "Yuki","Hana","Wei","Mei","Aiden","Zara", | |
| ] | |
| LAST_NAMES = [ | |
| "Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy", | |
| "Smith","Johnson","Brown","Williams","Jones","Davis","Wilson", | |
| "MΓΌller","Schmidt","Schneider","Fischer","Weber", | |
| "Martin","Bernard","Thomas","Richard","Petit", | |
| "Garcia","Martinez","Lopez","Sanchez","Gonzalez", | |
| ] | |
| def _random_date(start_year: int = 2022, end_year: int = 2025) -> str: | |
| start = date(start_year, 1, 1) | |
| end = date(end_year, 12, 31) | |
| delta = (end - start).days | |
| return (start + timedelta(days=RNG.randint(0, delta))).isoformat() | |
| def seed_database(conn: sqlite3.Connection) -> None: | |
| """Populate the database with deterministic synthetic data.""" | |
| conn.execute("PRAGMA foreign_keys = ON") | |
| cur = conn.cursor() | |
| # ββ Categories ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| for i, name in enumerate(CATEGORIES, 1): | |
| cur.execute( | |
| "INSERT OR IGNORE INTO categories(id, name) VALUES (?, ?)", | |
| (i, name), | |
| ) | |
| # ββ Products (8 per category β 64 total) βββββββββββββββββββββββββββββ | |
| pid = 1 | |
| for cat_id, (cat_name, names) in enumerate(PRODUCT_NAMES.items(), 1): | |
| for pname in names: | |
| price = round(RNG.uniform(5.0, 250.0), 2) | |
| stock = RNG.randint(0, 500) | |
| cur.execute( | |
| "INSERT OR IGNORE INTO products(id, name, category_id, price, stock_quantity) " | |
| "VALUES (?, ?, ?, ?, ?)", | |
| (pid, pname, cat_id, price, stock), | |
| ) | |
| pid += 1 | |
| # ββ Customers (150 total) βββββββββββββββββββββββββββββββββββββββββββββ | |
| used_emails: set = set() | |
| for cid in range(1, 151): | |
| fname = RNG.choice(FIRST_NAMES) | |
| lname = RNG.choice(LAST_NAMES) | |
| name = f"{fname} {lname}" | |
| email_base = f"{fname.lower()}.{lname.lower()}" | |
| email = f"{email_base}{cid}@example.com" | |
| while email in used_emails: | |
| email = f"{email_base}{cid}x@example.com" | |
| used_emails.add(email) | |
| # Bias: 60% bronze, 30% silver, 10% gold | |
| tier = RNG.choices(TIERS, weights=[60, 30, 10])[0] | |
| country = RNG.choice(COUNTRIES) | |
| created = _random_date(2021, 2023) | |
| cur.execute( | |
| "INSERT OR IGNORE INTO customers(id, name, email, country, tier, created_at) " | |
| "VALUES (?, ?, ?, ?, ?, ?)", | |
| (cid, name, email, country, tier, created), | |
| ) | |
| # ββ Orders + Order items ββββββββββββββββββββββββββββββββββββββββββββββ | |
| oid = 1 | |
| item_id = 1 | |
| for cid in range(1, 151): | |
| # Each customer has 0β8 orders; gold customers tend to have more | |
| tier_row = cur.execute( | |
| "SELECT tier FROM customers WHERE id=?", (cid,) | |
| ).fetchone() | |
| tier = tier_row[0] if tier_row else "bronze" | |
| n_orders = RNG.choices( | |
| range(9), | |
| weights=[5, 20, 20, 15, 15, 10, 8, 5, 2] if tier == "bronze" | |
| else ([2, 10, 15, 20, 20, 15, 10, 5, 3] if tier == "silver" | |
| else [1, 5, 10, 15, 20, 20, 15, 10, 4]), | |
| )[0] | |
| for _ in range(n_orders): | |
| status = RNG.choices(STATUSES, weights=[5, 10, 15, 60, 10])[0] | |
| order_date = _random_date(2022, 2025) | |
| # Pick 1β4 products for this order | |
| n_items = RNG.randint(1, 4) | |
| chosen_pids = RNG.sample(range(1, 65), k=min(n_items, 64)) | |
| total = 0.0 | |
| cur.execute( | |
| "INSERT OR IGNORE INTO orders(id, customer_id, status, created_at, total_amount) " | |
| "VALUES (?, ?, ?, ?, ?)", | |
| (oid, cid, status, order_date, 0.0), # update total after items | |
| ) | |
| for cpid in chosen_pids: | |
| qty = RNG.randint(1, 5) | |
| price_row = cur.execute( | |
| "SELECT price FROM products WHERE id=?", (cpid,) | |
| ).fetchone() | |
| unit_price = price_row[0] if price_row else 10.0 | |
| total += round(qty * unit_price, 2) | |
| cur.execute( | |
| "INSERT OR IGNORE INTO order_items(id, order_id, product_id, quantity, unit_price) " | |
| "VALUES (?, ?, ?, ?, ?)", | |
| (item_id, oid, cpid, qty, unit_price), | |
| ) | |
| item_id += 1 | |
| cur.execute( | |
| "UPDATE orders SET total_amount=? WHERE id=?", | |
| (round(total, 2), oid), | |
| ) | |
| oid += 1 | |
| # ββ Reviews βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Each customer reviews 0β6 products they (may have) ordered | |
| rev_id = 1 | |
| reviewed: set = set() # (customer_id, product_id) pairs | |
| for cid in range(1, 151): | |
| n_reviews = RNG.randint(0, 6) | |
| for _ in range(n_reviews): | |
| rpid = RNG.randint(1, 64) | |
| if (cid, rpid) in reviewed: | |
| continue | |
| reviewed.add((cid, rpid)) | |
| rating = RNG.choices([1, 2, 3, 4, 5], weights=[5, 10, 15, 35, 35])[0] | |
| rev_date = _random_date(2022, 2025) | |
| cur.execute( | |
| "INSERT OR IGNORE INTO reviews(id, product_id, customer_id, rating, created_at) " | |
| "VALUES (?, ?, ?, ?, ?)", | |
| (rev_id, rpid, cid, rating, rev_date), | |
| ) | |
| rev_id += 1 | |
| conn.commit() | |
| def get_db_summary(conn: sqlite3.Connection) -> dict: | |
| """Return row counts per table for debugging / README stats.""" | |
| tables = ["categories", "products", "customers", "orders", "order_items", "reviews"] | |
| summary = {} | |
| for t in tables: | |
| row = conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone() | |
| summary[t] = row[0] if row else 0 | |
| return summary | |
| if __name__ == "__main__": | |
| import os | |
| schema_path = os.path.join(os.path.dirname(__file__), "schema.sql") | |
| conn = sqlite3.connect(":memory:") | |
| conn.row_factory = sqlite3.Row | |
| with open(schema_path) as f: | |
| conn.executescript(f.read()) | |
| seed_database(conn) | |
| print("Seed stats:", get_db_summary(conn)) | |
| conn.close() | |