File size: 11,511 Bytes
28035e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
"""Data profiler β€” samples the actual database to give the AI business context.

Profiles each table to discover:
- Categorical columns and their distinct values (status, type, category, etc.)
- Numeric column ranges (min, max, avg)
- Date column ranges
- Sample rows

This info is injected into the AI prompts so it can make smart
business decisions (e.g., filter by status='closed' for revenue).
"""

import time
from typing import Any

from sqlalchemy import text

from db.connection import get_engine
from db.schema import get_schema

# ── Cache ───────────────────────────────────────────────────────────────────
_profile_cache: str | None = None
_profile_ts: float = 0.0
_PROFILE_TTL: float = 600.0  # 10 minutes


def get_data_profile(force_refresh: bool = False) -> str:
    """Return a formatted data profile string for prompt injection."""
    global _profile_cache, _profile_ts

    if not force_refresh and _profile_cache and (time.time() - _profile_ts < _PROFILE_TTL):
        return _profile_cache

    schema = get_schema()
    profile_parts: list[str] = []

    engine = get_engine()
    with engine.connect() as conn:
        for table, columns in schema.items():
            table_profile = _profile_table(conn, table, columns)
            if table_profile:
                profile_parts.append(table_profile)

    # Auto-generate business rules
    rules = _generate_business_rules(schema)
    if rules:
        profile_parts.append(rules)

    _profile_cache = "\n".join(profile_parts)
    _profile_ts = time.time()
    return _profile_cache


def _profile_table(conn, table: str, columns: list[dict]) -> str:
    """Profile a single table."""
    lines: list[str] = [f"TABLE PROFILE: {table}"]

    # Row count
    try:
        count = conn.execute(text(f'SELECT count(*) FROM "{table}"')).scalar()
        lines.append(f"  Total rows: {count}")
    except Exception:
        return ""

    if count == 0:
        lines.append("  (empty table)")
        return "\n".join(lines)

    # Profile each column
    for col in columns:
        cname = col["column_name"]
        dtype = col["data_type"]

        try:
            if _is_categorical(dtype, cname):
                profile = _profile_categorical(conn, table, cname, count)
                if profile:
                    lines.append(profile)
            elif _is_numeric(dtype):
                profile = _profile_numeric(conn, table, cname)
                if profile:
                    lines.append(profile)
            elif _is_date(dtype):
                profile = _profile_date(conn, table, cname)
                if profile:
                    lines.append(profile)
        except Exception:
            continue

    lines.append("")
    return "\n".join(lines)


def _is_categorical(dtype: str, cname: str) -> bool:
    """Check if a column is likely categorical (status, type, category, etc.)."""
    categorical_types = {"character varying", "text", "varchar", "char", "character"}
    categorical_keywords = {
        "status", "state", "type", "category", "kind", "class",
        "group", "level", "tier", "grade", "priority", "stage",
        "flag", "mode", "role", "region", "country", "city",
        "gender", "channel", "source", "segment", "department",
    }
    if dtype.lower() in categorical_types:
        # Check if the column name suggests it's categorical
        lower_name = cname.lower()
        if any(kw in lower_name for kw in categorical_keywords):
            return True
        # Also profile short text columns
        return True
    return False


def _is_numeric(dtype: str) -> bool:
    numeric_types = {
        "integer", "bigint", "smallint", "numeric", "real",
        "double precision", "decimal", "float", "int",
    }
    return dtype.lower() in numeric_types


def _is_date(dtype: str) -> bool:
    date_types = {
        "date", "timestamp", "timestamp without time zone",
        "timestamp with time zone", "timestamptz",
    }
    return dtype.lower() in date_types


def _profile_categorical(conn, table: str, col: str, total_rows: int) -> str | None:
    """Get distinct values for categorical columns (up to 25 values)."""
    result = conn.execute(text(
        f'SELECT "{col}", count(*) as cnt FROM "{table}" '
        f'WHERE "{col}" IS NOT NULL '
        f'GROUP BY "{col}" ORDER BY cnt DESC LIMIT 25'
    )).fetchall()

    if not result:
        return None

    distinct_count = len(result)

    # Only profile if it's truly categorical (not too many unique values)
    if distinct_count > 20:
        # Check total distinct count
        total_distinct = conn.execute(text(
            f'SELECT count(DISTINCT "{col}") FROM "{table}" WHERE "{col}" IS NOT NULL'
        )).scalar()
        if total_distinct > 50:
            return f"  {col}: {total_distinct} distinct values (high cardinality - not categorical)"

    values_str = ", ".join(
        f"'{r[0]}' ({r[1]} rows)" for r in result[:15]
    )
    return f"  {col}: DISTINCT VALUES = [{values_str}]"


def _profile_numeric(conn, table: str, col: str) -> str | None:
    """Get min, max, avg for numeric columns."""
    result = conn.execute(text(
        f'SELECT min("{col}"), max("{col}"), round(avg("{col}")::numeric, 2) '
        f'FROM "{table}" WHERE "{col}" IS NOT NULL'
    )).fetchone()

    if not result or result[0] is None:
        return None

    return f"  {col}: min={result[0]}, max={result[1]}, avg={result[2]}"


def _profile_date(conn, table: str, col: str) -> str | None:
    """Get date range."""
    result = conn.execute(text(
        f'SELECT min("{col}"), max("{col}") '
        f'FROM "{table}" WHERE "{col}" IS NOT NULL'
    )).fetchone()

    if not result or result[0] is None:
        return None

    return f"  {col}: from {result[0]} to {result[1]}"


