File size: 13,286 Bytes
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
"""
data_factory/augmentor.py
==========================
Rule-based Natural Language augmentation.

These transformations operate ONLY on NL question strings.
SQL is NEVER modified β€” it always comes from the verified template library.

Three augmentation strategies:
  1. Synonym replacement  β€” swaps domain words with semantically equivalent ones
  2. Condition reordering β€” shuffles conjunctive phrases (preserves meaning)
  3. Date normalisation   β€” expresses dates in different formats when applicable
"""

from __future__ import annotations

import random
import re
from copy import deepcopy
from typing import Iterator


# ─────────────────────────────────────────────────────────────────────────────
# SYNONYM DICTIONARIES
# ─────────────────────────────────────────────────────────────────────────────

# Format: "canonical_term": ["synonym1", "synonym2", ...]
# All synonyms are semantically equivalent in a business context.

_SYNONYMS: dict[str, list[str]] = {

    # Verbs / action starters
    "list":         ["show", "display", "return", "give me", "find", "retrieve"],
    "show":         ["list", "display", "return", "get", "retrieve"],
    "find":         ["identify", "locate", "get", "show", "retrieve", "look up"],
    "return":       ["show", "give", "list", "retrieve", "output"],
    "retrieve":     ["fetch", "get", "return", "pull"],
    "get":          ["retrieve", "fetch", "return", "give me"],

    # Aggregation words
    "total":        ["sum", "aggregate", "overall", "cumulative", "combined"],
    "average":      ["mean", "avg", "typical"],
    "count":        ["number of", "quantity of", "how many"],
    "highest":      ["largest", "maximum", "top", "greatest"],
    "lowest":       ["smallest", "minimum", "least"],

    # Business / domain
    "customer":     ["client", "buyer", "user", "account holder", "shopper"],
    "customers":    ["clients", "buyers", "users", "account holders", "shoppers"],
    "product":      ["item", "SKU", "article", "goods"],
    "products":     ["items", "SKUs", "articles", "goods"],
    "order":        ["purchase", "transaction", "sale"],
    "orders":       ["purchases", "transactions", "sales"],
    "revenue":      ["income", "earnings", "sales amount", "money earned"],
    "spending":     ["expenditure", "spend", "purchases"],
    "amount":       ["value", "sum", "total", "figure"],
    "price":        ["cost", "rate", "charge", "fee"],

    # Healthcare
    "patient":      ["person", "individual", "case"],
    "patients":     ["persons", "individuals", "cases"],
    "doctor":       ["physician", "clinician", "practitioner", "specialist"],
    "doctors":      ["physicians", "clinicians", "practitioners"],
    "appointment":  ["visit", "consultation", "session"],
    "appointments": ["visits", "consultations", "sessions"],
    "medication":   ["drug", "medicine", "pharmaceutical", "prescription drug"],
    "medications":  ["drugs", "medicines", "pharmaceuticals"],
    "diagnosis":    ["condition", "finding", "medical finding"],

    # Finance
    "account":      ["bank account", "profile", "portfolio entry"],
    "accounts":     ["bank accounts", "profiles"],
    "loan":         ["credit", "borrowing", "debt instrument"],
    "loans":        ["credits", "borrowings", "debt instruments"],
    "transaction":  ["transfer", "payment", "operation", "activity"],
    "transactions": ["transfers", "payments", "operations"],
    "balance":      ["funds", "available amount", "account balance"],

    # HR
    "employee":     ["staff member", "worker", "team member", "headcount"],
    "employees":    ["staff", "workers", "team members", "workforce"],
    "department":   ["team", "division", "unit", "group"],
    "departments":  ["teams", "divisions", "units"],
    "salary":       ["pay", "compensation", "remuneration", "earnings"],
    "project":      ["initiative", "program", "assignment", "engagement"],
    "projects":     ["initiatives", "programs", "assignments"],

    # Adjectives / Qualifiers
    "active":       ["current", "ongoing", "live", "existing"],
    "delivered":    ["completed", "fulfilled", "received"],
    "cancelled":    ["voided", "aborted", "terminated"],
    "alphabetically": ["by name", "in alphabetical order", "A to Z"],
    "descending":   ["from highest to lowest", "in decreasing order", "largest first"],
    "ascending":    ["from lowest to highest", "in increasing order", "smallest first"],
    "distinct":     ["unique", "different"],
    "in stock":     ["available", "with available inventory", "not out of stock"],
}


# ─────────────────────────────────────────────────────────────────────────────
# DATE PHRASE PATTERNS
# These will be replaced with alternative date expressions.
# ─────────────────────────────────────────────────────────────────────────────

_DATE_ALTERNATES: list[tuple[str, list[str]]] = [
    # ISO partial
    ("2024-01-01",   ["January 1st 2024", "Jan 1, 2024", "the start of 2024", "2024 start"]),
    ("2023-01-01",   ["January 1st 2023", "Jan 1, 2023", "the start of 2023"]),
    ("2025-01-01",   ["January 1st 2025", "the start of 2025"]),
    # Quarter references
    ("Q1",           ["the first quarter", "January through March", "Jan-Mar"]),
    ("Q2",           ["the second quarter", "April through June", "Apr-Jun"]),
    ("Q3",           ["the third quarter", "July through September", "Jul-Sep"]),
    ("Q4",           ["the fourth quarter", "October through December", "Oct-Dec"]),
    # Year references
    ("in 2024",      ["during 2024", "throughout 2024", "for the year 2024"]),
    ("in 2023",      ["during 2023", "throughout 2023", "for the year 2023"]),
]


