nl2sql-bench / server /db /seed.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
nl2sql-bench/server/db/seed.py
==============================
Deterministic synthetic data generator for the NL2SQL-Bench SQLite database.
Uses a fixed random seed so every fresh environment build produces
IDENTICAL data, which is essential for reproducible grader scores across
different machines, runs, and Docker containers.
Call: seed_database(conn) once after creating tables.
"""
from __future__ import annotations
import random
import sqlite3
from datetime import date, timedelta
from typing import List
# ── Deterministic seed ────────────────────────────────────────────────────
SEED = 42
RNG = random.Random(SEED)
# ── Domain constants ──────────────────────────────────────────────────────
CATEGORIES = [
"Electronics", "Clothing", "Books", "Home & Garden",
"Sports & Outdoors", "Toys & Games", "Beauty", "Automotive",
]
PRODUCT_NAMES = {
"Electronics": ["Wireless Headphones", "USB-C Hub", "Mechanical Keyboard",
"Webcam 4K", "Portable Charger", "Smart Speaker",
"Monitor Stand", "HDMI Cable 2.1"],
"Clothing": ["Cotton T-Shirt", "Slim Fit Jeans", "Hoodie",
"Running Shorts", "Winter Jacket", "Polo Shirt",
"Casual Sneakers", "Wool Socks"],
"Books": ["Clean Code", "Designing Data-Intensive Applications",
"The Pragmatic Programmer", "System Design Interview",
"Deep Learning Book", "Python Cookbook",
"Domain-Driven Design", "Refactoring"],
"Home & Garden": ["Coffee Maker", "Air Purifier", "LED Desk Lamp",
"Plant Pot Set", "Storage Organiser", "Cutting Board",
"Vacuum Cleaner", "Electric Kettle"],
"Sports & Outdoors":["Yoga Mat", "Resistance Bands", "Cycling Gloves",
"Trekking Poles", "Water Bottle 1L", "Jump Rope",
"Foam Roller", "Compression Socks"],
"Toys & Games": ["Lego City Set", "Card Game Pack", "Puzzle 1000pc",
"Remote Control Car", "Building Blocks",
"Board Game Strategy", "Art Set", "Toy Drone"],
"Beauty": ["Face Serum", "SPF 50 Sunscreen", "Lip Balm",
"Shampoo Pro", "Hair Mask", "Eye Cream",
"Vitamin C Cream", "Toner Mist"],
"Automotive": ["Car Phone Mount", "Dash Cam", "Tyre Inflator",
"Car Vacuum", "Seat Cushion", "Steering Wheel Cover",
"OBD Scanner", "Jump Starter"],
}
COUNTRIES = ["India", "USA", "Germany", "UK", "Canada",
"Australia", "France", "Brazil", "Japan", "Singapore"]
TIERS = ["bronze", "silver", "gold"]
STATUSES = ["pending", "processing", "shipped", "delivered", "cancelled"]
FIRST_NAMES = [
"Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja",
"Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica",
"Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura",
"Yuki","Hana","Wei","Mei","Aiden","Zara",
]
LAST_NAMES = [
"Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy",
"Smith","Johnson","Brown","Williams","Jones","Davis","Wilson",
"MΓΌller","Schmidt","Schneider","Fischer","Weber",
"Martin","Bernard","Thomas","Richard","Petit",
"Garcia","Martinez","Lopez","Sanchez","Gonzalez",
]
def _random_date(start_year: int = 2022, end_year: int = 2025) -> str:
start = date(start_year, 1, 1)
end = date(end_year, 12, 31)
delta = (end - start).days
return (start + timedelta(days=RNG.randint(0, delta))).isoformat()
def seed_database(conn: sqlite3.Connection) -> None:
"""Populate the database with deterministic synthetic data."""
conn.execute("PRAGMA foreign_keys = ON")
cur = conn.cursor()
# ── Categories ────────────────────────────────────────────────────────
for i, name in enumerate(CATEGORIES, 1):
cur.execute(
"INSERT OR IGNORE INTO categories(id, name) VALUES (?, ?)",
(i, name),
)
# ── Products (8 per category β†’ 64 total) ─────────────────────────────
pid = 1
for cat_id, (cat_name, names) in enumerate(PRODUCT_NAMES.items(), 1):
for pname in names:
price = round(RNG.uniform(5.0, 250.0), 2)
stock = RNG.randint(0, 500)
cur.execute(
"INSERT OR IGNORE INTO products(id, name, category_id, price, stock_quantity) "
"VALUES (?, ?, ?, ?, ?)",
(pid, pname, cat_id, price, stock),
)
pid += 1
# ── Customers (150 total) ─────────────────────────────────────────────
used_emails: set = set()
for cid in range(1, 151):
fname = RNG.choice(FIRST_NAMES)
lname = RNG.choice(LAST_NAMES)
name = f"{fname} {lname}"
email_base = f"{fname.lower()}.{lname.lower()}"
email = f"{email_base}{cid}@example.com"
while email in used_emails:
email = f"{email_base}{cid}x@example.com"
used_emails.add(email)
# Bias: 60% bronze, 30% silver, 10% gold
tier = RNG.choices(TIERS, weights=[60, 30, 10])[0]
country = RNG.choice(COUNTRIES)
created = _random_date(2021, 2023)
cur.execute(
"INSERT OR IGNORE INTO customers(id, name, email, country, tier, created_at) "
"VALUES (?, ?, ?, ?, ?, ?)",
(cid, name, email, country, tier, created),
)
# ── Orders + Order items ──────────────────────────────────────────────
oid = 1
item_id = 1
for cid in range(1, 151):
# Each customer has 0–8 orders; gold customers tend to have more
tier_row = cur.execute(
"SELECT tier FROM customers WHERE id=?", (cid,)
).fetchone()
tier = tier_row[0] if tier_row else "bronze"
n_orders = RNG.choices(
range(9),
weights=[5, 20, 20, 15, 15, 10, 8, 5, 2] if tier == "bronze"
else ([2, 10, 15, 20, 20, 15, 10, 5, 3] if tier == "silver"
else [1, 5, 10, 15, 20, 20, 15, 10, 4]),
)[0]
for _ in range(n_orders):
status = RNG.choices(STATUSES, weights=[5, 10, 15, 60, 10])[0]
order_date = _random_date(2022, 2025)
# Pick 1–4 products for this order
n_items = RNG.randint(1, 4)
chosen_pids = RNG.sample(range(1, 65), k=min(n_items, 64))
total = 0.0
cur.execute(
"INSERT OR IGNORE INTO orders(id, customer_id, status, created_at, total_amount) "
"VALUES (?, ?, ?, ?, ?)",
(oid, cid, status, order_date, 0.0), # update total after items
)
for cpid in chosen_pids:
qty = RNG.randint(1, 5)
price_row = cur.execute(
"SELECT price FROM products WHERE id=?", (cpid,)
).fetchone()
unit_price = price_row[0] if price_row else 10.0
total += round(qty * unit_price, 2)
cur.execute(
"INSERT OR IGNORE INTO order_items(id, order_id, product_id, quantity, unit_price) "
"VALUES (?, ?, ?, ?, ?)",
(item_id, oid, cpid, qty, unit_price),
)
item_id += 1
cur.execute(
"UPDATE orders SET total_amount=? WHERE id=?",
(round(total, 2), oid),
)
oid += 1
# ── Reviews ───────────────────────────────────────────────────────────
# Each customer reviews 0–6 products they (may have) ordered
rev_id = 1
reviewed: set = set() # (customer_id, product_id) pairs
for cid in range(1, 151):
n_reviews = RNG.randint(0, 6)
for _ in range(n_reviews):
rpid = RNG.randint(1, 64)
if (cid, rpid) in reviewed:
continue
reviewed.add((cid, rpid))
rating = RNG.choices([1, 2, 3, 4, 5], weights=[5, 10, 15, 35, 35])[0]
rev_date = _random_date(2022, 2025)
cur.execute(
"INSERT OR IGNORE INTO reviews(id, product_id, customer_id, rating, created_at) "
"VALUES (?, ?, ?, ?, ?)",
(rev_id, rpid, cid, rating, rev_date),
)
rev_id += 1
conn.commit()
def get_db_summary(conn: sqlite3.Connection) -> dict:
"""Return row counts per table for debugging / README stats."""
tables = ["categories", "products", "customers", "orders", "order_items", "reviews"]
summary = {}
for t in tables:
row = conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()
summary[t] = row[0] if row else 0
return summary
if __name__ == "__main__":
import os
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
with open(schema_path) as f:
conn.executescript(f.read())
seed_database(conn)
print("Seed stats:", get_db_summary(conn))
conn.close()