def _generate_business_rules(schema: dict[str, list[dict]]) -> str:
    """Auto-infer business rules from column patterns across all tables."""
    rules: list[str] = [
        "=" * 60,
        "BUSINESS INTELLIGENCE RULES β€” YOU MUST FOLLOW THESE",
        "=" * 60,
    ]

    # ── Rule 0: Query type awareness
    rules.append("")
    rules.append("RULE 0 β€” KNOW YOUR QUERY TYPE:")
    rules.append("  PRODUCT ATTRIBUTE queries (category, name, weight, details):")
    rules.append("    β†’ Use product/variant catalog tables directly.")
    rules.append("    β†’ No status filter needed.")
    rules.append("  PRODUCT PRICE queries (most expensive, cheapest, price lookup):")
    rules.append("    β†’ Use sales_order_line_pricing.selling_price_per_unit as source of truth.")
    rules.append("    β†’ JOIN to product_master for product_name. GROUP BY to avoid duplicates.")
    rules.append("  TRANSACTIONAL queries (revenue, AOV, order counts, sales trends):")
    rules.append("    β†’ Use sales tables. MUST filter by sales_order.status = 'closed'.")
    rules.append("    β†’ Examples: 'total revenue', 'AOV', 'top customers by spending'")

    # ── Rule 1: Avoiding duplicates
    rules.append("")
    rules.append("RULE 1 β€” AVOID DUPLICATE ROWS (CRITICAL):")
    rules.append("  When JOINing tables, products may have MULTIPLE variants (different karat, quality, etc.).")
    rules.append("  This causes duplicate product names in results.")
    rules.append("  ALWAYS use one of these to prevent duplicates:")
    rules.append("    - GROUP BY product_id (or product_name) with MAX/MIN/AVG on value columns")
    rules.append("    - SELECT DISTINCT when you only need unique values")
    rules.append("    - Use subqueries with aggregation before joining")
    rules.append("  NEVER return raw joins that produce repeated product names.")

    # ── Rule 2: Product price lookup
    rules.append("")
    rules.append("RULE 2 β€” PRODUCT PRICE LOOKUP (SOURCE OF TRUTH):")
    rules.append("  The SOURCE OF TRUTH for product prices is the sales_order_line_pricing table.")
    rules.append("  It has 'selling_price_per_unit' which is the actual price per 1 unit of a product.")
    rules.append("  For 'most expensive products', 'cheapest products', 'product price':")
    rules.append("    β†’ Query sales_order_line_pricing and JOIN to product tables for product_name")
    rules.append("    β†’ Use selling_price_per_unit (NOT line_total_price, NOT selling_price from catalog)")
    rules.append("    β†’ GROUP BY product_id, product_name and use MAX(selling_price_per_unit)")
    rules.append("    β†’ Join path: sales_order_line_pricing.product_id = product_master.product_id")
    rules.append("  Do NOT use product_variant_summary.selling_price or variant_sku_table.selling_price")
    rules.append("  β€” those are catalog/list prices, not actual transaction prices.")
    rules.append("  For 'highest revenue products' or 'best selling products':")
    rules.append("    β†’ Use SUM(line_total_price) grouped by product, filtered by status='closed'")

    # ── Rule 3: Status filtering (only for transactional queries)
    rules.append("")
    rules.append("RULE 3 β€” STATUS FILTERING (TRANSACTIONAL ONLY):")
    rules.append("  The 'status' column on the sales_order table has values: closed, open, cancelled, processing.")
    rules.append("  For revenue, AOV, sales counts: WHERE status = 'closed'")
    rules.append("  For product catalog queries: NO status filter needed")
    rules.append("  IMPORTANT: The 'status' column is ONLY on the sales_order table.")
    rules.append("  Do NOT look for payment_status or status on pricing/line tables β€” it does not exist there.")

    # ── Rule 4: Unit price vs total price
    rules.append("")
    rules.append("RULE 4 β€” UNIT PRICE vs TOTAL PRICE:")
    rules.append("  line_total_price = selling_price_per_unit Γ— quantity (total for order line)")
    rules.append("  selling_price_per_unit = the actual price of 1 unit of the product")
    rules.append("  base_price_per_unit = cost price of 1 unit before margin")
    rules.append("  NEVER use line_total_price as a product's price β€” it includes quantity.")
    rules.append("  To get a product's price: use selling_price_per_unit or selling_price column")

    # ── Rule 5: Common metrics formulas
    rules.append("")
    rules.append("RULE 5 β€” METRIC FORMULAS:")
    rules.append("  AOV = SUM(so.total_amount) / COUNT(DISTINCT so.so_id) WHERE so.status='closed'")
    rules.append("  Revenue = SUM(so.total_amount) WHERE so.status='closed'")
    rules.append("  Most Expensive Product = MAX(pvs.selling_price) GROUP BY product_id, product_name")
    rules.append("  Margin % = (selling_price - base_price) / selling_price Γ— 100")
    rules.append("  Order Count = COUNT(DISTINCT so.so_id) WHERE so.status='closed'")

    # ── Rule 6: Table relationships
    rules.append("")
    rules.append("RULE 6 β€” TABLE JOIN PATHS:")
    rules.append("  Sales chain: sales_order(so_id) β†’ sales_order_line(so_id, sol_id) β†’ sales_order_line_pricing(sol_id)")
    rules.append("  Product chain: product_master(product_id) β†’ product_variant_summary(product_id) β†’ variant_sku_table(variant_sku)")
    rules.append("  Sales ↔ Product: sales_order_line.variant_sku = variant_sku_table.variant_sku")
    rules.append("  Sales ↔ Customer: sales_order.customer_id = customer_master.customer_id")
    rules.append("  Sales ↔ Payment: sales_order.so_id = sales_order_payments.so_id")

    return "\n".join(rules)