# ─────────────────────────────────────────────────────────────────────────────
# CONDITION REORDERING
# Splits on "and" between two conditions and reverses them.
# ─────────────────────────────────────────────────────────────────────────────

def _reorder_conditions(text: str, rng: random.Random) -> str:
    """
    If the text contains ' and ' connecting two distinct clauses,
    randomly swap their order 50% of the time.

    Example:
      "active employees earning above $100,000"
      β†’ "employees earning above $100,000 that are active"
    """
    # Only attempt if "and" is present as a clause connector
    matches = list(re.finditer(r'\b(?:and|who are|that are|with)\b', text, re.IGNORECASE))
    if not matches or rng.random() > 0.5:
        return text

    # Take the first match and swap text around it
    m = matches[0]
    before = text[:m.start()].strip()
    after  = text[m.end():].strip()
    connector = m.group(0).lower()

    # Build swapped version
    if connector in ("and",):
        swapped = f"{after} and {before}"
    else:
        swapped = f"{after} {connector} {before}"

    # Return swapped only if it doesn't break grammar badly
    # (heuristic: swapped should not start with a verb)
    if swapped and not swapped[0].isupper():
        swapped = swapped[0].upper() + swapped[1:]
    return swapped


# ─────────────────────────────────────────────────────────────────────────────
# SYNONYM REPLACEMENT
# ─────────────────────────────────────────────────────────────────────────────

def _apply_synonyms(text: str, rng: random.Random, max_replacements: int = 3) -> str:
    """
    Replace up to `max_replacements` words/phrases with synonyms.
    Replacement is probabilistic (50% chance per match) to maintain diversity.
    """
    result = text
    replacements_done = 0

    # Shuffle the synonym keys to get different replacement targets each call
    keys = list(_SYNONYMS.keys())
    rng.shuffle(keys)

    for canonical in keys:
        if replacements_done >= max_replacements:
            break
        synonyms = _SYNONYMS[canonical]
        # Case-insensitive match on word boundary
        pattern = re.compile(r'\b' + re.escape(canonical) + r'\b', re.IGNORECASE)
        if pattern.search(result) and rng.random() < 0.5:
            replacement = rng.choice(synonyms)
            # Preserve original casing for first character
            def _replace(m: re.Match) -> str:
                original = m.group(0)
                if original[0].isupper():
                    return replacement[0].upper() + replacement[1:]
                return replacement
            result = pattern.sub(_replace, result, count=1)
            replacements_done += 1

    return result


# ─────────────────────────────────────────────────────────────────────────────
# DATE FORMAT VARIATION
# ─────────────────────────────────────────────────────────────────────────────

def _vary_dates(text: str, rng: random.Random) -> str:
    """Replace date phrases with alternate representations."""
    result = text
    for phrase, alternates in _DATE_ALTERNATES:
        if phrase.lower() in result.lower() and rng.random() < 0.6:
            alt = rng.choice(alternates)
            result = re.sub(re.escape(phrase), alt, result, count=1, flags=re.IGNORECASE)
    return result


# ─────────────────────────────────────────────────────────────────────────────
# PUBLIC API
# ─────────────────────────────────────────────────────────────────────────────

def augment_nl(
    nl_question: str,
    n: int = 3,
    seed: int = 42,
) -> list[str]:
    """
    Generate `n` rule-based augmented variants of a natural language question.

    Each variant applies a different combination of:
      - synonym replacement
      - condition reordering
      - date format variation

    The original question is NOT included in the output.

    Parameters
    ----------
    nl_question : str
        The base NL question to augment.
    n : int
        Number of variants to generate.
    seed : int
        Random seed for reproducibility.

    Returns
    -------
    list[str]
        Up to `n` distinct augmented strings. May be fewer if the question
        is too short to vary meaningfully.
    """
    rng = random.Random(seed)
    variants: list[str] = []
    seen: set[str] = {nl_question}

    strategies = [
        # Strategy 1: synonym only
        lambda t, r: _apply_synonyms(t, r, max_replacements=2),
        # Strategy 2: synonym + date
        lambda t, r: _vary_dates(_apply_synonyms(t, r, max_replacements=2), r),
        # Strategy 3: condition reorder + synonym
        lambda t, r: _apply_synonyms(_reorder_conditions(t, r), r, max_replacements=1),
        # Strategy 4: heavy synonym
        lambda t, r: _apply_synonyms(t, r, max_replacements=4),
        # Strategy 5: date only
        lambda t, r: _vary_dates(t, r),
    ]

    for i in range(n * 3):   # Over-generate, then deduplicate
        strategy = strategies[i % len(strategies)]
        # Use a different seed offset per variant attempt
        local_rng = random.Random(seed + i * 31)
        candidate = strategy(nl_question, local_rng).strip()

        # Normalise whitespace
        candidate = " ".join(candidate.split())

        if candidate and candidate not in seen:
            seen.add(candidate)
            variants.append(candidate)

        if len(variants) >= n:
            break

    return variants


def generate_all_augmentations(
    nl_question: str,
    seed: int = 42,
    n_per_template: int = 3,
) -> Iterator[str]:
    """
    Yield augmented NL variants one at a time (generator).
    Suitable for streaming into a large dataset without memory pressure.
    """
    yield from augment_nl(nl_question, n=n_per_template, seed=seed)