File size: 9,946 Bytes
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""
nl2sql-bench/server/db/seed.py
==============================
Deterministic synthetic data generator for the NL2SQL-Bench SQLite database.

Uses a fixed random seed so every fresh environment build produces
IDENTICAL data, which is essential for reproducible grader scores across
different machines, runs, and Docker containers.

Call: seed_database(conn) once after creating tables.
"""

from __future__ import annotations

import random
import sqlite3
from datetime import date, timedelta
from typing import List

# ── Deterministic seed ────────────────────────────────────────────────────
SEED = 42
RNG = random.Random(SEED)

# ── Domain constants ──────────────────────────────────────────────────────
CATEGORIES = [
    "Electronics", "Clothing", "Books", "Home & Garden",
    "Sports & Outdoors", "Toys & Games", "Beauty", "Automotive",
]

PRODUCT_NAMES = {
    "Electronics":      ["Wireless Headphones", "USB-C Hub", "Mechanical Keyboard",
                         "Webcam 4K", "Portable Charger", "Smart Speaker",
                         "Monitor Stand", "HDMI Cable 2.1"],
    "Clothing":         ["Cotton T-Shirt", "Slim Fit Jeans", "Hoodie",
                         "Running Shorts", "Winter Jacket", "Polo Shirt",
                         "Casual Sneakers", "Wool Socks"],
    "Books":            ["Clean Code", "Designing Data-Intensive Applications",
                         "The Pragmatic Programmer", "System Design Interview",
                         "Deep Learning Book", "Python Cookbook",
                         "Domain-Driven Design", "Refactoring"],
    "Home & Garden":    ["Coffee Maker", "Air Purifier", "LED Desk Lamp",
                         "Plant Pot Set", "Storage Organiser", "Cutting Board",
                         "Vacuum Cleaner", "Electric Kettle"],
    "Sports & Outdoors":["Yoga Mat", "Resistance Bands", "Cycling Gloves",
                         "Trekking Poles", "Water Bottle 1L", "Jump Rope",
                         "Foam Roller", "Compression Socks"],
    "Toys & Games":     ["Lego City Set", "Card Game Pack", "Puzzle 1000pc",
                         "Remote Control Car", "Building Blocks",
                         "Board Game Strategy", "Art Set", "Toy Drone"],
    "Beauty":           ["Face Serum", "SPF 50 Sunscreen", "Lip Balm",
                         "Shampoo Pro", "Hair Mask", "Eye Cream",
                         "Vitamin C Cream", "Toner Mist"],
    "Automotive":       ["Car Phone Mount", "Dash Cam", "Tyre Inflator",
                         "Car Vacuum", "Seat Cushion", "Steering Wheel Cover",
                         "OBD Scanner", "Jump Starter"],
}

COUNTRIES = ["India", "USA", "Germany", "UK", "Canada",
             "Australia", "France", "Brazil", "Japan", "Singapore"]

TIERS    = ["bronze", "silver", "gold"]
STATUSES = ["pending", "processing", "shipped", "delivered", "cancelled"]

FIRST_NAMES = [
    "Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja",
    "Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica",
    "Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura",
    "Yuki","Hana","Wei","Mei","Aiden","Zara",
]
LAST_NAMES = [
    "Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy",
    "Smith","Johnson","Brown","Williams","Jones","Davis","Wilson",
    "MΓΌller","Schmidt","Schneider","Fischer","Weber",
    "Martin","Bernard","Thomas","Richard","Petit",
    "Garcia","Martinez","Lopez","Sanchez","Gonzalez",
]


def _random_date(start_year: int = 2022, end_year: int = 2025) -> str:
    start = date(start_year, 1, 1)
    end   = date(end_year, 12, 31)
    delta = (end - start).days
    return (start + timedelta(days=RNG.randint(0, delta))).isoformat()


def seed_database(conn: sqlite3.Connection) -> None:
    """Populate the database with deterministic synthetic data."""
    conn.execute("PRAGMA foreign_keys = ON")
    cur = conn.cursor()

    # ── Categories ────────────────────────────────────────────────────────
    for i, name in enumerate(CATEGORIES, 1):
        cur.execute(
            "INSERT OR IGNORE INTO categories(id, name) VALUES (?, ?)",
            (i, name),
        )

    # ── Products (8 per category β†’ 64 total) ─────────────────────────────
    pid = 1
    for cat_id, (cat_name, names) in enumerate(PRODUCT_NAMES.items(), 1):
        for pname in names:
            price = round(RNG.uniform(5.0, 250.0), 2)
            stock = RNG.randint(0, 500)
            cur.execute(
                "INSERT OR IGNORE INTO products(id, name, category_id, price, stock_quantity) "
                "VALUES (?, ?, ?, ?, ?)",
                (pid, pname, cat_id, price, stock),
            )
            pid += 1

    # ── Customers (150 total) ─────────────────────────────────────────────
    used_emails: set = set()
    for cid in range(1, 151):
        fname = RNG.choice(FIRST_NAMES)
        lname = RNG.choice(LAST_NAMES)
        name  = f"{fname} {lname}"
        email_base = f"{fname.lower()}.{lname.lower()}"
        email = f"{email_base}{cid}@example.com"
        while email in used_emails:
            email = f"{email_base}{cid}x@example.com"
        used_emails.add(email)

        # Bias: 60% bronze, 30% silver, 10% gold
        tier    = RNG.choices(TIERS, weights=[60, 30, 10])[0]
        country = RNG.choice(COUNTRIES)
        created = _random_date(2021, 2023)
        cur.execute(
            "INSERT OR IGNORE INTO customers(id, name, email, country, tier, created_at) "
            "VALUES (?, ?, ?, ?, ?, ?)",
            (cid, name, email, country, tier, created),
        )

    # ── Orders + Order items ──────────────────────────────────────────────
    oid = 1
    item_id = 1
    for cid in range(1, 151):
        # Each customer has 0–8 orders; gold customers tend to have more
        tier_row = cur.execute(
            "SELECT tier FROM customers WHERE id=?", (cid,)
        ).fetchone()
        tier = tier_row[0] if tier_row else "bronze"
        n_orders = RNG.choices(
            range(9),
            weights=[5, 20, 20, 15, 15, 10, 8, 5, 2] if tier == "bronze"
            else ([2, 10, 15, 20, 20, 15, 10, 5, 3] if tier == "silver"
            else [1, 5, 10, 15, 20, 20, 15, 10, 4]),
        )[0]

        for _ in range(n_orders):
            status      = RNG.choices(STATUSES, weights=[5, 10, 15, 60, 10])[0]
            order_date  = _random_date(2022, 2025)
            # Pick 1–4 products for this order
            n_items     = RNG.randint(1, 4)
            chosen_pids = RNG.sample(range(1, 65), k=min(n_items, 64))
            total       = 0.0

            cur.execute(
                "INSERT OR IGNORE INTO orders(id, customer_id, status, created_at, total_amount) "
                "VALUES (?, ?, ?, ?, ?)",
                (oid, cid, status, order_date, 0.0),   # update total after items
            )

            for cpid in chosen_pids:
                qty = RNG.randint(1, 5)
                price_row = cur.execute(
                    "SELECT price FROM products WHERE id=?", (cpid,)
                ).fetchone()
                unit_price = price_row[0] if price_row else 10.0
                total += round(qty * unit_price, 2)
                cur.execute(
                    "INSERT OR IGNORE INTO order_items(id, order_id, product_id, quantity, unit_price) "
                    "VALUES (?, ?, ?, ?, ?)",
                    (item_id, oid, cpid, qty, unit_price),
                )
                item_id += 1

            cur.execute(
                "UPDATE orders SET total_amount=? WHERE id=?",
                (round(total, 2), oid),
            )
            oid += 1

    # ── Reviews ───────────────────────────────────────────────────────────
    # Each customer reviews 0–6 products they (may have) ordered
    rev_id = 1
    reviewed: set = set()   # (customer_id, product_id) pairs
    for cid in range(1, 151):
        n_reviews = RNG.randint(0, 6)
        for _ in range(n_reviews):
            rpid = RNG.randint(1, 64)
            if (cid, rpid) in reviewed:
                continue
            reviewed.add((cid, rpid))
            rating     = RNG.choices([1, 2, 3, 4, 5], weights=[5, 10, 15, 35, 35])[0]
            rev_date   = _random_date(2022, 2025)
            cur.execute(
                "INSERT OR IGNORE INTO reviews(id, product_id, customer_id, rating, created_at) "
                "VALUES (?, ?, ?, ?, ?)",
                (rev_id, rpid, cid, rating, rev_date),
            )
            rev_id += 1

    conn.commit()


def get_db_summary(conn: sqlite3.Connection) -> dict:
    """Return row counts per table for debugging / README stats."""
    tables = ["categories", "products", "customers", "orders", "order_items", "reviews"]
    summary = {}
    for t in tables:
        row = conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()
        summary[t] = row[0] if row else 0
    return summary


if __name__ == "__main__":
    import os
    schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
    conn = sqlite3.connect(":memory:")
    conn.row_factory = sqlite3.Row
    with open(schema_path) as f:
        conn.executescript(f.read())
    seed_database(conn)
    print("Seed stats:", get_db_summary(conn))
    conn.close()