Spaces:
Running
Running
File size: 9,946 Bytes
a39d8ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | """
nl2sql-bench/server/db/seed.py
==============================
Deterministic synthetic data generator for the NL2SQL-Bench SQLite database.
Uses a fixed random seed so every fresh environment build produces
IDENTICAL data, which is essential for reproducible grader scores across
different machines, runs, and Docker containers.
Call: seed_database(conn) once after creating tables.
"""
from __future__ import annotations
import random
import sqlite3
from datetime import date, timedelta
from typing import List
# ββ Deterministic seed ββββββββββββββββββββββββββββββββββββββββββββββββββββ
SEED = 42
RNG = random.Random(SEED)
# ββ Domain constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
CATEGORIES = [
"Electronics", "Clothing", "Books", "Home & Garden",
"Sports & Outdoors", "Toys & Games", "Beauty", "Automotive",
]
PRODUCT_NAMES = {
"Electronics": ["Wireless Headphones", "USB-C Hub", "Mechanical Keyboard",
"Webcam 4K", "Portable Charger", "Smart Speaker",
"Monitor Stand", "HDMI Cable 2.1"],
"Clothing": ["Cotton T-Shirt", "Slim Fit Jeans", "Hoodie",
"Running Shorts", "Winter Jacket", "Polo Shirt",
"Casual Sneakers", "Wool Socks"],
"Books": ["Clean Code", "Designing Data-Intensive Applications",
"The Pragmatic Programmer", "System Design Interview",
"Deep Learning Book", "Python Cookbook",
"Domain-Driven Design", "Refactoring"],
"Home & Garden": ["Coffee Maker", "Air Purifier", "LED Desk Lamp",
"Plant Pot Set", "Storage Organiser", "Cutting Board",
"Vacuum Cleaner", "Electric Kettle"],
"Sports & Outdoors":["Yoga Mat", "Resistance Bands", "Cycling Gloves",
"Trekking Poles", "Water Bottle 1L", "Jump Rope",
"Foam Roller", "Compression Socks"],
"Toys & Games": ["Lego City Set", "Card Game Pack", "Puzzle 1000pc",
"Remote Control Car", "Building Blocks",
"Board Game Strategy", "Art Set", "Toy Drone"],
"Beauty": ["Face Serum", "SPF 50 Sunscreen", "Lip Balm",
"Shampoo Pro", "Hair Mask", "Eye Cream",
"Vitamin C Cream", "Toner Mist"],
"Automotive": ["Car Phone Mount", "Dash Cam", "Tyre Inflator",
"Car Vacuum", "Seat Cushion", "Steering Wheel Cover",
"OBD Scanner", "Jump Starter"],
}
COUNTRIES = ["India", "USA", "Germany", "UK", "Canada",
"Australia", "France", "Brazil", "Japan", "Singapore"]
TIERS = ["bronze", "silver", "gold"]
STATUSES = ["pending", "processing", "shipped", "delivered", "cancelled"]
FIRST_NAMES = [
"Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja",
"Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica",
"Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura",
"Yuki","Hana","Wei","Mei","Aiden","Zara",
]
LAST_NAMES = [
"Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy",
"Smith","Johnson","Brown","Williams","Jones","Davis","Wilson",
"MΓΌller","Schmidt","Schneider","Fischer","Weber",
"Martin","Bernard","Thomas","Richard","Petit",
"Garcia","Martinez","Lopez","Sanchez","Gonzalez",
]
def _random_date(start_year: int = 2022, end_year: int = 2025) -> str:
start = date(start_year, 1, 1)
end = date(end_year, 12, 31)
delta = (end - start).days
return (start + timedelta(days=RNG.randint(0, delta))).isoformat()
def seed_database(conn: sqlite3.Connection) -> None:
"""Populate the database with deterministic synthetic data."""
conn.execute("PRAGMA foreign_keys = ON")
cur = conn.cursor()
# ββ Categories ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
for i, name in enumerate(CATEGORIES, 1):
cur.execute(
"INSERT OR IGNORE INTO categories(id, name) VALUES (?, ?)",
(i, name),
)
# ββ Products (8 per category β 64 total) βββββββββββββββββββββββββββββ
pid = 1
for cat_id, (cat_name, names) in enumerate(PRODUCT_NAMES.items(), 1):
for pname in names:
price = round(RNG.uniform(5.0, 250.0), 2)
stock = RNG.randint(0, 500)
cur.execute(
"INSERT OR IGNORE INTO products(id, name, category_id, price, stock_quantity) "
"VALUES (?, ?, ?, ?, ?)",
(pid, pname, cat_id, price, stock),
)
pid += 1
# ββ Customers (150 total) βββββββββββββββββββββββββββββββββββββββββββββ
used_emails: set = set()
for cid in range(1, 151):
fname = RNG.choice(FIRST_NAMES)
lname = RNG.choice(LAST_NAMES)
name = f"{fname} {lname}"
email_base = f"{fname.lower()}.{lname.lower()}"
email = f"{email_base}{cid}@example.com"
while email in used_emails:
email = f"{email_base}{cid}x@example.com"
used_emails.add(email)
# Bias: 60% bronze, 30% silver, 10% gold
tier = RNG.choices(TIERS, weights=[60, 30, 10])[0]
country = RNG.choice(COUNTRIES)
created = _random_date(2021, 2023)
cur.execute(
"INSERT OR IGNORE INTO customers(id, name, email, country, tier, created_at) "
"VALUES (?, ?, ?, ?, ?, ?)",
(cid, name, email, country, tier, created),
)
# ββ Orders + Order items ββββββββββββββββββββββββββββββββββββββββββββββ
oid = 1
item_id = 1
for cid in range(1, 151):
# Each customer has 0β8 orders; gold customers tend to have more
tier_row = cur.execute(
"SELECT tier FROM customers WHERE id=?", (cid,)
).fetchone()
tier = tier_row[0] if tier_row else "bronze"
n_orders = RNG.choices(
range(9),
weights=[5, 20, 20, 15, 15, 10, 8, 5, 2] if tier == "bronze"
else ([2, 10, 15, 20, 20, 15, 10, 5, 3] if tier == "silver"
else [1, 5, 10, 15, 20, 20, 15, 10, 4]),
)[0]
for _ in range(n_orders):
status = RNG.choices(STATUSES, weights=[5, 10, 15, 60, 10])[0]
order_date = _random_date(2022, 2025)
# Pick 1β4 products for this order
n_items = RNG.randint(1, 4)
chosen_pids = RNG.sample(range(1, 65), k=min(n_items, 64))
total = 0.0
cur.execute(
"INSERT OR IGNORE INTO orders(id, customer_id, status, created_at, total_amount) "
"VALUES (?, ?, ?, ?, ?)",
(oid, cid, status, order_date, 0.0), # update total after items
)
for cpid in chosen_pids:
qty = RNG.randint(1, 5)
price_row = cur.execute(
"SELECT price FROM products WHERE id=?", (cpid,)
).fetchone()
unit_price = price_row[0] if price_row else 10.0
total += round(qty * unit_price, 2)
cur.execute(
"INSERT OR IGNORE INTO order_items(id, order_id, product_id, quantity, unit_price) "
"VALUES (?, ?, ?, ?, ?)",
(item_id, oid, cpid, qty, unit_price),
)
item_id += 1
cur.execute(
"UPDATE orders SET total_amount=? WHERE id=?",
(round(total, 2), oid),
)
oid += 1
# ββ Reviews βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Each customer reviews 0β6 products they (may have) ordered
rev_id = 1
reviewed: set = set() # (customer_id, product_id) pairs
for cid in range(1, 151):
n_reviews = RNG.randint(0, 6)
for _ in range(n_reviews):
rpid = RNG.randint(1, 64)
if (cid, rpid) in reviewed:
continue
reviewed.add((cid, rpid))
rating = RNG.choices([1, 2, 3, 4, 5], weights=[5, 10, 15, 35, 35])[0]
rev_date = _random_date(2022, 2025)
cur.execute(
"INSERT OR IGNORE INTO reviews(id, product_id, customer_id, rating, created_at) "
"VALUES (?, ?, ?, ?, ?)",
(rev_id, rpid, cid, rating, rev_date),
)
rev_id += 1
conn.commit()
def get_db_summary(conn: sqlite3.Connection) -> dict:
"""Return row counts per table for debugging / README stats."""
tables = ["categories", "products", "customers", "orders", "order_items", "reviews"]
summary = {}
for t in tables:
row = conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()
summary[t] = row[0] if row else 0
return summary
if __name__ == "__main__":
import os
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
conn = sqlite3.connect(":memory:")
conn.row_factory = sqlite3.Row
with open(schema_path) as f:
conn.executescript(f.read())
seed_database(conn)
print("Seed stats:", get_db_summary(conn))
conn.close()
|