mikeboone Claude Sonnet 4.6 commited on
Commit
5ac32c1
·
1 Parent(s): eae322e

feat: March 2026 sprint — new vision merge, pipeline improvements, settings refactor

Browse files

Core changes:
- Merged DemoPrep_new_vision2: StorySpec outlier injection engine, domain packs,
quality validator, 1600-line generator with deterministic seeds + trend signals
- Consolidated outlier_system.py + 3 data_adjuster files → demo_personas.py liveboard_questions
- Rewrote liveboard AI fill prompt: no redundant KPIs, business insight focus
- Fixed "Show me" title strip, Spotter Viz Story now persona-driven (Part 1) + AI (Part 2)
- Fixed progress meter stuck on Data during ThoughtSpot deployment
- Fixed liveboard name: UI field value now always takes priority over DB default
- Settings UI: split into Default Settings (AI Model, Use Case, LB Name) + App Settings
- Added session_logger.py, admin log viewer, state isolation (per-session loggers)
- TS environment dropdown, front page redesign, progress meter Init→Complete
- Supabase session logging with file fallback
- Sharing after every build (model + liveboard)
- Sage indexing retry on 10004, fallback TML skips invalid column refs
- Fixed domain-specific NAME column generation (DRUG_NAME was using fake.name())
- Batch test runner + test cases copied from new_vision2 branch

New files: session_logger.py, sprint_2026_03.md, legitdata storyspec/domain/quality packages,
tests/newvision_*, tests/test_mcp_liveboard.py
Removed: outlier_system.py, chat_data_adjuster.py, conversational_data_adjuster.py,
data_adjuster.py, sprint_2026_01.md

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (42) hide show
  1. .gitignore +1 -0
  2. CLAUDE.md +4 -2
  3. cdw_connector.py +177 -18
  4. chat_data_adjuster.py +0 -163
  5. chat_interface.py +787 -352
  6. conversational_data_adjuster.py +0 -448
  7. data_adjuster.py +0 -213
  8. demo_personas.py +366 -0
  9. legitdata_project/README.md +265 -0
  10. legitdata_project/legitdata/__init__.py +4 -0
  11. legitdata_project/legitdata/analyzer/column_classifier.py +129 -11
  12. legitdata_project/legitdata/ddl/parser.py +21 -2
  13. legitdata_project/legitdata/domain/__init__.py +22 -0
  14. legitdata_project/legitdata/domain/domain_packs.py +543 -0
  15. legitdata_project/legitdata/domain/semantic_types.py +292 -0
  16. legitdata_project/legitdata/generator.py +1286 -63
  17. legitdata_project/legitdata/quality/__init__.py +22 -0
  18. legitdata_project/legitdata/quality/quality_spec.py +149 -0
  19. legitdata_project/legitdata/quality/repair.py +612 -0
  20. legitdata_project/legitdata/quality/validator.py +935 -0
  21. legitdata_project/legitdata/relationships/fk_manager.py +6 -7
  22. legitdata_project/legitdata/sourcer/ai_generator.py +14 -31
  23. legitdata_project/legitdata/sourcer/generic.py +187 -28
  24. legitdata_project/legitdata/storyspec.py +220 -0
  25. legitdata_project/pyproject.toml +67 -0
  26. legitdata_project/test_legitdata.py +261 -0
  27. legitdata_project/test_with_ai.py +72 -0
  28. liveboard_creator.py +81 -75
  29. outlier_system.py +0 -769
  30. prompt_logger.py +2 -2
  31. prompts.py +14 -14
  32. session_logger.py +204 -0
  33. smart_data_adjuster.py +543 -536
  34. snowflake_auth.py +13 -43
  35. sprint_2026_01.md +0 -520
  36. sprint_2026_03.md +227 -0
  37. tests/__init__.py +1 -0
  38. tests/newvision_sample_runner.py +1246 -0
  39. tests/newvision_test_cases.yaml +46 -0
  40. tests/newvision_test_cases_2.yaml +278 -0
  41. tests/test_mcp_liveboard.py +54 -0
  42. thoughtspot_deployer.py +131 -46
.gitignore CHANGED
@@ -111,6 +111,7 @@ env.bak/
111
  venv.bak/
112
  demo_wire/ # Your virtual environment
113
  demo_wire/**
 
114
 
115
  # Spyder project settings
116
  .spyderproject
 
111
  venv.bak/
112
  demo_wire/ # Your virtual environment
113
  demo_wire/**
114
+ demoprep/ # Local virtual environment
115
 
116
  # Spyder project settings
117
  .spyderproject
CLAUDE.md CHANGED
@@ -33,7 +33,7 @@ See `dev_notes/USE_CASE_FLOW.md` for the full use case framework documentation.
33
 
34
  ## Current Sprint
35
 
36
- **ALWAYS READ FIRST**: `sprint_2026_02.md` (in root)
37
 
38
  For overall project status, see `PROJECT_STATUS.md` in root.
39
 
@@ -58,7 +58,7 @@ Example:
58
  - [x] snake_case naming ✅ - `_to_snake_case()` in deployer
59
  ```
60
 
61
- **Previous sprints:** `sprint_2026_01.md` (January - closed)
62
 
63
  ---
64
 
@@ -112,6 +112,7 @@ Example:
112
  2. **No silent fallbacks** — if something is missing, error immediately with a clear message
113
  - ❌ WRONG: `value = os.getenv('KEY', 'some_default')`
114
  - ✅ RIGHT: `value = get_admin_setting('KEY')` → raises error if empty
 
115
 
116
  3. **Single source of truth** — every piece of data should have one canonical source
117
  - Admin settings live in Supabase, read via `get_admin_setting()`
@@ -135,6 +136,7 @@ Example:
135
  2. **No silent fallbacks** - if something is missing, error immediately with a clear message
136
  - WRONG: `value = os.getenv('KEY', 'some_default')`
137
  - RIGHT: `value = get_admin_setting('KEY')` raises error if empty
 
138
 
139
  3. **Single source of truth** - every piece of data should have one canonical source
140
  - Admin settings live in Supabase, read via `get_admin_setting()`
 
33
 
34
  ## Current Sprint
35
 
36
+ **ALWAYS READ FIRST**: `sprint_2026_03.md` (in root)
37
 
38
  For overall project status, see `PROJECT_STATUS.md` in root.
39
 
 
58
  - [x] snake_case naming ✅ - `_to_snake_case()` in deployer
59
  ```
60
 
61
+ **Previous sprints:** `sprint_2026_02.md` (February/March - closed), `sprint_2026_01.md` (January - closed)
62
 
63
  ---
64
 
 
112
  2. **No silent fallbacks** — if something is missing, error immediately with a clear message
113
  - ❌ WRONG: `value = os.getenv('KEY', 'some_default')`
114
  - ✅ RIGHT: `value = get_admin_setting('KEY')` → raises error if empty
115
+ - **User identity is required**: Never use fallback users like `default@user.com` or `USER_EMAIL` env defaults for settings. If authenticated username/email is missing, fail immediately with a clear error.
116
 
117
  3. **Single source of truth** — every piece of data should have one canonical source
118
  - Admin settings live in Supabase, read via `get_admin_setting()`
 
136
  2. **No silent fallbacks** - if something is missing, error immediately with a clear message
137
  - WRONG: `value = os.getenv('KEY', 'some_default')`
138
  - RIGHT: `value = get_admin_setting('KEY')` raises error if empty
139
+ - User identity is required: Never use fallback users like `default@user.com` or `USER_EMAIL` env defaults for settings. If authenticated username/email is missing, fail immediately with a clear error.
140
 
141
  3. **Single source of truth** - every piece of data should have one canonical source
142
  - Admin settings live in Supabase, read via `get_admin_setting()`
cdw_connector.py CHANGED
@@ -3,9 +3,9 @@ CDW Connector - Sprint 2A
3
  Snowflake deployment functionality for generated schemas and data
4
  """
5
 
6
- import os
7
  from datetime import datetime
8
  from typing import Dict, List, Tuple, Optional
 
9
 
10
  try:
11
  import snowflake.connector
@@ -19,18 +19,17 @@ try:
19
  except ImportError:
20
  SQLPARSE_AVAILABLE = False
21
 
22
- try:
23
- from dotenv import load_dotenv
24
- load_dotenv(override=True)
25
- except ImportError:
26
- pass
27
-
28
  try:
29
  from snowflake_auth import get_snowflake_connection_params
30
  SNOWFLAKE_AUTH_AVAILABLE = True
31
  except ImportError:
32
  SNOWFLAKE_AUTH_AVAILABLE = False
33
 
 
 
 
 
 
34
  class SnowflakeDeployer:
35
  """Handle Snowflake deployment for generated schemas and data"""
36
 
@@ -42,14 +41,11 @@ class SnowflakeDeployer:
42
  def connect(self) -> Tuple[bool, str]:
43
  """Connect to Snowflake using OKTA authentication"""
44
  try:
45
- # Validate required environment variables
46
- required_vars = ['SNOWFLAKE_ACCOUNT', 'SNOWFLAKE_KP_USER', 'SNOWFLAKE_KP_PK', 'SNOWFLAKE_WAREHOUSE', 'SNOWFLAKE_DATABASE']
47
- missing_vars = [var for var in required_vars if not os.getenv(var)]
48
-
49
- if missing_vars:
50
- return False, f"Missing required environment variables: {missing_vars}"
51
-
52
- # Key pair authentication - let's fix the implementation
53
  if not SNOWFLAKE_AUTH_AVAILABLE:
54
  return False, "snowflake_auth module not available"
55
 
@@ -117,8 +113,10 @@ class SnowflakeDeployer:
117
 
118
  # Parse and execute DDL statements
119
  statements = self._parse_ddl_statements(ddl_statements, schema_name)
 
120
 
121
  executed_count = 0
 
122
  for statement in statements:
123
  if statement.strip():
124
  try:
@@ -126,8 +124,32 @@ class SnowflakeDeployer:
126
  cursor.execute(statement)
127
  executed_count += 1
128
  except Exception as e:
129
- print(f"Warning: Failed to execute statement: {str(e)}")
130
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  # Check autocommit setting
133
  print(f"🔍 DEBUG: Connection autocommit: {self.connection.autocommit}")
@@ -144,6 +166,17 @@ class SnowflakeDeployer:
144
  print(f"🔍 DEBUG: Current schema: {current_schema}")
145
  print(f"🔍 DEBUG: Schema verification completed - using current schema context")
146
 
 
 
 
 
 
 
 
 
 
 
 
147
  cursor.close()
148
 
149
  success_message = f"Schema '{schema_name}' created successfully with {executed_count} tables"
@@ -168,7 +201,12 @@ class SnowflakeDeployer:
168
  clean_text = ddl_text.replace("```sql", "").replace("```", "").strip()
169
 
170
  # Split by semicolon
171
- statements = [stmt.strip() for stmt in clean_text.split(';') if stmt.strip()]
 
 
 
 
 
172
  processed_statements = []
173
 
174
  # Separate tables with and without foreign keys
@@ -190,6 +228,101 @@ class SnowflakeDeployer:
190
  except Exception as e:
191
  print(f"Error parsing DDL statements: {str(e)}")
192
  return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  def _qualify_table_name(self, statement: str, schema_name: str) -> str:
195
  """Ensure table names are properly schema-qualified"""
@@ -223,6 +356,32 @@ class SnowflakeDeployer:
223
  except Exception as e:
224
  print(f"Error qualifying table name: {str(e)}")
225
  return statement # Return original if processing fails
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  def list_demo_schemas(self) -> List[Dict]:
228
  """List all demo schemas in the database"""
 
3
  Snowflake deployment functionality for generated schemas and data
4
  """
5
 
 
6
  from datetime import datetime
7
  from typing import Dict, List, Tuple, Optional
8
+ import re
9
 
10
  try:
11
  import snowflake.connector
 
19
  except ImportError:
20
  SQLPARSE_AVAILABLE = False
21
 
 
 
 
 
 
 
22
  try:
23
  from snowflake_auth import get_snowflake_connection_params
24
  SNOWFLAKE_AUTH_AVAILABLE = True
25
  except ImportError:
26
  SNOWFLAKE_AUTH_AVAILABLE = False
27
 
28
+ try:
29
+ from supabase_client import inject_admin_settings_to_env
30
+ except ImportError:
31
+ inject_admin_settings_to_env = None
32
+
33
  class SnowflakeDeployer:
34
  """Handle Snowflake deployment for generated schemas and data"""
35
 
 
41
  def connect(self) -> Tuple[bool, str]:
42
  """Connect to Snowflake using OKTA authentication"""
43
  try:
44
+ # Keep legacy env-based consumers aligned with Supabase admin settings.
45
+ if inject_admin_settings_to_env:
46
+ inject_admin_settings_to_env()
47
+
48
+ # Key pair authentication via centralized auth helper
 
 
 
49
  if not SNOWFLAKE_AUTH_AVAILABLE:
50
  return False, "snowflake_auth module not available"
51
 
 
113
 
114
  # Parse and execute DDL statements
115
  statements = self._parse_ddl_statements(ddl_statements, schema_name)
116
+ expected_tables = self._extract_expected_table_names(statements)
117
 
118
  executed_count = 0
119
+ failed_statements = []
120
  for statement in statements:
121
  if statement.strip():
122
  try:
 
124
  cursor.execute(statement)
125
  executed_count += 1
126
  except Exception as e:
127
+ failed_statements.append((statement, str(e)))
128
+
129
+ # Retry failed statements once to handle dependency-order issues.
130
+ if failed_statements:
131
+ print(f"⚠️ First DDL pass had {len(failed_statements)} failures. Retrying once...")
132
+ retry_failures = []
133
+ for statement, first_error in failed_statements:
134
+ try:
135
+ cursor.execute(statement)
136
+ executed_count += 1
137
+ except Exception as e:
138
+ retry_failures.append((statement, first_error, str(e)))
139
+ failed_statements = retry_failures
140
+
141
+ if failed_statements:
142
+ preview_errors = []
143
+ for statement, first_error, second_error in failed_statements[:3]:
144
+ stmt_preview = statement.replace("\n", " ")[:120]
145
+ preview_errors.append(
146
+ f"- {stmt_preview} | first_error={first_error} | retry_error={second_error}"
147
+ )
148
+ error_summary = "\n".join(preview_errors)
149
+ raise Exception(
150
+ f"DDL deployment failed after retry; {len(failed_statements)} statements still failing.\n"
151
+ f"{error_summary}"
152
+ )
153
 
154
  # Check autocommit setting
155
  print(f"🔍 DEBUG: Connection autocommit: {self.connection.autocommit}")
 
166
  print(f"🔍 DEBUG: Current schema: {current_schema}")
167
  print(f"🔍 DEBUG: Schema verification completed - using current schema context")
168
 
169
+ # Verify expected tables actually exist before declaring success.
170
+ if expected_tables:
171
+ cursor.execute(f'SHOW TABLES IN SCHEMA "{schema_name}"')
172
+ existing_tables = {row[1].upper() for row in cursor.fetchall()}
173
+ missing_tables = sorted(tbl for tbl in expected_tables if tbl not in existing_tables)
174
+ if missing_tables:
175
+ raise Exception(
176
+ "DDL deployment completed but required tables are missing in schema "
177
+ f"{schema_name}: {missing_tables}"
178
+ )
179
+
180
  cursor.close()
181
 
182
  success_message = f"Schema '{schema_name}' created successfully with {executed_count} tables"
 
201
  clean_text = ddl_text.replace("```sql", "").replace("```", "").strip()
202
 
203
  # Split by semicolon
204
+ statements = []
205
+ for stmt in clean_text.split(';'):
206
+ stmt = stmt.strip()
207
+ if not stmt:
208
+ continue
209
+ statements.append(self._sanitize_create_table_statement(stmt))
210
  processed_statements = []
211
 
212
  # Separate tables with and without foreign keys
 
228
  except Exception as e:
229
  print(f"Error parsing DDL statements: {str(e)}")
230
  return []
231
+
232
+ def _sanitize_create_table_statement(self, statement: str) -> str:
233
+ """
234
+ Remove CHECK constraints from CREATE TABLE statements.
235
+ This deploy path rejects CHECK constraints, so we sanitize deterministically.
236
+ """
237
+ if "CREATE TABLE" not in statement.upper():
238
+ return statement
239
+
240
+ open_idx = statement.find("(")
241
+ if open_idx == -1:
242
+ return statement
243
+
244
+ close_idx = self._find_matching_paren(statement, open_idx)
245
+ if close_idx == -1:
246
+ return statement
247
+
248
+ inner = statement[open_idx + 1:close_idx]
249
+ chunks = self._split_top_level_csv(inner)
250
+
251
+ sanitized_chunks = []
252
+ removed_any = False
253
+ for chunk in chunks:
254
+ stripped = chunk.strip()
255
+ if not stripped:
256
+ continue
257
+
258
+ upper = stripped.upper()
259
+ is_table_check = upper.startswith("CHECK") or ("CHECK" in upper and upper.startswith("CONSTRAINT"))
260
+ if is_table_check:
261
+ removed_any = True
262
+ continue
263
+
264
+ cleaned = self._remove_inline_check_clause(stripped)
265
+ if cleaned != stripped:
266
+ removed_any = True
267
+ if cleaned.strip():
268
+ sanitized_chunks.append(cleaned.strip())
269
+
270
+ if not removed_any:
271
+ return statement
272
+
273
+ rebuilt_inner = ",\n ".join(sanitized_chunks)
274
+ rebuilt = f"{statement[:open_idx + 1]}\n {rebuilt_inner}\n{statement[close_idx:]}"
275
+ print("⚠️ Sanitized unsupported CHECK constraint(s) from CREATE TABLE statement.")
276
+ return rebuilt
277
+
278
+ def _find_matching_paren(self, text: str, open_idx: int) -> int:
279
+ """Find matching closing parenthesis index for text[open_idx] == '('."""
280
+ depth = 0
281
+ for i in range(open_idx, len(text)):
282
+ ch = text[i]
283
+ if ch == "(":
284
+ depth += 1
285
+ elif ch == ")":
286
+ depth -= 1
287
+ if depth == 0:
288
+ return i
289
+ return -1
290
+
291
+ def _split_top_level_csv(self, text: str) -> List[str]:
292
+ """Split by commas that are not nested in parentheses."""
293
+ parts = []
294
+ start = 0
295
+ depth = 0
296
+ for i, ch in enumerate(text):
297
+ if ch == "(":
298
+ depth += 1
299
+ elif ch == ")":
300
+ depth = max(0, depth - 1)
301
+ elif ch == "," and depth == 0:
302
+ parts.append(text[start:i])
303
+ start = i + 1
304
+ parts.append(text[start:])
305
+ return parts
306
+
307
+ def _remove_inline_check_clause(self, chunk: str) -> str:
308
+ """Remove inline `CHECK (...)` clause from a column definition chunk."""
309
+ m = re.search(r"\bCHECK\b", chunk, flags=re.IGNORECASE)
310
+ if not m:
311
+ return chunk
312
+
313
+ start = m.start()
314
+ i = m.end()
315
+ while i < len(chunk) and chunk[i].isspace():
316
+ i += 1
317
+ if i >= len(chunk) or chunk[i] != "(":
318
+ # Bare CHECK token; remove token only as fallback.
319
+ return (chunk[:start] + chunk[m.end():]).strip()
320
+
321
+ end = self._find_matching_paren(chunk, i)
322
+ if end == -1:
323
+ return chunk
324
+
325
+ return (chunk[:start] + " " + chunk[end + 1:]).strip()
326
 
327
  def _qualify_table_name(self, statement: str, schema_name: str) -> str:
328
  """Ensure table names are properly schema-qualified"""
 
356
  except Exception as e:
357
  print(f"Error qualifying table name: {str(e)}")
358
  return statement # Return original if processing fails
359
+
360
+ def _extract_expected_table_names(self, statements: List[str]) -> List[str]:
361
+ """Extract CREATE TABLE targets for post-deploy existence verification."""
362
+ expected = []
363
+ table_pattern = re.compile(
364
+ r'CREATE\s+(?:OR\s+REPLACE\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?'
365
+ r'(?:(?:"[^"]+"|[A-Za-z0-9_]+)\.)?(?:"([^"]+)"|([A-Za-z0-9_]+))',
366
+ re.IGNORECASE
367
+ )
368
+
369
+ for statement in statements:
370
+ match = table_pattern.search(statement)
371
+ if not match:
372
+ continue
373
+ table_name = (match.group(1) or match.group(2) or "").strip()
374
+ if table_name:
375
+ expected.append(table_name.upper())
376
+
377
+ # Keep order but remove duplicates.
378
+ seen = set()
379
+ ordered_unique = []
380
+ for name in expected:
381
+ if name not in seen:
382
+ seen.add(name)
383
+ ordered_unique.append(name)
384
+ return ordered_unique
385
 
386
  def list_demo_schemas(self) -> List[Dict]:
387
  """List all demo schemas in the database"""
chat_data_adjuster.py DELETED
@@ -1,163 +0,0 @@
1
- """
2
- Simple Chat Interface for Data Adjustment
3
-
4
- A basic command-line chat to test the conversational data adjuster.
5
- User can keep making adjustments until they type 'done' or 'exit'.
6
- """
7
-
8
- from dotenv import load_dotenv
9
- import os
10
- from conversational_data_adjuster import ConversationalDataAdjuster
11
-
12
- load_dotenv()
13
-
14
-
15
- def chat_loop():
16
- """Main chat loop for data adjustment"""
17
-
18
- print("""
19
- ╔════════════════════════════════════════════════════════════╗
20
- ║ ║
21
- ║ Data Adjustment Chat ║
22
- ║ ║
23
- ╚════════════════════════════════════════════════════════════╝
24
-
25
- Commands:
26
- - Type your adjustment request naturally
27
- - "done" or "exit" to quit
28
- - "help" for examples
29
-
30
- Examples:
31
- - "increase 1080p Webcam sales to 50B"
32
- - "set profit margin to 25% for electronics"
33
- - "make Tablet revenue 100 billion"
34
-
35
- """)
36
-
37
- # Initialize adjuster
38
- database = os.getenv('SNOWFLAKE_DATABASE')
39
- schema = "20251116_140933_AMAZO_SAL"
40
- model_id = "3c97b0d6-448b-440a-b628-bac1f3d73049"
41
-
42
- print(f"📊 Connected to: {database}.{schema}")
43
- print(f"🎯 Model: {model_id}\n")
44
-
45
- adjuster = ConversationalDataAdjuster(database, schema, model_id)
46
- adjuster.connect()
47
-
48
- tables = adjuster.get_available_tables()
49
- print(f"📋 Available tables: {', '.join(tables)}\n")
50
- print("="*80)
51
- print("Ready! Type your adjustment request...")
52
- print("="*80 + "\n")
53
-
54
- while True:
55
- # Get user input
56
- user_input = input("\n💬 You: ").strip()
57
-
58
- if not user_input:
59
- continue
60
-
61
- # Check for exit commands
62
- if user_input.lower() in ['done', 'exit', 'quit', 'bye']:
63
- print("\n👋 Goodbye!")
64
- break
65
-
66
- # Check for help
67
- if user_input.lower() == 'help':
68
- print("""
69
- 📚 Help - How to make adjustments:
70
-
71
- Format: "make/increase/set [entity] [metric] to [value]"
72
-
73
- Examples:
74
- ✅ "increase 1080p Webcam revenue to 50 billion"
75
- ✅ "set profit margin to 25% for electronics"
76
- ✅ "make Laptop sales 100B"
77
- ✅ "increase customer segment premium revenue by 30%"
78
-
79
- You'll see 3 strategy options - type A, B, or C to pick one.
80
- """)
81
- continue
82
-
83
- try:
84
- # Step 1: Parse the request
85
- print(f"\n🤔 Parsing your request...")
86
- adjustment = adjuster.parse_adjustment_request(user_input, tables)
87
-
88
- if 'error' in adjustment:
89
- print(f"❌ {adjustment['error']}")
90
- print("💡 Try rephrasing or type 'help' for examples")
91
- continue
92
-
93
- # Step 2: Analyze current data
94
- print(f"📊 Analyzing current data...")
95
- analysis = adjuster.analyze_current_data(adjustment)
96
-
97
- if analysis['current_total'] == 0:
98
- print(f"⚠️ No data found for '{adjustment['entity_value']}'")
99
- print("💡 Try a different product/entity name")
100
- continue
101
-
102
- # Step 3: Generate strategies
103
- strategies = adjuster.generate_strategy_options(adjustment, analysis)
104
-
105
- # Step 4: Present options
106
- adjuster.present_options(adjustment, analysis, strategies)
107
-
108
- # Step 5: Get user's strategy choice
109
- print("\n" + "="*80)
110
- choice = input("Which strategy? [A/B/C or 'skip']: ").strip().upper()
111
-
112
- if choice == 'SKIP' or not choice:
113
- print("⏭️ Skipping this adjustment")
114
- continue
115
-
116
- # Find the chosen strategy
117
- chosen = None
118
- for s in strategies:
119
- if s['id'] == choice:
120
- chosen = s
121
- break
122
-
123
- if not chosen:
124
- print(f"❌ Invalid choice: {choice}")
125
- continue
126
-
127
- # Step 6: Confirm
128
- print(f"\n⚠️ About to execute: {chosen['name']}")
129
- print(f" This will affect {chosen.get('details', {}).get('rows_affected', 'some')} rows")
130
- confirm = input(" Confirm? [yes/no]: ").strip().lower()
131
-
132
- if confirm not in ['yes', 'y']:
133
- print("❌ Cancelled")
134
- continue
135
-
136
- # Step 7: Execute
137
- result = adjuster.execute_strategy(chosen)
138
-
139
- if result['success']:
140
- print(f"\n✅ {result['message']}")
141
- print(f"🔄 Data updated! Refresh your ThoughtSpot liveboard to see changes.")
142
- else:
143
- print(f"\n❌ Failed: {result.get('error')}")
144
-
145
- except KeyboardInterrupt:
146
- print("\n\n⚠️ Interrupted")
147
- break
148
- except Exception as e:
149
- print(f"\n❌ Error: {e}")
150
- import traceback
151
- print(traceback.format_exc())
152
-
153
- # Cleanup
154
- adjuster.close()
155
- print("\n✅ Connection closed")
156
-
157
-
158
- if __name__ == "__main__":
159
- try:
160
- chat_loop()
161
- except KeyboardInterrupt:
162
- print("\n\n👋 Goodbye!")
163
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chat_interface.py CHANGED
@@ -13,6 +13,7 @@ import sys
13
  import json
14
  import time
15
  import glob
 
16
  from dotenv import load_dotenv
17
  from demo_builder_class import DemoBuilder
18
  from supabase_client import load_gradio_settings, get_admin_setting, inject_admin_settings_to_env
@@ -35,6 +36,53 @@ from llm_config import (
35
  load_dotenv(override=True)
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # ==========================================================================
39
  # SETTINGS SCHEMA - Single source of truth for all settings
40
  # To add a new setting: add ONE entry here, then create the UI component
@@ -43,7 +91,7 @@ load_dotenv(override=True)
43
  SETTINGS_SCHEMA = [
44
  # App Settings - Left Column
45
  ('default_ai_model', 'default_llm', DEFAULT_LLM_MODEL, str),
46
- ('default_company_url', 'default_company_url', '', str),
47
  ('default_use_case', 'default_use_case', 'Sales Analytics', str),
48
  ('liveboard_name', 'liveboard_name', '', str),
49
  ('tag_name', 'tag_name', '', str),
@@ -68,8 +116,10 @@ SETTINGS_SCHEMA = [
68
  ('default_warehouse', 'default_warehouse', 'COMPUTE_WH', str),
69
  ('default_database', 'default_database', 'DEMO_DB', str),
70
  ('default_schema', 'default_schema', 'PUBLIC', str),
71
- ('ts_instance_url', 'thoughtspot_url', '', str),
72
  ('ts_username', 'thoughtspot_username', '', str),
 
 
73
  # Status (special - not saved, just displayed)
74
  ('settings_status', None, '', str),
75
  ]
@@ -144,10 +194,9 @@ def require_authenticated_email(request: gr.Request = None, user_email: str = No
144
 
145
  def build_initial_chat_message(company: str, use_case: str) -> str:
146
  """Build the pre-filled chat message from current settings."""
147
- return (
148
- "I am creating a demo for the company: "
149
- f"{company} and for the use case: {use_case}"
150
- )
151
 
152
 
153
  def safe_print(*args, **kwargs):
@@ -185,6 +234,9 @@ class ChatDemoInterface:
185
  self.live_progress_log = [] # Real-time deployment progress
186
  self.demo_pack_content = "" # Generated demo pack markdown
187
  self.spotter_viz_story = "" # Spotter Viz story (NL prompts for Spotter Viz agent)
 
 
 
188
 
189
  def _get_effective_user_email(self) -> str:
190
  """Resolve and cache effective user identity for settings access."""
@@ -252,37 +304,23 @@ class ChatDemoInterface:
252
 
253
  return f"""👋 **Welcome to ThoughtSpot Demo Builder!**
254
 
255
- I need two things to get started:
256
 
257
- **1. Company URL** (must be a real website)
258
- - Examples: Nike.com, Target.com, Walmart.com
259
- - I'll research their actual business model and data
 
 
260
 
261
- **2. Use Case** (Vertical × Function, or any custom use case!)
262
-
263
- **Configured combinations** (with KPIs, outliers, and Spotter questions):
264
  {uc_list}
265
  - ...and more!
266
-
267
- **Or create your own custom use case:**
268
- - "analyzing customer churn patterns across regions"
269
- - "tracking manufacturing defects by product line"
270
- - Literally ANYTHING — AI will research and build it!
271
-
272
- **How to tell me (BOTH required):**
273
-
274
- ```
275
- I'm creating a demo for company: Nike.com use case: Retail Sales
276
- ```
277
 
278
- Or with custom use case:
279
- ```
280
- I'm creating a demo for company: Target.com use case: analyzing seasonal inventory trends
281
- ```
282
 
283
- ⚠️ **Important:** You MUST provide BOTH company AND use case in one message!
284
 
285
- *Ready? Tell me your company and use case!*"""
286
  else:
287
  return f"""👋 **Welcome to ThoughtSpot Demo Builder!**
288
 
@@ -335,10 +373,27 @@ What's your first step?"""
335
  Process user message and return updated chat history and state (with streaming)
336
  Returns: (chat_history, current_stage, current_model, company, use_case, next_textbox_value)
337
  """
338
-
 
 
 
 
 
 
 
 
 
 
339
  # Add user message to history
340
  chat_history.append((message, None))
341
-
 
 
 
 
 
 
 
342
  # Validate required settings before proceeding
343
  missing = self.validate_required_settings()
344
  if missing and current_stage == 'initialization':
@@ -369,6 +424,70 @@ What's your first step?"""
369
 
370
  # Stage-based processing
371
  if current_stage == 'initialization':
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  # Check if user just provided a standalone URL (e.g., "Comscore.com")
373
  standalone_url = re.search(r'^([a-zA-Z0-9-]+\.[a-zA-Z]{2,})$', cleaned_message.strip())
374
  if standalone_url:
@@ -540,7 +659,71 @@ This may take 1-2 minutes. Watch the AI Feedback tab for progress!"""
540
  yield chat_history, current_stage, current_model, company, use_case, ""
541
 
542
  return
543
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  elif current_stage == 'awaiting_context':
545
  # User is providing context for ANY use case (both generic and established)
546
  if message_lower.strip() in ['proceed', 'continue', 'no', 'skip']:
@@ -618,6 +801,10 @@ Watch the AI Feedback tab for real-time progress!"""
618
  for progress_update in self.run_deployment_streaming():
619
  if isinstance(progress_update, tuple):
620
  final_result = progress_update
 
 
 
 
621
  else:
622
  deploy_update_count += 1
623
  # Only yield every 5th deployment update
@@ -815,10 +1002,12 @@ This may take 1-2 minutes. Watch the AI Feedback tab for progress!"""
815
  final_result = None
816
  for progress_update in self.run_deployment_streaming():
817
  if isinstance(progress_update, tuple):
818
- # Final result with next_msg
819
  final_result = progress_update
 
 
 
 
820
  else:
821
- # Progress update - show in chat
822
  chat_history[-1] = (message, f"**DDL Approved - Deploying...**\n\n{progress_update}")
823
  yield chat_history, current_stage, current_model, company, use_case, ""
824
 
@@ -881,10 +1070,12 @@ This may take 1-2 minutes. Watch the AI Feedback tab for progress!"""
881
  final_result = None
882
  for progress_update in self.run_deployment_streaming():
883
  if isinstance(progress_update, tuple):
884
- # Final result with next_msg
885
  final_result = progress_update
 
 
 
 
886
  else:
887
- # Progress update - show in chat
888
  chat_history[-1] = (message, progress_update)
889
  yield chat_history, current_stage, current_model, company, use_case, ""
890
 
@@ -930,6 +1121,10 @@ This may take 1-2 minutes. Watch the AI Feedback tab for progress!"""
930
  for progress_update in self.run_deployment_streaming():
931
  if isinstance(progress_update, tuple):
932
  final_result = progress_update
 
 
 
 
933
  else:
934
  chat_history.append((None, progress_update))
935
  yield chat_history, current_stage, current_model, company, use_case, ""
@@ -1191,7 +1386,7 @@ This may take 1-2 minutes. Watch the AI Feedback tab for progress!"""
1191
  return
1192
  return
1193
 
1194
- elif current_stage == 'outliers':
1195
  # Handle outlier adjustment stage
1196
  if 'done' in message_lower or 'finish' in message_lower or 'complete' in message_lower:
1197
  # Close adjuster connection
@@ -1316,40 +1511,49 @@ Try again or type **'done'** to finish."""
1316
  yield chat_history, current_stage, current_model, company, use_case, ""
1317
  return
1318
 
1319
- # Get current value
1320
- current_value = adjuster.get_current_value(
1321
- match['entity_value'],
1322
- match['metric_column']
1323
- )
1324
-
1325
- if current_value == 0:
1326
- response = f"""❌ **No data found**
1327
 
1328
- Could not find data for '{match['entity_value']}'.
 
 
 
 
1329
 
1330
- Please check the spelling or try a different entity.
1331
- Type **'done'** to finish."""
 
 
 
1332
  chat_history[-1] = (message, response)
1333
  yield chat_history, current_stage, current_model, company, use_case, ""
1334
  return
1335
-
1336
- # Calculate target if percentage
1337
  target_value = match.get('target_value')
1338
- if match.get('is_percentage'):
1339
- percentage = match.get('percentage', 0)
1340
  target_value = current_value * (1 + percentage / 100)
1341
  match['target_value'] = target_value
1342
-
1343
  # Generate strategy
1344
  strategy = adjuster.generate_strategy(
1345
  match['entity_value'],
1346
- match['metric_column'],
1347
  current_value,
1348
- target_value
 
 
1349
  )
1350
-
1351
  # Present smart confirmation
1352
- confirmation = adjuster.present_smart_confirmation(match, current_value, strategy)
1353
 
1354
  # Store for execution if user confirms
1355
  self._pending_adjustment = {
@@ -1414,7 +1618,12 @@ Try a different request or type **'done'** to finish."""
1414
  # Clean up trailing dots just in case
1415
  company = company.rstrip('.')
1416
  return company
1417
-
 
 
 
 
 
1418
  return None
1419
 
1420
  def extract_use_case_from_message(self, message):
@@ -1447,7 +1656,18 @@ Try a different request or type **'done'** to finish."""
1447
  if re.search(r'\.(com|org|net|io|co|ai)\b', use_case, re.IGNORECASE):
1448
  continue
1449
  return use_case
1450
-
 
 
 
 
 
 
 
 
 
 
 
1451
  return None
1452
 
1453
  def handle_override(self, message):
@@ -1505,12 +1725,15 @@ To change settings, use:
1505
 
1506
  def run_research_streaming(self, company, use_case, generic_context=""):
1507
  """Run the research phase with streaming updates
1508
-
1509
  Args:
1510
  company: Company URL/name
1511
  use_case: Use case name
1512
  generic_context: Additional context provided by user for generic use cases
1513
  """
 
 
 
1514
  print(f"\n\n[CACHE DEBUG] === run_research_streaming called ===")
1515
  print(f"[CACHE DEBUG] company: {company}")
1516
  print(f"[CACHE DEBUG] use_case: {use_case}\n\n")
@@ -1897,38 +2120,33 @@ To change settings, use:
1897
  """Generate use-case specific Spotter questions based on the schema.
1898
 
1899
  Priority order:
1900
- 1. OutlierPattern.spotter_questions from the vertical×function system
1901
- 2. FUNCTIONS[fn].spotter_templates from the config
1902
  3. Hardcoded fallbacks per use case
1903
  4. Generic questions
1904
  """
1905
  import re
1906
-
1907
- # Priority 1: Try the outlier system for configured Spotter questions
1908
  try:
1909
- from outlier_system import get_outliers_for_use_case
1910
- v = self.vertical or "Generic"
1911
- f = self.function or "Generic"
1912
- outlier_config = get_outliers_for_use_case(v, f)
1913
-
1914
  configured_questions = []
1915
- for op in outlier_config.required + outlier_config.optional:
1916
- for sq in op.spotter_questions:
1917
  configured_questions.append({
1918
  'question': sq,
1919
- 'purpose': f'Reveals {op.name} pattern'
1920
- })
1921
- for fq in op.spotter_followups[:1]: # 1 followup per pattern
1922
- configured_questions.append({
1923
- 'question': fq,
1924
- 'purpose': f'Follow-up on {op.name}'
1925
  })
1926
-
1927
  if configured_questions:
1928
- self.log_feedback(f"📋 Using {len(configured_questions)} Spotter questions from {v}×{f} outlier config")
1929
- return configured_questions[:8] # Cap at 8
 
 
1930
  except Exception as e:
1931
- self.log_feedback(f"⚠️ Outlier Spotter questions not available: {e}")
1932
 
1933
  # Priority 2: Try FUNCTIONS config for spotter_templates
1934
  try:
@@ -2042,89 +2260,111 @@ To change settings, use:
2042
  - **Ask questions**: Let the AI demonstrate natural language
2043
  - **End with action**: Show how insights lead to decisions""")
2044
 
2045
- def _generate_spotter_viz_story(self, company_name: str, use_case: str,
2046
  model_name: str = None, liveboard_name: str = None) -> str:
2047
  """Generate a Spotter Viz story — a conversational sequence of NL prompts
2048
  that can be entered into ThoughtSpot's Spotter Viz agent to build a liveboard.
2049
-
2050
- Uses the build_prompt() system with stage="spotter_viz_story" + LLM call.
2051
- Falls back to a template-based story if LLM fails.
 
2052
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2053
  try:
2054
  from prompts import build_prompt
2055
- from demo_personas import parse_use_case
2056
-
2057
- v, f = parse_use_case(use_case or '')
2058
  vertical = v or "Generic"
2059
  function = f or "Generic"
2060
-
2061
- # Build company context for the prompt
2062
  company_context = f"Company: {company_name}\nUse Case: {use_case}"
2063
  if model_name:
2064
  company_context += f"\nData Source/Model: {model_name}"
2065
  if liveboard_name:
2066
  company_context += f"\nLiveboard Name: {liveboard_name}"
2067
-
2068
- # Add research context if available
2069
  if hasattr(self, 'demo_builder') and self.demo_builder:
2070
  research = getattr(self.demo_builder, 'company_summary', '') or ''
2071
  if research:
2072
  company_context += f"\n\nCompany Research:\n{research[:1500]}"
2073
-
2074
  prompt = build_prompt(
2075
  stage="spotter_viz_story",
2076
  vertical=vertical,
2077
  function=function,
2078
  company_context=company_context,
2079
  )
2080
-
2081
- # Make LLM call
2082
- from prompt_logger import logged_completion
2083
- llm_model = self.settings.get('model', DEFAULT_LLM_MODEL)
2084
- self.log_feedback(f"🎬 Generating Spotter Viz story ({llm_model})...")
2085
-
2086
- response = logged_completion(
2087
- stage="spotter_viz_story",
2088
- model=llm_model,
2089
- messages=[{"role": "user", "content": prompt}],
2090
- max_tokens=2000,
2091
- temperature=0.7,
2092
- )
2093
-
2094
- story = response.choices[0].message.content.strip()
2095
-
2096
- # Add header
2097
- header = f"""# Spotter Viz Story: {company_name}
2098
- ## {use_case}
2099
 
2100
- *Copy these prompts into ThoughtSpot Spotter Viz to build this liveboard interactively.*
 
2101
 
2102
- ---
 
 
 
2103
 
2104
- """
2105
- return header + story
2106
-
2107
  except Exception as e:
2108
- self.log_feedback(f"⚠️ Spotter Viz story generation failed: {e}")
2109
- # Fallback: build a basic template from what we know
2110
- return self._build_fallback_spotter_story(company_name, use_case, model_name)
 
2111
 
2112
  def _build_fallback_spotter_story(self, company_name: str, use_case: str,
2113
  model_name: str = None) -> str:
2114
  """Build a basic Spotter Viz story without LLM, using available context."""
2115
  data_source = model_name or f"{company_name} model"
2116
 
2117
- # Get spotter questions from outlier system
2118
  spotter_qs = []
2119
  try:
2120
- from demo_personas import parse_use_case
2121
- from outlier_system import get_outliers_for_use_case
2122
  v, f = parse_use_case(use_case or '')
2123
- if v or f:
2124
- outlier_config = get_outliers_for_use_case(v or "Generic", f or "Generic")
2125
- for op in outlier_config.required:
2126
- for sq in op.spotter_questions[:1]:
2127
- spotter_qs.append(sq)
2128
  except:
2129
  pass
2130
 
@@ -2641,7 +2881,10 @@ Generate complete CREATE TABLE statements with proper Snowflake syntax and depen
2641
  # Generate domain-specific realistic data based on column name, then truncate to fit
2642
  base_value = None
2643
  if 'NAME' in col_name_upper and 'COMPANY' not in col_name_upper:
2644
- if 'PRODUCT' in col_name_upper:
 
 
 
2645
  base_value = "random.choice(['Laptop Pro 15', 'Wireless Mouse 2.4GHz', 'USB-C Cable 6ft', 'Monitor Stand Adjustable', 'Mechanical Keyboard RGB', 'Noise Canceling Headphones', '1080p Webcam', 'Portable SSD 1TB', 'Power Bank 20000mAh', 'Tablet 10 inch', 'Smart Watch', 'Bluetooth Speaker', 'Gaming Mouse Pad', 'Phone Case', 'Screen Protector', 'Charging Cable', 'Desk Lamp LED', 'Laptop Bag', 'Wireless Earbuds', 'USB Hub'])"
2646
  elif 'CUSTOMER' in col_name_upper or 'USER' in col_name_upper:
2647
  base_value = "fake.name()"
@@ -2996,16 +3239,19 @@ LegitData will generate realistic, AI-powered data.
2996
 
2997
  def run_deployment_streaming(self):
2998
  """Run deployment to Snowflake using LegitData - yields progress updates"""
 
 
 
2999
  progress = ""
3000
-
3001
  # Clear and initialize live progress for Snowflake deployment
3002
  self.live_progress_log = ["=" * 60, "SNOWFLAKE DEPLOYMENT STARTING", "=" * 60, ""]
3003
-
3004
  def log_progress(msg):
3005
  """Log to both AI feedback and live progress"""
3006
  self.log_feedback(msg)
3007
  self.live_progress_log.append(msg)
3008
-
3009
  try:
3010
  # Ensure deploy-time modules that still use os.getenv() see Supabase admin settings.
3011
  inject_admin_settings_to_env()
@@ -3409,10 +3655,11 @@ Cannot deploy to ThoughtSpot without tables.""",
3409
 
3410
  yield f"**Starting ThoughtSpot Deployment...**\n\nSchema verified: {database}.{schema_name}\nFound {len(tables)} tables\n\n"
3411
 
3412
- # Create deployer with settings (fall back to .env if not in settings)
3413
- ts_url = get_admin_setting('THOUGHTSPOT_URL')
 
3414
  ts_user = get_admin_setting('THOUGHTSPOT_ADMIN_USER')
3415
- ts_secret = get_admin_setting('THOUGHTSPOT_TRUSTED_AUTH_KEY')
3416
 
3417
  deployer = ThoughtSpotDeployer(
3418
  base_url=ts_url,
@@ -3434,11 +3681,13 @@ Cannot deploy to ThoughtSpot without tables.""",
3434
  safe_print(msg, flush=True)
3435
 
3436
  # Show initial message
3437
- yield """**Starting ThoughtSpot Deployment...**
 
 
3438
 
3439
  Authenticating with ThoughtSpot...
3440
 
3441
- **This takes 2-5 minutes.**
3442
 
3443
  **Switch to the "Live Progress" tab** to watch real-time progress.
3444
 
@@ -3449,6 +3698,7 @@ Steps:
3449
  4. Liveboard creation
3450
 
3451
  This chat will update when complete."""
 
3452
 
3453
  safe_print("\n" + "="*60, flush=True)
3454
  safe_print("THOUGHTSPOT DEPLOYMENT STARTING", flush=True)
@@ -3492,6 +3742,7 @@ This chat will update when complete."""
3492
  llm_model=llm_model,
3493
  tag_name=tag_name_value,
3494
  liveboard_method=liveboard_method,
 
3495
  progress_callback=progress_callback
3496
  )
3497
  except Exception as e:
@@ -3660,17 +3911,21 @@ Ask these questions to showcase ThoughtSpot's AI capabilities:
3660
  try:
3661
  from smart_data_adjuster import SmartDataAdjuster
3662
 
3663
- # Pass the selected LLM model to the adjuster
3664
  llm_model = self.settings.get('model', DEFAULT_LLM_MODEL)
3665
- adjuster = SmartDataAdjuster(database, schema_name, liveboard_guid, llm_model=llm_model)
 
 
 
 
 
3666
  adjuster.connect()
3667
 
3668
  if adjuster.load_liveboard_context():
3669
  self._adjuster = adjuster
3670
 
3671
  viz_list = "\n".join([
3672
- f" [{i+1}] {v['name']}\n Columns: {', '.join(v['columns'][:5])}"
3673
- + (f"... (+{len(v['columns'])-5} more)" if len(v['columns']) > 5 else "")
3674
  for i, v in enumerate(adjuster.visualizations)
3675
  ])
3676
 
@@ -3714,7 +3969,7 @@ I've loaded your liveboard context. Here are the visualizations:
3714
  - Reference by viz number: "viz 3, increase laptop to 50B"
3715
 
3716
  **Try an adjustment now, or type 'done' to finish!**"""
3717
- final_stage = 'outliers'
3718
  else:
3719
  ts_url = get_admin_setting('THOUGHTSPOT_URL', required=False).rstrip('/')
3720
  model_guid = results.get('model_guid') or ''
@@ -3765,27 +4020,27 @@ Your demo is ready!
3765
  Note: Could not load liveboard context for adjustments: {str(e)}
3766
  Type **'done'** to finish."""
3767
  else:
3768
- # No liveboard_guid - this shouldn't happen on success, but handle gracefully
3769
  ts_url = get_admin_setting('THOUGHTSPOT_URL', required=False).rstrip('/')
3770
  model_guid = results.get('model_guid') or ''
3771
- lb_url = results.get('liveboard_url', '')
3772
  table_names = results.get('tables', [])
3773
  tables_list = ', '.join(table_names) if table_names else 'N/A'
3774
-
3775
- # Build URLs for easy Slack pasting
3776
- model_url = f"{ts_url}/#/data/tables/{model_guid}" if model_guid and ts_url else None
3777
-
3778
- response = f"""**ThoughtSpot Deployment Complete**
 
3779
 
3780
  **Created:**
3781
  - Connection: {results.get('connection', 'N/A')}
3782
  - Tables: {tables_list}
3783
- - Model: {model_url if model_url else results.get('model', 'N/A')}
3784
- - Liveboard: {lb_url if lb_url else liveboard_name_result}
3785
-
3786
- Your demo is ready!
3787
 
3788
- You can now access it in ThoughtSpot.
 
 
 
3789
 
3790
  Type **'done'** to finish."""
3791
 
@@ -3793,18 +4048,55 @@ Type **'done'** to finish."""
3793
  else:
3794
  errors = results.get('errors', ['Unknown error'])
3795
  error_details = '\n'.join(errors)
3796
-
3797
- if 'schema validation' in error_details.lower() or 'schema' in error_details.lower():
3798
- guidance = "**Root Cause:** The model TML has validation errors."
3799
- elif 'connection' in error_details.lower():
3800
- guidance = "**Root Cause:** Connection issue with ThoughtSpot or Snowflake"
3801
- elif 'authenticate' in error_details.lower() or 'auth' in error_details.lower():
3802
- guidance = "**Root Cause:** Authentication failed"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3803
  else:
3804
- guidance = "**Check the progress log above for details.**"
3805
-
3806
- yield {
3807
- 'response': f"""❌ **ThoughtSpot Deployment Failed**
 
 
 
 
 
 
 
3808
 
3809
  **Error Details:**
3810
  ```
@@ -3820,8 +4112,8 @@ Type **'done'** to finish."""
3820
  **Next Steps:**
3821
  - Type **'retry'** to try again
3822
  - Or fix the issues above first""",
3823
- 'stage': 'deploy'
3824
- }
3825
 
3826
  except Exception as e:
3827
  import traceback
@@ -3971,6 +4263,7 @@ def create_chat_interface():
3971
  current_model = gr.State(default_settings['model'])
3972
  current_company = gr.State(default_settings['company'])
3973
  current_usecase = gr.State(default_settings['use_case'])
 
3974
 
3975
  # Header
3976
  gr.Markdown("""
@@ -4234,17 +4527,22 @@ def create_chat_interface():
4234
  gr.Markdown("### System-Wide Settings")
4235
  gr.Markdown("These settings apply to all users. Only admins can view and edit.")
4236
 
 
 
 
 
 
 
4237
  with gr.Row():
4238
  with gr.Column():
4239
  gr.Markdown("#### ThoughtSpot Connection")
4240
- admin_ts_url = gr.Textbox(label="ThoughtSpot URL", placeholder="https://your-instance.thoughtspot.cloud")
4241
  admin_ts_user = gr.Textbox(label="ThoughtSpot Admin Username", placeholder="admin@company.com")
4242
- admin_ts_auth_key = gr.Textbox(label="Trusted Auth Key", type="password")
4243
-
4244
- gr.Markdown("#### LLM API Keys (Legacy storage - runtime uses .env)")
4245
- admin_openai_key = gr.Textbox(label="OpenAI API Key", type="password")
4246
- admin_google_key = gr.Textbox(label="Google API Key", type="password")
4247
-
4248
  with gr.Column():
4249
  gr.Markdown("#### Snowflake Connection")
4250
  admin_sf_account = gr.Textbox(label="Snowflake Account")
@@ -4268,14 +4566,16 @@ def create_chat_interface():
4268
  admin_sf_account, admin_sf_kp_user, admin_sf_kp_pk,
4269
  admin_sf_kp_pass, admin_sf_role, admin_sf_warehouse,
4270
  admin_sf_database, admin_sf_sso_user,
 
4271
  ]
4272
-
4273
  admin_keys_order = [
4274
  "THOUGHTSPOT_URL", "THOUGHTSPOT_TRUSTED_AUTH_KEY", "THOUGHTSPOT_ADMIN_USER",
4275
  "OPENAI_API_KEY", "GOOGLE_API_KEY",
4276
  "SNOWFLAKE_ACCOUNT", "SNOWFLAKE_KP_USER", "SNOWFLAKE_KP_PK",
4277
  "SNOWFLAKE_KP_PASSPHRASE", "SNOWFLAKE_ROLE", "SNOWFLAKE_WAREHOUSE",
4278
  "SNOWFLAKE_DATABASE", "SNOWFLAKE_SSO_USER",
 
4279
  ]
4280
 
4281
  def load_admin_settings_handler():
@@ -4322,6 +4622,59 @@ def create_chat_interface():
4322
  outputs=admin_fields + [admin_settings_status]
4323
  )
4324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4325
  # Check admin status and toggle admin-only settings visibility
4326
  def check_admin_visibility(request: gr.Request):
4327
  """Check if logged-in user is admin and toggle settings visibility."""
@@ -4411,6 +4764,7 @@ def create_chat_interface():
4411
  company = (str(settings.get("default_company_url", "")).strip() or "Amazon.com")
4412
  use_case = (str(settings.get("default_use_case", "")).strip() or "Sales Analytics")
4413
  model = (str(settings.get("default_llm", "")).strip() or DEFAULT_LLM_MODEL)
 
4414
  initial_message = build_initial_chat_message(company, use_case)
4415
 
4416
  return (
@@ -4418,10 +4772,12 @@ def create_chat_interface():
4418
  model,
4419
  company,
4420
  use_case,
 
4421
  gr.update(value=model),
 
4422
  initial_message,
4423
  )
4424
-
4425
  # Wire up load handler - outputs follow SETTINGS_SCHEMA order
4426
  interface.load(
4427
  fn=load_settings_on_startup,
@@ -4437,7 +4793,9 @@ def create_chat_interface():
4437
  current_model,
4438
  current_company,
4439
  current_usecase,
 
4440
  chat_components["model_dropdown"],
 
4441
  chat_components["msg"],
4442
  ]
4443
  )
@@ -4473,14 +4831,9 @@ def create_chat_tab(chat_controller_state, settings, current_stage, current_mode
4473
  )
4474
 
4475
  with gr.Row():
4476
- # Pre-populate with company and use case from settings
4477
- initial_message = build_initial_chat_message(settings['company'], settings['use_case'])
4478
- print(f"DEBUG: Initial message set to: {initial_message}")
4479
- print(f"DEBUG: Settings company: {settings['company']}, use_case: {settings['use_case']}")
4480
-
4481
  msg = gr.Textbox(
4482
  label="Your message",
4483
- value=initial_message,
4484
  placeholder="Type your message here or use /over to change settings...",
4485
  lines=1,
4486
  max_lines=1,
@@ -4498,17 +4851,16 @@ def create_chat_tab(chat_controller_state, settings, current_stage, current_mode
4498
 
4499
  # Right column - Status & Settings
4500
  with gr.Column(scale=1):
4501
- gr.Markdown("### 📊 Current Status")
4502
-
4503
- # Stage display (read-only)
4504
- stage_display = gr.Textbox(
4505
- label="Stage",
4506
- value=settings['stage'].replace('_', ' ').title(),
4507
- interactive=False,
4508
- lines=1
4509
  )
4510
-
4511
- # Model selector (editable)
4512
  model_dropdown = gr.Dropdown(
4513
  label="AI Model",
4514
  choices=list(UI_MODEL_CHOICES),
@@ -4516,60 +4868,135 @@ def create_chat_tab(chat_controller_state, settings, current_stage, current_mode
4516
  interactive=True,
4517
  allow_custom_value=True
4518
  )
4519
-
 
 
 
 
 
 
 
 
 
4520
  gr.Markdown("### 📈 Progress")
4521
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4522
  def get_progress_html(stage):
4523
- """Generate progress HTML - highlight current stage in blue"""
4524
- stages = [
4525
- ('research', 'Research'),
4526
- ('create_ddl', 'Create DDL'),
4527
- ('deploy', 'Deploy'),
4528
- ('thoughtspot', 'ThoughtSpot'),
4529
- ('complete', 'Complete')
4530
- ]
4531
-
4532
- html = "<div style='padding: 10px; font-size: 14px;'>"
4533
- for stage_key, stage_name in stages:
4534
- if stage == stage_key:
4535
- # Current - blue
4536
- html += f"<div style='margin: 5px 0; color: #3b82f6; font-weight: bold;'>🔵 {stage_name}</div>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4537
  else:
4538
- # Not current - gray
4539
- html += f"<div style='margin: 5px 0; color: #6b7280;'>⚪ {stage_name}</div>"
4540
-
4541
  html += "</div>"
4542
  return html
4543
-
4544
  progress_html = gr.HTML(get_progress_html('initialization'))
4545
 
4546
  # Event handlers - each creates/uses session-specific controller
4547
- def send_message(controller, message, history, stage, model, company, usecase, request: gr.Request = None):
4548
  """Handle sending a message - creates controller if needed"""
4549
- # Create controller for this session if it doesn't exist
4550
  username = getattr(request, 'username', None) if request else None
4551
  if controller is None:
4552
  controller = ChatDemoInterface(user_email=username)
4553
  print(f"[SESSION] Created new ChatDemoInterface for {username or 'anonymous'}")
4554
-
4555
- # Process the message
4556
- for result in controller.process_chat_message(
4557
- message, history, stage, model, company, usecase
4558
- ):
4559
- # result = (chat_history, stage, model, company, usecase, msg_clear)
4560
- # Extract the new stage to update progress
4561
- new_stage = result[1] if len(result) > 1 else stage
4562
- progress = get_progress_html(new_stage)
4563
- # Yield: controller, chat_history, stage, model, company, usecase, msg_clear, progress_html
4564
- yield (controller,) + result + (progress,)
4565
-
4566
- def quick_action(controller, action_text, history, stage, model, company, usecase, request: gr.Request = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4567
  """Handle quick action button clicks"""
4568
  username = getattr(request, 'username', None) if request else None
4569
  if controller is None:
4570
  controller = ChatDemoInterface(user_email=username)
4571
  print(f"[SESSION] Created new ChatDemoInterface for {username or 'anonymous'}")
4572
-
 
 
 
 
 
 
 
 
 
 
 
4573
  for result in controller.process_chat_message(
4574
  action_text, history, stage, model, company, usecase
4575
  ):
@@ -4578,58 +5005,70 @@ def create_chat_tab(chat_controller_state, settings, current_stage, current_mode
4578
  yield (controller,) + result + (progress,)
4579
 
4580
  # Wire up send button and enter key
4581
- msg.submit(
4582
- fn=send_message,
4583
- inputs=[chat_controller_state, msg, chatbot, current_stage, current_model, current_company, current_usecase],
4584
- outputs=[chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase, msg, progress_html]
4585
- )
4586
-
4587
- send_btn.click(
4588
- fn=send_message,
4589
- inputs=[chat_controller_state, msg, chatbot, current_stage, current_model, current_company, current_usecase],
4590
- outputs=[chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase, msg, progress_html]
4591
- )
4592
 
4593
  # Quick action wrapper functions
4594
- def start_action(controller, history, stage, model, company, usecase):
4595
- yield from quick_action(controller, "Start research", history, stage, model, company, usecase)
4596
-
4597
- def configure_action(controller, history, stage, model, company, usecase):
4598
- yield from quick_action(controller, "Configure settings", history, stage, model, company, usecase)
4599
-
4600
- def help_action(controller, history, stage, model, company, usecase):
4601
- yield from quick_action(controller, "Help", history, stage, model, company, usecase)
4602
-
 
 
 
4603
  # Quick action buttons
4604
- start_btn.click(
4605
- fn=start_action,
4606
- inputs=[chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase],
4607
- outputs=[chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase, msg, progress_html]
4608
- )
4609
-
4610
- configure_btn.click(
4611
- fn=configure_action,
4612
- inputs=[chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase],
4613
- outputs=[chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase, msg, progress_html]
4614
- )
4615
-
4616
- help_btn.click(
4617
- fn=help_action,
4618
- inputs=[chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase],
4619
- outputs=[chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase, msg, progress_html]
4620
- )
4621
 
4622
  # Model dropdown change
4623
  def update_model(new_model, controller, history):
4624
  if controller is not None:
4625
  controller.settings['model'] = new_model
4626
  return new_model, history
4627
-
4628
  model_dropdown.change(
4629
  fn=update_model,
4630
  inputs=[model_dropdown, chat_controller_state, chatbot],
4631
  outputs=[current_model, chatbot]
4632
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4633
 
4634
  # Return components for external access
4635
  return {
@@ -4637,8 +5076,9 @@ def create_chat_tab(chat_controller_state, settings, current_stage, current_mode
4637
  'msg': msg,
4638
  'model_dropdown': model_dropdown,
4639
  'send_btn': send_btn,
4640
- 'send_btn_ref': send_btn, # Reference for updates
4641
- 'stage_display': stage_display,
 
4642
  'progress_html': progress_html
4643
  }
4644
 
@@ -4651,106 +5091,86 @@ def create_settings_tab():
4651
  gr.Markdown("## ⚙️ Configuration Settings")
4652
  gr.Markdown("Configure your demo builder preferences")
4653
 
4654
- # Main settings - Most important at the top in 2 columns
4655
- gr.Markdown("### ⭐ App Settings")
4656
-
 
 
 
 
 
 
 
 
 
 
 
4657
  with gr.Row():
4658
- with gr.Column():
4659
- default_ai_model = gr.Dropdown(
4660
- label="Default AI Model",
4661
- choices=list(UI_MODEL_CHOICES),
4662
- value=DEFAULT_LLM_MODEL,
4663
- info="Primary model for demo generation",
4664
- allow_custom_value=True
4665
- )
4666
-
4667
- default_company_url = gr.Textbox(
4668
- label="Default Company URL",
4669
- placeholder="https://example.com",
4670
- value="",
4671
- info="Pre-fill company for faster starts"
4672
- )
4673
-
4674
- # Build use case choices from VERTICALS × FUNCTIONS matrix
4675
- use_case_choices = []
4676
- for v_name in VERTICALS:
4677
- for f_name in FUNCTIONS:
4678
- use_case_choices.append(f"{v_name} {f_name}")
4679
- use_case_choices.append("Custom (type in chat)")
4680
-
4681
- default_use_case = gr.Dropdown(
4682
- label="Default Use Case",
4683
- choices=use_case_choices,
4684
- value="Retail Sales",
4685
 
4686
- info="Vertical × Function combination, or type any custom use case"
4687
- )
4688
-
4689
- liveboard_name = gr.Textbox(
4690
- label="Liveboard Name",
4691
- placeholder="My Demo Liveboard",
4692
- value="",
4693
- info="Default name for generated liveboards"
4694
- )
4695
-
4696
  tag_name = gr.Textbox(
4697
  label="Tag Name",
4698
  placeholder="e.g., 'Sales_Demo' or 'Q4_2024'",
4699
  value="",
4700
  info="Tag to apply to all ThoughtSpot objects (connection, tables, model, liveboard)"
4701
  )
4702
-
4703
- with gr.Column():
4704
  fact_table_size = gr.Dropdown(
4705
  label="Fact Table Size",
4706
  choices=["1000", "10000", "100000"],
4707
  value="1000",
4708
  info="Number of rows in fact table"
4709
  )
4710
-
4711
  dim_table_size = gr.Dropdown(
4712
  label="Dim Table Size",
4713
  choices=["50", "100", "1000"],
4714
  value="100",
4715
  info="Number of rows in dimension tables"
4716
  )
4717
-
 
4718
  object_naming_prefix = gr.Textbox(
4719
  label="Object Naming Prefix",
4720
  placeholder="e.g., 'ACME_' or 'DEMO_'",
4721
  value="",
4722
  info="Prefix for ThoughtSpot objects (for future use)"
4723
  )
4724
-
4725
  column_naming_style = gr.Dropdown(
4726
  label="Column Naming Style",
4727
  choices=["Regular Case", "snake_case", "camelCase", "PascalCase", "UPPER_CASE", "original"],
4728
  value="Regular Case",
4729
  info="Naming convention for ThoughtSpot model columns (Regular Case = State Id, Total Revenue)"
4730
  )
4731
-
4732
 
4733
-
4734
- # Existing Model Section
4735
- gr.Markdown("### 🔗 Use Existing Model")
4736
- gr.Markdown("*Skip table/model creation and create liveboard from an existing ThoughtSpot model*")
4737
-
4738
- with gr.Row():
4739
- with gr.Column():
4740
- use_existing_model = gr.Checkbox(
4741
- label="Use Existing Model",
4742
- value=False,
4743
- info="Enable to skip table creation and use an existing model"
4744
- )
4745
-
4746
- existing_model_guid = gr.Textbox(
4747
- label="Model GUID",
4748
- placeholder="e.g., ce2b12d9-07c0-4f38-9394-2cc1d5a5dc6f",
4749
- value="",
4750
- info="GUID of the existing ThoughtSpot model to use",
4751
- visible=True
4752
- )
4753
-
4754
  gr.Markdown("---")
4755
  gr.Markdown("### 🌍 Other Settings")
4756
 
@@ -4814,8 +5234,8 @@ def create_settings_tab():
4814
  admin_db_accordion = gr.Accordion("💾 Database Connections (Admin)", open=False, visible=True)
4815
  with admin_db_accordion:
4816
  gr.Markdown("""
4817
- **⚠️ Note:** These settings are **not currently used** by the application.
4818
- Connection credentials are read from your `.env` file.
4819
  This section is for reference/future use only.
4820
  """)
4821
 
@@ -4862,11 +5282,8 @@ def create_settings_tab():
4862
  with gr.Column():
4863
  gr.Markdown("### 📊 ThoughtSpot Settings")
4864
 
4865
- ts_instance_url = gr.Textbox(
4866
- label="ThoughtSpot URL",
4867
- placeholder="https://your-instance.thoughtspot.cloud",
4868
- info="Your ThoughtSpot instance"
4869
- )
4870
 
4871
  ts_username = gr.Textbox(
4872
  label="ThoughtSpot Username",
@@ -4882,15 +5299,31 @@ def create_settings_tab():
4882
  )
4883
 
4884
  gr.Markdown("---")
4885
- gr.Markdown("""
4886
- ### 🔐 Security Note
4887
- Settings are saved to your user profile. API keys remain in `.env` file.
4888
-
4889
- Required in `.env`:
4890
- - `SNOWFLAKE_PASSWORD` or `SNOWFLAKE_KP_PK` (for key-pair auth)
4891
- - `OPENAI_API_KEY` (required for Codex/OpenAI models)
4892
- """)
4893
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4894
  with gr.Row():
4895
  save_settings_btn = gr.Button("💾 Save Settings", variant="primary", size="lg")
4896
  reset_settings_btn = gr.Button("🔄 Reset to Defaults", size="lg")
@@ -4956,6 +5389,8 @@ def create_settings_tab():
4956
  'default_schema': default_schema,
4957
  'ts_instance_url': ts_instance_url,
4958
  'ts_username': ts_username,
 
 
4959
  # Status
4960
  'settings_status': settings_status,
4961
  # Admin-only visibility toggles
 
13
  import json
14
  import time
15
  import glob
16
+ from datetime import datetime
17
  from dotenv import load_dotenv
18
  from demo_builder_class import DemoBuilder
19
  from supabase_client import load_gradio_settings, get_admin_setting, inject_admin_settings_to_env
 
36
  load_dotenv(override=True)
37
 
38
 
39
+ # ==========================================================================
40
+ # TS ENVIRONMENT HELPERS
41
+ # ==========================================================================
42
+
43
+ def get_ts_environments() -> list:
44
+ """Return list of environment labels from TS_ENV_N_LABEL/URL/.env entries."""
45
+ envs = []
46
+ i = 1
47
+ while True:
48
+ label = os.getenv(f'TS_ENV_{i}_LABEL', '').strip()
49
+ url = os.getenv(f'TS_ENV_{i}_URL', '').strip()
50
+ if not label or not url:
51
+ break
52
+ envs.append(label)
53
+ i += 1
54
+ return envs or ['(no environments configured)']
55
+
56
+
57
+ def get_ts_env_url(label: str) -> str:
58
+ """Return the URL for a given environment label."""
59
+ i = 1
60
+ while True:
61
+ env_label = os.getenv(f'TS_ENV_{i}_LABEL', '').strip()
62
+ if not env_label:
63
+ break
64
+ if env_label == label:
65
+ return os.getenv(f'TS_ENV_{i}_URL', '').strip().rstrip('/')
66
+ i += 1
67
+ return ''
68
+
69
+
70
+ def get_ts_env_auth_key(label: str) -> str:
71
+ """Return the actual auth key value for a given environment label.
72
+ TS_ENV_N_KEY_VAR holds the name of the env var containing the token.
73
+ """
74
+ i = 1
75
+ while True:
76
+ env_label = os.getenv(f'TS_ENV_{i}_LABEL', '').strip()
77
+ if not env_label:
78
+ break
79
+ if env_label == label:
80
+ key_var = os.getenv(f'TS_ENV_{i}_KEY_VAR', '').strip()
81
+ return os.getenv(key_var, '').strip() if key_var else ''
82
+ i += 1
83
+ return ''
84
+
85
+
86
  # ==========================================================================
87
  # SETTINGS SCHEMA - Single source of truth for all settings
88
  # To add a new setting: add ONE entry here, then create the UI component
 
91
  SETTINGS_SCHEMA = [
92
  # App Settings - Left Column
93
  ('default_ai_model', 'default_llm', DEFAULT_LLM_MODEL, str),
94
+ # default_company_url removed — company is chat-driven now
95
  ('default_use_case', 'default_use_case', 'Sales Analytics', str),
96
  ('liveboard_name', 'liveboard_name', '', str),
97
  ('tag_name', 'tag_name', '', str),
 
116
  ('default_warehouse', 'default_warehouse', 'COMPUTE_WH', str),
117
  ('default_database', 'default_database', 'DEMO_DB', str),
118
  ('default_schema', 'default_schema', 'PUBLIC', str),
119
+ # ts_instance_url removed replaced by TS Environment dropdown on front page
120
  ('ts_username', 'thoughtspot_username', '', str),
121
+ ('data_adjuster_url', 'data_adjuster_url', '', str),
122
+ ('share_with', 'share_with', '', str),
123
  # Status (special - not saved, just displayed)
124
  ('settings_status', None, '', str),
125
  ]
 
194
 
195
  def build_initial_chat_message(company: str, use_case: str) -> str:
196
  """Build the pre-filled chat message from current settings."""
197
+ if company and use_case:
198
+ return f"{company}, {use_case}"
199
+ return ""
 
200
 
201
 
202
  def safe_print(*args, **kwargs):
 
234
  self.live_progress_log = [] # Real-time deployment progress
235
  self.demo_pack_content = "" # Generated demo pack markdown
236
  self.spotter_viz_story = "" # Spotter Viz story (NL prompts for Spotter Viz agent)
237
+ # Per-session loggers (NOT module-level singletons — avoids cross-session contamination)
238
+ self._session_logger = None
239
+ self._prompt_logger = None
240
 
241
  def _get_effective_user_email(self) -> str:
242
  """Resolve and cache effective user identity for settings access."""
 
304
 
305
  return f"""👋 **Welcome to ThoughtSpot Demo Builder!**
306
 
307
+ Tell me who you're demoing to and what they care about — I'll build the whole thing.
308
 
309
+ **Just say something like:**
310
+ - *"Nike.com, Retail Sales"*
311
+ - *"Salesforce.com, Software Sales"*
312
+ - *"Target.com — supply chain analytics for their VP of Operations"*
313
+ - *"Pfizer.com — analyzing clinical trial pipeline for a CMO persona"*
314
 
315
+ **Pre-configured use cases** (KPIs, outliers, and Spotter questions ready to go):
 
 
316
  {uc_list}
317
  - ...and more!
 
 
 
 
 
 
 
 
 
 
 
318
 
319
+ **Or describe any custom use case** — AI will research and build it from scratch.
 
 
 
320
 
321
+ > 💡 **Tip:** Select your ThoughtSpot environment and AI model in the panel to the right before starting.
322
 
323
+ *What company and use case are you working on?*"""
324
  else:
325
  return f"""👋 **Welcome to ThoughtSpot Demo Builder!**
326
 
 
373
  Process user message and return updated chat history and state (with streaming)
374
  Returns: (chat_history, current_stage, current_model, company, use_case, next_textbox_value)
375
  """
376
+ from session_logger import init_session_logger
377
+ from prompt_logger import reset_prompt_logger
378
+
379
+ # Init per-session loggers on first message (stored on self, not module singletons)
380
+ if self._session_logger is None:
381
+ session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
382
+ self._session_logger = init_session_logger(session_id, user_email=getattr(self, 'user_email', None))
383
+ self._prompt_logger = reset_prompt_logger(session_id)
384
+ _slog = self._session_logger
385
+ _slog.log(current_stage or 'init', f"user message received: {message[:120]}")
386
+
387
  # Add user message to history
388
  chat_history.append((message, None))
389
+
390
+ # If data_adjuster_url is saved in settings and we're at init, inject it as the message
391
+ # so the user lands directly in Data Adjuster without having to paste the URL manually
392
+ da_url = self.settings.get('data_adjuster_url', '').strip()
393
+ if da_url and current_stage == 'initialization' and 'pinboard/' in da_url:
394
+ message = da_url
395
+ chat_history[-1] = (da_url, None)
396
+
397
  # Validate required settings before proceeding
398
  missing = self.validate_required_settings()
399
  if missing and current_stage == 'initialization':
 
424
 
425
  # Stage-based processing
426
  if current_stage == 'initialization':
427
+ # Check if user pasted a ThoughtSpot liveboard URL → jump straight to data adjuster
428
+ lb_guid_match = re.search(
429
+ r'pinboard/([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})',
430
+ message, re.I
431
+ )
432
+ if lb_guid_match:
433
+ liveboard_guid = lb_guid_match.group(1)
434
+ chat_history[-1] = (message, "🔍 **Resolving liveboard context...**")
435
+ yield chat_history, current_stage, current_model, company, use_case, ""
436
+
437
+ try:
438
+ from smart_data_adjuster import SmartDataAdjuster, load_context_from_liveboard
439
+ from thoughtspot_deployer import ThoughtSpotDeployer
440
+ from supabase_client import get_admin_setting
441
+
442
+ ts_url = (self.settings.get('thoughtspot_url') or '').strip() or get_admin_setting('THOUGHTSPOT_URL')
443
+ ts_user = get_admin_setting('THOUGHTSPOT_ADMIN_USER')
444
+ ts_secret = (self.settings.get('thoughtspot_trusted_auth_key') or '').strip() or get_admin_setting('THOUGHTSPOT_TRUSTED_AUTH_KEY')
445
+
446
+ ts_client = ThoughtSpotDeployer(base_url=ts_url, username=ts_user, secret_key=ts_secret)
447
+ ts_client.authenticate()
448
+
449
+ ctx = load_context_from_liveboard(liveboard_guid, ts_client)
450
+
451
+ llm_model = self.settings.get('model', DEFAULT_LLM_MODEL)
452
+ adjuster = SmartDataAdjuster(
453
+ database=ctx['database'],
454
+ schema=ctx['schema'],
455
+ liveboard_guid=liveboard_guid,
456
+ llm_model=llm_model,
457
+ ts_url=ts_url,
458
+ ts_secret=ts_secret,
459
+ )
460
+ adjuster.connect()
461
+
462
+ if not adjuster.load_liveboard_context():
463
+ raise ValueError("Liveboard has no answer-based visualizations to adjust.")
464
+
465
+ self._adjuster = adjuster
466
+ current_stage = 'outlier_adjustment'
467
+
468
+ viz_list = "\n".join(
469
+ f" [{i+1}] {v['name']}"
470
+ for i, v in enumerate(adjuster.visualizations)
471
+ )
472
+ response = f"""✅ **Liveboard context loaded — ready for data adjustments**
473
+
474
+ **Liveboard:** {ctx['liveboard_name']}
475
+ **Model:** {ctx['model_name']}
476
+ **Snowflake:** `{ctx['database']}`.`{ctx['schema']}`
477
+
478
+ **Visualizations:**
479
+ {viz_list}
480
+
481
+ Tell me what to change — e.g. *"increase webcam revenue by 20%"* or *"make Acme Corp 50B"*.
482
+ Type **done** when finished."""
483
+
484
+ except Exception as e:
485
+ response = f"❌ **Could not load liveboard context**\n\n`{e}`"
486
+
487
+ chat_history[-1] = (message, response)
488
+ yield chat_history, current_stage, current_model, company, use_case, ""
489
+ return
490
+
491
  # Check if user just provided a standalone URL (e.g., "Comscore.com")
492
  standalone_url = re.search(r'^([a-zA-Z0-9-]+\.[a-zA-Z]{2,})$', cleaned_message.strip())
493
  if standalone_url:
 
659
  yield chat_history, current_stage, current_model, company, use_case, ""
660
 
661
  return
662
+
663
+ # --- Catch-all: try to extract company + use case from any free-form message ---
664
+ # Handles: "Nike.com, Retail Sales" / "Salesforce - Software Sales" / etc.
665
+ extracted_company = self.extract_company_from_message(cleaned_message)
666
+ extracted_use_case = self.extract_use_case_from_message(cleaned_message)
667
+
668
+ if extracted_company and extracted_use_case:
669
+ company = extracted_company
670
+ self.vertical, self.function = parse_use_case(extracted_use_case)
671
+ self.use_case_config = get_use_case_config(
672
+ self.vertical or "Generic", self.function or "Generic"
673
+ )
674
+ is_known = self.vertical and self.function and not self.use_case_config.get('is_generic')
675
+ use_case_display = self.use_case_config.get('use_case_name', extracted_use_case)
676
+ self.is_generic_use_case = not is_known
677
+ self.pending_generic_company = company
678
+ self.pending_generic_use_case = use_case_display
679
+
680
+ if is_known:
681
+ note = f"\n\n*Matched: **{self.vertical}** × **{self.function}** — KPIs, outliers, and Spotter questions ready.*"
682
+ elif self.vertical or self.function:
683
+ note = f"\n\n*Partial match: **{self.vertical or self.function}** recognized — AI will fill in the gaps.*"
684
+ else:
685
+ note = "\n\n*Custom use case — AI will research and build from scratch.*"
686
+
687
+ context_prompt = f"""✅ **Demo Configuration**
688
+
689
+ **Company:** {company}
690
+ **Use Case:** {use_case_display}{note}
691
+
692
+ **Want to add any requirements?** (or just say "proceed")
693
+ - "Include a RETURNS table"
694
+ - "Focus on enterprise accounts only"
695
+ - "I need 2 fact tables: Sales and Inventory"
696
+
697
+ *Type your requirements, or say **"proceed"** to use defaults.*"""
698
+
699
+ chat_history[-1] = (message, context_prompt)
700
+ current_stage = 'awaiting_context'
701
+ yield chat_history, current_stage, current_model, company, use_case_display, "proceed"
702
+ return
703
+
704
+ elif extracted_company and not extracted_use_case:
705
+ uc_opts = "\n".join([f"- {v} {f}" for v in list(VERTICALS.keys())[:3] for f in FUNCTIONS.keys()])
706
+ response = f"""Got it — **{extracted_company}**!
707
+
708
+ What use case are we building? A few options:
709
+ {uc_opts}
710
+ - Or describe any custom scenario!"""
711
+ chat_history[-1] = (message, response)
712
+ yield chat_history, current_stage, current_model, company, use_case, ""
713
+ return
714
+
715
+ else:
716
+ # Nothing useful extracted — show a brief prompt
717
+ response = """I need a **company** and **use case** to get started.
718
+
719
+ Try something like:
720
+ - *"Nike.com, Retail Sales"*
721
+ - *"Salesforce.com — Software pipeline analytics"*
722
+ - *"Walmart.com, Supply Chain"*"""
723
+ chat_history[-1] = (message, response)
724
+ yield chat_history, current_stage, current_model, company, use_case, ""
725
+ return
726
+
727
  elif current_stage == 'awaiting_context':
728
  # User is providing context for ANY use case (both generic and established)
729
  if message_lower.strip() in ['proceed', 'continue', 'no', 'skip']:
 
801
  for progress_update in self.run_deployment_streaming():
802
  if isinstance(progress_update, tuple):
803
  final_result = progress_update
804
+ elif isinstance(progress_update, dict):
805
+ current_stage = progress_update.get('stage', current_stage)
806
+ chat_history[-1] = (message, progress_update['response'])
807
+ yield chat_history, current_stage, current_model, company, use_case, ""
808
  else:
809
  deploy_update_count += 1
810
  # Only yield every 5th deployment update
 
1002
  final_result = None
1003
  for progress_update in self.run_deployment_streaming():
1004
  if isinstance(progress_update, tuple):
 
1005
  final_result = progress_update
1006
+ elif isinstance(progress_update, dict):
1007
+ current_stage = progress_update.get('stage', current_stage)
1008
+ chat_history[-1] = (message, f"**DDL Approved - Deploying...**\n\n{progress_update['response']}")
1009
+ yield chat_history, current_stage, current_model, company, use_case, ""
1010
  else:
 
1011
  chat_history[-1] = (message, f"**DDL Approved - Deploying...**\n\n{progress_update}")
1012
  yield chat_history, current_stage, current_model, company, use_case, ""
1013
 
 
1070
  final_result = None
1071
  for progress_update in self.run_deployment_streaming():
1072
  if isinstance(progress_update, tuple):
 
1073
  final_result = progress_update
1074
+ elif isinstance(progress_update, dict):
1075
+ current_stage = progress_update.get('stage', current_stage)
1076
+ chat_history[-1] = (message, progress_update['response'])
1077
+ yield chat_history, current_stage, current_model, company, use_case, ""
1078
  else:
 
1079
  chat_history[-1] = (message, progress_update)
1080
  yield chat_history, current_stage, current_model, company, use_case, ""
1081
 
 
1121
  for progress_update in self.run_deployment_streaming():
1122
  if isinstance(progress_update, tuple):
1123
  final_result = progress_update
1124
+ elif isinstance(progress_update, dict):
1125
+ current_stage = progress_update.get('stage', current_stage)
1126
+ chat_history.append((None, progress_update['response']))
1127
+ yield chat_history, current_stage, current_model, company, use_case, ""
1128
  else:
1129
  chat_history.append((None, progress_update))
1130
  yield chat_history, current_stage, current_model, company, use_case, ""
 
1386
  return
1387
  return
1388
 
1389
+ elif current_stage == 'outlier_adjustment':
1390
  # Handle outlier adjustment stage
1391
  if 'done' in message_lower or 'finish' in message_lower or 'complete' in message_lower:
1392
  # Close adjuster connection
 
1511
  yield chat_history, current_stage, current_model, company, use_case, ""
1512
  return
1513
 
1514
+ # Pick metric column (use hint from match if available)
1515
+ metric_hint = match.get('metric_hint')
1516
+ metric_column = adjuster._pick_metric_column(metric_hint)
1517
+ if not metric_column:
1518
+ response = "❌ Could not identify a metric column in your schema. Try specifying the column name."
1519
+ chat_history[-1] = (message, response)
1520
+ yield chat_history, current_stage, current_model, company, use_case, ""
1521
+ return
1522
 
1523
+ # Get current value (new 4-tuple return: value, matched_name, dim_table, fact_table)
1524
+ entity_type = match.get('entity_type')
1525
+ current_value, matched_entity, dim_table, fact_table = adjuster.get_current_value(
1526
+ match['entity_value'], metric_column, entity_type
1527
+ )
1528
 
1529
+ if current_value == 0 or matched_entity is None:
1530
+ response = (
1531
+ f"❌ **No data found** for `{match['entity_value']}`.\n\n"
1532
+ f"Check the spelling or try a different entity. Type **'done'** to finish."
1533
+ )
1534
  chat_history[-1] = (message, response)
1535
  yield chat_history, current_stage, current_model, company, use_case, ""
1536
  return
1537
+
1538
+ # Calculate target
1539
  target_value = match.get('target_value')
1540
+ percentage = match.get('percentage') if match.get('is_percentage') else None
1541
+ if percentage is not None:
1542
  target_value = current_value * (1 + percentage / 100)
1543
  match['target_value'] = target_value
1544
+
1545
  # Generate strategy
1546
  strategy = adjuster.generate_strategy(
1547
  match['entity_value'],
1548
+ metric_column,
1549
  current_value,
1550
+ target_value=target_value,
1551
+ percentage=percentage,
1552
+ entity_type=entity_type,
1553
  )
1554
+
1555
  # Present smart confirmation
1556
+ confirmation = adjuster.present_smart_confirmation(match, current_value, strategy, metric_column)
1557
 
1558
  # Store for execution if user confirms
1559
  self._pending_adjustment = {
 
1618
  # Clean up trailing dots just in case
1619
  company = company.rstrip('.')
1620
  return company
1621
+
1622
+ # Fallback: any bare domain.tld in the message (e.g. "Nike.com, Retail Sales")
1623
+ bare_url = re.search(r'\b([a-zA-Z0-9-]+\.[a-zA-Z]{2,})\b', cleaned_message, re.IGNORECASE)
1624
+ if bare_url:
1625
+ return bare_url.group(1).rstrip('.')
1626
+
1627
  return None
1628
 
1629
  def extract_use_case_from_message(self, message):
 
1656
  if re.search(r'\.(com|org|net|io|co|ai)\b', use_case, re.IGNORECASE):
1657
  continue
1658
  return use_case
1659
+
1660
+ # Fallback: text after a comma or dash following a domain.tld
1661
+ # Handles: "Nike.com, Retail Sales" / "Nike.com - Supply Chain"
1662
+ after_domain = re.search(
1663
+ r'[a-zA-Z0-9-]+\.[a-zA-Z]{2,}[\s]*[,\-–—]\s*(.+)',
1664
+ message, re.IGNORECASE
1665
+ )
1666
+ if after_domain:
1667
+ use_case = after_domain.group(1).strip().rstrip('.')
1668
+ if use_case and not re.search(r'\.(com|org|net|io|co|ai)\b', use_case, re.IGNORECASE):
1669
+ return use_case
1670
+
1671
  return None
1672
 
1673
  def handle_override(self, message):
 
1725
 
1726
  def run_research_streaming(self, company, use_case, generic_context=""):
1727
  """Run the research phase with streaming updates
1728
+
1729
  Args:
1730
  company: Company URL/name
1731
  use_case: Use case name
1732
  generic_context: Additional context provided by user for generic use cases
1733
  """
1734
+ _slog = self._session_logger
1735
+ _t = _slog.log_start("research") if _slog else None
1736
+
1737
  print(f"\n\n[CACHE DEBUG] === run_research_streaming called ===")
1738
  print(f"[CACHE DEBUG] company: {company}")
1739
  print(f"[CACHE DEBUG] use_case: {use_case}\n\n")
 
2120
  """Generate use-case specific Spotter questions based on the schema.
2121
 
2122
  Priority order:
2123
+ 1. liveboard_questions.spotter_qs from the vertical×function config
2124
+ 2. FUNCTIONS[fn].spotter_templates from the config
2125
  3. Hardcoded fallbacks per use case
2126
  4. Generic questions
2127
  """
2128
  import re
2129
+
2130
+ # Priority 1: Use spotter_qs from liveboard_questions in use case config
2131
  try:
2132
+ uc_config = self.use_case_config or get_use_case_config(
2133
+ self.vertical or "Generic", self.function or "Generic"
2134
+ )
2135
+ lq = uc_config.get("liveboard_questions", [])
 
2136
  configured_questions = []
2137
+ for q in lq:
2138
+ for sq in q.get('spotter_qs', []):
2139
  configured_questions.append({
2140
  'question': sq,
2141
+ 'purpose': f'Reveals {q["title"]} pattern'
 
 
 
 
 
2142
  })
 
2143
  if configured_questions:
2144
+ v = self.vertical or "Generic"
2145
+ f = self.function or "Generic"
2146
+ self.log_feedback(f"📋 Using {len(configured_questions)} Spotter questions from {v}×{f} config")
2147
+ return configured_questions[:8]
2148
  except Exception as e:
2149
+ self.log_feedback(f"⚠️ Spotter questions not available: {e}")
2150
 
2151
  # Priority 2: Try FUNCTIONS config for spotter_templates
2152
  try:
 
2260
  - **Ask questions**: Let the AI demonstrate natural language
2261
  - **End with action**: Show how insights lead to decisions""")
2262
 
2263
+ def _generate_spotter_viz_story(self, company_name: str, use_case: str,
2264
  model_name: str = None, liveboard_name: str = None) -> str:
2265
  """Generate a Spotter Viz story — a conversational sequence of NL prompts
2266
  that can be entered into ThoughtSpot's Spotter Viz agent to build a liveboard.
2267
+
2268
+ Returns two sections:
2269
+ 1. Persona-driven story built from liveboard_questions config (always present)
2270
+ 2. AI-generated story from LLM (appended after a divider)
2271
  """
2272
+ from demo_personas import parse_use_case, get_use_case_config
2273
+
2274
+ v, f = parse_use_case(use_case or '')
2275
+ uc_cfg = get_use_case_config(v or "Generic", f or "Generic")
2276
+ lq = uc_cfg.get("liveboard_questions", [])
2277
+ data_source = model_name or f"{company_name} model"
2278
+
2279
+ # --- Section 1: Persona-driven story from liveboard_questions ---
2280
+ header = f"""# Spotter Viz Story: {company_name}
2281
+ ## {use_case}
2282
+
2283
+ *Copy these prompts into ThoughtSpot Spotter Viz to build this liveboard interactively.*
2284
+
2285
+ ---
2286
+
2287
+ ## Part 1: Structured Demo Flow
2288
+
2289
+ """
2290
+ steps = []
2291
+ step_num = 1
2292
+ for q in lq:
2293
+ title = q['title']
2294
+ viz_q = q['viz_question']
2295
+ insight = q.get('insight', '')
2296
+ spotter_qs = q.get('spotter_qs', [])
2297
+
2298
+ step = f"### Step {step_num}: {title}\n"
2299
+ step += f'> "{viz_q}"\n\n'
2300
+ if insight:
2301
+ step += f"**What to look for:** {insight}\n\n"
2302
+ if spotter_qs:
2303
+ step += "**Follow-up Spotter questions:**\n"
2304
+ for sq in spotter_qs[:2]:
2305
+ step += f'> "{sq}"\n\n'
2306
+ steps.append(step)
2307
+ step_num += 1
2308
+
2309
+ if steps:
2310
+ persona_section = header + "\n".join(steps)
2311
+ else:
2312
+ persona_section = header + f'> "Build a {use_case} dashboard for {company_name} using {data_source}"\n\n'
2313
+
2314
+ persona_section += "\n---\n\n"
2315
+
2316
+ # --- Section 2: AI-generated story ---
2317
  try:
2318
  from prompts import build_prompt
2319
+
 
 
2320
  vertical = v or "Generic"
2321
  function = f or "Generic"
2322
+
 
2323
  company_context = f"Company: {company_name}\nUse Case: {use_case}"
2324
  if model_name:
2325
  company_context += f"\nData Source/Model: {model_name}"
2326
  if liveboard_name:
2327
  company_context += f"\nLiveboard Name: {liveboard_name}"
 
 
2328
  if hasattr(self, 'demo_builder') and self.demo_builder:
2329
  research = getattr(self.demo_builder, 'company_summary', '') or ''
2330
  if research:
2331
  company_context += f"\n\nCompany Research:\n{research[:1500]}"
2332
+
2333
  prompt = build_prompt(
2334
  stage="spotter_viz_story",
2335
  vertical=vertical,
2336
  function=function,
2337
  company_context=company_context,
2338
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2339
 
2340
+ llm_model = self.settings.get('model', DEFAULT_LLM_MODEL)
2341
+ self.log_feedback(f"🎬 Generating AI Spotter Viz story ({llm_model})...")
2342
 
2343
+ provider_name, model_name_str = map_llm_display_to_provider(llm_model)
2344
+ researcher = MultiLLMResearcher(provider=provider_name, model=model_name_str)
2345
+ ai_story = researcher.make_request(prompt, max_tokens=2000, temperature=0.7)
2346
+ ai_section = "## Part 2: AI-Generated Story\n\n" + ai_story
2347
 
 
 
 
2348
  except Exception as e:
2349
+ self.log_feedback(f"⚠️ AI Spotter Viz story generation failed: {e}")
2350
+ ai_section = "## Part 2: AI-Generated Story\n\n*(Generation failed use the structured flow above.)*"
2351
+
2352
+ return persona_section + ai_section
2353
 
2354
  def _build_fallback_spotter_story(self, company_name: str, use_case: str,
2355
  model_name: str = None) -> str:
2356
  """Build a basic Spotter Viz story without LLM, using available context."""
2357
  data_source = model_name or f"{company_name} model"
2358
 
2359
+ # Get spotter questions from use case config
2360
  spotter_qs = []
2361
  try:
2362
+ from demo_personas import parse_use_case, get_use_case_config
 
2363
  v, f = parse_use_case(use_case or '')
2364
+ uc_cfg = get_use_case_config(v or "Generic", f or "Generic")
2365
+ for q in uc_cfg.get("liveboard_questions", []):
2366
+ if q.get("required") and q.get("spotter_qs"):
2367
+ spotter_qs.append(q["spotter_qs"][0])
 
2368
  except:
2369
  pass
2370
 
 
2881
  # Generate domain-specific realistic data based on column name, then truncate to fit
2882
  base_value = None
2883
  if 'NAME' in col_name_upper and 'COMPANY' not in col_name_upper:
2884
+ # Check domain-specific name columns BEFORE falling back to fake.name()
2885
+ if 'DRUG' in col_name_upper or 'MEDICATION' in col_name_upper or 'THERAPEUTIC' in col_name_upper:
2886
+ base_value = "random.choice(['Lipitor', 'Humira', 'Eliquis', 'Keytruda', 'Revlimid', 'Opdivo', 'Ozempic', 'Dupixent', 'Trulicity', 'Entresto', 'Metformin', 'Atorvastatin', 'Lisinopril', 'Amlodipine', 'Metoprolol', 'Omeprazole', 'Simvastatin', 'Losartan', 'Albuterol', 'Gabapentin'])"
2887
+ elif 'PRODUCT' in col_name_upper:
2888
  base_value = "random.choice(['Laptop Pro 15', 'Wireless Mouse 2.4GHz', 'USB-C Cable 6ft', 'Monitor Stand Adjustable', 'Mechanical Keyboard RGB', 'Noise Canceling Headphones', '1080p Webcam', 'Portable SSD 1TB', 'Power Bank 20000mAh', 'Tablet 10 inch', 'Smart Watch', 'Bluetooth Speaker', 'Gaming Mouse Pad', 'Phone Case', 'Screen Protector', 'Charging Cable', 'Desk Lamp LED', 'Laptop Bag', 'Wireless Earbuds', 'USB Hub'])"
2889
  elif 'CUSTOMER' in col_name_upper or 'USER' in col_name_upper:
2890
  base_value = "fake.name()"
 
3239
 
3240
  def run_deployment_streaming(self):
3241
  """Run deployment to Snowflake using LegitData - yields progress updates"""
3242
+ _slog = self._session_logger
3243
+ _t_deploy = _slog.log_start("deploy") if _slog else None
3244
+
3245
  progress = ""
3246
+
3247
  # Clear and initialize live progress for Snowflake deployment
3248
  self.live_progress_log = ["=" * 60, "SNOWFLAKE DEPLOYMENT STARTING", "=" * 60, ""]
3249
+
3250
  def log_progress(msg):
3251
  """Log to both AI feedback and live progress"""
3252
  self.log_feedback(msg)
3253
  self.live_progress_log.append(msg)
3254
+
3255
  try:
3256
  # Ensure deploy-time modules that still use os.getenv() see Supabase admin settings.
3257
  inject_admin_settings_to_env()
 
3655
 
3656
  yield f"**Starting ThoughtSpot Deployment...**\n\nSchema verified: {database}.{schema_name}\nFound {len(tables)} tables\n\n"
3657
 
3658
+ # Create deployer prefer session-selected env (from TS env dropdown),
3659
+ # fall back to admin settings
3660
+ ts_url = (self.settings.get('thoughtspot_url') or '').strip() or get_admin_setting('THOUGHTSPOT_URL')
3661
  ts_user = get_admin_setting('THOUGHTSPOT_ADMIN_USER')
3662
+ ts_secret = (self.settings.get('thoughtspot_trusted_auth_key') or '').strip() or get_admin_setting('THOUGHTSPOT_TRUSTED_AUTH_KEY')
3663
 
3664
  deployer = ThoughtSpotDeployer(
3665
  base_url=ts_url,
 
3681
  safe_print(msg, flush=True)
3682
 
3683
  # Show initial message
3684
+ yield {
3685
+ 'stage': 'thoughtspot',
3686
+ 'response': """**Starting ThoughtSpot Deployment...**
3687
 
3688
  Authenticating with ThoughtSpot...
3689
 
3690
+ **This takes 2-5 minutes.**
3691
 
3692
  **Switch to the "Live Progress" tab** to watch real-time progress.
3693
 
 
3698
  4. Liveboard creation
3699
 
3700
  This chat will update when complete."""
3701
+ }
3702
 
3703
  safe_print("\n" + "="*60, flush=True)
3704
  safe_print("THOUGHTSPOT DEPLOYMENT STARTING", flush=True)
 
3742
  llm_model=llm_model,
3743
  tag_name=tag_name_value,
3744
  liveboard_method=liveboard_method,
3745
+ share_with=self.settings.get('share_with', '').strip() or None,
3746
  progress_callback=progress_callback
3747
  )
3748
  except Exception as e:
 
3911
  try:
3912
  from smart_data_adjuster import SmartDataAdjuster
3913
 
3914
+ # Pass selected LLM model + session-selected TS env to adjuster
3915
  llm_model = self.settings.get('model', DEFAULT_LLM_MODEL)
3916
+ adjuster = SmartDataAdjuster(
3917
+ database, schema_name, liveboard_guid,
3918
+ llm_model=llm_model,
3919
+ ts_url=self.settings.get('thoughtspot_url') or None,
3920
+ ts_secret=self.settings.get('thoughtspot_trusted_auth_key') or None,
3921
+ )
3922
  adjuster.connect()
3923
 
3924
  if adjuster.load_liveboard_context():
3925
  self._adjuster = adjuster
3926
 
3927
  viz_list = "\n".join([
3928
+ f" [{i+1}] {v['name']}"
 
3929
  for i, v in enumerate(adjuster.visualizations)
3930
  ])
3931
 
 
3969
  - Reference by viz number: "viz 3, increase laptop to 50B"
3970
 
3971
  **Try an adjustment now, or type 'done' to finish!**"""
3972
+ final_stage = 'outlier_adjustment'
3973
  else:
3974
  ts_url = get_admin_setting('THOUGHTSPOT_URL', required=False).rstrip('/')
3975
  model_guid = results.get('model_guid') or ''
 
4020
  Note: Could not load liveboard context for adjustments: {str(e)}
4021
  Type **'done'** to finish."""
4022
  else:
4023
+ # Deployed OK but no liveboard GUID returned treat as partial success
4024
  ts_url = get_admin_setting('THOUGHTSPOT_URL', required=False).rstrip('/')
4025
  model_guid = results.get('model_guid') or ''
 
4026
  table_names = results.get('tables', [])
4027
  tables_list = ', '.join(table_names) if table_names else 'N/A'
4028
+ model_url = f"{ts_url}/#/data/tables/{model_guid}" if model_guid and ts_url else results.get('model', 'N/A')
4029
+
4030
+ response = f"""⚠️ **Partial Success Dataset & Model Created**
4031
+
4032
+ Your Snowflake data and ThoughtSpot model were deployed successfully.
4033
+ The liveboard GUID couldn't be retrieved — it may still have been created in ThoughtSpot.
4034
 
4035
  **Created:**
4036
  - Connection: {results.get('connection', 'N/A')}
4037
  - Tables: {tables_list}
4038
+ - Model: {model_url}
 
 
 
4039
 
4040
+ **What you can do:**
4041
+ - Check the **Spotter Viz Story** tab to recreate the liveboard manually
4042
+ - Log into ThoughtSpot to verify whether the liveboard was created
4043
+ - Type **'retry liveboard'** to try building it again
4044
 
4045
  Type **'done'** to finish."""
4046
 
 
4048
  else:
4049
  errors = results.get('errors', ['Unknown error'])
4050
  error_details = '\n'.join(errors)
4051
+
4052
+ # Check for partial success: Snowflake + model deployed OK but liveboard failed
4053
+ model_ok = bool(results.get('model_guid'))
4054
+ liveboard_errors = [e for e in errors if 'liveboard' in e.lower()]
4055
+ non_liveboard_errors = [e for e in errors if 'liveboard' not in e.lower()]
4056
+
4057
+ if model_ok and liveboard_errors and not non_liveboard_errors:
4058
+ # Partial success — dataset and model are live, only liveboard failed
4059
+ ts_url = get_admin_setting('THOUGHTSPOT_URL', required=False).rstrip('/')
4060
+ model_guid = results.get('model_guid') or ''
4061
+ table_names = results.get('tables', [])
4062
+ tables_list = ', '.join(table_names) if table_names else 'N/A'
4063
+ model_url = f"{ts_url}/#/data/tables/{model_guid}" if model_guid and ts_url else results.get('model', 'N/A')
4064
+ lb_error = liveboard_errors[0]
4065
+
4066
+ yield {
4067
+ 'response': f"""⚠️ **Partial Success — Dataset & Model Created**
4068
+
4069
+ Your Snowflake data and ThoughtSpot model were deployed successfully.
4070
+ The liveboard couldn't be built automatically.
4071
+
4072
+ **Created:**
4073
+ - Connection: {results.get('connection', 'N/A')}
4074
+ - Tables: {tables_list}
4075
+ - Model: {model_url}
4076
+
4077
+ **Liveboard error:**
4078
+ ```
4079
+ {lb_error}
4080
+ ```
4081
+
4082
+ **What you can do:**
4083
+ - Check the **Spotter Viz Story** tab — use that sequence to recreate the liveboard manually in Spotter Viz
4084
+ - Type **'retry liveboard'** to try building it again
4085
+ - Or continue in ThoughtSpot using your model directly""",
4086
+ 'stage': 'deploy'
4087
+ }
4088
  else:
4089
+ if 'schema validation' in error_details.lower() or 'schema' in error_details.lower():
4090
+ guidance = "**Root Cause:** The model TML has validation errors."
4091
+ elif 'connection' in error_details.lower():
4092
+ guidance = "**Root Cause:** Connection issue with ThoughtSpot or Snowflake"
4093
+ elif 'authenticate' in error_details.lower() or 'auth' in error_details.lower():
4094
+ guidance = "**Root Cause:** Authentication failed"
4095
+ else:
4096
+ guidance = "**Check the progress log above for details.**"
4097
+
4098
+ yield {
4099
+ 'response': f"""❌ **ThoughtSpot Deployment Failed**
4100
 
4101
  **Error Details:**
4102
  ```
 
4112
  **Next Steps:**
4113
  - Type **'retry'** to try again
4114
  - Or fix the issues above first""",
4115
+ 'stage': 'deploy'
4116
+ }
4117
 
4118
  except Exception as e:
4119
  import traceback
 
4263
  current_model = gr.State(default_settings['model'])
4264
  current_company = gr.State(default_settings['company'])
4265
  current_usecase = gr.State(default_settings['use_case'])
4266
+ current_liveboard_name = gr.State(default_settings.get('liveboard_name', ''))
4267
 
4268
  # Header
4269
  gr.Markdown("""
 
4527
  gr.Markdown("### System-Wide Settings")
4528
  gr.Markdown("These settings apply to all users. Only admins can view and edit.")
4529
 
4530
+ # Hidden fields — values still saved/loaded but not shown in UI
4531
+ admin_ts_url = gr.Textbox(visible=False)
4532
+ admin_ts_auth_key = gr.Textbox(visible=False)
4533
+ admin_openai_key = gr.Textbox(visible=False)
4534
+ admin_google_key = gr.Textbox(visible=False)
4535
+
4536
  with gr.Row():
4537
  with gr.Column():
4538
  gr.Markdown("#### ThoughtSpot Connection")
 
4539
  admin_ts_user = gr.Textbox(label="ThoughtSpot Admin Username", placeholder="admin@company.com")
4540
+ admin_share_with = gr.Textbox(
4541
+ label="Default Share With (User or Group)",
4542
+ placeholder="user@company.com or group-name",
4543
+ info="System-wide default: model + liveboard shared here after every build"
4544
+ )
4545
+
4546
  with gr.Column():
4547
  gr.Markdown("#### Snowflake Connection")
4548
  admin_sf_account = gr.Textbox(label="Snowflake Account")
 
4566
  admin_sf_account, admin_sf_kp_user, admin_sf_kp_pk,
4567
  admin_sf_kp_pass, admin_sf_role, admin_sf_warehouse,
4568
  admin_sf_database, admin_sf_sso_user,
4569
+ admin_share_with,
4570
  ]
4571
+
4572
  admin_keys_order = [
4573
  "THOUGHTSPOT_URL", "THOUGHTSPOT_TRUSTED_AUTH_KEY", "THOUGHTSPOT_ADMIN_USER",
4574
  "OPENAI_API_KEY", "GOOGLE_API_KEY",
4575
  "SNOWFLAKE_ACCOUNT", "SNOWFLAKE_KP_USER", "SNOWFLAKE_KP_PK",
4576
  "SNOWFLAKE_KP_PASSPHRASE", "SNOWFLAKE_ROLE", "SNOWFLAKE_WAREHOUSE",
4577
  "SNOWFLAKE_DATABASE", "SNOWFLAKE_SSO_USER",
4578
+ "SHARE_WITH",
4579
  ]
4580
 
4581
  def load_admin_settings_handler():
 
4622
  outputs=admin_fields + [admin_settings_status]
4623
  )
4624
 
4625
+ # --- Session Log Viewer ---
4626
+ gr.Markdown("---")
4627
+ gr.Markdown("### 📋 Session Logs")
4628
+
4629
+ with gr.Row():
4630
+ log_user_filter = gr.Textbox(label="Filter by user (email, blank=all)", scale=2)
4631
+ log_limit = gr.Dropdown(label="Show", choices=["25", "50", "100"], value="50", scale=1)
4632
+ log_refresh_btn = gr.Button("🔄 Refresh", scale=1)
4633
+
4634
+ session_log_display = gr.Dataframe(
4635
+ headers=["Time", "User", "Stage", "Event", "Duration (ms)", "Error"],
4636
+ label="Recent Sessions",
4637
+ interactive=False,
4638
+ wrap=True,
4639
+ )
4640
+
4641
+ def load_session_logs(user_filter, limit):
4642
+ """Load session logs from Supabase session_logs table."""
4643
+ try:
4644
+ from supabase_client import SupabaseSettings
4645
+ ss = SupabaseSettings()
4646
+ if not ss.is_enabled():
4647
+ return [["Supabase not configured", "", "", "", "", ""]]
4648
+
4649
+ query = ss.client.table("session_logs").select(
4650
+ "ts,user_email,stage,event,duration_ms,error"
4651
+ ).order("ts", desc=True).limit(int(limit))
4652
+
4653
+ if user_filter and user_filter.strip():
4654
+ query = query.ilike("user_email", f"%{user_filter.strip()}%")
4655
+
4656
+ result = query.execute()
4657
+ rows = []
4658
+ for r in result.data:
4659
+ ts = r.get("ts", "")[:19].replace("T", " ") # trim to seconds
4660
+ rows.append([
4661
+ ts,
4662
+ r.get("user_email", ""),
4663
+ r.get("stage", ""),
4664
+ r.get("event", ""),
4665
+ str(r.get("duration_ms", "") or ""),
4666
+ r.get("error", "") or "",
4667
+ ])
4668
+ return rows if rows else [["No logs found", "", "", "", "", ""]]
4669
+ except Exception as e:
4670
+ return [[f"Error: {e}", "", "", "", "", ""]]
4671
+
4672
+ log_refresh_btn.click(
4673
+ fn=load_session_logs,
4674
+ inputs=[log_user_filter, log_limit],
4675
+ outputs=[session_log_display]
4676
+ )
4677
+
4678
  # Check admin status and toggle admin-only settings visibility
4679
  def check_admin_visibility(request: gr.Request):
4680
  """Check if logged-in user is admin and toggle settings visibility."""
 
4764
  company = (str(settings.get("default_company_url", "")).strip() or "Amazon.com")
4765
  use_case = (str(settings.get("default_use_case", "")).strip() or "Sales Analytics")
4766
  model = (str(settings.get("default_llm", "")).strip() or DEFAULT_LLM_MODEL)
4767
+ liveboard_name = str(settings.get("liveboard_name", "")).strip()
4768
  initial_message = build_initial_chat_message(company, use_case)
4769
 
4770
  return (
 
4772
  model,
4773
  company,
4774
  use_case,
4775
+ liveboard_name,
4776
  gr.update(value=model),
4777
+ gr.update(value=liveboard_name),
4778
  initial_message,
4779
  )
4780
+
4781
  # Wire up load handler - outputs follow SETTINGS_SCHEMA order
4782
  interface.load(
4783
  fn=load_settings_on_startup,
 
4793
  current_model,
4794
  current_company,
4795
  current_usecase,
4796
+ current_liveboard_name,
4797
  chat_components["model_dropdown"],
4798
+ chat_components["liveboard_name_input"],
4799
  chat_components["msg"],
4800
  ]
4801
  )
 
4831
  )
4832
 
4833
  with gr.Row():
 
 
 
 
 
4834
  msg = gr.Textbox(
4835
  label="Your message",
4836
+ value="",
4837
  placeholder="Type your message here or use /over to change settings...",
4838
  lines=1,
4839
  max_lines=1,
 
4851
 
4852
  # Right column - Status & Settings
4853
  with gr.Column(scale=1):
4854
+ # TS Environment selector
4855
+ ts_env_choices = get_ts_environments()
4856
+ ts_env_dropdown = gr.Dropdown(
4857
+ label="TS Environment",
4858
+ choices=ts_env_choices,
4859
+ value=ts_env_choices[0] if ts_env_choices else None,
4860
+ interactive=True,
 
4861
  )
4862
+
4863
+ # AI Model selector
4864
  model_dropdown = gr.Dropdown(
4865
  label="AI Model",
4866
  choices=list(UI_MODEL_CHOICES),
 
4868
  interactive=True,
4869
  allow_custom_value=True
4870
  )
4871
+
4872
+ # Liveboard name (quick access — same setting as in Settings tab)
4873
+ liveboard_name_input = gr.Textbox(
4874
+ label="Liveboard Name",
4875
+ placeholder="Auto-generated if blank",
4876
+ value=settings.get('liveboard_name', ''),
4877
+ lines=1,
4878
+ interactive=True,
4879
+ )
4880
+
4881
  gr.Markdown("### 📈 Progress")
4882
+
4883
+ # Stage order used to determine done/current/upcoming
4884
+ _STAGE_ORDER = [
4885
+ 'initialization', 'awaiting_context', 'research',
4886
+ 'create_ddl', 'deploy', 'populate',
4887
+ 'thoughtspot', 'outlier_adjustment', 'complete',
4888
+ ]
4889
+ # Each display step: (label, [stage keys that map to it])
4890
+ _PROGRESS_STEPS = [
4891
+ ('Init', ['initialization']),
4892
+ ('Research', ['awaiting_context', 'research']),
4893
+ ('DDL', ['create_ddl']),
4894
+ ('Data', ['deploy', 'populate']),
4895
+ ('ThoughtSpot', ['thoughtspot']),
4896
+ ('Data Adjuster',['outlier_adjustment']),
4897
+ ('Complete', ['complete']),
4898
+ ]
4899
+
4900
  def get_progress_html(stage):
4901
+ """Generate progress HTML showing done/current/upcoming states."""
4902
+ try:
4903
+ current_idx = _STAGE_ORDER.index(stage)
4904
+ except ValueError:
4905
+ current_idx = 0
4906
+
4907
+ # Find which display step is current
4908
+ current_step = None
4909
+ for step_label, step_keys in _PROGRESS_STEPS:
4910
+ for k in step_keys:
4911
+ if stage == k:
4912
+ current_step = step_label
4913
+ break
4914
+
4915
+ # Build ordered list of display steps that are reached
4916
+ reached = set()
4917
+ for step_label, step_keys in _PROGRESS_STEPS:
4918
+ for k in step_keys:
4919
+ try:
4920
+ if _STAGE_ORDER.index(k) <= current_idx:
4921
+ reached.add(step_label)
4922
+ except ValueError:
4923
+ pass
4924
+
4925
+ html = "<div style='padding:8px 4px; font-size:13px; line-height:1.8;'>"
4926
+ for step_label, _ in _PROGRESS_STEPS:
4927
+ # Skip Data Adjuster unless it's active
4928
+ if step_label == 'Data Adjuster' and step_label not in reached:
4929
+ continue
4930
+ if step_label == current_step:
4931
+ html += (f"<div style='margin:3px 0; color:#3b82f6; font-weight:bold;'>"
4932
+ f"▶ {step_label}</div>")
4933
+ elif step_label in reached:
4934
+ html += (f"<div style='margin:3px 0; color:#22c55e;'>"
4935
+ f"✓ {step_label}</div>")
4936
  else:
4937
+ html += (f"<div style='margin:3px 0; color:#9ca3af;'>"
4938
+ f" {step_label}</div>")
 
4939
  html += "</div>"
4940
  return html
4941
+
4942
  progress_html = gr.HTML(get_progress_html('initialization'))
4943
 
4944
  # Event handlers - each creates/uses session-specific controller
4945
+ def send_message(controller, message, history, stage, model, company, usecase, env_label=None, liveboard_name_ui=None, request: gr.Request = None):
4946
  """Handle sending a message - creates controller if needed"""
4947
+ import traceback
4948
  username = getattr(request, 'username', None) if request else None
4949
  if controller is None:
4950
  controller = ChatDemoInterface(user_email=username)
4951
  print(f"[SESSION] Created new ChatDemoInterface for {username or 'anonymous'}")
4952
+ # Apply selected TS environment settings to the new controller
4953
+ if env_label:
4954
+ _url = get_ts_env_url(env_label)
4955
+ _key_value = get_ts_env_auth_key(env_label)
4956
+ if _url:
4957
+ controller.settings['thoughtspot_url'] = _url
4958
+ if _key_value:
4959
+ controller.settings['thoughtspot_trusted_auth_key'] = _key_value
4960
+ # Always use the current UI value — takes priority over DB-loaded default
4961
+ if liveboard_name_ui is not None:
4962
+ controller.settings['liveboard_name'] = liveboard_name_ui
4963
+ try:
4964
+ for result in controller.process_chat_message(
4965
+ message, history, stage, model, company, usecase
4966
+ ):
4967
+ new_stage = result[1] if len(result) > 1 else stage
4968
+ progress = get_progress_html(new_stage)
4969
+ yield (controller,) + result + (progress,)
4970
+ except Exception as e:
4971
+ err_tb = traceback.format_exc()
4972
+ print(f"[ERROR] send_message unhandled exception:\n{err_tb}")
4973
+ err_msg = (
4974
+ f"❌ **An unexpected error occurred**\n\n"
4975
+ f"`{type(e).__name__}: {e}`\n\n"
4976
+ f"The pipeline has been interrupted. You can try again or start a new session."
4977
+ )
4978
+ history = history or []
4979
+ history.append((message, err_msg))
4980
+ yield (controller, history, stage, model, company, usecase, "", get_progress_html(stage))
4981
+
4982
+ def quick_action(controller, action_text, history, stage, model, company, usecase, env_label=None, liveboard_name_ui=None, request: gr.Request = None):
4983
  """Handle quick action button clicks"""
4984
  username = getattr(request, 'username', None) if request else None
4985
  if controller is None:
4986
  controller = ChatDemoInterface(user_email=username)
4987
  print(f"[SESSION] Created new ChatDemoInterface for {username or 'anonymous'}")
4988
+ # Apply selected TS environment settings to the new controller
4989
+ if env_label:
4990
+ _url = get_ts_env_url(env_label)
4991
+ _key_value = get_ts_env_auth_key(env_label)
4992
+ if _url:
4993
+ controller.settings['thoughtspot_url'] = _url
4994
+ if _key_value:
4995
+ controller.settings['thoughtspot_trusted_auth_key'] = _key_value
4996
+ # Always use the current UI value — takes priority over DB-loaded default
4997
+ if liveboard_name_ui is not None:
4998
+ controller.settings['liveboard_name'] = liveboard_name_ui
4999
+
5000
  for result in controller.process_chat_message(
5001
  action_text, history, stage, model, company, usecase
5002
  ):
 
5005
  yield (controller,) + result + (progress,)
5006
 
5007
  # Wire up send button and enter key
5008
+ _send_inputs = [chat_controller_state, msg, chatbot, current_stage, current_model, current_company, current_usecase, ts_env_dropdown, liveboard_name_input]
5009
+ _send_outputs = [chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase, msg, progress_html]
5010
+
5011
+ msg.submit(fn=send_message, inputs=_send_inputs, outputs=_send_outputs)
5012
+ send_btn.click(fn=send_message, inputs=_send_inputs, outputs=_send_outputs)
 
 
 
 
 
 
5013
 
5014
  # Quick action wrapper functions
5015
+ def start_action(controller, history, stage, model, company, usecase, env_label, liveboard_name_ui):
5016
+ yield from quick_action(controller, "Start research", history, stage, model, company, usecase, env_label, liveboard_name_ui)
5017
+
5018
+ def configure_action(controller, history, stage, model, company, usecase, env_label, liveboard_name_ui):
5019
+ yield from quick_action(controller, "Configure settings", history, stage, model, company, usecase, env_label, liveboard_name_ui)
5020
+
5021
+ def help_action(controller, history, stage, model, company, usecase, env_label, liveboard_name_ui):
5022
+ yield from quick_action(controller, "Help", history, stage, model, company, usecase, env_label, liveboard_name_ui)
5023
+
5024
+ _action_inputs = [chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase, ts_env_dropdown, liveboard_name_input]
5025
+ _action_outputs = [chat_controller_state, chatbot, current_stage, current_model, current_company, current_usecase, msg, progress_html]
5026
+
5027
  # Quick action buttons
5028
+ start_btn.click(fn=start_action, inputs=_action_inputs, outputs=_action_outputs)
5029
+ configure_btn.click(fn=configure_action, inputs=_action_inputs, outputs=_action_outputs)
5030
+ help_btn.click(fn=help_action, inputs=_action_inputs, outputs=_action_outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5031
 
5032
  # Model dropdown change
5033
  def update_model(new_model, controller, history):
5034
  if controller is not None:
5035
  controller.settings['model'] = new_model
5036
  return new_model, history
5037
+
5038
  model_dropdown.change(
5039
  fn=update_model,
5040
  inputs=[model_dropdown, chat_controller_state, chatbot],
5041
  outputs=[current_model, chatbot]
5042
  )
5043
+
5044
+ # Liveboard name change — update controller settings in real-time
5045
+ def update_liveboard_name(name, controller):
5046
+ if controller is not None:
5047
+ controller.settings['liveboard_name'] = name
5048
+ return name
5049
+
5050
+ liveboard_name_input.change(
5051
+ fn=update_liveboard_name,
5052
+ inputs=[liveboard_name_input, chat_controller_state],
5053
+ outputs=[liveboard_name_input]
5054
+ )
5055
+
5056
+ # TS environment change — update controller settings in real-time
5057
+ def update_ts_env(label, controller):
5058
+ url = get_ts_env_url(label)
5059
+ auth_key_value = get_ts_env_auth_key(label)
5060
+ if controller is not None:
5061
+ if url:
5062
+ controller.settings['thoughtspot_url'] = url
5063
+ if auth_key_value:
5064
+ controller.settings['thoughtspot_trusted_auth_key'] = auth_key_value
5065
+ return label
5066
+
5067
+ ts_env_dropdown.change(
5068
+ fn=update_ts_env,
5069
+ inputs=[ts_env_dropdown, chat_controller_state],
5070
+ outputs=[]
5071
+ )
5072
 
5073
  # Return components for external access
5074
  return {
 
5076
  'msg': msg,
5077
  'model_dropdown': model_dropdown,
5078
  'send_btn': send_btn,
5079
+ 'send_btn_ref': send_btn,
5080
+ 'ts_env_dropdown': ts_env_dropdown,
5081
+ 'liveboard_name_input': liveboard_name_input,
5082
  'progress_html': progress_html
5083
  }
5084
 
 
5091
  gr.Markdown("## ⚙️ Configuration Settings")
5092
  gr.Markdown("Configure your demo builder preferences")
5093
 
5094
+ # Default Settings the three fields that appear on the chat page too
5095
+ gr.Markdown("### ⭐ Default Settings")
5096
+ gr.Markdown("*These also appear on the chat page — set them once here as defaults.*")
5097
+
5098
+ # default_company_url removed — company is set via chat conversation
5099
+ default_company_url = gr.Textbox(visible=False)
5100
+
5101
+ # Build use case choices from VERTICALS × FUNCTIONS matrix
5102
+ use_case_choices = []
5103
+ for v_name in VERTICALS:
5104
+ for f_name in FUNCTIONS:
5105
+ use_case_choices.append(f"{v_name} {f_name}")
5106
+ use_case_choices.append("Custom (type in chat)")
5107
+
5108
  with gr.Row():
5109
+ default_ai_model = gr.Dropdown(
5110
+ label="Default AI Model",
5111
+ choices=list(UI_MODEL_CHOICES),
5112
+ value=DEFAULT_LLM_MODEL,
5113
+ info="Primary model for demo generation",
5114
+ allow_custom_value=True
5115
+ )
5116
+ default_use_case = gr.Dropdown(
5117
+ label="Default Use Case",
5118
+ choices=use_case_choices,
5119
+ value="Retail Sales",
5120
+ info="Vertical × Function combination, or type any custom use case"
5121
+ )
5122
+ liveboard_name = gr.Textbox(
5123
+ label="Default Liveboard Name",
5124
+ placeholder="My Demo Liveboard",
5125
+ value="",
5126
+ info="Default name for generated liveboards (overridable on chat page)"
5127
+ )
 
 
 
 
 
 
 
 
5128
 
5129
+ gr.Markdown("---")
5130
+ gr.Markdown("### 🔧 App Settings")
5131
+
5132
+ with gr.Row():
5133
+ with gr.Column():
 
 
 
 
 
5134
  tag_name = gr.Textbox(
5135
  label="Tag Name",
5136
  placeholder="e.g., 'Sales_Demo' or 'Q4_2024'",
5137
  value="",
5138
  info="Tag to apply to all ThoughtSpot objects (connection, tables, model, liveboard)"
5139
  )
5140
+
 
5141
  fact_table_size = gr.Dropdown(
5142
  label="Fact Table Size",
5143
  choices=["1000", "10000", "100000"],
5144
  value="1000",
5145
  info="Number of rows in fact table"
5146
  )
5147
+
5148
  dim_table_size = gr.Dropdown(
5149
  label="Dim Table Size",
5150
  choices=["50", "100", "1000"],
5151
  value="100",
5152
  info="Number of rows in dimension tables"
5153
  )
5154
+
5155
+ with gr.Column():
5156
  object_naming_prefix = gr.Textbox(
5157
  label="Object Naming Prefix",
5158
  placeholder="e.g., 'ACME_' or 'DEMO_'",
5159
  value="",
5160
  info="Prefix for ThoughtSpot objects (for future use)"
5161
  )
5162
+
5163
  column_naming_style = gr.Dropdown(
5164
  label="Column Naming Style",
5165
  choices=["Regular Case", "snake_case", "camelCase", "PascalCase", "UPPER_CASE", "original"],
5166
  value="Regular Case",
5167
  info="Naming convention for ThoughtSpot model columns (Regular Case = State Id, Total Revenue)"
5168
  )
 
5169
 
5170
+ # Hidden — use_existing_model kept in schema for backward compat but not shown
5171
+ use_existing_model = gr.Checkbox(visible=False, value=False)
5172
+ existing_model_guid = gr.Textbox(visible=False, value="")
5173
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5174
  gr.Markdown("---")
5175
  gr.Markdown("### 🌍 Other Settings")
5176
 
 
5234
  admin_db_accordion = gr.Accordion("💾 Database Connections (Admin)", open=False, visible=True)
5235
  with admin_db_accordion:
5236
  gr.Markdown("""
5237
+ **⚠️ Note:** These fields are **legacy placeholders** and are **not used by deploy runtime**.
5238
+ Active deployment credentials come from **Admin Settings** (system-wide `__admin__` values in Supabase).
5239
  This section is for reference/future use only.
5240
  """)
5241
 
 
5282
  with gr.Column():
5283
  gr.Markdown("### 📊 ThoughtSpot Settings")
5284
 
5285
+ # ts_instance_url removed — replaced by TS Environment dropdown on front page
5286
+ ts_instance_url = gr.Textbox(visible=False)
 
 
 
5287
 
5288
  ts_username = gr.Textbox(
5289
  label="ThoughtSpot Username",
 
5299
  )
5300
 
5301
  gr.Markdown("---")
5302
+ gr.Markdown("### 🔧 Data Adjuster")
5303
+ gr.Markdown("*Jump straight to data adjustment on an existing liveboard — skips the build pipeline entirely.*")
5304
+
5305
+ with gr.Row():
5306
+ with gr.Column():
5307
+ data_adjuster_url = gr.Textbox(
5308
+ label="Liveboard URL",
5309
+ placeholder="https://your-instance.thoughtspot.cloud/#/pinboard/guid",
5310
+ value="",
5311
+ info="Paste a ThoughtSpot liveboard URL to open it directly in Data Adjuster"
5312
+ )
5313
+
5314
+ gr.Markdown("---")
5315
+ gr.Markdown("### 🔗 Sharing")
5316
+ gr.Markdown("*After building, model and liveboard are shared with this user or group (can edit).*")
5317
+ with gr.Row():
5318
+ with gr.Column():
5319
+ share_with = gr.Textbox(
5320
+ label="Share With (User or Group)",
5321
+ placeholder="your.email@company.com or group-name",
5322
+ value="",
5323
+ info="User email (contains @) or group name. Leave blank to skip sharing."
5324
+ )
5325
+
5326
+ gr.Markdown("---")
5327
  with gr.Row():
5328
  save_settings_btn = gr.Button("💾 Save Settings", variant="primary", size="lg")
5329
  reset_settings_btn = gr.Button("🔄 Reset to Defaults", size="lg")
 
5389
  'default_schema': default_schema,
5390
  'ts_instance_url': ts_instance_url,
5391
  'ts_username': ts_username,
5392
+ 'data_adjuster_url': data_adjuster_url,
5393
+ 'share_with': share_with,
5394
  # Status
5395
  'settings_status': settings_status,
5396
  # Admin-only visibility toggles
conversational_data_adjuster.py DELETED
@@ -1,448 +0,0 @@
1
- """
2
- Conversational Data Adjuster
3
-
4
- Allows natural language data adjustments with strategy selection:
5
- User: "Make 1080p webcam sales 50B"
6
- System: Analyzes data, presents options
7
- User: Picks strategy
8
- System: Executes SQL
9
- """
10
-
11
- from typing import Dict, List, Optional
12
- from snowflake_auth import get_snowflake_connection
13
- import json
14
- from llm_config import DEFAULT_LLM_MODEL, build_openai_chat_token_kwargs
15
- from llm_client_factory import create_openai_client
16
-
17
-
18
- class ConversationalDataAdjuster:
19
- """Interactive data adjustment with user choice of strategy"""
20
-
21
- def __init__(self, database: str, schema: str, model_id: str):
22
- self.database = database
23
- self.schema = schema
24
- self.model_id = model_id
25
- self.conn = None
26
- self.openai_client = create_openai_client()
27
- self.current_context = {}
28
-
29
- def connect(self):
30
- """Connect to Snowflake"""
31
- self.conn = get_snowflake_connection()
32
- cursor = self.conn.cursor()
33
- cursor.execute(f"USE DATABASE {self.database}")
34
- cursor.execute(f'USE SCHEMA "{self.schema}"') # Quote schema name (may start with number)
35
- print(f"✅ Connected to {self.database}.{self.schema}")
36
-
37
- def parse_adjustment_request(self, request: str, available_tables: List[str]) -> Dict:
38
- """
39
- Parse natural language request to identify what to adjust
40
-
41
- Args:
42
- request: e.g., "increase 1080p webcam sales to 50B"
43
- available_tables: List of table names in schema
44
-
45
- Returns:
46
- {
47
- 'table': 'SALES_TRANSACTIONS',
48
- 'entity_column': 'product_name',
49
- 'entity_value': '1080p webcam',
50
- 'metric_column': 'total_revenue',
51
- 'target_value': 50000000000,
52
- 'current_value': 30000000000 # if known
53
- }
54
- """
55
- prompt = f"""Parse this data adjustment request.
56
-
57
- Request: "{request}"
58
-
59
- Available tables: {', '.join(available_tables)}
60
-
61
- Common columns:
62
- - SALES_TRANSACTIONS: PRODUCT_ID, CUSTOMER_ID, SELLER_ID, TOTAL_REVENUE, QUANTITY_SOLD, PROFIT_MARGIN, ORDER_DATE
63
- - PRODUCTS: PRODUCT_ID, PRODUCT_NAME, CATEGORY
64
- - CUSTOMERS: CUSTOMER_ID, CUSTOMER_SEGMENT
65
-
66
- IMPORTANT - Column Meanings:
67
- - TOTAL_REVENUE = dollar value of sales (e.g., $50B means fifty billion dollars)
68
- - QUANTITY_SOLD = number of units sold (e.g., 1000 units)
69
-
70
- When user says "sales", "revenue", or dollar amounts → use TOTAL_REVENUE
71
- When user says "quantity", "units", or "items sold" → use QUANTITY_SOLD
72
-
73
- Note: To filter by product name, you'll need to reference PRODUCTS table or use PRODUCT_ID directly
74
-
75
- Extract:
76
- 1. table: Which table to modify (likely SALES_TRANSACTIONS for revenue/sales changes)
77
- 2. entity_column: Column to filter by (e.g., product_name, customer_segment)
78
- 3. entity_value: Specific value to filter (e.g., "1080p webcam", "Electronics")
79
- 4. metric_column: Numeric column to change
80
- - If request mentions "sales", "revenue", or dollar amounts → TOTAL_REVENUE
81
- - If request mentions "quantity", "units", "items" → QUANTITY_SOLD
82
- - If request mentions "profit margin" → PROFIT_MARGIN
83
- 5. target_value: The target numeric value (convert B to billions, M to millions)
84
-
85
- Return ONLY valid JSON: {{"table": "...", "entity_column": "...", "entity_value": "...", "metric_column": "...", "target_value": 123}}
86
-
87
- Examples:
88
- - "increase 1080p webcam sales to 50B" → {{"table": "SALES_TRANSACTIONS", "entity_column": "PRODUCT_ID", "entity_value": "1080p Webcam", "metric_column": "TOTAL_REVENUE", "target_value": 50000000000, "needs_join": "PRODUCTS", "join_column": "PRODUCT_NAME"}}
89
- - "make tablet revenue 100 billion" → {{"table": "SALES_TRANSACTIONS", "entity_column": "PRODUCT_ID", "entity_value": "Tablet", "metric_column": "TOTAL_REVENUE", "target_value": 100000000000, "needs_join": "PRODUCTS", "join_column": "PRODUCT_NAME"}}
90
- - "increase laptop quantity to 50000 units" → {{"table": "SALES_TRANSACTIONS", "entity_column": "PRODUCT_ID", "entity_value": "Laptop", "metric_column": "QUANTITY_SOLD", "target_value": 50000, "needs_join": "PRODUCTS", "join_column": "PRODUCT_NAME"}}
91
- - "set profit margin to 25% for electronics" → {{"table": "SALES_TRANSACTIONS", "entity_column": "PRODUCT_ID", "entity_value": "electronics", "metric_column": "PROFIT_MARGIN", "target_value": 25, "needs_join": "PRODUCTS", "join_column": "CATEGORY"}}
92
-
93
- If the entity refers to a column not in the target table (e.g., product_name when modifying SALES_TRANSACTIONS),
94
- include "needs_join" with the table name and "join_column" with the column to match on.
95
- """
96
-
97
- response = self.openai_client.chat.completions.create(
98
- model=DEFAULT_LLM_MODEL,
99
- messages=[{"role": "user", "content": prompt}],
100
- temperature=0,
101
- **build_openai_chat_token_kwargs(DEFAULT_LLM_MODEL, 2000),
102
- )
103
-
104
- content = response.choices[0].message.content
105
-
106
- # Strip markdown code blocks if present
107
- if content.startswith('```'):
108
- lines = content.split('\n')
109
- content = '\n'.join(lines[1:-1]) # Remove first and last line (``` markers)
110
-
111
- try:
112
- result = json.loads(content)
113
- print(f"✅ Parsed request: {result.get('entity_value')} - {result.get('metric_column')}")
114
- return result
115
- except json.JSONDecodeError as e:
116
- print(f"❌ Failed to parse JSON: {e}")
117
- print(f"Content was: {content}")
118
- return {'error': f'Failed to parse request: {content}'}
119
-
120
- def analyze_current_data(self, adjustment: Dict) -> Dict:
121
- """
122
- Query current state of the data
123
-
124
- Returns:
125
- {
126
- 'current_total': float,
127
- 'row_count': int,
128
- 'avg_value': float,
129
- 'min_value': float,
130
- 'max_value': float,
131
- 'gap': float # target - current
132
- }
133
- """
134
- cursor = self.conn.cursor()
135
-
136
- table = adjustment['table']
137
- entity_col = adjustment['entity_column']
138
- entity_val = adjustment['entity_value']
139
- metric_col = adjustment['metric_column']
140
- target = adjustment['target_value']
141
-
142
- # Build WHERE clause - handle joins if needed
143
- if adjustment.get('needs_join'):
144
- join_table = adjustment['needs_join']
145
- join_col = adjustment['join_column']
146
- where_clause = f"""WHERE {entity_col} IN (
147
- SELECT PRODUCT_ID FROM {self.database}."{self.schema}".{join_table}
148
- WHERE LOWER({join_col}) = LOWER('{entity_val}')
149
- )"""
150
- else:
151
- where_clause = f"WHERE LOWER({entity_col}) = LOWER('{entity_val}')"
152
-
153
- # Query current state
154
- query = f"""
155
- SELECT
156
- SUM({metric_col}) as total,
157
- COUNT(*) as row_count,
158
- AVG({metric_col}) as avg_value,
159
- MIN({metric_col}) as min_value,
160
- MAX({metric_col}) as max_value
161
- FROM {self.database}."{self.schema}".{table}
162
- {where_clause}
163
- """
164
-
165
- print(f"\n🔍 Analyzing current data...")
166
- print(f" Query: {query}")
167
-
168
- cursor.execute(query)
169
- row = cursor.fetchone()
170
-
171
- current_total = float(row[0]) if row[0] else 0
172
- row_count = int(row[1])
173
- avg_value = float(row[2]) if row[2] else 0
174
- min_value = float(row[3]) if row[3] else 0
175
- max_value = float(row[4]) if row[4] else 0
176
-
177
- gap = target - current_total
178
-
179
- return {
180
- 'current_total': current_total,
181
- 'row_count': row_count,
182
- 'avg_value': avg_value,
183
- 'min_value': min_value,
184
- 'max_value': max_value,
185
- 'gap': gap
186
- }
187
-
188
- def generate_strategy_options(self, adjustment: Dict, analysis: Dict) -> List[Dict]:
189
- """
190
- Generate 3 strategy options for achieving the target
191
-
192
- Returns list of strategies with details
193
- """
194
- table = adjustment['table']
195
- entity_col = adjustment['entity_column']
196
- entity_val = adjustment['entity_value']
197
- metric_col = adjustment['metric_column']
198
- target = adjustment['target_value']
199
-
200
- # Build WHERE clause - handle joins if needed
201
- if adjustment.get('needs_join'):
202
- join_table = adjustment['needs_join']
203
- join_col = adjustment['join_column']
204
- where_clause = f"""{entity_col} IN (
205
- SELECT PRODUCT_ID FROM {self.database}."{self.schema}".{join_table}
206
- WHERE LOWER({join_col}) = LOWER('{entity_val}')
207
- )"""
208
- else:
209
- where_clause = f"LOWER({entity_col}) = LOWER('{entity_val}')"
210
-
211
- current = analysis['current_total']
212
- gap = analysis['gap']
213
- row_count = analysis['row_count']
214
-
215
- if gap <= 0:
216
- return [{
217
- 'id': 'decrease',
218
- 'name': 'Decrease All',
219
- 'description': f"Current value ({current:,.0f}) already exceeds target ({target:,.0f})",
220
- 'sql': None
221
- }]
222
-
223
- strategies = []
224
-
225
- # Strategy A: Distribute increase across all rows
226
- multiplier = target / current if current > 0 else 1
227
- percentage_increase = (multiplier - 1) * 100
228
-
229
- strategies.append({
230
- 'id': 'A',
231
- 'name': 'Distribute Across All Transactions',
232
- 'description': f"Increase all {row_count:,} existing transactions by {percentage_increase:.1f}%",
233
- 'details': {
234
- 'approach': 'Multiply all existing values',
235
- 'rows_affected': row_count,
236
- 'new_avg': analysis['avg_value'] * multiplier
237
- },
238
- 'sql': f"""UPDATE {self.database}."{self.schema}".{table}
239
- SET {metric_col} = {metric_col} * {multiplier:.6f}
240
- WHERE {where_clause}"""
241
- })
242
-
243
- # Strategy B: Add new large transactions
244
- num_new_transactions = max(1, int(gap / (analysis['max_value'] * 2))) # Add transactions 2x the current max
245
- value_per_new = gap / num_new_transactions
246
-
247
- strategies.append({
248
- 'id': 'B',
249
- 'name': 'Add New Large Transactions',
250
- 'description': f"Insert {num_new_transactions} new transactions of ${value_per_new:,.0f} each",
251
- 'details': {
252
- 'approach': 'Create new outlier transactions',
253
- 'rows_to_add': num_new_transactions,
254
- 'value_each': value_per_new
255
- },
256
- 'sql': f"""-- INSERT new transactions (requires full row data)
257
- -- INSERT INTO {self.database}."{self.schema}".{table} ({entity_col}, {metric_col}, ...)
258
- -- VALUES ('{entity_val}', {value_per_new}, ...)
259
- -- NOTE: This requires knowing all required columns in the table"""
260
- })
261
-
262
- # Strategy C: Boost top transactions
263
- top_n = min(10, max(1, row_count // 10)) # Top 10% of transactions
264
- boost_needed_per_row = gap / top_n
265
-
266
- strategies.append({
267
- 'id': 'C',
268
- 'name': 'Boost Top Transactions',
269
- 'description': f"Increase the top {top_n} transactions by ${boost_needed_per_row:,.0f} each",
270
- 'details': {
271
- 'approach': 'Create outliers from existing top transactions',
272
- 'rows_affected': top_n,
273
- 'boost_per_row': boost_needed_per_row
274
- },
275
- 'sql': f"""WITH top_rows AS (
276
- SELECT * FROM {self.database}."{self.schema}".{table}
277
- WHERE {where_clause}
278
- ORDER BY {metric_col} DESC
279
- LIMIT {top_n}
280
- )
281
- UPDATE {self.database}."{self.schema}".{table} t
282
- SET {metric_col} = {metric_col} + {boost_needed_per_row:.2f}
283
- WHERE EXISTS (
284
- SELECT 1 FROM top_rows
285
- WHERE top_rows.rowid = t.rowid
286
- )"""
287
- })
288
-
289
- return strategies
290
-
291
- def present_options(self, adjustment: Dict, analysis: Dict, strategies: List[Dict]) -> None:
292
- """Display options to user in a friendly format"""
293
-
294
- print("\n" + "="*80)
295
- print("📊 DATA ADJUSTMENT OPTIONS")
296
- print("="*80)
297
-
298
- entity = f"{adjustment['entity_column']}='{adjustment['entity_value']}'"
299
- metric = adjustment['metric_column']
300
-
301
- print(f"\n🎯 Goal: Adjust {metric} for {entity}")
302
- print(f" Current Total: ${analysis['current_total']:,.0f}")
303
- print(f" Target Total: ${adjustment['target_value']:,.0f}")
304
- print(f" Gap to Fill: ${analysis['gap']:,.0f} ({analysis['gap']/analysis['current_total']*100:.1f}% increase)")
305
-
306
- print(f"\n📈 Current Data:")
307
- print(f" Rows: {analysis['row_count']:,}")
308
- print(f" Average: ${analysis['avg_value']:,.0f}")
309
- print(f" Range: ${analysis['min_value']:,.0f} - ${analysis['max_value']:,.0f}")
310
-
311
- print(f"\n" + "="*80)
312
- print("STRATEGY OPTIONS:")
313
- print("="*80)
314
-
315
- for strategy in strategies:
316
- print(f"\n[{strategy['id']}] {strategy['name']}")
317
- print(f" {strategy['description']}")
318
-
319
- if 'details' in strategy:
320
- details = strategy['details']
321
- print(f" Details:")
322
- for key, value in details.items():
323
- if isinstance(value, float):
324
- print(f" - {key}: ${value:,.0f}")
325
- else:
326
- print(f" - {key}: {value}")
327
-
328
- if strategy['sql']:
329
- print(f"\n SQL Preview:")
330
- sql_preview = strategy['sql'].strip().split('\n')
331
- for line in sql_preview[:3]: # Show first 3 lines
332
- print(f" {line}")
333
- if len(sql_preview) > 3:
334
- print(f" ... ({len(sql_preview)-3} more lines)")
335
-
336
- print("\n" + "="*80)
337
-
338
- def execute_strategy(self, strategy: Dict) -> Dict:
339
- """Execute the chosen strategy"""
340
-
341
- if not strategy['sql']:
342
- return {
343
- 'success': False,
344
- 'error': 'This strategy requires manual implementation (INSERT statements)'
345
- }
346
-
347
- cursor = self.conn.cursor()
348
-
349
- print(f"\n⚙️ Executing strategy: {strategy['name']}")
350
- print(f" SQL: {strategy['sql'][:200]}...")
351
-
352
- try:
353
- cursor.execute(strategy['sql'])
354
- rows_affected = cursor.rowcount
355
- self.conn.commit()
356
-
357
- return {
358
- 'success': True,
359
- 'message': f"✅ Updated {rows_affected} rows",
360
- 'rows_affected': rows_affected
361
- }
362
- except Exception as e:
363
- self.conn.rollback()
364
- return {
365
- 'success': False,
366
- 'error': str(e)
367
- }
368
-
369
- def get_available_tables(self) -> List[str]:
370
- """Get list of tables in the schema"""
371
- cursor = self.conn.cursor()
372
- cursor.execute(f"""
373
- SELECT TABLE_NAME
374
- FROM {self.database}.INFORMATION_SCHEMA.TABLES
375
- WHERE TABLE_SCHEMA = '{self.schema}'
376
- """)
377
- tables = [row[0] for row in cursor.fetchall()]
378
- return tables
379
-
380
- def close(self):
381
- """Close connection"""
382
- if self.conn:
383
- self.conn.close()
384
-
385
-
386
- # Test/demo function
387
- def demo_conversation():
388
- """Simulate the conversational flow"""
389
-
390
- print("""
391
- ╔════════════════════════════════════════════════════════════╗
392
- ║ ║
393
- ║ Conversational Data Adjuster Demo ║
394
- ║ ║
395
- ╚════════════════════════════════════════════════════════════╝
396
- """)
397
-
398
- # Setup from environment variables
399
- from dotenv import load_dotenv
400
- load_dotenv()
401
-
402
- adjuster = ConversationalDataAdjuster(
403
- database=os.getenv('SNOWFLAKE_DATABASE'),
404
- schema="20251116_140933_AMAZO_SAL", # Schema from deployment
405
- model_id="3c97b0d6-448b-440a-b628-bac1f3d73049"
406
- )
407
-
408
- print(f"Using database: {os.getenv('SNOWFLAKE_DATABASE')}")
409
- print(f"Using schema: 20251116_140933_AMAZO_SAL")
410
-
411
- adjuster.connect()
412
-
413
- # User request (using actual product from our data)
414
- user_request = "increase 1080p Webcam sales to 50 billion"
415
- print(f"\n💬 User: \"{user_request}\"")
416
- print(f" (Current: ~$17.6B, Target: $50B)")
417
-
418
- # Step 1: Parse request
419
- tables = adjuster.get_available_tables()
420
- adjustment = adjuster.parse_adjustment_request(user_request, tables)
421
-
422
- # Step 2: Analyze current data
423
- analysis = adjuster.analyze_current_data(adjustment)
424
-
425
- # Step 3: Generate strategies
426
- strategies = adjuster.generate_strategy_options(adjustment, analysis)
427
-
428
- # Step 4: Present options
429
- adjuster.present_options(adjustment, analysis, strategies)
430
-
431
- # Step 5: User picks (simulated)
432
- print("\n💬 User: \"Use strategy A\"")
433
- chosen_strategy = strategies[0] # Strategy A
434
-
435
- # Step 6: Execute
436
- result = adjuster.execute_strategy(chosen_strategy)
437
-
438
- if result['success']:
439
- print(f"\n{result['message']}")
440
- else:
441
- print(f"\n❌ Error: {result.get('error')}")
442
-
443
- adjuster.close()
444
-
445
-
446
- if __name__ == "__main__":
447
- demo_conversation()
448
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_adjuster.py DELETED
@@ -1,213 +0,0 @@
1
- """
2
- Data Adjustment Module for Liveboard Refinement
3
-
4
- Allows natural language adjustments to demo data:
5
- - "make Product A 55% higher"
6
- - "increase Customer B revenue by 20%"
7
- - "set profit margin to 15% for Segment C"
8
- """
9
-
10
- import re
11
- from typing import Dict, Optional
12
- from snowflake_auth import get_snowflake_connection
13
- from llm_config import DEFAULT_LLM_MODEL, build_openai_chat_token_kwargs
14
- from llm_client_factory import create_openai_client
15
-
16
-
17
- class DataAdjuster:
18
- """Adjust demo data based on natural language requests"""
19
-
20
- def __init__(self, database: str, schema: str):
21
- self.database = database
22
- self.schema = schema
23
- self.conn = None
24
- self.openai_client = create_openai_client()
25
-
26
- def connect(self):
27
- """Connect to Snowflake"""
28
- self.conn = get_snowflake_connection()
29
- self.conn.cursor().execute(f"USE DATABASE {self.database}")
30
- self.conn.cursor().execute(f"USE SCHEMA {self.schema}")
31
- print(f"✅ Connected to {self.database}.{self.schema}")
32
-
33
- def parse_adjustment_request(self, request: str, available_columns: list) -> Dict:
34
- """
35
- Parse natural language adjustment request using AI
36
-
37
- Args:
38
- request: e.g., "make Product A 55% higher" or "increase revenue for Customer B by 20%"
39
- available_columns: List of column names in the data
40
-
41
- Returns:
42
- {
43
- 'entity_column': 'product_name',
44
- 'entity_value': 'Product A',
45
- 'metric_column': 'total_revenue',
46
- 'adjustment_type': 'percentage_increase',
47
- 'adjustment_value': 55
48
- }
49
- """
50
- prompt = f"""Parse this data adjustment request and extract structured information.
51
-
52
- Request: "{request}"
53
-
54
- Available columns in the dataset: {', '.join(available_columns)}
55
-
56
- Extract:
57
- 1. entity_column: Which column identifies what to change (e.g., product_name, customer_segment)
58
- 2. entity_value: The specific value to filter by (e.g., "Product A", "Electronics")
59
- 3. metric_column: Which numeric column to adjust (e.g., total_revenue, profit_margin, quantity_sold)
60
- 4. adjustment_type: One of: "percentage_increase", "percentage_decrease", "set_value", "add_value"
61
- 5. adjustment_value: The numeric value (e.g., 55 for "55%", 1000 for "add 1000")
62
-
63
- Return ONLY a JSON object with these fields. If you can't parse it, return {{"error": "description"}}.
64
-
65
- Examples:
66
- - "make Product A 55% higher" → {{"entity_column": "product_name", "entity_value": "Product A", "metric_column": "total_revenue", "adjustment_type": "percentage_increase", "adjustment_value": 55}}
67
- - "set profit margin to 15% for Electronics" → {{"entity_column": "category", "entity_value": "Electronics", "metric_column": "profit_margin", "adjustment_type": "set_value", "adjustment_value": 15}}
68
- """
69
-
70
- response = self.openai_client.chat.completions.create(
71
- model=DEFAULT_LLM_MODEL,
72
- messages=[{"role": "user", "content": prompt}],
73
- temperature=0,
74
- **build_openai_chat_token_kwargs(DEFAULT_LLM_MODEL, 2000),
75
- )
76
-
77
- import json
78
- result = json.loads(response.choices[0].message.content)
79
- return result
80
-
81
- def get_available_columns(self, table_name: str) -> list:
82
- """Get list of columns from a table"""
83
- cursor = self.conn.cursor()
84
- cursor.execute(f"""
85
- SELECT COLUMN_NAME
86
- FROM {self.database}.INFORMATION_SCHEMA.COLUMNS
87
- WHERE TABLE_SCHEMA = '{self.schema}'
88
- AND TABLE_NAME = '{table_name.upper()}'
89
- """)
90
- columns = [row[0].lower() for row in cursor.fetchall()]
91
- return columns
92
-
93
- def apply_adjustment(self, table_name: str, adjustment: Dict) -> Dict:
94
- """
95
- Apply the parsed adjustment to the database
96
-
97
- Returns:
98
- {'success': bool, 'message': str, 'rows_affected': int}
99
- """
100
- if 'error' in adjustment:
101
- return {'success': False, 'message': adjustment['error']}
102
-
103
- cursor = self.conn.cursor()
104
-
105
- # Build the UPDATE statement
106
- entity_col = adjustment['entity_column']
107
- entity_val = adjustment['entity_value']
108
- metric_col = adjustment['metric_column']
109
- adj_type = adjustment['adjustment_type']
110
- adj_value = adjustment['adjustment_value']
111
-
112
- # Calculate new value based on adjustment type
113
- if adj_type == 'percentage_increase':
114
- new_value_expr = f"{metric_col} * (1 + {adj_value}/100.0)"
115
- elif adj_type == 'percentage_decrease':
116
- new_value_expr = f"{metric_col} * (1 - {adj_value}/100.0)"
117
- elif adj_type == 'set_value':
118
- new_value_expr = f"{adj_value}"
119
- elif adj_type == 'add_value':
120
- new_value_expr = f"{metric_col} + {adj_value}"
121
- else:
122
- return {'success': False, 'message': f"Unknown adjustment type: {adj_type}"}
123
-
124
- # Execute UPDATE
125
- update_sql = f"""
126
- UPDATE {self.database}.{self.schema}.{table_name}
127
- SET {metric_col} = {new_value_expr}
128
- WHERE LOWER({entity_col}) = LOWER('{entity_val}')
129
- """
130
-
131
- print(f"\n🔧 Executing adjustment:")
132
- print(f" SQL: {update_sql}")
133
-
134
- try:
135
- cursor.execute(update_sql)
136
- rows_affected = cursor.rowcount
137
- self.conn.commit()
138
-
139
- return {
140
- 'success': True,
141
- 'message': f"Updated {rows_affected} rows: {entity_col}='{entity_val}', adjusted {metric_col} by {adj_type}",
142
- 'rows_affected': rows_affected
143
- }
144
- except Exception as e:
145
- return {
146
- 'success': False,
147
- 'message': f"Database error: {str(e)}"
148
- }
149
-
150
- def adjust_data_for_liveboard(self, request: str, table_name: str) -> Dict:
151
- """
152
- Full workflow: parse request, update data
153
-
154
- Args:
155
- request: Natural language request like "make Product A 55% higher"
156
- table_name: Name of the table to update
157
-
158
- Returns:
159
- Result dictionary with success status and details
160
- """
161
- if not self.conn:
162
- self.connect()
163
-
164
- # Get available columns
165
- columns = self.get_available_columns(table_name)
166
- print(f"📋 Available columns: {', '.join(columns)}")
167
-
168
- # Parse the request
169
- print(f"\n🤔 Parsing request: '{request}'")
170
- adjustment = self.parse_adjustment_request(request, columns)
171
- print(f"✅ Parsed: {adjustment}")
172
-
173
- if 'error' in adjustment:
174
- return {'success': False, 'error': adjustment['error']}
175
-
176
- # Apply the adjustment
177
- result = self.apply_adjustment(table_name, adjustment)
178
-
179
- return result
180
-
181
- def close(self):
182
- """Close database connection"""
183
- if self.conn:
184
- self.conn.close()
185
-
186
-
187
- # Example usage function
188
- def test_data_adjustment():
189
- """Test the data adjustment functionality"""
190
- adjuster = DataAdjuster(
191
- database="DEMO_DATABASE",
192
- schema="DEMO_SCHEMA"
193
- )
194
-
195
- # Example: "make Product A 55% higher"
196
- result = adjuster.adjust_data_for_liveboard(
197
- request="make Product A 55% higher",
198
- table_name="FACT_SALES"
199
- )
200
-
201
- print(f"\n{'='*60}")
202
- if result['success']:
203
- print(f"✅ SUCCESS: {result['message']}")
204
- print(f"📊 Rows affected: {result['rows_affected']}")
205
- else:
206
- print(f"❌ FAILED: {result.get('error', result.get('message'))}")
207
-
208
- adjuster.close()
209
-
210
-
211
- if __name__ == "__main__":
212
- test_data_adjustment()
213
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo_personas.py CHANGED
@@ -3,6 +3,8 @@ Demo Personas and Use Case Configurations
3
  All persona data and prompt templates for use case-driven demo preparation
4
  """
5
 
 
 
6
  from schema_utils import extract_key_business_terms
7
 
8
  # ============================================================================
@@ -12,27 +14,92 @@ from schema_utils import extract_key_business_terms
12
  # Keep USE_CASE_PERSONAS below for backward compatibility during transition
13
  # ============================================================================
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # VERTICALS: Industry-specific context
16
  VERTICALS = {
17
  "Retail": {
18
  "typical_entities": ["Store", "Product", "Category", "Region", "Customer"],
19
  "industry_terms": ["SKU", "basket", "shrink", "markdown", "comp sales", "footfall"],
20
  "data_patterns": ["seasonality", "holiday_spikes", "weather_impact", "back_to_school"],
 
 
 
 
 
 
21
  },
22
  "Banking": {
23
  "typical_entities": ["Account", "Customer", "Branch", "Product", "Loan"],
24
  "industry_terms": ["AUM", "NIM", "deposits", "charge-off", "delinquency", "APR"],
25
  "data_patterns": ["month_end_spikes", "rate_sensitivity", "quarter_close"],
 
 
 
 
 
 
26
  },
27
  "Software": {
28
  "typical_entities": ["Account", "User", "Subscription", "Feature", "License"],
29
  "industry_terms": ["ARR", "MRR", "churn", "NRR", "seats", "expansion"],
30
  "data_patterns": ["renewal_cycles", "usage_spikes", "trial_conversion"],
 
 
 
 
 
 
31
  },
32
  "Manufacturing": {
33
  "typical_entities": ["Plant", "Line", "Product", "Supplier", "Shift"],
34
  "industry_terms": ["OEE", "yield", "scrap", "downtime", "throughput", "WIP"],
35
  "data_patterns": ["shift_patterns", "maintenance_cycles", "supply_disruptions"],
 
 
 
 
 
 
36
  },
37
  }
38
 
@@ -53,6 +120,18 @@ FUNCTIONS = {
53
  "Why did {kpi} drop last month?",
54
  "Compare {kpi} across {dimension}",
55
  ],
 
 
 
 
 
 
 
 
 
 
 
 
56
  },
57
  "Supply Chain": {
58
  "kpis": ["Avg Inventory", "OTIF", "Days on Hand", "Stockout Rate"],
@@ -69,6 +148,19 @@ FUNCTIONS = {
69
  "Show inventory levels by {dimension}",
70
  "Which suppliers have the longest lead times?",
71
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  },
73
  "Marketing": {
74
  "kpis": ["CTR", "Bounce Rate", "Fill Rate", "Approval Rate"],
@@ -85,6 +177,102 @@ FUNCTIONS = {
85
  "Show me the funnel for {campaign}",
86
  "Which channel has the highest CTR?",
87
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  },
89
  }
90
 
@@ -100,6 +288,18 @@ MATRIX_OVERRIDES = {
100
  "add_viz": ["by_store", "by_category"],
101
  "target_persona": "VP Merchandising, Retail Sales Leader",
102
  "business_problem": "$1T lost annually to stockouts and overstock",
 
 
 
 
 
 
 
 
 
 
 
 
103
  },
104
  ("Banking", "Marketing"): {
105
  "add_kpis": ["Application Fill Rate", "Cost per Acquisition"],
@@ -110,6 +310,18 @@ MATRIX_OVERRIDES = {
110
  "rename_kpis": {"CTR": "Click-through Rate"},
111
  "target_persona": "CMO, VP Digital Marketing",
112
  "business_problem": "High cost per acquisition, low funnel conversion",
 
 
 
 
 
 
 
 
 
 
 
 
113
  },
114
  ("Software", "Sales"): {
115
  "add_kpis": ["ARR", "Net Revenue Retention", "Pipeline Coverage"],
@@ -120,7 +332,124 @@ MATRIX_OVERRIDES = {
120
  },
121
  "add_viz": ["by_segment", "by_rep"],
122
  "target_persona": "CRO, VP Sales",
 
 
 
 
 
 
 
 
 
 
 
 
123
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  }
125
 
126
 
@@ -142,6 +471,10 @@ def parse_use_case(user_input: str) -> tuple[str | None, str | None]:
142
  return (None, None)
143
 
144
  user_input_lower = user_input.strip().lower()
 
 
 
 
145
 
146
  # Try to find both vertical and function in the input
147
  found_vertical = None
@@ -211,6 +544,14 @@ def get_use_case_config(vertical: str, function: str) -> dict:
211
  "viz_types": f.get("viz_types", []).copy(),
212
  "outlier_categories": f.get("outlier_categories", []).copy(),
213
  "spotter_templates": f.get("spotter_templates", []).copy(),
 
 
 
 
 
 
 
 
214
 
215
  # Flags
216
  "is_generic": False,
@@ -233,6 +574,22 @@ def get_use_case_config(vertical: str, function: str) -> dict:
233
  config["target_persona"] = override["target_persona"]
234
  if override.get("business_problem"):
235
  config["business_problem"] = override["business_problem"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  # Handle generic cases
238
  if not is_known_vertical and not is_known_function:
@@ -240,6 +597,13 @@ def get_use_case_config(vertical: str, function: str) -> dict:
240
  config["is_generic"] = True
241
  config["ai_should_determine"] = ["entities", "industry_terms", "kpis", "viz_types", "outliers"]
242
  config["prompt_user_for"] = ["key_metrics", "target_persona", "business_questions"]
 
 
 
 
 
 
 
243
  elif not is_known_vertical:
244
  # Known function, unknown vertical
245
  config["is_generic"] = True
@@ -264,6 +628,8 @@ def get_use_case_config(vertical: str, function: str) -> dict:
264
  config["cost_impact"] = "Significant business impact through data-driven decisions"
265
  if "success_outcomes" not in config:
266
  config["success_outcomes"] = f"Improved {function.lower()} performance and faster insights"
 
 
267
 
268
  return config
269
 
 
3
  All persona data and prompt templates for use case-driven demo preparation
4
  """
5
 
6
+ from copy import deepcopy
7
+
8
  from schema_utils import extract_key_business_terms
9
 
10
  # ============================================================================
 
14
  # Keep USE_CASE_PERSONAS below for backward compatibility during transition
15
  # ============================================================================
16
 
17
+ # Story controls are intentionally strict and deterministic for demo realism.
18
+ DEFAULT_STORY_CONTROLS = {
19
+ "realism_mode": "strict",
20
+ "trend_style": "smooth",
21
+ "trend_noise_band_pct": 0.04,
22
+ "seasonal_strength": 0.10,
23
+ "reproducibility_seed_strategy": "company_use_case_hash",
24
+ "outlier_budget": {
25
+ "max_events": 2,
26
+ "max_points_pct": 0.02,
27
+ "event_multiplier_min": 1.8,
28
+ "event_multiplier_max": 3.2,
29
+ },
30
+ "value_guardrails": {
31
+ "allow_negative_revenue": False,
32
+ "max_mom_change_pct": 35,
33
+ },
34
+ "quality_targets": {
35
+ "min_semantic_pass_ratio": 0.98,
36
+ },
37
+ }
38
+
39
+
40
+ def _deep_merge_dict(base: dict, update: dict) -> dict:
41
+ """Deep merge dictionaries while preserving base keys."""
42
+ merged = deepcopy(base)
43
+ for key, val in (update or {}).items():
44
+ if isinstance(val, dict) and isinstance(merged.get(key), dict):
45
+ merged[key] = _deep_merge_dict(merged[key], val)
46
+ else:
47
+ merged[key] = deepcopy(val)
48
+ return merged
49
+
50
+
51
+ def _merge_story_controls(*controls: dict) -> dict:
52
+ merged = deepcopy(DEFAULT_STORY_CONTROLS)
53
+ for control in controls:
54
+ merged = _deep_merge_dict(merged, control or {})
55
+ return merged
56
+
57
+
58
  # VERTICALS: Industry-specific context
59
  VERTICALS = {
60
  "Retail": {
61
  "typical_entities": ["Store", "Product", "Category", "Region", "Customer"],
62
  "industry_terms": ["SKU", "basket", "shrink", "markdown", "comp sales", "footfall"],
63
  "data_patterns": ["seasonality", "holiday_spikes", "weather_impact", "back_to_school"],
64
+ "story_controls": {
65
+ "seasonal_strength": 0.18,
66
+ "trend_noise_band_pct": 0.035,
67
+ "outlier_budget": {"max_events": 3},
68
+ "recommended_companies": ["Amazon.com", "Walmart.com", "Target.com"],
69
+ },
70
  },
71
  "Banking": {
72
  "typical_entities": ["Account", "Customer", "Branch", "Product", "Loan"],
73
  "industry_terms": ["AUM", "NIM", "deposits", "charge-off", "delinquency", "APR"],
74
  "data_patterns": ["month_end_spikes", "rate_sensitivity", "quarter_close"],
75
+ "story_controls": {
76
+ "seasonal_strength": 0.06,
77
+ "trend_noise_band_pct": 0.025,
78
+ "outlier_budget": {"max_events": 2},
79
+ "recommended_companies": ["Chase.com", "CapitalOne.com", "BankOfAmerica.com"],
80
+ },
81
  },
82
  "Software": {
83
  "typical_entities": ["Account", "User", "Subscription", "Feature", "License"],
84
  "industry_terms": ["ARR", "MRR", "churn", "NRR", "seats", "expansion"],
85
  "data_patterns": ["renewal_cycles", "usage_spikes", "trial_conversion"],
86
+ "story_controls": {
87
+ "seasonal_strength": 0.05,
88
+ "trend_noise_band_pct": 0.03,
89
+ "outlier_budget": {"max_events": 2},
90
+ "recommended_companies": ["Salesforce.com", "HubSpot.com", "Atlassian.com"],
91
+ },
92
  },
93
  "Manufacturing": {
94
  "typical_entities": ["Plant", "Line", "Product", "Supplier", "Shift"],
95
  "industry_terms": ["OEE", "yield", "scrap", "downtime", "throughput", "WIP"],
96
  "data_patterns": ["shift_patterns", "maintenance_cycles", "supply_disruptions"],
97
+ "story_controls": {
98
+ "seasonal_strength": 0.08,
99
+ "trend_noise_band_pct": 0.03,
100
+ "outlier_budget": {"max_events": 2},
101
+ "recommended_companies": ["Caterpillar.com", "GE.com", "3M.com"],
102
+ },
103
  },
104
  }
105
 
 
120
  "Why did {kpi} drop last month?",
121
  "Compare {kpi} across {dimension}",
122
  ],
123
+ "liveboard_questions": [
124
+ {"title": "Revenue Trend", "viz_type": "LINE", "viz_question": "Dollar Sales trend by month", "required": True, "insight": "Monthly revenue trend reveals growth patterns and seasonality", "spotter_qs": ["What is the sales trend?", "Show me revenue by month", "When was our peak sales period?"]},
125
+ {"title": "Sales by Region", "viz_type": "COLUMN", "viz_question": "Dollar Sales by Region", "required": True, "insight": "Regional performance comparison reveals geographic variance", "spotter_qs": ["Which region has the highest sales?", "Compare regional performance", "Show me the top performing regions"]},
126
+ {"title": "ASP Trend", "viz_type": "KPI", "viz_question": "ASP weekly", "required": False, "insight": "Average selling price trend shows pricing health", "spotter_qs": ["Why did ASP change last month?", "Show me ASP by product", "Which products have the biggest discount?"]},
127
+ {"title": "Sales by Category", "viz_type": "COLUMN", "viz_question": "Dollar Sales by Category", "required": False, "insight": "Category mix reveals product performance drivers", "spotter_qs": ["Which category grew fastest?", "Compare top categories", "Show me category performance trend"]},
128
+ ],
129
+ "story_controls": {
130
+ "trend_style": "smooth",
131
+ "trend_noise_band_pct": 0.03,
132
+ "allowed_dimensions": ["Region", "Product", "Category", "Store", "Channel"],
133
+ "allowed_measures": ["Dollar Sales", "Unit Sales", "ASP"],
134
+ },
135
  },
136
  "Supply Chain": {
137
  "kpis": ["Avg Inventory", "OTIF", "Days on Hand", "Stockout Rate"],
 
148
  "Show inventory levels by {dimension}",
149
  "Which suppliers have the longest lead times?",
150
  ],
151
+ "liveboard_questions": [
152
+ {"title": "Inventory Levels", "viz_type": "KPI", "viz_question": "Average inventory by month", "required": True, "insight": "Inventory trend over time", "spotter_qs": ["What is current inventory level?", "Show me inventory trend by month"]},
153
+ {"title": "Stockout Risk by Product", "viz_type": "COLUMN", "viz_question": "Stockout rate by product", "required": True, "insight": "Products at risk of going out of stock", "spotter_qs": ["Which products are at risk of stockout?", "Show stockout rate by category"]},
154
+ {"title": "OTIF Performance", "viz_type": "KPI", "viz_question": "OTIF rate by month", "required": False, "insight": "On-time in-full delivery performance", "spotter_qs": ["What is our OTIF rate?", "Which suppliers have the lowest OTIF?"]},
155
+ {"title": "Days on Hand by Supplier", "viz_type": "COLUMN", "viz_question": "Days on hand by supplier", "required": False, "insight": "Supplier inventory coverage comparison", "spotter_qs": ["Which supplier has the highest days on hand?", "Show me inventory coverage by supplier"]},
156
+ ],
157
+ "story_controls": {
158
+ "trend_style": "smooth",
159
+ "trend_noise_band_pct": 0.025,
160
+ "allowed_dimensions": ["Supplier", "Warehouse", "Region", "Product", "Category"],
161
+ "allowed_measures": ["Avg Inventory", "OTIF", "Days on Hand", "Stockout Rate"],
162
+ "outlier_budget": {"max_events": 2},
163
+ },
164
  },
165
  "Marketing": {
166
  "kpis": ["CTR", "Bounce Rate", "Fill Rate", "Approval Rate"],
 
177
  "Show me the funnel for {campaign}",
178
  "Which channel has the highest CTR?",
179
  ],
180
+ "liveboard_questions": [
181
+ {"title": "Conversion Funnel", "viz_type": "COLUMN", "viz_question": "Conversion rate by funnel stage", "required": True, "insight": "Drop-off at each stage reveals where prospects are lost", "spotter_qs": ["Where is our biggest funnel drop-off?", "What is our application completion rate?"]},
182
+ {"title": "CTR by Channel", "viz_type": "COLUMN", "viz_question": "CTR by channel", "required": True, "insight": "Click-through rate comparison reveals best-performing channels", "spotter_qs": ["Which channel has the best CTR?", "Compare channel performance"]},
183
+ {"title": "Campaign ROI Trend", "viz_type": "KPI", "viz_question": "Campaign ROI by month", "required": False, "insight": "Return on marketing investment over time", "spotter_qs": ["Which campaigns have the best ROI?", "Show me ROI trend"]},
184
+ {"title": "Cost per Acquisition Trend", "viz_type": "LINE", "viz_question": "Cost per acquisition trend by month", "required": False, "insight": "CPA trend reveals acquisition efficiency", "spotter_qs": ["How is our CPA trending?", "Which channel has the lowest CPA?"]},
185
+ ],
186
+ "story_controls": {
187
+ "trend_style": "smooth",
188
+ "trend_noise_band_pct": 0.04,
189
+ "allowed_dimensions": ["Channel", "Campaign", "Segment", "Region"],
190
+ "allowed_measures": ["CTR", "Bounce Rate", "Fill Rate", "Approval Rate"],
191
+ "outlier_budget": {"max_events": 2},
192
+ },
193
+ },
194
+ "Finance": {
195
+ "kpis": [
196
+ "ARR",
197
+ "MRR",
198
+ "NRR",
199
+ "GRR",
200
+ "Net New ARR",
201
+ "Billings",
202
+ "Collections",
203
+ "Deferred Revenue",
204
+ "CAC",
205
+ "Gross Margin %",
206
+ "DSO",
207
+ ],
208
+ "kpi_definitions": {
209
+ "ARR": "Ending annual recurring revenue for the period",
210
+ "MRR": "Monthly recurring revenue",
211
+ "NRR": "(Starting ARR + Expansion - Contraction - Churn) ÷ Starting ARR",
212
+ "GRR": "(Starting ARR - Contraction - Churn) ÷ Starting ARR",
213
+ "Net New ARR": "New Logo ARR + Expansion ARR - Contraction ARR - Churned ARR",
214
+ "Billings": "Amount invoiced in the period",
215
+ "Collections": "Cash collected in the period",
216
+ "Deferred Revenue": "Billed but not yet recognized revenue balance",
217
+ "CAC": "Sales and marketing spend ÷ new customers acquired",
218
+ "Gross Margin %": "Gross Margin ÷ Revenue",
219
+ "DSO": "Average days sales outstanding",
220
+ },
221
+ "viz_types": [
222
+ "KPI_sparkline",
223
+ "arr_bridge",
224
+ "retention_trend",
225
+ "cash_flow",
226
+ "segment_comparison",
227
+ "by_region",
228
+ ],
229
+ "outlier_categories": [
230
+ "expansion_uplift",
231
+ "churn_spike",
232
+ "collections_lag",
233
+ "upsell_conversion",
234
+ "billing_prepay",
235
+ ],
236
+ "spotter_templates": [
237
+ "Why did ARR jump in Q4?",
238
+ "What is driving NRR by segment?",
239
+ "Show billings versus collections trend by month",
240
+ "Which region has the highest CAC?",
241
+ ],
242
+ "liveboard_questions": [
243
+ {"title": "ARR Trend", "viz_type": "KPI", "viz_question": "ARR by month", "required": True, "insight": "Annual recurring revenue growth trend", "spotter_qs": ["Why did ARR jump in Q4?", "Show me ARR trend by segment"]},
244
+ {"title": "ARR Bridge", "viz_type": "COLUMN", "viz_question": "New logo, expansion, contraction, and churn ARR by month", "required": True, "insight": "Waterfall breakdown of ARR movements", "spotter_qs": ["What is driving net new ARR?", "How much ARR did we lose to churn?"]},
245
+ {"title": "NRR by Segment", "viz_type": "COLUMN", "viz_question": "Net revenue retention by segment", "required": False, "insight": "Retention health across customer segments", "spotter_qs": ["What is driving NRR by segment?", "Which segment has the lowest retention?"]},
246
+ {"title": "CAC by Region", "viz_type": "COLUMN", "viz_question": "CAC by region", "required": False, "insight": "Customer acquisition cost comparison", "spotter_qs": ["Which region has the highest CAC?", "How is CAC trending?"]},
247
+ ],
248
+ "story_controls": {
249
+ "trend_style": "smooth",
250
+ "trend_noise_band_pct": 0.02,
251
+ "seasonal_strength": 0.04,
252
+ "allowed_dimensions": [
253
+ "Segment",
254
+ "Region",
255
+ "Product Family",
256
+ "Plan Tier",
257
+ "Contract Type",
258
+ "Customer Tier",
259
+ ],
260
+ "allowed_measures": [
261
+ "ARR",
262
+ "MRR",
263
+ "NRR",
264
+ "GRR",
265
+ "Net New ARR",
266
+ "Billings",
267
+ "Collections",
268
+ "Deferred Revenue",
269
+ "CAC",
270
+ "Gross Margin %",
271
+ "DSO",
272
+ ],
273
+ "outlier_budget": {"max_events": 2, "max_points_pct": 0.015},
274
+ "value_guardrails": {"max_mom_change_pct": 20},
275
+ },
276
  },
277
  }
278
 
 
288
  "add_viz": ["by_store", "by_category"],
289
  "target_persona": "VP Merchandising, Retail Sales Leader",
290
  "business_problem": "$1T lost annually to stockouts and overstock",
291
+ "liveboard_questions": [
292
+ {"title": "ASP Decline", "viz_type": "KPI", "viz_question": "average selling price by week", "required": True, "insight": "ASP dropped even though revenue is up — discounting too heavily", "spotter_qs": ["Why did ASP drop last month?", "Which products have the biggest discount?", "Show me ASP by region"]},
293
+ {"title": "Regional Variance", "viz_type": "COLUMN", "viz_question": "total sales revenue by region", "required": True, "insight": "West region outperforming by 40% — what are they doing differently?", "spotter_qs": ["Which region has the highest sales?", "Compare West to East performance", "What products are driving the top region?"]},
294
+ {"title": "Seasonal Trend", "viz_type": "LINE", "viz_question": "total sales revenue by category and month", "required": False, "insight": "Holiday surge 3x normal — were we prepared?", "spotter_qs": ["Show me sales trend for Q4", "When was our peak sales day?"]},
295
+ {"title": "Category Surge", "viz_type": "COLUMN", "viz_question": "total sales revenue by product category", "required": False, "insight": "Electronics up 60% YoY while Apparel flat", "spotter_qs": ["Which category grew fastest?", "Compare Electronics to Apparel"]},
296
+ ],
297
+ "story_controls": {
298
+ "seasonal_strength": 0.20,
299
+ "trend_noise_band_pct": 0.028,
300
+ "outlier_budget": {"max_events": 3, "max_points_pct": 0.025},
301
+ "allowed_dimensions": ["Store", "Category", "Region", "Product"],
302
+ },
303
  },
304
  ("Banking", "Marketing"): {
305
  "add_kpis": ["Application Fill Rate", "Cost per Acquisition"],
 
310
  "rename_kpis": {"CTR": "Click-through Rate"},
311
  "target_persona": "CMO, VP Digital Marketing",
312
  "business_problem": "High cost per acquisition, low funnel conversion",
313
+ "liveboard_questions": [
314
+ {"title": "Funnel Drop-off", "viz_type": "COLUMN", "viz_question": "Conversion rate by funnel stage", "required": True, "insight": "70% drop-off at application page — UX issue?", "spotter_qs": ["Where is our biggest funnel drop-off?", "What is our application completion rate?"]},
315
+ {"title": "Channel CTR", "viz_type": "COLUMN", "viz_question": "CTR by channel", "required": True, "insight": "Mobile CTR 2x desktop — shift budget?", "spotter_qs": ["Which channel has the best CTR?", "Compare mobile vs desktop performance"]},
316
+ {"title": "Cost per Acquisition Trend", "viz_type": "LINE", "viz_question": "Cost per acquisition by month", "required": False, "insight": "CPA trend over time reveals acquisition efficiency", "spotter_qs": ["How is CPA trending?", "Which channel has the lowest CPA?"]},
317
+ {"title": "Application Volume by Segment", "viz_type": "COLUMN", "viz_question": "Applications by customer segment", "required": False, "insight": "Which segments are converting on products?", "spotter_qs": ["Which segment submits the most applications?", "Compare approval rates by segment"]},
318
+ ],
319
+ "story_controls": {
320
+ "seasonal_strength": 0.07,
321
+ "trend_noise_band_pct": 0.03,
322
+ "outlier_budget": {"max_events": 2},
323
+ "allowed_dimensions": ["Channel", "Campaign", "Branch", "Segment"],
324
+ },
325
  },
326
  ("Software", "Sales"): {
327
  "add_kpis": ["ARR", "Net Revenue Retention", "Pipeline Coverage"],
 
332
  },
333
  "add_viz": ["by_segment", "by_rep"],
334
  "target_persona": "CRO, VP Sales",
335
+ "liveboard_questions": [
336
+ {"title": "ARR by Segment", "viz_type": "COLUMN", "viz_question": "ARR by segment", "required": True, "insight": "Revenue breakdown reveals which segments drive growth", "spotter_qs": ["Which segment has the highest ARR?", "Show ARR growth by segment"]},
337
+ {"title": "Pipeline Coverage", "viz_type": "KPI", "viz_question": "Pipeline coverage ratio by month", "required": True, "insight": "Pipeline vs quota ratio shows sales health", "spotter_qs": ["What is our pipeline coverage?", "Which rep has the lowest pipeline coverage?"]},
338
+ {"title": "Win Rate by Rep", "viz_type": "COLUMN", "viz_question": "Win rate by sales rep", "required": False, "insight": "Rep performance comparison reveals coaching opportunities", "spotter_qs": ["Which rep has the highest win rate?", "Compare win rates by region"]},
339
+ {"title": "Deal Velocity Trend", "viz_type": "LINE", "viz_question": "Average sales cycle length by month", "required": False, "insight": "Deal velocity trend shows if sales motion is tightening or slowing", "spotter_qs": ["How is deal cycle length trending?", "Which segment closes fastest?"]},
340
+ ],
341
+ "story_controls": {
342
+ "seasonal_strength": 0.05,
343
+ "trend_noise_band_pct": 0.025,
344
+ "outlier_budget": {"max_events": 2},
345
+ "allowed_dimensions": ["Segment", "Rep", "Region", "Product"],
346
+ },
347
  },
348
+ ("Software", "Finance"): {
349
+ "use_case_name": "SaaS Finance and Unit Economics",
350
+ "canonical_use_case": "SaaS Finance and Unit Economics",
351
+ "target_persona": "CFO, VP Finance, FP&A Leader",
352
+ "business_problem": (
353
+ "Recurring revenue growth can look healthy while retention, cash conversion, "
354
+ "and acquisition efficiency quietly deteriorate."
355
+ ),
356
+ "story_controls": {
357
+ "seasonal_strength": 0.03,
358
+ "trend_noise_band_pct": 0.015,
359
+ "outlier_budget": {"max_events": 2, "max_points_pct": 0.0125},
360
+ "value_guardrails": {"max_mom_change_pct": 20},
361
+ "allowed_dimensions": [
362
+ "Segment",
363
+ "Region",
364
+ "Product Family",
365
+ "Plan Tier",
366
+ "Contract Type",
367
+ "Customer Tier",
368
+ "Billing Cadence",
369
+ ],
370
+ },
371
+ "schema_contract": {
372
+ "mode": "saas_finance_gold",
373
+ "date_grain": "monthly",
374
+ "history_months": 24,
375
+ "required_dimensions": [
376
+ "DATES",
377
+ "CUSTOMERS",
378
+ "PRODUCTS",
379
+ "LOCATIONS",
380
+ ],
381
+ "required_facts": [
382
+ "SAAS_CUSTOMER_MONTHLY",
383
+ "SALES_MARKETING_SPEND_MONTHLY",
384
+ ],
385
+ "customer_month_fact": {
386
+ "table": "SAAS_CUSTOMER_MONTHLY",
387
+ "grain": "one row per customer per month",
388
+ "required_columns": [
389
+ "MONTH_KEY",
390
+ "CUSTOMER_KEY",
391
+ "PRODUCT_KEY",
392
+ "LOCATION_KEY",
393
+ "SCENARIO",
394
+ "STARTING_ARR_USD",
395
+ "NEW_LOGO_ARR_USD",
396
+ "EXPANSION_ARR_USD",
397
+ "CONTRACTION_ARR_USD",
398
+ "CHURNED_ARR_USD",
399
+ "ENDING_ARR_USD",
400
+ "MRR_USD",
401
+ "BILLINGS_USD",
402
+ "COLLECTIONS_USD",
403
+ "DEFERRED_REVENUE_USD",
404
+ "GROSS_MARGIN_USD",
405
+ "GROSS_MARGIN_PCT",
406
+ "SUPPORT_COST_USD",
407
+ "CAC_ALLOCATED_USD",
408
+ "DSO_DAYS",
409
+ "SEAT_COUNT",
410
+ "USAGE_UNITS",
411
+ "CUSTOMER_ACQUIRED_FLAG",
412
+ "CUSTOMER_CHURNED_FLAG",
413
+ ],
414
+ },
415
+ "spend_fact": {
416
+ "table": "SALES_MARKETING_SPEND_MONTHLY",
417
+ "grain": "one row per month per segment per region",
418
+ "required_columns": [
419
+ "MONTH_KEY",
420
+ "SEGMENT",
421
+ "REGION",
422
+ "SALES_SPEND_USD",
423
+ "MARKETING_SPEND_USD",
424
+ "TOTAL_S_AND_M_SPEND_USD",
425
+ "NEW_CUSTOMERS_ACQUIRED",
426
+ "NEW_LOGO_ARR_USD",
427
+ "NET_NEW_ARR_USD",
428
+ "NRR_PCT",
429
+ "GRR_PCT",
430
+ "CAC_USD",
431
+ ],
432
+ },
433
+ "identities": [
434
+ "ENDING_ARR_USD = STARTING_ARR_USD + NEW_LOGO_ARR_USD + EXPANSION_ARR_USD - CONTRACTION_ARR_USD - CHURNED_ARR_USD",
435
+ "MRR_USD = ENDING_ARR_USD / 12",
436
+ "TOTAL_S_AND_M_SPEND_USD = SALES_SPEND_USD + MARKETING_SPEND_USD",
437
+ "CAC_ALLOCATED_USD >= 0",
438
+ "COLLECTIONS_USD <= BILLINGS_USD + DEFERRED_REVENUE_USD",
439
+ ],
440
+ },
441
+ },
442
+ }
443
+
444
+
445
+ ROUTED_USE_CASES = {
446
+ "saas finance and unit economics": ("Software", "Finance"),
447
+ "saas finance": ("Software", "Finance"),
448
+ "software finance": ("Software", "Finance"),
449
+ "financial analytics": ("Software", "Finance"),
450
+ "finance analytics": ("Software", "Finance"),
451
+ "fp&a": ("Software", "Finance"),
452
+ "fpa": ("Software", "Finance"),
453
  }
454
 
455
 
 
471
  return (None, None)
472
 
473
  user_input_lower = user_input.strip().lower()
474
+
475
+ for phrase, routed in sorted(ROUTED_USE_CASES.items(), key=lambda item: len(item[0]), reverse=True):
476
+ if phrase in user_input_lower:
477
+ return routed
478
 
479
  # Try to find both vertical and function in the input
480
  found_vertical = None
 
544
  "viz_types": f.get("viz_types", []).copy(),
545
  "outlier_categories": f.get("outlier_categories", []).copy(),
546
  "spotter_templates": f.get("spotter_templates", []).copy(),
547
+ "story_controls": _merge_story_controls(
548
+ v.get("story_controls", {}),
549
+ f.get("story_controls", {}),
550
+ override.get("story_controls", {}),
551
+ ),
552
+ "schema_contract": deepcopy(override.get("schema_contract", {})),
553
+ # Override takes precedence; else use function-level questions; else empty
554
+ "liveboard_questions": override.get("liveboard_questions") or f.get("liveboard_questions", []),
555
 
556
  # Flags
557
  "is_generic": False,
 
574
  config["target_persona"] = override["target_persona"]
575
  if override.get("business_problem"):
576
  config["business_problem"] = override["business_problem"]
577
+ if override.get("use_case_name"):
578
+ config["use_case_name"] = override["use_case_name"]
579
+ if override.get("canonical_use_case"):
580
+ config["canonical_use_case"] = override["canonical_use_case"]
581
+
582
+ # Story guardrails derived from matrix controls
583
+ config["allowed_story_dimensions"] = (
584
+ config["story_controls"].get("allowed_dimensions")
585
+ or f.get("story_controls", {}).get("allowed_dimensions")
586
+ or config["entities"][:]
587
+ )
588
+ config["allowed_story_measures"] = (
589
+ config["story_controls"].get("allowed_measures")
590
+ or f.get("story_controls", {}).get("allowed_measures")
591
+ or config["kpis"][:]
592
+ )
593
 
594
  # Handle generic cases
595
  if not is_known_vertical and not is_known_function:
 
597
  config["is_generic"] = True
598
  config["ai_should_determine"] = ["entities", "industry_terms", "kpis", "viz_types", "outliers"]
599
  config["prompt_user_for"] = ["key_metrics", "target_persona", "business_questions"]
600
+ config["story_controls"] = _deep_merge_dict(
601
+ config["story_controls"],
602
+ {
603
+ "outlier_budget": {"max_events": 1},
604
+ "trend_noise_band_pct": 0.03,
605
+ },
606
+ )
607
  elif not is_known_vertical:
608
  # Known function, unknown vertical
609
  config["is_generic"] = True
 
628
  config["cost_impact"] = "Significant business impact through data-driven decisions"
629
  if "success_outcomes" not in config:
630
  config["success_outcomes"] = f"Improved {function.lower()} performance and faster insights"
631
+ if "recommended_companies" not in config:
632
+ config["recommended_companies"] = config["story_controls"].get("recommended_companies", [])
633
 
634
  return config
635
 
legitdata_project/README.md ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LegitData
2
+
3
+ > Generate realistic synthetic data for analytics warehouses - not dummy data.
4
+
5
+ LegitData uses AI to generate contextually appropriate, realistic data for your dimension and fact tables. Instead of random strings and numbers from Faker, get real product names, actual company names, and coherent data that tells a story.
6
+
7
+ ## Features
8
+
9
+ - **Smart Column Classification**: Automatically determines how to source data for each column
10
+ - `SEARCH_REAL`: Web search for real products, brands, companies
11
+ - `AI_GEN`: AI-generated contextual data (regions, categories, segments)
12
+ - `GENERIC`: Calculated values (dates, amounts, IDs)
13
+
14
+ - **Company-Aware Context**: Scrapes your company URL to understand your business and generate appropriate data
15
+
16
+ - **FK Integrity**: Maintains referential integrity with realistic distributions (Pareto for fact tables)
17
+
18
+ - **Size Presets**: Small, Medium, Large, XL presets for different use cases
19
+
20
+ - **Caching**: Caches search results and generated values for consistency and speed
21
+
22
+ ## Installation
23
+
24
+ ```bash
25
+ # Basic installation
26
+ pip install legitdata
27
+
28
+ # With Snowflake support
29
+ pip install legitdata[snowflake]
30
+
31
+ # With AI support (recommended)
32
+ pip install legitdata[ai]
33
+
34
+ # Everything
35
+ pip install legitdata[all]
36
+ ```
37
+
38
+ ## Quick Start
39
+
40
+ ### Python Library
41
+
42
+ ```python
43
+ from legitdata import LegitGenerator
44
+ from anthropic import Anthropic
45
+
46
+ # Initialize with your context
47
+ gen = LegitGenerator(
48
+ url="https://amazon.com",
49
+ use_case="Retail Analytics",
50
+ connection_string="snowflake://user:pass@account/database/schema",
51
+ anthropic_client=Anthropic() # Uses ANTHROPIC_API_KEY env var
52
+ )
53
+
54
+ # Load your schema
55
+ gen.load_ddl("schema.sql")
56
+
57
+ # Generate and insert data
58
+ results = gen.generate(size="medium")
59
+ # {'CUSTOMERS': 100, 'PRODUCTS': 100, 'SELLERS': 100, 'SALES_TRANSACTIONS': 1000}
60
+ ```
61
+
62
+ ### CLI
63
+
64
+ ```bash
65
+ # Generate data
66
+ legitdata generate \
67
+ --ddl schema.sql \
68
+ --url https://amazon.com \
69
+ --use-case "Retail Analytics" \
70
+ --connection "snowflake://user:pass@account/db/schema" \
71
+ --size medium
72
+
73
+ # Preview without writing
74
+ legitdata preview \
75
+ --ddl schema.sql \
76
+ --url https://amazon.com \
77
+ --use-case "Retail Analytics" \
78
+ --rows 5
79
+
80
+ # Manage cache
81
+ legitdata cache --stats
82
+ legitdata cache --clear
83
+ ```
84
+
85
+ ## Size Presets
86
+
87
+ | Size | Fact Rows | Dim Rows | Use Case |
88
+ |------|-----------|----------|----------|
89
+ | small | 100 | 20 | Quick testing |
90
+ | medium | 1,000 | 100 | Demo/dev |
91
+ | large | 10,000 | 500 | Realistic workload |
92
+ | xl | 100,000 | 500 | Performance testing |
93
+
94
+ Or specify exact counts:
95
+
96
+ ```python
97
+ gen.generate(row_counts={
98
+ "CUSTOMERS": 200,
99
+ "PRODUCTS": 50,
100
+ "SALES_TRANSACTIONS": 5000
101
+ })
102
+ ```
103
+
104
+ ## Use Cases
105
+
106
+ Built-in use cases that help guide data generation:
107
+
108
+ - Sales Analytics
109
+ - Supply Chain Analytics
110
+ - Customer Analytics
111
+ - Financial Analytics
112
+ - Marketing Analytics
113
+ - Retail Analytics
114
+
115
+ Or use any custom use case:
116
+
117
+ ```python
118
+ gen = LegitGenerator(
119
+ url="https://example.com",
120
+ use_case="Call Center Sentiment Analysis", # Custom!
121
+ ...
122
+ )
123
+ ```
124
+
125
+ ## How It Works
126
+
127
+ ### 1. Context Building
128
+ LegitData scrapes your company URL to understand:
129
+ - Company name and industry
130
+ - Products/services offered
131
+ - Target customers
132
+ - Geographic focus
133
+
134
+ ### 2. Column Classification
135
+ Each column is analyzed and classified:
136
+
137
+ | Column Type | Classification | Example |
138
+ |-------------|----------------|---------|
139
+ | ProductName | SEARCH_REAL | Web search for real products |
140
+ | Brand | SEARCH_REAL | Web search for real brands |
141
+ | Region | AI_GEN | AI generates realistic regions |
142
+ | CustomerSegment | AI_GEN | AI generates "Enterprise", "SMB", etc. |
143
+ | OrderDate | GENERIC | Random date in range |
144
+ | Revenue | GENERIC | Calculated decimal |
145
+ | CustomerID | GENERIC | FK reference |
146
+
147
+ ### 3. Data Generation
148
+ - **SEARCH_REAL**: Web search → AI extracts real values
149
+ - **AI_GEN**: AI generates contextually appropriate values
150
+ - **GENERIC**: Faker/calculated values
151
+
152
+ ### 4. FK Relationships
153
+ Fact tables use Pareto distribution (80/20) for realistic patterns:
154
+ - Top 20% of customers generate 80% of orders
155
+ - Popular products appear more frequently
156
+
157
+ ## Connection Strings
158
+
159
+ ### Snowflake
160
+ ```
161
+ snowflake://user:password@account/database/schema
162
+ snowflake://user:password@account/database/schema?warehouse=WH&role=ROLE
163
+ ```
164
+
165
+ ## Caching
166
+
167
+ LegitData caches:
168
+ - Company context (extracted from URL)
169
+ - Web search results
170
+ - Column classifications
171
+ - Generated values
172
+
173
+ Cache location: `.legitdata_cache/`
174
+
175
+ ```python
176
+ # Disable caching
177
+ gen = LegitGenerator(..., cache_enabled=False)
178
+
179
+ # Clear cache
180
+ from legitdata.cache import FileCache
181
+ cache = FileCache()
182
+ cache.clear()
183
+ ```
184
+
185
+ ## Preview Mode
186
+
187
+ Test generation without writing to database:
188
+
189
+ ```python
190
+ # Preview 5 rows per table
191
+ preview = gen.preview(num_rows=5)
192
+
193
+ for table, rows in preview.items():
194
+ print(f"{table}: {len(rows)} rows")
195
+ print(rows[0]) # First row
196
+ ```
197
+
198
+ Or use dry run:
199
+
200
+ ```python
201
+ gen = LegitGenerator(..., dry_run=True)
202
+ gen.generate() # Prints operations but doesn't write
203
+ ```
204
+
205
+ ## API Reference
206
+
207
+ ### LegitGenerator
208
+
209
+ ```python
210
+ LegitGenerator(
211
+ url: str, # Company website URL
212
+ use_case: str, # Analytics use case
213
+ connection_string: str, # Database connection
214
+ anthropic_client=None, # Anthropic client for AI
215
+ web_search_fn=None, # Custom web search function
216
+ cache_enabled=True, # Enable caching
217
+ cache_dir=".legitdata_cache", # Cache directory
218
+ dry_run=False # Don't write to DB
219
+ )
220
+ ```
221
+
222
+ ### Methods
223
+
224
+ - `load_ddl(ddl_or_path)` - Load DDL from string or file
225
+ - `generate(size, row_counts, truncate_first)` - Generate and insert data
226
+ - `preview(table_name, num_rows)` - Preview without inserting
227
+
228
+ ## Extending
229
+
230
+ ### Custom Web Search
231
+
232
+ ```python
233
+ def my_search(query: str) -> list[dict]:
234
+ # Your search implementation
235
+ return [{"title": "...", "snippet": "...", "url": "..."}]
236
+
237
+ gen = LegitGenerator(
238
+ ...,
239
+ web_search_fn=my_search
240
+ )
241
+ ```
242
+
243
+ ### Custom Database Writer
244
+
245
+ ```python
246
+ from legitdata.writers import BaseWriter
247
+
248
+ class MyWriter(BaseWriter):
249
+ def connect(self): ...
250
+ def disconnect(self): ...
251
+ def insert_rows(self, table, columns, rows, batch_size): ...
252
+ def truncate_table(self, table): ...
253
+ def table_exists(self, table): ...
254
+ def get_table_columns(self, table): ...
255
+ ```
256
+
257
+ ## Requirements
258
+
259
+ - Python 3.10+
260
+ - Snowflake Connector (for Snowflake support)
261
+ - Anthropic SDK (for AI features)
262
+
263
+ ## License
264
+
265
+ MIT
legitdata_project/legitdata/__init__.py CHANGED
@@ -18,6 +18,7 @@ from .config import SIZE_PRESETS, USE_CASES, GenerationConfig
18
  from .ddl import parse_ddl, parse_ddl_file, Schema, Table, Column, ColumnClassification
19
  from .analyzer import CompanyContext
20
  from .writers import SnowflakeWriter, DryRunWriter
 
21
 
22
  __version__ = "0.1.0"
23
 
@@ -35,4 +36,7 @@ __all__ = [
35
  'CompanyContext',
36
  'SnowflakeWriter',
37
  'DryRunWriter',
 
 
 
38
  ]
 
18
  from .ddl import parse_ddl, parse_ddl_file, Schema, Table, Column, ColumnClassification
19
  from .analyzer import CompanyContext
20
  from .writers import SnowflakeWriter, DryRunWriter
21
+ from .storyspec import StorySpec, build_story_spec, build_story_bundle
22
 
23
  __version__ = "0.1.0"
24
 
 
36
  'CompanyContext',
37
  'SnowflakeWriter',
38
  'DryRunWriter',
39
+ 'StorySpec',
40
+ 'build_story_spec',
41
+ 'build_story_bundle',
42
  ]
legitdata_project/legitdata/analyzer/column_classifier.py CHANGED
@@ -2,11 +2,20 @@
2
 
3
  import json
4
  import re
5
- from datetime import datetime
6
  from typing import Optional
7
  from ..ddl.models import Schema, Table, Column, ColumnClassification
8
  from ..config import COLUMN_HINTS
9
  from .context_builder import CompanyContext
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  class ColumnClassifier:
@@ -153,6 +162,7 @@ Return ONLY valid JSON, no other text."""
153
  """Classify a single column using heuristics."""
154
  col_name = column.name.lower()
155
  col_upper = column.name.upper()
 
156
 
157
  # Primary keys and identity columns
158
  if column.is_primary_key or column.is_identity:
@@ -161,6 +171,11 @@ Return ONLY valid JSON, no other text."""
161
  # Foreign keys
162
  if col_upper in fk_columns:
163
  return ColumnClassification.GENERIC, "foreign_key"
 
 
 
 
 
164
 
165
  # Check patterns for SEARCH_REAL
166
  for pattern in COLUMN_HINTS["search_real_patterns"]:
@@ -220,7 +235,13 @@ Return ONLY valid JSON, no other text."""
220
  col_name = column.name.lower()
221
  data_type = column.data_type.upper()
222
 
223
- # BOOLEAN COLUMNS - Check FIRST before phone (is_mobile should be boolean, not phone!)
 
 
 
 
 
 
224
  if col_name.startswith('is_') or col_name.startswith('has_') or col_name.endswith('_flag'):
225
  return "boolean:0.5"
226
 
@@ -231,14 +252,27 @@ Return ONLY valid JSON, no other text."""
231
  # Phone columns (but NOT is_mobile - that's boolean, checked above)
232
  if 'phone' in col_name or 'mobile' in col_name or 'tel' in col_name:
233
  return "phone"
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  # Name columns
236
  if col_name in ('first_name', 'firstname', 'fname', 'given_name'):
237
  return "first_name"
238
  if col_name in ('last_name', 'lastname', 'lname', 'surname'):
239
  return "last_name"
240
- if col_name in ('full_name', 'fullname', 'name', 'customer_name', 'user_name'):
241
  return "name"
 
 
242
 
243
  # Address columns
244
  if 'address' in col_name and 'email' not in col_name:
@@ -252,9 +286,46 @@ Return ONLY valid JSON, no other text."""
252
  if 'zip' in col_name or 'postal' in col_name:
253
  return "zipcode"
254
 
255
- # Date columns - use current date for end range
256
- today = datetime.now().strftime('%Y-%m-%d')
257
- if 'date' in col_name or data_type in ('DATE', 'TIMESTAMP', 'DATETIME'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  if 'created' in col_name or 'registration' in col_name:
259
  return f"date_between:2020-01-01,{today}"
260
  elif 'launch' in col_name or 'open' in col_name:
@@ -269,8 +340,16 @@ Return ONLY valid JSON, no other text."""
269
  # Numeric columns
270
  if 'quantity' in col_name:
271
  return "random_int:1,100"
272
- elif 'revenue' in col_name or 'amount' in col_name:
 
 
 
 
 
 
273
  return "random_decimal:10.00,10000.00"
 
 
274
  elif 'cost' in col_name or 'fee' in col_name or 'price' in col_name:
275
  return "random_decimal:1.00,500.00"
276
  elif 'rating' in col_name:
@@ -381,15 +460,43 @@ Return ONLY valid JSON, no other text."""
381
  These are columns where Faker produces better results than AI generation.
382
  """
383
  col_lower = col_name.lower()
 
 
 
 
 
 
384
 
385
  # DEBUG: Log all checks
386
  print(f" [FORCE_CHECK] Column: '{col_name}' (lower: '{col_lower}')")
387
 
388
- # BOOLEAN COLUMNS - Check FIRST before phone (is_mobile should be boolean, not phone)
 
 
 
 
 
 
 
 
389
  if col_lower.startswith('is_') or col_lower.startswith('has_') or col_lower.endswith('_flag'):
390
  print(f" → FORCED to 'boolean:0.5' (boolean pattern)")
391
  return "boolean:0.5"
392
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  # Check if data type is numeric
394
  is_numeric = False
395
  if data_type:
@@ -439,20 +546,27 @@ Return ONLY valid JSON, no other text."""
439
  return "last_name"
440
 
441
  # Full name - any column that's clearly a person's name
442
- if col_lower in ('full_name', 'fullname', 'name', 'customer_name', 'user_name', 'username'):
443
  return "name"
 
 
444
 
445
  # Person name columns that don't match exact list above
446
  # rep_name, manager_name, agent_name, employee_name, contact_name, etc.
447
  if col_lower.endswith('_name') or col_lower.endswith('name'):
448
  # Exclude non-person names (these are handled elsewhere as SEARCH_REAL or AI_GEN)
449
  non_person = ('product_name', 'productname', 'company_name', 'companyname',
 
 
450
  'brand_name', 'brandname', 'store_name', 'storename',
451
  'warehouse_name', 'warehousename', 'center_name', 'centername',
452
  'campaign_name', 'campaignname', 'table_name', 'tablename',
453
  'column_name', 'columnname', 'schema_name', 'schemaname',
454
  'file_name', 'filename', 'host_name', 'hostname',
455
- 'category_name', 'categoryname', 'holiday_name', 'holidayname')
 
 
 
456
  if col_lower not in non_person:
457
  print(f" → FORCED to 'name' (person name pattern: {col_lower})")
458
  return "name"
@@ -481,7 +595,11 @@ Return ONLY valid JSON, no other text."""
481
  return "country"
482
 
483
  # URL/Website
484
- if 'url' in col_lower or 'website' in col_lower:
 
 
 
 
485
  return "url"
486
 
487
  return None
 
2
 
3
  import json
4
  import re
5
+ from datetime import datetime, timedelta
6
  from typing import Optional
7
  from ..ddl.models import Schema, Table, Column, ColumnClassification
8
  from ..config import COLUMN_HINTS
9
  from .context_builder import CompanyContext
10
+ from ..domain import infer_semantic_type, is_business_categorical
11
+
12
+
13
+ def _years_ago(dt: datetime, years: int) -> datetime:
14
+ """Return datetime shifted back by full years (leap-safe)."""
15
+ try:
16
+ return dt.replace(year=dt.year - years)
17
+ except ValueError:
18
+ return dt.replace(month=2, day=28, year=dt.year - years)
19
 
20
 
21
  class ColumnClassifier:
 
162
  """Classify a single column using heuristics."""
163
  col_name = column.name.lower()
164
  col_upper = column.name.upper()
165
+ semantic = infer_semantic_type(column.name, column.data_type, table.name, use_case)
166
 
167
  # Primary keys and identity columns
168
  if column.is_primary_key or column.is_identity:
 
171
  # Foreign keys
172
  if col_upper in fk_columns:
173
  return ColumnClassification.GENERIC, "foreign_key"
174
+
175
+ # Quality-first: keep business categorical columns contextual.
176
+ if is_business_categorical(semantic):
177
+ prompt = self._generate_ai_prompt(col_name, table.name, context, use_case)
178
+ return ColumnClassification.AI_GEN, prompt
179
 
180
  # Check patterns for SEARCH_REAL
181
  for pattern in COLUMN_HINTS["search_real_patterns"]:
 
235
  col_name = column.name.lower()
236
  data_type = column.data_type.upper()
237
 
238
+ # BOOLEAN COLUMNS - Check FIRST before phone/contact heuristics
239
+ # so names like MobileTransaction BOOLEAN don't get treated as
240
+ # phone numbers just because they contain "mobile".
241
+ if 'BOOLEAN' in data_type or data_type in ('BOOL', 'BIT'):
242
+ return "boolean:0.8"
243
+
244
+ # BOOLEAN NAME PATTERNS - still keep these ahead of contact fields.
245
  if col_name.startswith('is_') or col_name.startswith('has_') or col_name.endswith('_flag'):
246
  return "boolean:0.5"
247
 
 
252
  # Phone columns (but NOT is_mobile - that's boolean, checked above)
253
  if 'phone' in col_name or 'mobile' in col_name or 'tel' in col_name:
254
  return "phone"
255
+
256
+ # Identifier-like short codes should not be treated as monetary values.
257
+ if (
258
+ 'tax_id' in col_name
259
+ or 'taxid' in col_name
260
+ or 'last4' in col_name
261
+ or col_name.endswith('_id_last4')
262
+ ):
263
+ if any(t in data_type for t in ('CHAR', 'TEXT', 'STRING', 'VARCHAR')):
264
+ return "random_string:4"
265
+ return "random_int:1000,9999"
266
 
267
  # Name columns
268
  if col_name in ('first_name', 'firstname', 'fname', 'given_name'):
269
  return "first_name"
270
  if col_name in ('last_name', 'lastname', 'lname', 'surname'):
271
  return "last_name"
272
+ if col_name in ('full_name', 'fullname', 'name', 'user_name'):
273
  return "name"
274
+ if col_name in ('customer_name', 'account_name', 'client_name', 'company_name', 'organization_name'):
275
+ return "company"
276
 
277
  # Address columns
278
  if 'address' in col_name and 'email' not in col_name:
 
286
  if 'zip' in col_name or 'postal' in col_name:
287
  return "zipcode"
288
 
289
+ # Birth date columns - must be evaluated before generic date checks
290
+ # to avoid impossible ages from generic recent-date generation.
291
+ today_dt = datetime.now()
292
+ today = today_dt.strftime('%Y-%m-%d')
293
+ is_birth_col = (
294
+ 'birthdate' in col_name or
295
+ 'birth_date' in col_name or
296
+ 'date_of_birth' in col_name or
297
+ col_name == 'dob' or
298
+ col_name.endswith('_dob')
299
+ )
300
+ if is_birth_col:
301
+ normalized_today = today_dt.replace(hour=0, minute=0, second=0, microsecond=0)
302
+ oldest = _years_ago(normalized_today, 95).strftime('%Y-%m-%d')
303
+ youngest = _years_ago(normalized_today, 18).strftime('%Y-%m-%d')
304
+ return f"date_between:{oldest},{youngest}"
305
+
306
+ # Date columns - use current date for end range.
307
+ # Guard against financial names like *_AMOUNT_TO_DATE on numeric fields.
308
+ financial_to_date_name = (
309
+ ('to_date' in col_name or col_name.endswith('todate'))
310
+ and any(
311
+ token in col_name
312
+ for token in (
313
+ 'amount',
314
+ 'balance',
315
+ 'cost',
316
+ 'price',
317
+ 'revenue',
318
+ 'sales',
319
+ 'fee',
320
+ 'tax',
321
+ 'total',
322
+ 'recovery',
323
+ 'payment',
324
+ )
325
+ )
326
+ )
327
+ is_name_date = ('date' in col_name) and not financial_to_date_name
328
+ if is_name_date or data_type in ('DATE', 'TIMESTAMP', 'DATETIME'):
329
  if 'created' in col_name or 'registration' in col_name:
330
  return f"date_between:2020-01-01,{today}"
331
  elif 'launch' in col_name or 'open' in col_name:
 
340
  # Numeric columns
341
  if 'quantity' in col_name:
342
  return "random_int:1,100"
343
+ elif 'discount' in col_name:
344
+ return "random_decimal:0.00,200.00"
345
+ elif 'shipping' in col_name:
346
+ return "random_decimal:2.00,150.00"
347
+ elif 'tax' in col_name:
348
+ return "random_decimal:1.00,250.00"
349
+ elif 'revenue' in col_name:
350
  return "random_decimal:10.00,10000.00"
351
+ elif 'amount' in col_name:
352
+ return "random_decimal:5.00,5000.00"
353
  elif 'cost' in col_name or 'fee' in col_name or 'price' in col_name:
354
  return "random_decimal:1.00,500.00"
355
  elif 'rating' in col_name:
 
460
  These are columns where Faker produces better results than AI generation.
461
  """
462
  col_lower = col_name.lower()
463
+ semantic = infer_semantic_type(col_name, data_type)
464
+
465
+ # Never force business categorical columns to generic;
466
+ # these need domain-aware contextual generation.
467
+ if is_business_categorical(semantic):
468
+ return None
469
 
470
  # DEBUG: Log all checks
471
  print(f" [FORCE_CHECK] Column: '{col_name}' (lower: '{col_lower}')")
472
 
473
+ # BOOLEAN COLUMNS - Check FIRST before phone/contact heuristics so
474
+ # boolean names containing "mobile" don't become phone numbers.
475
+ if data_type:
476
+ dt_upper = data_type.upper()
477
+ if 'BOOLEAN' in dt_upper or dt_upper in ('BOOL', 'BIT'):
478
+ print(f" → FORCED to 'boolean:0.5' (boolean data type)")
479
+ return "boolean:0.5"
480
+
481
+ # BOOLEAN NAME PATTERNS
482
  if col_lower.startswith('is_') or col_lower.startswith('has_') or col_lower.endswith('_flag'):
483
  print(f" → FORCED to 'boolean:0.5' (boolean pattern)")
484
  return "boolean:0.5"
485
 
486
+ # Birth date columns must always produce realistic adult ages.
487
+ is_birth_col = (
488
+ 'birthdate' in col_lower or
489
+ 'birth_date' in col_lower or
490
+ 'date_of_birth' in col_lower or
491
+ col_lower == 'dob' or
492
+ col_lower.endswith('_dob')
493
+ )
494
+ if is_birth_col:
495
+ today_dt = datetime.now()
496
+ oldest = _years_ago(today_dt, 95).strftime('%Y-%m-%d')
497
+ youngest = _years_ago(today_dt, 18).strftime('%Y-%m-%d')
498
+ return f"date_between:{oldest},{youngest}"
499
+
500
  # Check if data type is numeric
501
  is_numeric = False
502
  if data_type:
 
546
  return "last_name"
547
 
548
  # Full name - any column that's clearly a person's name
549
+ if col_lower in ('full_name', 'fullname', 'name', 'user_name', 'username'):
550
  return "name"
551
+ if col_lower in ('customer_name', 'account_name', 'client_name', 'company_name', 'organization_name'):
552
+ return "company"
553
 
554
  # Person name columns that don't match exact list above
555
  # rep_name, manager_name, agent_name, employee_name, contact_name, etc.
556
  if col_lower.endswith('_name') or col_lower.endswith('name'):
557
  # Exclude non-person names (these are handled elsewhere as SEARCH_REAL or AI_GEN)
558
  non_person = ('product_name', 'productname', 'company_name', 'companyname',
559
+ 'account_name', 'accountname', 'customer_name', 'customername',
560
+ 'client_name', 'clientname', 'organization_name', 'organizationname',
561
  'brand_name', 'brandname', 'store_name', 'storename',
562
  'warehouse_name', 'warehousename', 'center_name', 'centername',
563
  'campaign_name', 'campaignname', 'table_name', 'tablename',
564
  'column_name', 'columnname', 'schema_name', 'schemaname',
565
  'file_name', 'filename', 'host_name', 'hostname',
566
+ 'category_name', 'categoryname', 'holiday_name', 'holidayname',
567
+ 'branch_name', 'branchname', 'department_name', 'departmentname',
568
+ 'region_name', 'regionname', 'city_name', 'cityname',
569
+ 'month_name', 'monthname', 'quarter_name', 'quartername', 'day_name', 'dayname')
570
  if col_lower not in non_person:
571
  print(f" → FORCED to 'name' (person name pattern: {col_lower})")
572
  return "name"
 
595
  return "country"
596
 
597
  # URL/Website
598
+ # Avoid false positives like "hourly" containing "url".
599
+ if (
600
+ re.search(r'(^|[_\-])url([_\-]|$)', col_lower)
601
+ or 'website' in col_lower
602
+ ):
603
  return "url"
604
 
605
  return None
legitdata_project/legitdata/ddl/parser.py CHANGED
@@ -63,6 +63,9 @@ class DDLParser:
63
  table.columns.append(column)
64
  if column.is_primary_key:
65
  table.primary_key = column.name
 
 
 
66
 
67
  return table
68
 
@@ -116,14 +119,30 @@ class DDLParser:
116
  table.foreign_keys.append(fk)
117
  return
118
 
119
- # Primary key constraint: PRIMARY KEY (col)
120
- pk_match = re.search(r'PRIMARY\s+KEY\s*\((\w+)\)', part, re.IGNORECASE)
121
  if pk_match:
122
  table.primary_key = pk_match.group(1)
123
  # Mark the column as PK
124
  col = table.get_column(pk_match.group(1))
125
  if col:
126
  col.is_primary_key = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  def _parse_column(self, part: str) -> Optional[Column]:
129
  """Parse a column definition."""
 
63
  table.columns.append(column)
64
  if column.is_primary_key:
65
  table.primary_key = column.name
66
+ inline_fk = self._parse_inline_foreign_key(part, column.name)
67
+ if inline_fk:
68
+ table.foreign_keys.append(inline_fk)
69
 
70
  return table
71
 
 
119
  table.foreign_keys.append(fk)
120
  return
121
 
122
+ # Primary key constraint: PRIMARY KEY (col) or PRIMARY KEY (col1, col2, ...)
123
+ pk_match = re.search(r'PRIMARY\s+KEY\s*\((\w+)', part, re.IGNORECASE)
124
  if pk_match:
125
  table.primary_key = pk_match.group(1)
126
  # Mark the column as PK
127
  col = table.get_column(pk_match.group(1))
128
  if col:
129
  col.is_primary_key = True
130
+
131
+ def _parse_inline_foreign_key(self, part: str, column_name: str) -> Optional[ForeignKey]:
132
+ """Parse inline FK syntax like `col INT REFERENCES parent(id)`."""
133
+ fk_match = re.search(
134
+ r'\bREFERENCES\s+(\w+)\s*\((\w+)\)',
135
+ part,
136
+ re.IGNORECASE
137
+ )
138
+ if not fk_match:
139
+ return None
140
+
141
+ return ForeignKey(
142
+ column_name=column_name,
143
+ references_table=fk_match.group(1),
144
+ references_column=fk_match.group(2)
145
+ )
146
 
147
  def _parse_column(self, part: str) -> Optional[Column]:
148
  """Parse a column definition."""
legitdata_project/legitdata/domain/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Domain-aware semantic typing and value packs."""
2
+
3
+ from .semantic_types import (
4
+ SemanticType,
5
+ BUSINESS_CATEGORICAL_TYPES,
6
+ NUMERIC_TYPES,
7
+ infer_semantic_type,
8
+ is_business_categorical,
9
+ )
10
+ from .domain_packs import get_domain_values, infer_vertical_key, map_state_to_region
11
+
12
+ __all__ = [
13
+ "SemanticType",
14
+ "BUSINESS_CATEGORICAL_TYPES",
15
+ "NUMERIC_TYPES",
16
+ "infer_semantic_type",
17
+ "is_business_categorical",
18
+ "get_domain_values",
19
+ "infer_vertical_key",
20
+ "map_state_to_region",
21
+ ]
22
+
legitdata_project/legitdata/domain/domain_packs.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Curated domain value packs for quality-first data generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .semantic_types import SemanticType
6
+
7
+
8
+ BASE_DOMAIN_VALUES = {
9
+ SemanticType.REGION: [
10
+ "Northeast",
11
+ "Southeast",
12
+ "Midwest",
13
+ "Southwest",
14
+ "West",
15
+ ],
16
+ SemanticType.CATEGORY: [
17
+ "Electronics",
18
+ "Apparel",
19
+ "Home Goods",
20
+ "Sports & Outdoors",
21
+ "Health & Beauty",
22
+ "Office Supplies",
23
+ "Food & Grocery",
24
+ ],
25
+ SemanticType.SECTOR_NAME: [
26
+ "Technology",
27
+ "Healthcare",
28
+ "Consumer Discretionary",
29
+ "Consumer Staples",
30
+ "Industrials",
31
+ "Financial Services",
32
+ "Business Services",
33
+ ],
34
+ SemanticType.SECTOR_CATEGORY: ["Growth", "Defensive", "Cyclical", "Core"],
35
+ SemanticType.SEGMENT: ["Enterprise", "Mid-Market", "SMB", "Consumer", "Government"],
36
+ SemanticType.TIER: ["Platinum", "Gold", "Silver", "Bronze"],
37
+ SemanticType.STATUS: ["Active", "Pending", "Completed", "Cancelled", "On Hold"],
38
+ SemanticType.FUND_STRATEGY: [
39
+ "Buyout",
40
+ "Growth Equity",
41
+ "Venture Capital",
42
+ "Private Credit",
43
+ "Infrastructure",
44
+ "Secondaries",
45
+ ],
46
+ SemanticType.INVESTOR_TYPE: [
47
+ "Pension Fund",
48
+ "Endowment",
49
+ "Family Office",
50
+ "Sovereign Wealth",
51
+ "Insurance",
52
+ ],
53
+ SemanticType.INVESTMENT_STAGE: ["Active", "Realized", "Written Off"],
54
+ SemanticType.COVENANT_STATUS: ["Compliant", "Waived", "Breached"],
55
+ SemanticType.DEBT_PERFORMANCE_STATUS: ["Performing", "Watch List", "Non-Performing"],
56
+ SemanticType.CHANNEL: ["Online", "In-Store", "Mobile App", "Partner", "Marketplace"],
57
+ SemanticType.COUNTRY: ["United States", "Canada", "United Kingdom"],
58
+ SemanticType.STATE: [
59
+ "California",
60
+ "Texas",
61
+ "New York",
62
+ "Florida",
63
+ "Illinois",
64
+ "Washington",
65
+ "Georgia",
66
+ "Massachusetts",
67
+ ],
68
+ SemanticType.CITY: [
69
+ "Los Angeles",
70
+ "San Francisco",
71
+ "Dallas",
72
+ "Houston",
73
+ "New York City",
74
+ "Chicago",
75
+ "Seattle",
76
+ "Atlanta",
77
+ "Boston",
78
+ "Miami",
79
+ ],
80
+ SemanticType.POSTAL_CODE: [
81
+ "90001",
82
+ "94102",
83
+ "75201",
84
+ "77001",
85
+ "10001",
86
+ "60601",
87
+ "98101",
88
+ "30301",
89
+ "02101",
90
+ "33101",
91
+ ],
92
+ SemanticType.SEASON: ["Winter", "Spring", "Summer", "Fall"],
93
+ SemanticType.HOLIDAY_NAME: [
94
+ "New Year's Day",
95
+ "Memorial Day",
96
+ "Independence Day",
97
+ "Labor Day",
98
+ "Thanksgiving",
99
+ "Christmas Day",
100
+ "None",
101
+ ],
102
+ SemanticType.EVENT_NAME: [
103
+ "Regular Trading Day",
104
+ "Back to School",
105
+ "Holiday Promotions",
106
+ "Black Friday",
107
+ "Cyber Monday",
108
+ "Year End Campaign",
109
+ ],
110
+ SemanticType.PRODUCT_NAME: [
111
+ "Wireless Earbuds Pro",
112
+ "4K Smart TV 55in",
113
+ "Running Shoes Elite",
114
+ "Organic Cotton T-Shirt",
115
+ "Stainless Water Bottle 32oz",
116
+ "Laptop Backpack 15in",
117
+ "Bluetooth Soundbar",
118
+ "Air Fryer 6 Quart",
119
+ "Portable Charger 20000mAh",
120
+ "Smart Home Camera",
121
+ ],
122
+ SemanticType.BRAND_NAME: [
123
+ "Apex",
124
+ "Northline",
125
+ "EverPeak",
126
+ "BlueHarbor",
127
+ "Lumen",
128
+ "SummitCo",
129
+ "UrbanLeaf",
130
+ "PrimeWorks",
131
+ ],
132
+ SemanticType.BRANCH_NAME: [
133
+ "Downtown Branch",
134
+ "Midtown Branch",
135
+ "Northside Branch",
136
+ "Southside Branch",
137
+ "Riverside Branch",
138
+ "West End Branch",
139
+ "Lakeside Branch",
140
+ "Hillcrest Branch",
141
+ ],
142
+ SemanticType.ORG_NAME: [
143
+ "Apex Advisory Group",
144
+ "Northline Holdings",
145
+ "BlueHarbor Partners",
146
+ "PrimeWorks Consulting",
147
+ "Summit Operations Group",
148
+ "UrbanLeaf Services",
149
+ "Lumen Strategy Co",
150
+ "EverPeak Solutions",
151
+ ],
152
+ SemanticType.PERSON_NAME: [
153
+ "Olivia Carter",
154
+ "Liam Brooks",
155
+ "Ava Morgan",
156
+ "Noah Bennett",
157
+ "Emma Collins",
158
+ "Ethan Cooper",
159
+ "Sophia Reed",
160
+ "Mason Turner",
161
+ "Isabella Hayes",
162
+ "Lucas Ward",
163
+ ],
164
+ SemanticType.DEPARTMENT_NAME: [
165
+ "Operations",
166
+ "Finance",
167
+ "Legal",
168
+ "Sales",
169
+ "Marketing",
170
+ "Customer Success",
171
+ "Product",
172
+ "Engineering",
173
+ ],
174
+ SemanticType.COLOR_DESCRIPTION: [
175
+ "Black",
176
+ "White",
177
+ "Navy Blue",
178
+ "Charcoal Gray",
179
+ "Forest Green",
180
+ "Red",
181
+ "Beige",
182
+ ],
183
+ SemanticType.PACKAGING_SIZE: [
184
+ "Single Pack",
185
+ "2-Pack",
186
+ "4-Pack",
187
+ "8-Pack",
188
+ "16 oz",
189
+ "32 oz",
190
+ "64 oz",
191
+ ],
192
+ }
193
+
194
+
195
+ VERTICAL_OVERRIDES = {
196
+ "healthcare": {
197
+ SemanticType.CATEGORY: [
198
+ "Primary Care",
199
+ "Cardiology",
200
+ "Oncology",
201
+ "Orthopedics",
202
+ "Neurology",
203
+ "Emergency Services",
204
+ "Pediatrics",
205
+ ],
206
+ SemanticType.SEGMENT: ["Inpatient", "Outpatient", "Emergency", "Telehealth"],
207
+ },
208
+ "banking": {
209
+ SemanticType.CATEGORY: [
210
+ "Checking",
211
+ "Savings",
212
+ "Credit Card",
213
+ "Mortgage",
214
+ "Auto Loan",
215
+ "Investment",
216
+ ],
217
+ SemanticType.SEGMENT: ["Consumer", "Small Business", "Commercial", "Private Banking"],
218
+ SemanticType.CHANNEL: ["Branch", "Online Banking", "Mobile Banking", "Call Center", "ATM"],
219
+ SemanticType.BRANCH_NAME: [
220
+ "Main Street Branch",
221
+ "Riverside Branch",
222
+ "Oak Grove Branch",
223
+ "Lakeside Branch",
224
+ "Midtown Branch",
225
+ "Heritage Branch",
226
+ ],
227
+ SemanticType.DEPARTMENT_NAME: [
228
+ "Retail Banking",
229
+ "Consumer Lending",
230
+ "Mortgage Services",
231
+ "Wealth Management",
232
+ "Commercial Banking",
233
+ "Collections",
234
+ ],
235
+ },
236
+ "retail": {
237
+ SemanticType.CATEGORY: [
238
+ "Electronics",
239
+ "Apparel",
240
+ "Home & Kitchen",
241
+ "Beauty",
242
+ "Sports & Outdoors",
243
+ "Toys",
244
+ "Grocery",
245
+ ],
246
+ SemanticType.CHANNEL: ["E-commerce", "Store", "Marketplace", "Mobile App"],
247
+ SemanticType.BRAND_NAME: ["Nike", "Adidas", "Apple", "Samsung", "Sony", "Levi's", "KitchenAid", "Dyson"],
248
+ SemanticType.DEPARTMENT_NAME: ["Footwear", "Apparel", "Consumer Electronics", "Home Appliances", "Beauty", "Outdoor"],
249
+ },
250
+ "sportswear": {
251
+ SemanticType.CATEGORY: [
252
+ "Footwear",
253
+ "Apparel",
254
+ "Running",
255
+ "Training",
256
+ "Basketball",
257
+ "Soccer",
258
+ "Lifestyle",
259
+ ],
260
+ SemanticType.SEGMENT: [
261
+ "Performance Athlete",
262
+ "Everyday Active",
263
+ "Team Sports",
264
+ "Youth",
265
+ "Women",
266
+ "Men",
267
+ ],
268
+ SemanticType.CHANNEL: [
269
+ "Nike App",
270
+ "SNKRS",
271
+ "Nike.com",
272
+ "Brand Store",
273
+ "Outlet",
274
+ "Wholesale Partner",
275
+ ],
276
+ SemanticType.BRAND_NAME: [
277
+ "Nike",
278
+ "Jordan",
279
+ "Converse",
280
+ "Nike SB",
281
+ "ACG",
282
+ "Nike Pro",
283
+ ],
284
+ SemanticType.PRODUCT_NAME: [
285
+ "Air Zoom Pegasus",
286
+ "Air Force 1",
287
+ "Dunk Low",
288
+ "Metcon Trainer",
289
+ "Mercurial Vapor",
290
+ "Phantom GX",
291
+ "Dri-FIT Tee",
292
+ "Tech Fleece Hoodie",
293
+ ],
294
+ SemanticType.DEPARTMENT_NAME: [
295
+ "Running",
296
+ "Basketball",
297
+ "Training",
298
+ "Sportswear",
299
+ "Lifestyle",
300
+ "Accessories",
301
+ ],
302
+ SemanticType.BRANCH_NAME: [
303
+ "5th Ave Flagship",
304
+ "SoHo Experience Store",
305
+ "Melrose Brand House",
306
+ "Chicago Magnificent Mile",
307
+ "Miami Beach Store",
308
+ "Seattle Downtown Store",
309
+ ],
310
+ },
311
+ "pharma": {
312
+ SemanticType.CATEGORY: [
313
+ "Cardiometabolic",
314
+ "Oncology",
315
+ "Women's Health",
316
+ "Immunology",
317
+ "Primary Care",
318
+ "Rare Disease",
319
+ "Vaccine",
320
+ ],
321
+ SemanticType.SEGMENT: [
322
+ "High Prescriber",
323
+ "New Writer",
324
+ "Existing Writer",
325
+ "Hospital Network",
326
+ "Community Practice",
327
+ "Specialty Clinic",
328
+ ],
329
+ SemanticType.CHANNEL: [
330
+ "Field Sales",
331
+ "Digital HCP",
332
+ "Medical Affairs",
333
+ "Patient Support",
334
+ "Specialty Pharmacy",
335
+ "Wholesaler",
336
+ ],
337
+ SemanticType.BRAND_NAME: [
338
+ "Bayer",
339
+ "Xarelto",
340
+ "Eylea",
341
+ "Mirena",
342
+ "Yaz",
343
+ "Nexavar",
344
+ ],
345
+ SemanticType.PRODUCT_NAME: [
346
+ "Rivaroxaban 20mg",
347
+ "Aflibercept 2mg",
348
+ "Levonorgestrel IUS",
349
+ "Drospirenone Ethinyl Estradiol",
350
+ "Sorafenib 200mg",
351
+ "Aspirin 81mg",
352
+ "Canagliflozin 100mg",
353
+ "Rosuvastatin 20mg",
354
+ ],
355
+ SemanticType.DEPARTMENT_NAME: [
356
+ "Cardiology Franchise",
357
+ "Oncology Franchise",
358
+ "Women's Health Franchise",
359
+ "Medical Affairs",
360
+ "Market Access",
361
+ "Patient Services",
362
+ ],
363
+ },
364
+ "supply chain": {
365
+ SemanticType.CATEGORY: [
366
+ "Raw Materials",
367
+ "Finished Goods",
368
+ "Packaging",
369
+ "MRO",
370
+ "Transportation",
371
+ ],
372
+ SemanticType.STATUS: ["Planned", "In Transit", "Received", "Delayed", "Backordered"],
373
+ SemanticType.CHANNEL: ["Direct", "Distributor", "3PL", "Wholesale"],
374
+ },
375
+ "private_equity": {
376
+ SemanticType.SECTOR_NAME: [
377
+ "Industrials",
378
+ "Healthcare",
379
+ "Technology",
380
+ "Consumer",
381
+ "Financial Services",
382
+ ],
383
+ SemanticType.SECTOR_CATEGORY: ["Cyclical", "Defensive", "Growth"],
384
+ SemanticType.FUND_STRATEGY: ["Buyout", "Growth", "Venture"],
385
+ SemanticType.INVESTOR_TYPE: [
386
+ "Pension Fund",
387
+ "Endowment",
388
+ "Family Office",
389
+ "Sovereign Wealth",
390
+ "Insurance",
391
+ ],
392
+ SemanticType.INVESTMENT_STAGE: ["Active", "Realized", "Written Off"],
393
+ SemanticType.COVENANT_STATUS: ["Compliant", "Waived", "Breached"],
394
+ SemanticType.DEBT_PERFORMANCE_STATUS: ["Performing", "Watch List", "Non-Performing"],
395
+ },
396
+ "legal": {
397
+ SemanticType.CATEGORY: [
398
+ "Commercial",
399
+ "Litigation",
400
+ "Intellectual Property",
401
+ "Employment",
402
+ "Corporate Governance",
403
+ "Regulatory",
404
+ "Privacy",
405
+ "Compliance",
406
+ ],
407
+ SemanticType.DEPARTMENT_NAME: [
408
+ "Commercial Legal",
409
+ "Intellectual Property",
410
+ "Employment Law",
411
+ "Corporate Legal",
412
+ "Regulatory Affairs",
413
+ "Privacy & Compliance",
414
+ "Litigation",
415
+ ],
416
+ SemanticType.ORG_NAME: [
417
+ "Cooley LLP",
418
+ "Baker McKenzie",
419
+ "McDermott Will & Emery",
420
+ "Young Basile",
421
+ "Newman Du Wors",
422
+ "Morgan Lewis",
423
+ "Latham & Watkins",
424
+ "Wilson Sonsini",
425
+ ],
426
+ },
427
+ "saas": {
428
+ SemanticType.CATEGORY: [
429
+ "Core Platform",
430
+ "AI Analyst",
431
+ "Analytics",
432
+ "Security",
433
+ "Observability",
434
+ "Integrations",
435
+ "Embedded Analytics",
436
+ "Platform Add-ons",
437
+ "Support",
438
+ ],
439
+ SemanticType.SEGMENT: [
440
+ "Enterprise",
441
+ "Mid-Market",
442
+ "SMB",
443
+ ],
444
+ SemanticType.REGION: [
445
+ "North America",
446
+ "EMEA",
447
+ "APAC",
448
+ ],
449
+ SemanticType.CHANNEL: [
450
+ "Direct Sales",
451
+ "Partner",
452
+ "Digital Self-Serve",
453
+ "Expansion",
454
+ ],
455
+ SemanticType.DEPARTMENT_NAME: [
456
+ "Finance",
457
+ "FP&A",
458
+ "Revenue Operations",
459
+ "Customer Success",
460
+ "Sales",
461
+ "Product",
462
+ "Engineering",
463
+ ],
464
+ SemanticType.ORG_NAME: [
465
+ "Enterprise Customers",
466
+ "Mid-Market Customers",
467
+ "SMB Customers",
468
+ "Strategic Accounts",
469
+ "Channel Partners",
470
+ "Global Accounts",
471
+ ],
472
+ },
473
+ }
474
+
475
+
476
+ STATE_TO_REGION = {
477
+ "CALIFORNIA": "West",
478
+ "WASHINGTON": "West",
479
+ "OREGON": "West",
480
+ "TEXAS": "Southwest",
481
+ "ARIZONA": "Southwest",
482
+ "NEW MEXICO": "Southwest",
483
+ "ILLINOIS": "Midwest",
484
+ "OHIO": "Midwest",
485
+ "MICHIGAN": "Midwest",
486
+ "NEW YORK": "Northeast",
487
+ "MASSACHUSETTS": "Northeast",
488
+ "PENNSYLVANIA": "Northeast",
489
+ "FLORIDA": "Southeast",
490
+ "GEORGIA": "Southeast",
491
+ "NORTH CAROLINA": "Southeast",
492
+ }
493
+
494
+
495
+ def infer_vertical_key(use_case: str | None) -> str:
496
+ text = (use_case or "").lower()
497
+ if any(
498
+ k in text
499
+ for k in (
500
+ "private equity",
501
+ "lp reporting",
502
+ "tvpi",
503
+ "dpi",
504
+ "rvpi",
505
+ "buyout",
506
+ "growth equity",
507
+ "venture capital",
508
+ "fund strategy",
509
+ )
510
+ ):
511
+ return "private_equity"
512
+ if any(k in text for k in ("legal", "counsel", "litigation", "contract", "matter", "spend management", "gc", "attorney")):
513
+ return "legal"
514
+ if any(k in text for k in ("saas", "arr", "mrr", "nrr", "churn", "subscription", "fp&a")):
515
+ return "saas"
516
+ if any(k in text for k in ("pharma", "pharmaceutical", "prescription", "formulary", "payer", "drug", "hcp", "rx")):
517
+ return "pharma"
518
+ if any(k in text for k in ("nike", "athletic", "sportswear", "sneaker", "footwear", "running shoe")):
519
+ return "sportswear"
520
+ if "health" in text or "clinical" in text or "patient" in text:
521
+ return "healthcare"
522
+ if "bank" in text or "credit union" in text or "mortgage" in text or "loan" in text:
523
+ return "banking"
524
+ if "retail" in text or "ecommerce" in text or "commerce" in text or "merch" in text:
525
+ return "retail"
526
+ if "supply" in text or "inventory" in text or "logistics" in text:
527
+ return "supply chain"
528
+ return "default"
529
+
530
+
531
+ def get_domain_values(semantic_type: SemanticType, use_case: str | None = None) -> list[str]:
532
+ values = list(BASE_DOMAIN_VALUES.get(semantic_type, []))
533
+ vertical = infer_vertical_key(use_case)
534
+ if vertical in VERTICAL_OVERRIDES and semantic_type in VERTICAL_OVERRIDES[vertical]:
535
+ values = list(VERTICAL_OVERRIDES[vertical][semantic_type])
536
+ return values
537
+
538
+
539
+ def map_state_to_region(state_value: str) -> str | None:
540
+ if not state_value:
541
+ return None
542
+ return STATE_TO_REGION.get(str(state_value).strip().upper())
543
+
legitdata_project/legitdata/domain/semantic_types.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Semantic typing for schema columns."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+
7
+
8
+ class SemanticType(str, Enum):
9
+ UNKNOWN = "unknown"
10
+
11
+ REGION = "region"
12
+ CATEGORY = "category"
13
+ SECTOR_NAME = "sector_name"
14
+ SECTOR_CATEGORY = "sector_category"
15
+ SEGMENT = "segment"
16
+ TIER = "tier"
17
+ STATUS = "status"
18
+ FUND_STRATEGY = "fund_strategy"
19
+ INVESTOR_TYPE = "investor_type"
20
+ INVESTMENT_STAGE = "investment_stage"
21
+ COVENANT_STATUS = "covenant_status"
22
+ DEBT_PERFORMANCE_STATUS = "debt_performance_status"
23
+ CHANNEL = "channel"
24
+ COUNTRY = "country"
25
+ STATE = "state"
26
+ CITY = "city"
27
+ POSTAL_CODE = "postal_code"
28
+ SEASON = "season"
29
+ HOLIDAY_NAME = "holiday_name"
30
+ EVENT_NAME = "event_name"
31
+
32
+ PERSON_NAME = "person_name"
33
+ ORG_NAME = "org_name"
34
+ PRODUCT_NAME = "product_name"
35
+ BRAND_NAME = "brand_name"
36
+ BRANCH_NAME = "branch_name"
37
+ DEPARTMENT_NAME = "department_name"
38
+ COLOR_DESCRIPTION = "color_description"
39
+ PACKAGING_SIZE = "packaging_size"
40
+
41
+ MONEY = "money"
42
+ QUANTITY = "quantity"
43
+ COUNT = "count"
44
+ PERCENT = "percent"
45
+ INTEREST_RATE = "interest_rate"
46
+ RETURN_RATE = "return_rate"
47
+ RETURN_MULTIPLE = "return_multiple"
48
+ BASIS_POINTS = "basis_points"
49
+ LEVERAGE_RATIO = "leverage_ratio"
50
+ SCORE = "score"
51
+ DURATION_DAYS = "duration_days"
52
+
53
+ DATE_EVENT = "date_event"
54
+ DATE_START = "date_start"
55
+ DATE_END = "date_end"
56
+ DATE_BIRTH = "date_birth"
57
+
58
+ BOOLEAN_FLAG = "boolean_flag"
59
+
60
+
61
+ BUSINESS_CATEGORICAL_TYPES = {
62
+ SemanticType.REGION,
63
+ SemanticType.CATEGORY,
64
+ SemanticType.SECTOR_NAME,
65
+ SemanticType.SECTOR_CATEGORY,
66
+ SemanticType.SEGMENT,
67
+ SemanticType.TIER,
68
+ SemanticType.STATUS,
69
+ SemanticType.FUND_STRATEGY,
70
+ SemanticType.INVESTOR_TYPE,
71
+ SemanticType.INVESTMENT_STAGE,
72
+ SemanticType.COVENANT_STATUS,
73
+ SemanticType.DEBT_PERFORMANCE_STATUS,
74
+ SemanticType.CHANNEL,
75
+ SemanticType.COUNTRY,
76
+ SemanticType.STATE,
77
+ SemanticType.CITY,
78
+ SemanticType.POSTAL_CODE,
79
+ SemanticType.PRODUCT_NAME,
80
+ SemanticType.BRAND_NAME,
81
+ SemanticType.BRANCH_NAME,
82
+ SemanticType.ORG_NAME,
83
+ SemanticType.DEPARTMENT_NAME,
84
+ SemanticType.COLOR_DESCRIPTION,
85
+ SemanticType.PACKAGING_SIZE,
86
+ SemanticType.SEASON,
87
+ SemanticType.HOLIDAY_NAME,
88
+ SemanticType.EVENT_NAME,
89
+ }
90
+
91
+
92
+ NUMERIC_TYPES = {
93
+ SemanticType.MONEY,
94
+ SemanticType.QUANTITY,
95
+ SemanticType.COUNT,
96
+ SemanticType.PERCENT,
97
+ SemanticType.INTEREST_RATE,
98
+ SemanticType.RETURN_RATE,
99
+ SemanticType.RETURN_MULTIPLE,
100
+ SemanticType.BASIS_POINTS,
101
+ SemanticType.LEVERAGE_RATIO,
102
+ SemanticType.SCORE,
103
+ SemanticType.DURATION_DAYS,
104
+ }
105
+
106
+
107
+ def _contains_any(text: str, tokens: tuple[str, ...]) -> bool:
108
+ return any(tok in text for tok in tokens)
109
+
110
+
111
+ def infer_semantic_type(
112
+ column_name: str,
113
+ data_type: str | None = None,
114
+ table_name: str | None = None,
115
+ use_case: str | None = None,
116
+ ) -> SemanticType:
117
+ """Infer a coarse semantic type from schema metadata."""
118
+ col = (column_name or "").lower()
119
+ dt = (data_type or "").upper()
120
+ table = (table_name or "").lower()
121
+ _ = (use_case or "").lower() # Reserved for future use-case specific typing.
122
+
123
+ if col.startswith("is_") or col.startswith("has_") or col.endswith("_flag") or "BOOL" in dt:
124
+ return SemanticType.BOOLEAN_FLAG
125
+
126
+ is_date_type = any(t in dt for t in ("DATE", "TIMESTAMP", "DATETIME"))
127
+ financial_to_date_name = (
128
+ ("to_date" in col or col.endswith("todate"))
129
+ and _contains_any(
130
+ col,
131
+ (
132
+ "amount",
133
+ "balance",
134
+ "cost",
135
+ "price",
136
+ "revenue",
137
+ "sales",
138
+ "fee",
139
+ "tax",
140
+ "total",
141
+ "recovery",
142
+ "payment",
143
+ ),
144
+ )
145
+ )
146
+ is_name_date = ("date" in col) and not financial_to_date_name
147
+ if _contains_any(col, ("birth", "dob", "date_of_birth")):
148
+ return SemanticType.DATE_BIRTH
149
+ if is_name_date or is_date_type:
150
+ if _contains_any(col, ("start", "begin", "open", "admission", "created")):
151
+ return SemanticType.DATE_START
152
+ if _contains_any(col, ("end", "close", "discharge", "resolved", "shipped")):
153
+ return SemanticType.DATE_END
154
+ return SemanticType.DATE_EVENT
155
+
156
+ # Identifier columns should remain code-like, not categorical labels.
157
+ if (
158
+ (col.endswith("_id") or col.endswith("id") or col.endswith("_key") or col.endswith("key"))
159
+ and not _contains_any(col, ("paid", "valid", "invalid"))
160
+ ):
161
+ return SemanticType.UNKNOWN
162
+
163
+ if _contains_any(col, ("region", "territory")):
164
+ return SemanticType.REGION
165
+ if "sector_category" in col:
166
+ return SemanticType.SECTOR_CATEGORY
167
+ if col == "sector" or "sector_name" in col or "sub_sector" in col:
168
+ return SemanticType.SECTOR_NAME
169
+ if ("month" in col and "name" in col) or ("quarter" in col and "name" in col):
170
+ # Calendar labels are handled by date-dimension generation rules.
171
+ return SemanticType.UNKNOWN
172
+ if _contains_any(col, ("category", "subcategory")):
173
+ return SemanticType.CATEGORY
174
+ if "segment" in col:
175
+ return SemanticType.SEGMENT
176
+ if _contains_any(col, ("tier", "level")):
177
+ return SemanticType.TIER
178
+ if "fund_strategy" in col or col == "strategy" or col.endswith("_strategy"):
179
+ return SemanticType.FUND_STRATEGY
180
+ if "investor_type" in col or "lp_type" in col:
181
+ return SemanticType.INVESTOR_TYPE
182
+ if "investment_stage" in col:
183
+ return SemanticType.INVESTMENT_STAGE
184
+ if "covenant_status" in col:
185
+ return SemanticType.COVENANT_STATUS
186
+ if "debt_performance_status" in col or "debt_status" in col:
187
+ return SemanticType.DEBT_PERFORMANCE_STATUS
188
+ if "status" in col:
189
+ return SemanticType.STATUS
190
+ if _contains_any(col, ("channel", "source")):
191
+ return SemanticType.CHANNEL
192
+ if "country" in col:
193
+ return SemanticType.COUNTRY
194
+ if _contains_any(col, ("state", "province")):
195
+ return SemanticType.STATE
196
+ if "city" in col:
197
+ return SemanticType.CITY
198
+ if _contains_any(col, ("zip", "postal")):
199
+ return SemanticType.POSTAL_CODE
200
+ if _contains_any(col, ("product_name", "item_name")):
201
+ return SemanticType.PRODUCT_NAME
202
+ if "brand" in col and "name" in col:
203
+ return SemanticType.BRAND_NAME
204
+ if _contains_any(col, ("branch_name", "branchname")) or col == "branch":
205
+ return SemanticType.BRANCH_NAME
206
+ if _contains_any(col, ("department_name", "dept_name", "department", "dept")):
207
+ return SemanticType.DEPARTMENT_NAME
208
+ if _contains_any(col, ("employee_name", "manager_name", "agent_name", "rep_name", "officer_name", "contact_name")):
209
+ return SemanticType.PERSON_NAME
210
+ if _contains_any(col, ("color_description", "colour_description", "color_desc")):
211
+ return SemanticType.COLOR_DESCRIPTION
212
+ if _contains_any(col, ("packaging_size", "package_size", "pack_size")):
213
+ return SemanticType.PACKAGING_SIZE
214
+ if "season" in col:
215
+ return SemanticType.SEASON
216
+ if "holiday" in col:
217
+ return SemanticType.HOLIDAY_NAME
218
+ if "event" in col:
219
+ return SemanticType.EVENT_NAME
220
+
221
+ if col.endswith("_name") or col == "name":
222
+ # Treat business-entity naming columns as organization labels first.
223
+ if _contains_any(
224
+ col,
225
+ (
226
+ "account",
227
+ "customer",
228
+ "client",
229
+ "company",
230
+ "organization",
231
+ "org",
232
+ "business",
233
+ "firm",
234
+ "vendor",
235
+ "supplier",
236
+ "partner",
237
+ "merchant",
238
+ "product",
239
+ "brand",
240
+ "store",
241
+ "warehouse",
242
+ "facility",
243
+ "branch",
244
+ "department",
245
+ ),
246
+ ):
247
+ return SemanticType.ORG_NAME
248
+ if _contains_any(table, ("customer", "employee", "staff", "rep", "agent", "user", "patient", "provider")):
249
+ return SemanticType.PERSON_NAME
250
+ return SemanticType.ORG_NAME
251
+
252
+ # SaaS/finance abbreviations frequently skip explicit "amount"/"revenue" suffixes.
253
+ if _contains_any(col, ("arr", "mrr", "cac", "ltv", "billings", "collections", "deferred_revenue")):
254
+ return SemanticType.MONEY
255
+ if _contains_any(col, ("net_new_arr", "churned_arr", "expansion_arr", "new_logo_arr", "contraction_arr")):
256
+ return SemanticType.MONEY
257
+ if "gross_margin" in col and _contains_any(col, ("_usd", "usd", "amount", "value")):
258
+ return SemanticType.MONEY
259
+ if "debt_to_ebitda" in col or "leverage_ratio" in col:
260
+ return SemanticType.LEVERAGE_RATIO
261
+ if "bps" in col:
262
+ return SemanticType.BASIS_POINTS
263
+ if "irr" in col:
264
+ return SemanticType.RETURN_RATE
265
+ if "multiple" in col:
266
+ return SemanticType.RETURN_MULTIPLE
267
+ if _contains_any(col, ("nrr", "grr", "gross_margin", "margin_pct", "churn_rate", "retention_rate")):
268
+ return SemanticType.PERCENT
269
+ if _contains_any(col, ("payback_months", "payback_period", "days_sales_outstanding", "dso")):
270
+ return SemanticType.DURATION_DAYS
271
+
272
+ if _contains_any(col, ("amount", "revenue", "sales", "cost", "price", "fee", "income", "profit", "total", "value", "balance", "commitment", "aum", "nav")):
273
+ return SemanticType.MONEY
274
+ if _contains_any(col, ("quantity", "qty", "units")):
275
+ return SemanticType.QUANTITY
276
+ if _contains_any(col, ("count", "num_", "number", "total_")):
277
+ return SemanticType.COUNT
278
+ if _contains_any(col, ("percent", "pct", "ratio", "margin")):
279
+ return SemanticType.PERCENT
280
+ if "interest_rate" in col or (("interest" in col or "apr" in col) and "rate" in col):
281
+ return SemanticType.INTEREST_RATE
282
+ if _contains_any(col, ("score", "rating", "grade")):
283
+ return SemanticType.SCORE
284
+ if _contains_any(col, ("days", "day_count", "length_of_stay", "duration")):
285
+ return SemanticType.DURATION_DAYS
286
+
287
+ return SemanticType.UNKNOWN
288
+
289
+
290
+ def is_business_categorical(semantic_type: SemanticType) -> bool:
291
+ return semantic_type in BUSINESS_CATEGORICAL_TYPES
292
+
legitdata_project/legitdata/generator.py CHANGED
@@ -1,9 +1,11 @@
1
  """Main LegitGenerator class - the primary interface for data generation."""
2
 
3
  import hashlib
 
4
  import random
 
5
  from typing import Any, Callable, Optional
6
- from datetime import datetime
7
 
8
  from .ddl import parse_ddl, parse_ddl_file, Schema, Table, Column, ColumnClassification
9
  from .analyzer import ContextBuilder, CompanyContext, ColumnClassifier
@@ -12,11 +14,37 @@ from .relationships import FKManager
12
  from .writers import BaseWriter, SnowflakeWriter, DryRunWriter
13
  from .cache import FileCache
14
  from .config import SIZE_PRESETS, GenerationConfig, MAX_AI_ROWS_PER_CALL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  class LegitGenerator:
18
  """Main class for generating realistic synthetic data."""
19
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def __init__(
21
  self,
22
  url: str,
@@ -42,8 +70,13 @@ class LegitGenerator:
42
  self.schema: Optional[Schema] = None
43
  self.context: Optional[CompanyContext] = None
44
  self.classifications: Optional[dict] = None
 
 
 
 
45
 
46
  self._init_components()
 
47
 
48
  def _init_components(self) -> None:
49
  self.context_builder = ContextBuilder(
@@ -67,6 +100,43 @@ class LegitGenerator:
67
  self.generic_sourcer = GenericSourcer()
68
  self.fk_manager = FKManager()
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def load_ddl(self, ddl_or_path: str) -> Schema:
71
  if '\n' not in ddl_or_path and (
72
  ddl_or_path.endswith('.sql') or
@@ -99,25 +169,42 @@ class LegitGenerator:
99
  row_counts=row_counts
100
  )
101
 
102
- print("\n=== Step 1: Building Company Context ===")
103
- self._build_context()
104
-
105
- print("\n=== Step 2: Classifying Columns ===")
106
- self._classify_columns()
107
-
 
 
 
 
 
 
 
 
 
 
 
 
108
  print("\n=== Step 3: Generating Data ===")
109
  generated_data = {}
110
-
111
- for table in self.schema.get_dependency_order():
112
- num_rows = config.get_table_row_count(table.name, table.is_fact_table)
113
- print(f"\nGenerating {num_rows} rows for {table.name}...")
114
-
115
- rows = self._generate_table_data(table, num_rows)
116
- generated_data[table.name] = rows
117
-
118
- # Register PK values and full rows for FK references
119
- self._register_pk_values(table, rows)
120
- self.fk_manager.register_rows(table.name, rows)
 
 
 
 
 
121
 
122
  print("\n=== Step 4: Writing to Database ===")
123
  results = self._write_to_database(generated_data, truncate_first)
@@ -127,28 +214,510 @@ class LegitGenerator:
127
  print(f"Total rows inserted: {total}")
128
 
129
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  def _register_pk_values(self, table: Table, rows: list[dict]) -> None:
132
- """Register PK values for FK lookups. Generate fake IDs for identity columns."""
133
- pk_col = table.primary_key
134
- if not pk_col:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return
136
-
137
- # Get the PK column definition
138
- pk_column = table.get_column(pk_col)
139
-
140
- pk_values = []
141
- for i, row in enumerate(rows):
142
- pk_val = row.get(pk_col)
143
-
144
- # If PK is None (identity column), generate a fake ID for FK purposes
145
- if pk_val is None:
146
- pk_val = i + 1 # Simple sequential ID
147
- row[pk_col] = pk_val # Update the row too
148
-
149
- pk_values.append(pk_val)
150
-
151
- self.fk_manager.register_pk_values(table.name, pk_col, pk_values)
 
 
 
 
152
 
153
  def _build_context(self) -> None:
154
  if self.cache_enabled and self.cache:
@@ -156,13 +725,18 @@ class LegitGenerator:
156
  if cached:
157
  print(f"Using cached context for {self.url}")
158
  self.context = CompanyContext.from_dict(cached)
 
 
159
  return
160
 
161
  self.context = self.context_builder.build_context(self.url, self.use_case)
162
  print(f"Built context for: {self.context.company_name}")
163
  print(f" Industry: {self.context.industry}")
164
  print(f" Focus: {self.context.geographic_focus}")
165
-
 
 
 
166
  if self.cache_enabled and self.cache:
167
  self.cache.set_context(self.url, self.context.to_dict())
168
 
@@ -177,6 +751,7 @@ class LegitGenerator:
177
  if cached:
178
  print("Using cached column classifications")
179
  self.classifications = cached
 
180
  return
181
 
182
  self.classifications = self.column_classifier.classify_schema(
@@ -194,13 +769,33 @@ class LegitGenerator:
194
  ai_gen = sum(1 for c in columns.values() if c['classification'] == 'AI_GEN')
195
  generic = sum(1 for c in columns.values() if c['classification'] == 'GENERIC')
196
  print(f" {table_name}: {search_real} SEARCH_REAL, {ai_gen} AI_GEN, {generic} GENERIC")
 
 
197
 
198
  if self.cache_enabled and self.cache:
199
  self.cache.set_classification(schema_hash, self.classifications)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  def _generate_table_data(self, table: Table, num_rows: int) -> list[dict]:
202
  rows = []
203
  table_class = self.classifications.get(table.name, {})
 
 
204
 
205
  # Build set of FK column names
206
  fk_columns = {fk.column_name.upper() for fk in table.foreign_keys}
@@ -213,6 +808,7 @@ class LegitGenerator:
213
  col_class = table_class.get(column.name, {})
214
  classification = col_class.get('classification', 'GENERIC').upper()
215
  strategy = col_class.get('strategy', '')
 
216
 
217
  # Skip identity/PK/FK columns
218
  if column.is_identity or column.is_primary_key:
@@ -224,16 +820,31 @@ class LegitGenerator:
224
  coherent_columns.append({
225
  'name': column.name,
226
  'classification': classification,
227
- 'strategy': strategy
 
228
  })
229
  print(f" [CLASSIFY] {column.name}: {classification} → coherent (AI-generated)")
 
 
 
 
 
 
 
 
 
230
  else:
231
  generic_columns.append(column)
232
  print(f" [CLASSIFY] {column.name}: {classification} strategy='{strategy}' → generic (Faker)")
233
 
234
  # For dimension tables with multiple coherent columns, generate as entities
235
  coherent_entities = []
236
- if not table.is_fact_table and len(coherent_columns) >= 2:
 
 
 
 
 
237
  print(f" Generating {num_rows} coherent entities for {table.name}...")
238
  coherent_entities = self.ai_generator.generate_entity_batch(
239
  table_name=table.name,
@@ -249,11 +860,12 @@ class LegitGenerator:
249
  col_name = col_info['name']
250
  classification = col_info['classification']
251
  strategy = col_info['strategy']
 
252
 
253
  if classification == 'SEARCH_REAL':
254
- values = self._get_search_real_values(col_name, strategy, num_rows)
255
  else: # AI_GEN
256
- values = self._get_ai_gen_values(col_name, strategy, num_rows)
257
  column_values[col_name] = values
258
 
259
  # Convert to entity format
@@ -342,36 +954,56 @@ class LegitGenerator:
342
 
343
  # GENERIC columns - use classification strategy if set, otherwise infer
344
  col_class = table_class.get(column.name, {})
345
- strategy = col_class.get('strategy', '')
346
-
347
- # If strategy looks like a Faker type (not an AI prompt), use it
348
- faker_strategies = ('email', 'first_name', 'last_name', 'name', 'phone',
349
- 'city', 'state', 'country', 'zipcode', 'address',
350
- 'uuid', 'url', 'company')
351
-
352
- # Check if column is numeric - don't use faker for numeric columns!
353
- data_type_upper = (column.data_type or '').upper()
354
- is_numeric = any(t in data_type_upper for t in ('INT', 'NUMBER', 'NUMERIC', 'DECIMAL', 'BIGINT', 'SMALLINT', 'FLOAT'))
355
 
356
  # DEBUG: Log first row only
357
  if i == 0:
358
- print(f" [GEN] {col_name}: strategy='{strategy}', is_faker={strategy in faker_strategies}, is_numeric={is_numeric}")
359
 
360
- if strategy and strategy in faker_strategies and not is_numeric:
361
- # Only use faker for non-numeric columns
362
- row[col_name] = self.generic_sourcer.generate_value(strategy)
 
 
 
 
 
 
 
 
 
 
363
  else:
364
- # Fall back to inferred strategy
365
- inferred = self._infer_strategy(column)
 
 
366
  if i == 0:
367
  print(f" [GEN] {col_name}: using inferred='{inferred}'")
368
- row[col_name] = self.generic_sourcer.generate_value(inferred)
 
 
 
 
 
 
 
369
 
370
  rows.append(row)
371
-
 
372
  return rows
373
 
374
- def _get_search_real_values(self, column_name: str, strategy: str, num_needed: int) -> list[str]:
 
 
 
 
 
 
375
  cache_key = f"search:{column_name}:{strategy}"
376
 
377
  if self.cache_enabled and self.cache:
@@ -396,12 +1028,20 @@ class LegitGenerator:
396
  use_case=self.use_case
397
  )
398
 
 
 
399
  if self.cache_enabled and self.cache and values:
400
  self.cache.set_generated_values(cache_key, values)
401
 
402
  return values
403
 
404
- def _get_ai_gen_values(self, column_name: str, strategy: str, num_needed: int) -> list[str]:
 
 
 
 
 
 
405
  cache_key = f"aigen:{column_name}:{strategy}:{self.context.company_name}"
406
 
407
  if self.cache_enabled and self.cache:
@@ -425,14 +1065,587 @@ class LegitGenerator:
425
  if not values:
426
  break
427
 
 
 
428
  if self.cache_enabled and self.cache and all_values:
429
  self.cache.set_generated_values(cache_key, all_values)
430
 
431
  return all_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
  def _infer_strategy(self, column: Column) -> str:
434
  col_lower = column.name.lower()
435
  data_type = column.data_type.upper() if column.data_type else ''
 
 
 
 
 
 
436
 
437
  # Check if data type is numeric
438
  is_numeric = any(t in data_type for t in ('INT', 'NUMBER', 'NUMERIC', 'DECIMAL', 'BIGINT', 'SMALLINT'))
@@ -443,13 +1656,23 @@ class LegitGenerator:
443
 
444
  # Date/timestamp columns - use current date for end range
445
  # Calculate dynamic date ranges relative to today
446
- from datetime import timedelta
447
  today_dt = datetime.now()
448
  today = today_dt.strftime('%Y-%m-%d')
449
  three_years_ago = (today_dt - timedelta(days=3*365)).strftime('%Y-%m-%d')
450
  two_years_ago = (today_dt - timedelta(days=2*365)).strftime('%Y-%m-%d')
451
  five_years_ago = (today_dt - timedelta(days=5*365)).strftime('%Y-%m-%d')
452
 
 
 
 
 
 
 
 
 
 
 
 
453
  if 'date' in col_lower or 'DATE' in data_type or 'TIMESTAMP' in data_type:
454
  if 'created' in col_lower:
455
  return f"date_between:{three_years_ago},{today}"
 
1
  """Main LegitGenerator class - the primary interface for data generation."""
2
 
3
  import hashlib
4
+ import math
5
  import random
6
+ import re
7
  from typing import Any, Callable, Optional
8
+ from datetime import datetime, date, timedelta
9
 
10
  from .ddl import parse_ddl, parse_ddl_file, Schema, Table, Column, ColumnClassification
11
  from .analyzer import ContextBuilder, CompanyContext, ColumnClassifier
 
14
  from .writers import BaseWriter, SnowflakeWriter, DryRunWriter
15
  from .cache import FileCache
16
  from .config import SIZE_PRESETS, GenerationConfig, MAX_AI_ROWS_PER_CALL
17
+ from .domain import (
18
+ SemanticType,
19
+ get_domain_values,
20
+ infer_semantic_type,
21
+ is_business_categorical,
22
+ map_state_to_region,
23
+ )
24
+ from .storyspec import StorySpec, build_story_spec
25
+
26
+ try:
27
+ from demo_personas import parse_use_case, get_use_case_config
28
+ except Exception:
29
+ parse_use_case = None
30
+ get_use_case_config = None
31
 
32
 
33
  class LegitGenerator:
34
  """Main class for generating realistic synthetic data."""
35
 
36
+ EXECUTABLE_GENERIC_STRATEGIES = {
37
+ "random_int", "random_decimal", "random_float", "random_string",
38
+ "choice", "weighted_choice", "uuid", "boolean", "date_between", "sequential",
39
+ "email", "phone", "name", "first_name", "last_name", "company",
40
+ "address", "city", "state", "country", "zipcode", "url", "text",
41
+ "paragraph", "sentence", "word", "lookup", "sku",
42
+ # aliases supported by GenericSourcer
43
+ "int", "integer", "decimal", "float", "number", "date", "datetime",
44
+ "bool", "flag", "guid", "id", "zip", "postal",
45
+ }
46
+ _GARBAGE_TOKEN_RE = re.compile(r'^[A-Z0-9]{8,}$')
47
+
48
  def __init__(
49
  self,
50
  url: str,
 
70
  self.schema: Optional[Schema] = None
71
  self.context: Optional[CompanyContext] = None
72
  self.classifications: Optional[dict] = None
73
+ self.semantic_types: dict[str, dict[str, SemanticType]] = {}
74
+ self.story_spec: StorySpec | None = None
75
+ self.story_events: list[dict[str, Any]] = []
76
+ self._story_rng = random.Random()
77
 
78
  self._init_components()
79
+ self._refresh_story_spec()
80
 
81
  def _init_components(self) -> None:
82
  self.context_builder = ContextBuilder(
 
100
  self.generic_sourcer = GenericSourcer()
101
  self.fk_manager = FKManager()
102
 
103
+ def _resolve_use_case_config(self) -> dict[str, Any]:
104
+ """Resolve vertical/function matrix config when available."""
105
+ if not parse_use_case or not get_use_case_config:
106
+ return {}
107
+ try:
108
+ vertical, function = parse_use_case(self.use_case or "")
109
+ return get_use_case_config(vertical or "Generic", function or "Generic")
110
+ except Exception:
111
+ return {}
112
+
113
+ def _refresh_story_spec(self) -> None:
114
+ """Build per-run StorySpec and deterministic RNG seed."""
115
+ cfg = self._resolve_use_case_config()
116
+ company_name = None
117
+ if self.context and getattr(self.context, "company_name", None):
118
+ company_name = self.context.company_name
119
+
120
+ self.story_spec = build_story_spec(
121
+ company_url=self.url,
122
+ use_case=self.use_case,
123
+ company_name=company_name,
124
+ use_case_config=cfg,
125
+ )
126
+
127
+ # Stabilize pseudo-random behavior for reproducible runs.
128
+ self._story_rng = random.Random(self.story_spec.seed)
129
+ random.seed(self.story_spec.seed)
130
+
131
+ def _domain_use_case(self) -> str:
132
+ """Use company + use_case context for domain pack routing."""
133
+ company_hint = ""
134
+ if self.context and getattr(self.context, "company_name", None):
135
+ company_hint = str(self.context.company_name).strip()
136
+ elif self.url:
137
+ company_hint = str(self.url).strip()
138
+ return f"{company_hint} {self.use_case or ''}".strip()
139
+
140
  def load_ddl(self, ddl_or_path: str) -> Schema:
141
  if '\n' not in ddl_or_path and (
142
  ddl_or_path.endswith('.sql') or
 
169
  row_counts=row_counts
170
  )
171
 
172
+ is_saas_finance_gold = self._is_saas_finance_gold_schema()
173
+ if is_saas_finance_gold:
174
+ print("\n=== Step 1: Preparing Deterministic Finance Context ===")
175
+ self.context = None
176
+ self._refresh_story_spec()
177
+ print(f" Story seed: {self.story_spec.seed}")
178
+ else:
179
+ print("\n=== Step 1: Building Company Context ===")
180
+ self._build_context()
181
+
182
+ if is_saas_finance_gold:
183
+ print("\n=== Step 2: Skipping Column Classification ===")
184
+ print(" Using deterministic semantic typing for SaaS finance gold-path schema...")
185
+ self._build_semantic_types()
186
+ else:
187
+ print("\n=== Step 2: Classifying Columns ===")
188
+ self._classify_columns()
189
+
190
  print("\n=== Step 3: Generating Data ===")
191
  generated_data = {}
192
+ self.story_events = []
193
+
194
+ if is_saas_finance_gold:
195
+ print(" Using deterministic SaaS finance gold-path generator...")
196
+ generated_data = self._generate_saas_finance_gold(config)
197
+ else:
198
+ for table in self.schema.get_dependency_order():
199
+ num_rows = config.get_table_row_count(table.name, table.is_fact_table)
200
+ print(f"\nGenerating {num_rows} rows for {table.name}...")
201
+
202
+ rows = self._generate_table_data(table, num_rows)
203
+ generated_data[table.name] = rows
204
+
205
+ # Register PK values and full rows for FK references
206
+ self._register_pk_values(table, rows)
207
+ self.fk_manager.register_rows(table.name, rows)
208
 
209
  print("\n=== Step 4: Writing to Database ===")
210
  results = self._write_to_database(generated_data, truncate_first)
 
214
  print(f"Total rows inserted: {total}")
215
 
216
  return results
217
+
218
+ def _is_saas_finance_gold_schema(self) -> bool:
219
+ if not self.schema:
220
+ return False
221
+ table_names = {table.name.upper() for table in self.schema.tables}
222
+ required = {"DATES", "CUSTOMERS", "PRODUCTS", "LOCATIONS", "SAAS_CUSTOMER_MONTHLY", "SALES_MARKETING_SPEND_MONTHLY"}
223
+ if required.issubset(table_names):
224
+ return True
225
+
226
+ cfg = self._resolve_use_case_config()
227
+ schema_contract = cfg.get("schema_contract", {}) if isinstance(cfg, dict) else {}
228
+ return schema_contract.get("mode") == "saas_finance_gold"
229
+
230
+ def _generate_saas_finance_gold(self, config: GenerationConfig) -> dict[str, list[dict]]:
231
+ """Generate a deterministic recurring-revenue SaaS finance dataset."""
232
+ rng = self._story_rng
233
+ today = datetime.now().date().replace(day=1)
234
+ months = []
235
+ for offset in range(23, -1, -1):
236
+ year = today.year
237
+ month = today.month - offset
238
+ while month <= 0:
239
+ month += 12
240
+ year -= 1
241
+ while month > 12:
242
+ month -= 12
243
+ year += 1
244
+ month_date = date(year, month, 1)
245
+ months.append(month_date)
246
+
247
+ def month_key(dt: date) -> int:
248
+ return int(dt.strftime("%Y%m01"))
249
+
250
+ dates_rows = []
251
+ for dt in months:
252
+ quarter = ((dt.month - 1) // 3) + 1
253
+ dates_rows.append(
254
+ {
255
+ "MONTH_KEY": month_key(dt),
256
+ "FULL_DATE": dt,
257
+ "MONTH_NAME": dt.strftime("%B"),
258
+ "QUARTER_NAME": f"Q{quarter}",
259
+ "YEAR_NUM": dt.year,
260
+ "IS_QUARTER_END": dt.month in {3, 6, 9, 12},
261
+ "IS_YEAR_END": dt.month == 12,
262
+ }
263
+ )
264
+
265
+ segment_weights = [("Enterprise", 0.28), ("Mid-Market", 0.34), ("SMB", 0.38)]
266
+ region_arr_multipliers = {"North America": 1.28, "EMEA": 1.04, "APAC": 0.81}
267
+ segment_motion_profiles = {
268
+ "Enterprise": {"expansion": (0.026, 0.078), "contraction": (0.000, 0.008)},
269
+ "Mid-Market": {"expansion": (0.006, 0.028), "contraction": (0.006, 0.026)},
270
+ "SMB": {"expansion": (0.000, 0.010), "contraction": (0.022, 0.072)},
271
+ }
272
+ region_motion_bias = {
273
+ "North America": {"expansion": 0.014, "contraction": -0.004},
274
+ "EMEA": {"expansion": -0.012, "contraction": 0.016},
275
+ "APAC": {"expansion": 0.002, "contraction": 0.006},
276
+ }
277
+ region_defs = [
278
+ ("North America", "NA - West", "Americas", "US", "CA", "San Francisco"),
279
+ ("North America", "NA - Central", "Americas", "US", "TX", "Austin"),
280
+ ("North America", "NA - East", "Americas", "US", "NY", "New York"),
281
+ ("EMEA", "EMEA - UKI", "Europe", "GB", "", "London"),
282
+ ("EMEA", "EMEA - DACH", "Europe", "DE", "", "Berlin"),
283
+ ("EMEA", "EMEA - Southern Europe", "Europe", "FR", "", "Paris"),
284
+ ("APAC", "APAC - ANZ", "Asia Pacific", "AU", "", "Sydney"),
285
+ ("APAC", "APAC - Japan", "Asia Pacific", "JP", "", "Tokyo"),
286
+ ("APAC", "APAC - India", "Asia Pacific", "IN", "", "Bengaluru"),
287
+ ]
288
+ product_defs = [
289
+ ("ThoughtSpot Analytics Cloud", "Core Platform", "Enterprise", "Subscription", True, False, 84.0, 120000.0),
290
+ ("Agentic Analytics", "AI Analyst", "Enterprise", "Subscription", False, True, 87.0, 42000.0),
291
+ ("Embedded Analytics", "Embedded Analytics", "Scale", "Subscription", False, True, 82.0, 36000.0),
292
+ ("Semantic Modeling Studio", "Semantic Modeling", "Professional", "Subscription", True, False, 86.0, 26000.0),
293
+ ("Data Prep Studio", "Data Prep & Studio", "Professional", "Subscription", False, True, 80.0, 18000.0),
294
+ ("Security & Governance", "Platform Add-ons", "Enterprise", "Subscription", False, True, 88.0, 22000.0),
295
+ ("Observability Cloud", "Observability", "Professional", "Usage", False, True, 79.0, 24000.0),
296
+ ("Developer APIs", "Platform Add-ons", "Growth", "Usage", False, True, 90.0, 12000.0),
297
+ ("Consumption Credits", "Usage Services", "Growth", "Usage", False, True, 76.0, 9000.0),
298
+ ("Customer Success Premium", "Services", "Enterprise", "Services", False, True, 68.0, 15000.0),
299
+ ]
300
+
301
+ location_rows = []
302
+ for idx, (region, sub_region, geo, country, state_abbr, city) in enumerate(region_defs, start=1):
303
+ location_rows.append(
304
+ {
305
+ "LOCATION_KEY": idx,
306
+ "GEO": geo,
307
+ "REGION": region,
308
+ "SUB_REGION": sub_region,
309
+ "COUNTRY": country,
310
+ "STATE_ABBR": state_abbr,
311
+ "CITY": city,
312
+ }
313
+ )
314
+
315
+ product_rows = []
316
+ for idx, product in enumerate(product_defs, start=1):
317
+ product_rows.append(
318
+ {
319
+ "PRODUCT_KEY": idx,
320
+ "PRODUCT_NAME": product[0],
321
+ "PRODUCT_FAMILY": product[1],
322
+ "PLAN_TIER": product[2],
323
+ "PRICING_MODEL": product[3],
324
+ "IS_CORE_PLATFORM": product[4],
325
+ "IS_ADD_ON": product[5],
326
+ "DEFAULT_GROSS_MARGIN_PCT": round(product[6], 2),
327
+ "LIST_PRICE_USD": round(product[7], 2),
328
+ }
329
+ )
330
+
331
+ requested_customer_count = config.row_counts.get("CUSTOMERS") if config.row_counts else None
332
+ if requested_customer_count is None:
333
+ requested_customer_count = config.get_table_row_count("CUSTOMERS", False)
334
+ customer_count = max(72, min(108, int(requested_customer_count or 84)))
335
+ company_tokens = ["Cloud", "Data", "Ops", "Signal", "Metric", "Scale", "Insight", "Lake", "Grid", "Core", "Logic", "Flow"]
336
+ suffixes = ["Systems", "Labs", "Software", "Analytics", "Platforms", "Works", "Networks", "Cloud", "Dynamics"]
337
+
338
+ def weighted_choice(weighted_values: list[tuple[str, float]]) -> str:
339
+ roll = rng.random()
340
+ running = 0.0
341
+ for value, weight in weighted_values:
342
+ running += weight
343
+ if roll <= running:
344
+ return value
345
+ return weighted_values[-1][0]
346
+
347
+ customers = []
348
+ customers_rows = []
349
+ for idx in range(1, customer_count + 1):
350
+ segment = weighted_choice(segment_weights)
351
+ region_row = rng.choice(location_rows)
352
+ product_row = rng.choice(product_rows[:8])
353
+ acquired_idx = rng.randint(0, 17)
354
+ acquired_month = months[acquired_idx]
355
+ billing_cadence = rng.choices(
356
+ ["Annual", "Quarterly", "Monthly"],
357
+ weights=[0.62, 0.18, 0.20] if segment != "SMB" else [0.18, 0.12, 0.70],
358
+ k=1,
359
+ )[0]
360
+ contract_type = rng.choices(
361
+ ["New Business", "Renewal", "Expansion"],
362
+ weights=[0.55, 0.25, 0.20],
363
+ k=1,
364
+ )[0]
365
+ risk_band = rng.choices(["Low", "Medium", "High"], weights=[0.55, 0.30, 0.15], k=1)[0]
366
+ owner = f"AE {idx:03d}"
367
+ token_a = company_tokens[idx % len(company_tokens)]
368
+ token_b = company_tokens[(idx * 3) % len(company_tokens)]
369
+ company_name = f"{token_a}{token_b} {suffixes[idx % len(suffixes)]}"
370
+ if segment == "Enterprise":
371
+ base_arr = rng.randint(90000, 340000)
372
+ elif segment == "Mid-Market":
373
+ base_arr = rng.randint(30000, 95000)
374
+ else:
375
+ base_arr = rng.randint(6000, 24000)
376
+ base_arr = int(
377
+ base_arr
378
+ * region_arr_multipliers[region_row["REGION"]]
379
+ * rng.uniform(0.88, 1.16)
380
+ )
381
+
382
+ customers.append(
383
+ {
384
+ "customer_key": idx,
385
+ "name": company_name,
386
+ "segment": segment,
387
+ "location_key": region_row["LOCATION_KEY"],
388
+ "region": region_row["REGION"],
389
+ "sub_region": region_row["SUB_REGION"],
390
+ "product_key": product_row["PRODUCT_KEY"],
391
+ "billing_cadence": billing_cadence,
392
+ "contract_type": contract_type,
393
+ "risk_band": risk_band,
394
+ "owner": owner,
395
+ "acquired_idx": acquired_idx,
396
+ "base_arr": float(base_arr),
397
+ "renewal_offset": acquired_idx % (12 if billing_cadence == "Annual" else 3 if billing_cadence == "Quarterly" else 1),
398
+ }
399
+ )
400
+
401
+ customers_rows.append(
402
+ {
403
+ "CUSTOMER_KEY": idx,
404
+ "CUSTOMER_NAME": company_name,
405
+ "CUSTOMER_TIER": "Strategic" if segment == "Enterprise" else "Growth" if segment == "Mid-Market" else "Velocity",
406
+ "SEGMENT": segment,
407
+ "CONTRACT_TYPE": contract_type,
408
+ "BILLING_CADENCE": billing_cadence,
409
+ "ACQUISITION_CHANNEL": rng.choices(
410
+ ["Direct Sales", "Partner", "Digital Self-Serve"],
411
+ weights=[0.58, 0.20, 0.22],
412
+ k=1,
413
+ )[0],
414
+ "CHURN_RISK_BAND": risk_band,
415
+ "CUSTOMER_STATUS": "Active",
416
+ "CUSTOMER_ACQUIRED_MONTH_KEY": month_key(acquired_month),
417
+ "PRIMARY_REGION": region_row["REGION"],
418
+ "ACCOUNT_OWNER": owner,
419
+ "PRIMARY_USE_CASE": rng.choice(
420
+ ["Executive Analytics", "Embedded Analytics", "FP&A", "Product Analytics", "Sales Operations"]
421
+ ),
422
+ }
423
+ )
424
+
425
+ fact_rows: list[dict[str, Any]] = []
426
+ new_logo_lookup: dict[tuple[int, str, str], list[int]] = {}
427
+ spend_accumulator: dict[tuple[int, str, str], dict[str, float]] = {}
428
+ fact_key = 1
429
+
430
+ july_spike_key = month_key(next((m for m in months if m.month == 7), months[len(months) // 2]))
431
+ q4_keys = {month_key(m) for m in months if m.month in {10, 11, 12} and m.year == months[-1].year}
432
+ emea_lag_keys = {month_key(m) for m in months if m.year == months[-1].year and m.month in {5, 6, 7}}
433
+ upsell_keys = {month_key(m) for m in months if m.year == months[-1].year and m.month in {4, 5, 6}}
434
+
435
+ customer_status_by_key = {row["CUSTOMER_KEY"]: row for row in customers_rows}
436
+ deferred_balance_by_customer: dict[int, float] = {}
437
+
438
+ for customer in customers:
439
+ churn_probability = 0.05 if customer["segment"] == "Enterprise" else 0.14 if customer["segment"] == "Mid-Market" else 0.30
440
+ if customer["region"] == "EMEA":
441
+ churn_probability += 0.05
442
+ elif customer["region"] == "North America":
443
+ churn_probability -= 0.02
444
+ if customer["risk_band"] == "High":
445
+ churn_probability += 0.06
446
+ elif customer["risk_band"] == "Low":
447
+ churn_probability -= 0.025
448
+ churn_probability = max(0.02, min(0.42, churn_probability))
449
+ churn_idx = None
450
+ if rng.random() < churn_probability:
451
+ earliest = max(customer["acquired_idx"] + 6, 8)
452
+ if earliest < len(months):
453
+ churn_idx = rng.randint(earliest, len(months) - 1)
454
+ prev_arr = 0.0
455
+ deferred_balance = 0.0
456
+ base_seats = max(5, int(customer["base_arr"] / (1200 if customer["segment"] == "Enterprise" else 600 if customer["segment"] == "Mid-Market" else 240)))
457
+
458
+ for month_index, month_start in enumerate(months):
459
+ if month_index < customer["acquired_idx"]:
460
+ continue
461
+ if churn_idx is not None and month_index > churn_idx:
462
+ continue
463
+
464
+ mk = month_key(month_start)
465
+ starting_arr = round(prev_arr, 2)
466
+ new_logo_arr = round(customer["base_arr"], 2) if month_index == customer["acquired_idx"] else 0.0
467
+
468
+ expansion_pct = 0.0
469
+ contraction_pct = 0.0
470
+ if starting_arr > 0:
471
+ motion_profile = segment_motion_profiles[customer["segment"]]
472
+ region_bias = region_motion_bias[customer["region"]]
473
+ base_expansion = rng.uniform(*motion_profile["expansion"]) + region_bias["expansion"]
474
+ base_contraction = rng.uniform(*motion_profile["contraction"]) + region_bias["contraction"]
475
+ if customer["segment"] == "Enterprise" and mk in q4_keys:
476
+ base_expansion += rng.uniform(0.05, 0.09)
477
+ if customer["segment"] == "SMB" and mk == july_spike_key:
478
+ base_contraction += rng.uniform(0.08, 0.14)
479
+ if customer["segment"] == "Mid-Market" and mk in upsell_keys:
480
+ base_expansion += rng.uniform(0.03, 0.06)
481
+ if customer["segment"] == "SMB" and cadence == "Monthly" and month_start.month in {3, 6, 9, 12}:
482
+ base_contraction += rng.uniform(0.015, 0.045)
483
+ if customer["segment"] == "Enterprise" and customer["region"] == "North America" and month_start.month in {2, 5, 8, 11}:
484
+ base_expansion += rng.uniform(0.015, 0.03)
485
+ if customer["region"] == "EMEA" and month_start.month in {4, 5, 6, 7}:
486
+ base_contraction += rng.uniform(0.012, 0.026)
487
+ expansion_pct = max(0.0, min(0.18, base_expansion))
488
+ contraction_pct = max(0.0, min(0.22, base_contraction))
489
+
490
+ expansion_arr = round(starting_arr * expansion_pct, 2)
491
+ contraction_arr = round(starting_arr * contraction_pct, 2)
492
+ churned_arr = 0.0
493
+ customer_churned_flag = False
494
+ if churn_idx is not None and month_index == churn_idx:
495
+ churned_arr = round(max(starting_arr + new_logo_arr + expansion_arr - contraction_arr, 0.0), 2)
496
+ customer_churned_flag = churned_arr > 0
497
+
498
+ ending_arr = round(max(starting_arr + new_logo_arr + expansion_arr - contraction_arr - churned_arr, 0.0), 2)
499
+ mrr_usd = round(ending_arr / 12.0, 2)
500
+
501
+ cadence = customer["billing_cadence"]
502
+ cadence_mod = 12 if cadence == "Annual" else 3 if cadence == "Quarterly" else 1
503
+ billing_due = ((month_index - customer["renewal_offset"]) % cadence_mod) == 0
504
+ billings_usd = 0.0
505
+ if cadence == "Monthly":
506
+ billings_usd = round(mrr_usd * rng.uniform(0.98, 1.03), 2)
507
+ elif billing_due and cadence == "Quarterly":
508
+ billings_usd = round(mrr_usd * 3.0 * rng.uniform(0.98, 1.04), 2)
509
+ elif billing_due and cadence == "Annual":
510
+ billings_usd = round(max(ending_arr, 0.0) * rng.uniform(0.99, 1.05), 2)
511
+
512
+ if customer["segment"] == "Enterprise" and month_start.month == 11 and cadence == "Annual":
513
+ billings_usd = round(billings_usd * 1.18, 2)
514
+
515
+ recognized_revenue = mrr_usd
516
+ lag_factor = 0.97
517
+ if customer["region"] == "EMEA":
518
+ lag_factor = 0.88 if mk in emea_lag_keys else 0.93
519
+ elif customer["region"] == "APAC":
520
+ lag_factor = 0.95
521
+ if customer["segment"] == "SMB":
522
+ lag_factor -= 0.02
523
+ collections_usd = round(min(billings_usd, max(billings_usd * lag_factor, 0.0)), 2)
524
+ deferred_balance = round(max(deferred_balance + billings_usd - recognized_revenue, 0.0), 2)
525
+ dso_days = round(max(22.0, min(92.0, 30.0 + ((billings_usd - collections_usd) / billings_usd * 45.0 if billings_usd else 0.0) + (10.0 if mk in emea_lag_keys and customer["region"] == "EMEA" else 0.0))), 2)
526
+
527
+ product_row = product_rows[customer["product_key"] - 1]
528
+ gm_pct = float(product_row["DEFAULT_GROSS_MARGIN_PCT"])
529
+ if product_row["PRODUCT_FAMILY"] == "Services":
530
+ gm_pct -= 12.0
531
+ if month_start.month in {5, 6} and product_row["PRODUCT_FAMILY"] == "Observability":
532
+ gm_pct -= 6.0
533
+ gm_pct = round(max(45.0, min(92.0, gm_pct + rng.uniform(-1.5, 1.5))), 2)
534
+ gross_margin_usd = round(recognized_revenue * (gm_pct / 100.0), 2)
535
+ support_cost_usd = round(max(recognized_revenue - gross_margin_usd, recognized_revenue * 0.06), 2)
536
+ seat_count = int(max(1, round(base_seats * max(ending_arr, customer["base_arr"]) / max(customer["base_arr"], 1.0))))
537
+ usage_units = round(seat_count * rng.uniform(1.8, 4.6), 2)
538
+ customer_acquired_flag = month_index == customer["acquired_idx"]
539
+
540
+ fact_rows.append(
541
+ {
542
+ "SAAS_CUSTOMER_MONTHLY_KEY": fact_key,
543
+ "MONTH_KEY": mk,
544
+ "CUSTOMER_KEY": customer["customer_key"],
545
+ "PRODUCT_KEY": customer["product_key"],
546
+ "LOCATION_KEY": customer["location_key"],
547
+ "SCENARIO": "Actual",
548
+ "CONTRACT_TERM_MONTHS": 12 if cadence == "Annual" else 3 if cadence == "Quarterly" else 1,
549
+ "STARTING_ARR_USD": starting_arr,
550
+ "NEW_LOGO_ARR_USD": new_logo_arr,
551
+ "EXPANSION_ARR_USD": expansion_arr,
552
+ "CONTRACTION_ARR_USD": contraction_arr,
553
+ "CHURNED_ARR_USD": churned_arr,
554
+ "ENDING_ARR_USD": ending_arr,
555
+ "MRR_USD": mrr_usd,
556
+ "BILLINGS_USD": billings_usd,
557
+ "COLLECTIONS_USD": collections_usd,
558
+ "DEFERRED_REVENUE_USD": deferred_balance,
559
+ "GROSS_MARGIN_USD": gross_margin_usd,
560
+ "GROSS_MARGIN_PCT": gm_pct,
561
+ "SUPPORT_COST_USD": support_cost_usd,
562
+ "CAC_ALLOCATED_USD": 0.0,
563
+ "DSO_DAYS": dso_days,
564
+ "SEAT_COUNT": seat_count,
565
+ "USAGE_UNITS": usage_units,
566
+ "CUSTOMER_ACQUIRED_FLAG": customer_acquired_flag,
567
+ "CUSTOMER_CHURNED_FLAG": customer_churned_flag,
568
+ }
569
+ )
570
+ if customer_acquired_flag:
571
+ lookup_key = (mk, customer["segment"], customer["region"])
572
+ new_logo_lookup.setdefault(lookup_key, []).append(fact_key)
573
+
574
+ spend_bucket = spend_accumulator.setdefault(
575
+ (mk, customer["segment"], customer["region"]),
576
+ {
577
+ "new_customers": 0.0,
578
+ "new_logo_arr": 0.0,
579
+ "starting_arr": 0.0,
580
+ "expansion_arr": 0.0,
581
+ "contraction_arr": 0.0,
582
+ "churned_arr": 0.0,
583
+ },
584
+ )
585
+ spend_bucket["starting_arr"] += starting_arr
586
+ spend_bucket["expansion_arr"] += expansion_arr
587
+ spend_bucket["contraction_arr"] += contraction_arr
588
+ spend_bucket["churned_arr"] += churned_arr
589
+ if customer_acquired_flag:
590
+ spend_bucket["new_customers"] += 1.0
591
+ spend_bucket["new_logo_arr"] += new_logo_arr
592
+
593
+ fact_key += 1
594
+ prev_arr = ending_arr
595
+ deferred_balance_by_customer[customer["customer_key"]] = deferred_balance
596
+
597
+ if churn_idx is not None:
598
+ customer_status_by_key[customer["customer_key"]]["CUSTOMER_STATUS"] = "Churned"
599
+
600
+ spend_rows: list[dict[str, Any]] = []
601
+ spend_key = 1
602
+ cac_targets = {
603
+ "Enterprise": 32000.0,
604
+ "Mid-Market": 14500.0,
605
+ "SMB": 4200.0,
606
+ }
607
+ region_multipliers = {"North America": 1.0, "EMEA": 1.12, "APAC": 0.94}
608
+
609
+ for month_start in months:
610
+ mk = month_key(month_start)
611
+ for segment in ("Enterprise", "Mid-Market", "SMB"):
612
+ for region in ("North America", "EMEA", "APAC"):
613
+ bucket = spend_accumulator.get(
614
+ (mk, segment, region),
615
+ {
616
+ "new_customers": 0.0,
617
+ "new_logo_arr": 0.0,
618
+ "starting_arr": 0.0,
619
+ "expansion_arr": 0.0,
620
+ "contraction_arr": 0.0,
621
+ "churned_arr": 0.0,
622
+ },
623
+ )
624
+ new_customers = int(bucket["new_customers"])
625
+ baseline_new_customers = max(new_customers, 1 if segment != "Enterprise" else 0)
626
+ target_cac = cac_targets[segment] * region_multipliers[region]
627
+ if region == "EMEA" and month_start.month in {5, 6, 7}:
628
+ target_cac *= 1.08
629
+ total_spend = round(baseline_new_customers * target_cac, 2)
630
+ sales_spend = round(total_spend * 0.64, 2)
631
+ marketing_spend = round(total_spend - sales_spend, 2)
632
+ net_new_arr = round(
633
+ bucket["new_logo_arr"] + bucket["expansion_arr"] - bucket["contraction_arr"] - bucket["churned_arr"],
634
+ 2,
635
+ )
636
+ starting_arr = float(bucket["starting_arr"] or 0.0)
637
+ nrr_pct = round(
638
+ ((starting_arr + bucket["expansion_arr"] - bucket["contraction_arr"] - bucket["churned_arr"]) / starting_arr) * 100.0,
639
+ 4,
640
+ ) if starting_arr > 0 else 100.0
641
+ grr_pct = round(
642
+ ((starting_arr - bucket["contraction_arr"] - bucket["churned_arr"]) / starting_arr) * 100.0,
643
+ 4,
644
+ ) if starting_arr > 0 else 100.0
645
+ cac_usd = round(total_spend / new_customers, 2) if new_customers > 0 else 0.0
646
+ spend_rows.append(
647
+ {
648
+ "SALES_MARKETING_SPEND_MONTHLY_KEY": spend_key,
649
+ "MONTH_KEY": mk,
650
+ "SEGMENT": segment,
651
+ "REGION": region,
652
+ "SALES_SPEND_USD": sales_spend,
653
+ "MARKETING_SPEND_USD": marketing_spend,
654
+ "TOTAL_S_AND_M_SPEND_USD": round(sales_spend + marketing_spend, 2),
655
+ "NEW_CUSTOMERS_ACQUIRED": new_customers,
656
+ "NEW_LOGO_ARR_USD": round(bucket["new_logo_arr"], 2),
657
+ "NET_NEW_ARR_USD": net_new_arr,
658
+ "NRR_PCT": nrr_pct,
659
+ "GRR_PCT": grr_pct,
660
+ "CAC_USD": cac_usd,
661
+ }
662
+ )
663
+ if new_customers > 0:
664
+ per_customer_cac = cac_usd
665
+ for fact_row_key in new_logo_lookup.get((mk, segment, region), []):
666
+ fact_rows[fact_row_key - 1]["CAC_ALLOCATED_USD"] = per_customer_cac
667
+ spend_key += 1
668
+
669
+ return {
670
+ "DATES": dates_rows,
671
+ "CUSTOMERS": customers_rows,
672
+ "PRODUCTS": product_rows,
673
+ "LOCATIONS": location_rows,
674
+ "SAAS_CUSTOMER_MONTHLY": fact_rows,
675
+ "SALES_MARKETING_SPEND_MONTHLY": spend_rows,
676
+ }
677
 
678
  def _register_pk_values(self, table: Table, rows: list[dict]) -> None:
679
+ """Register key-like values for FK lookups.
680
+
681
+ Parser output is not always perfect with table-level PK constraints, so
682
+ we also register identity/id-style columns that are likely FK targets.
683
+ """
684
+ key_columns: list[str] = []
685
+ if table.primary_key:
686
+ key_columns.append(table.primary_key)
687
+
688
+ for col in table.columns:
689
+ if col.is_identity and col.name not in key_columns:
690
+ key_columns.append(col.name)
691
+
692
+ if not key_columns:
693
+ for col in table.columns:
694
+ c = col.name.upper()
695
+ if c == "ID" or c.endswith("_ID") or c.endswith("ID"):
696
+ key_columns.append(col.name)
697
+ break
698
+
699
+ if not key_columns:
700
  return
701
+
702
+ for key_col in key_columns:
703
+ col_def = table.get_column(key_col)
704
+ if not col_def:
705
+ continue
706
+
707
+ key_values = []
708
+ for i, row in enumerate(rows):
709
+ key_val = row.get(key_col)
710
+
711
+ # Keep deterministic IDs available for FK assignment before DB insert.
712
+ if key_val is None and col_def.is_identity:
713
+ key_val = i + 1
714
+ row[key_col] = key_val
715
+
716
+ if key_val is not None:
717
+ key_values.append(key_val)
718
+
719
+ if key_values:
720
+ self.fk_manager.register_pk_values(table.name, key_col, key_values)
721
 
722
  def _build_context(self) -> None:
723
  if self.cache_enabled and self.cache:
 
725
  if cached:
726
  print(f"Using cached context for {self.url}")
727
  self.context = CompanyContext.from_dict(cached)
728
+ self._refresh_story_spec()
729
+ print(f" Story seed: {self.story_spec.seed}")
730
  return
731
 
732
  self.context = self.context_builder.build_context(self.url, self.use_case)
733
  print(f"Built context for: {self.context.company_name}")
734
  print(f" Industry: {self.context.industry}")
735
  print(f" Focus: {self.context.geographic_focus}")
736
+
737
+ self._refresh_story_spec()
738
+ print(f" Story seed: {self.story_spec.seed}")
739
+
740
  if self.cache_enabled and self.cache:
741
  self.cache.set_context(self.url, self.context.to_dict())
742
 
 
751
  if cached:
752
  print("Using cached column classifications")
753
  self.classifications = cached
754
+ self._build_semantic_types()
755
  return
756
 
757
  self.classifications = self.column_classifier.classify_schema(
 
769
  ai_gen = sum(1 for c in columns.values() if c['classification'] == 'AI_GEN')
770
  generic = sum(1 for c in columns.values() if c['classification'] == 'GENERIC')
771
  print(f" {table_name}: {search_real} SEARCH_REAL, {ai_gen} AI_GEN, {generic} GENERIC")
772
+
773
+ self._build_semantic_types()
774
 
775
  if self.cache_enabled and self.cache:
776
  self.cache.set_classification(schema_hash, self.classifications)
777
+
778
+ def _build_semantic_types(self) -> None:
779
+ """Build semantic type map for all table columns."""
780
+ self.semantic_types = {}
781
+ if not self.schema:
782
+ return
783
+ for table in self.schema.tables:
784
+ table_semantics: dict[str, SemanticType] = {}
785
+ for column in table.columns:
786
+ table_semantics[column.name] = infer_semantic_type(
787
+ column.name,
788
+ column.data_type,
789
+ table.name,
790
+ self._domain_use_case(),
791
+ )
792
+ self.semantic_types[table.name] = table_semantics
793
 
794
  def _generate_table_data(self, table: Table, num_rows: int) -> list[dict]:
795
  rows = []
796
  table_class = self.classifications.get(table.name, {})
797
+ table_semantics = self.semantic_types.get(table.name, {})
798
+ date_dim_profile = self._build_date_dimension_profile(table, table_semantics, num_rows)
799
 
800
  # Build set of FK column names
801
  fk_columns = {fk.column_name.upper() for fk in table.foreign_keys}
 
808
  col_class = table_class.get(column.name, {})
809
  classification = col_class.get('classification', 'GENERIC').upper()
810
  strategy = col_class.get('strategy', '')
811
+ semantic_type = table_semantics.get(column.name, SemanticType.UNKNOWN)
812
 
813
  # Skip identity/PK/FK columns
814
  if column.is_identity or column.is_primary_key:
 
820
  coherent_columns.append({
821
  'name': column.name,
822
  'classification': classification,
823
+ 'strategy': strategy,
824
+ 'semantic_type': semantic_type,
825
  })
826
  print(f" [CLASSIFY] {column.name}: {classification} → coherent (AI-generated)")
827
+ elif is_business_categorical(semantic_type):
828
+ # Treat business categorical columns as coherent to avoid noisy per-row drift.
829
+ coherent_columns.append({
830
+ 'name': column.name,
831
+ 'classification': 'AI_GEN',
832
+ 'strategy': strategy or f"Generate realistic {column.name} values",
833
+ 'semantic_type': semantic_type,
834
+ })
835
+ print(f" [CLASSIFY] {column.name}: semantic={semantic_type.value} → coherent (domain-aware)")
836
  else:
837
  generic_columns.append(column)
838
  print(f" [CLASSIFY] {column.name}: {classification} strategy='{strategy}' → generic (Faker)")
839
 
840
  # For dimension tables with multiple coherent columns, generate as entities
841
  coherent_entities = []
842
+ domain_use_case = self._domain_use_case()
843
+ has_freeform_coherent = any(
844
+ not is_business_categorical(col.get("semantic_type", SemanticType.UNKNOWN))
845
+ for col in coherent_columns
846
+ )
847
+ if not table.is_fact_table and len(coherent_columns) >= 2 and has_freeform_coherent:
848
  print(f" Generating {num_rows} coherent entities for {table.name}...")
849
  coherent_entities = self.ai_generator.generate_entity_batch(
850
  table_name=table.name,
 
860
  col_name = col_info['name']
861
  classification = col_info['classification']
862
  strategy = col_info['strategy']
863
+ semantic_type = col_info.get('semantic_type', SemanticType.UNKNOWN)
864
 
865
  if classification == 'SEARCH_REAL':
866
+ values = self._get_search_real_values(col_name, strategy, num_rows, semantic_type)
867
  else: # AI_GEN
868
+ values = self._get_ai_gen_values(col_name, strategy, num_rows, semantic_type)
869
  column_values[col_name] = values
870
 
871
  # Convert to entity format
 
954
 
955
  # GENERIC columns - use classification strategy if set, otherwise infer
956
  col_class = table_class.get(column.name, {})
957
+ classification = col_class.get('classification', 'GENERIC').upper()
958
+ strategy = (col_class.get('strategy') or '').strip()
959
+ strategy_is_executable = self._is_executable_generic_strategy(strategy)
960
+ semantic_type = table_semantics.get(col_name, SemanticType.UNKNOWN)
 
 
 
 
 
 
961
 
962
  # DEBUG: Log first row only
963
  if i == 0:
964
+ print(f" [GEN] {col_name}: class='{classification}', strategy='{strategy}', executable={strategy_is_executable}")
965
 
966
+ # Quality-first: strong semantic types override generic model hints.
967
+ semantic_strategy = self.generic_sourcer.get_strategy_for_semantic(semantic_type, domain_use_case)
968
+ story_domain_values = self._get_story_domain_values(semantic_type)
969
+ if semantic_strategy and (is_business_categorical(semantic_type) or semantic_type in {SemanticType.ORG_NAME, SemanticType.PRODUCT_NAME, SemanticType.BRAND_NAME, SemanticType.PERSON_NAME}):
970
+ if story_domain_values:
971
+ row[col_name] = self._choose_story_value(story_domain_values, i)
972
+ else:
973
+ row[col_name] = self.generic_sourcer.generate_value(semantic_strategy, expected_type=column.data_type)
974
+ elif semantic_strategy:
975
+ row[col_name] = self.generic_sourcer.generate_value(semantic_strategy, expected_type=column.data_type)
976
+ # Use explicit GENERIC strategy whenever possible.
977
+ elif classification == 'GENERIC' and strategy_is_executable:
978
+ row[col_name] = self.generic_sourcer.generate_value(strategy, expected_type=column.data_type)
979
  else:
980
+ # Fall back to semantic strategy first, then inferred strategy.
981
+ inferred = self.generic_sourcer.get_strategy_for_semantic(semantic_type, domain_use_case)
982
+ if not inferred:
983
+ inferred = self._infer_strategy(column)
984
  if i == 0:
985
  print(f" [GEN] {col_name}: using inferred='{inferred}'")
986
+ row[col_name] = self.generic_sourcer.generate_value(inferred, expected_type=column.data_type)
987
+
988
+ # If this table looks like a date dimension, force coherent calendar fields.
989
+ if date_dim_profile:
990
+ self._apply_date_dimension_rules(row, i, date_dim_profile)
991
+
992
+ # Apply row-level realism rules after base generation.
993
+ self._apply_row_business_rules(table, row)
994
 
995
  rows.append(row)
996
+
997
+ self._apply_storyspec_time_series(table, rows, table_semantics)
998
  return rows
999
 
1000
+ def _get_search_real_values(
1001
+ self,
1002
+ column_name: str,
1003
+ strategy: str,
1004
+ num_needed: int,
1005
+ semantic_type: SemanticType = SemanticType.UNKNOWN,
1006
+ ) -> list[str]:
1007
  cache_key = f"search:{column_name}:{strategy}"
1008
 
1009
  if self.cache_enabled and self.cache:
 
1028
  use_case=self.use_case
1029
  )
1030
 
1031
+ values = self._sanitize_generated_values(values, column_name, semantic_type, num_needed)
1032
+
1033
  if self.cache_enabled and self.cache and values:
1034
  self.cache.set_generated_values(cache_key, values)
1035
 
1036
  return values
1037
 
1038
+ def _get_ai_gen_values(
1039
+ self,
1040
+ column_name: str,
1041
+ strategy: str,
1042
+ num_needed: int,
1043
+ semantic_type: SemanticType = SemanticType.UNKNOWN,
1044
+ ) -> list[str]:
1045
  cache_key = f"aigen:{column_name}:{strategy}:{self.context.company_name}"
1046
 
1047
  if self.cache_enabled and self.cache:
 
1065
  if not values:
1066
  break
1067
 
1068
+ all_values = self._sanitize_generated_values(all_values, column_name, semantic_type, num_needed)
1069
+
1070
  if self.cache_enabled and self.cache and all_values:
1071
  self.cache.set_generated_values(cache_key, all_values)
1072
 
1073
  return all_values
1074
+
1075
+ def _sanitize_generated_values(
1076
+ self,
1077
+ values: list[Any],
1078
+ column_name: str,
1079
+ semantic_type: SemanticType,
1080
+ num_needed: int,
1081
+ ) -> list[Any]:
1082
+ """Normalize generated values to avoid obvious categorical junk."""
1083
+ if not values:
1084
+ fallback = self._get_story_domain_values(semantic_type)
1085
+ if fallback:
1086
+ return [fallback[i % len(fallback)] for i in range(num_needed)]
1087
+ if semantic_type == SemanticType.PERSON_NAME:
1088
+ return [self.generic_sourcer.generate_value("name") for _ in range(num_needed)]
1089
+ if semantic_type in {SemanticType.ORG_NAME, SemanticType.BRANCH_NAME}:
1090
+ return [self.generic_sourcer.generate_value("company") for _ in range(num_needed)]
1091
+ return values
1092
+
1093
+ domain_values = self._get_story_domain_values(semantic_type)
1094
+ strict_domain_semantics = {
1095
+ SemanticType.CATEGORY,
1096
+ SemanticType.SECTOR_NAME,
1097
+ SemanticType.SECTOR_CATEGORY,
1098
+ SemanticType.SEGMENT,
1099
+ SemanticType.CHANNEL,
1100
+ SemanticType.FUND_STRATEGY,
1101
+ SemanticType.INVESTOR_TYPE,
1102
+ SemanticType.INVESTMENT_STAGE,
1103
+ SemanticType.COVENANT_STATUS,
1104
+ SemanticType.DEBT_PERFORMANCE_STATUS,
1105
+ SemanticType.PRODUCT_NAME,
1106
+ SemanticType.BRAND_NAME,
1107
+ SemanticType.DEPARTMENT_NAME,
1108
+ SemanticType.BRANCH_NAME,
1109
+ SemanticType.ORG_NAME,
1110
+ }
1111
+
1112
+ def _looks_like_person_name(value: str) -> bool:
1113
+ parts = [p for p in value.replace(".", " ").split() if p]
1114
+ if len(parts) < 2:
1115
+ return False
1116
+ return all(re.fullmatch(r"[A-Za-z][A-Za-z'\-]*", p) for p in parts[:2])
1117
+
1118
+ if domain_values:
1119
+ allowed = {v.lower() for v in domain_values}
1120
+ cleaned = []
1121
+ for idx, value in enumerate(values):
1122
+ sval = str(value).strip()
1123
+ upper = sval.upper()
1124
+ is_number_suffixed = any(ch.isdigit() for ch in sval) and sval.lower() not in allowed
1125
+ looks_placeholder = (
1126
+ not sval
1127
+ or sval.lower().endswith("_value")
1128
+ or sval.lower() == f"{column_name.lower()}_value"
1129
+ or upper in {"UNKNOWN", "N/A", "NA", "NULL"}
1130
+ or bool(self._GARBAGE_TOKEN_RE.match(upper))
1131
+ )
1132
+ needs_domain_value = semantic_type in strict_domain_semantics and sval.lower() not in allowed
1133
+ if semantic_type == SemanticType.PERSON_NAME and not _looks_like_person_name(sval):
1134
+ cleaned.append(self.generic_sourcer.generate_value("name"))
1135
+ elif looks_placeholder or (is_business_categorical(semantic_type) and is_number_suffixed) or needs_domain_value:
1136
+ cleaned.append(domain_values[idx % len(domain_values)])
1137
+ else:
1138
+ cleaned.append(sval)
1139
+ values = cleaned
1140
+ elif semantic_type == SemanticType.PERSON_NAME:
1141
+ values = [
1142
+ (str(v).strip() if _looks_like_person_name(str(v).strip()) else self.generic_sourcer.generate_value("name"))
1143
+ for v in values
1144
+ ]
1145
+
1146
+ if len(values) < num_needed:
1147
+ # Repeat without appending numeric suffixes (prevents "Central 7" style artifacts).
1148
+ values = [values[i % len(values)] for i in range(num_needed)]
1149
+ return values[:num_needed]
1150
+
1151
+ def _get_story_domain_values(self, semantic_type: SemanticType) -> list[str]:
1152
+ """Get domain values with StorySpec overrides when available."""
1153
+ base_values = list(get_domain_values(semantic_type, self._domain_use_case()))
1154
+ if not self.story_spec:
1155
+ return base_values
1156
+
1157
+ override_key_map = {
1158
+ SemanticType.CATEGORY: "category",
1159
+ SemanticType.SEGMENT: "segment",
1160
+ SemanticType.CHANNEL: "channel",
1161
+ SemanticType.ORG_NAME: "org_name",
1162
+ SemanticType.PRODUCT_NAME: "product_name",
1163
+ SemanticType.BRAND_NAME: "brand_name",
1164
+ SemanticType.COUNTRY: "country",
1165
+ SemanticType.REGION: "region",
1166
+ SemanticType.CITY: "city",
1167
+ SemanticType.BRANCH_NAME: "org_name",
1168
+ SemanticType.DEPARTMENT_NAME: "category",
1169
+ }
1170
+ override_key = override_key_map.get(semantic_type)
1171
+ if not override_key:
1172
+ return base_values
1173
+
1174
+ override_values = list((self.story_spec.domain_overrides or {}).get(override_key, []))
1175
+ if not override_values:
1176
+ return base_values
1177
+
1178
+ # Keep story overrides inside the curated domain when one exists.
1179
+ # This prevents use-case drift like legal demos inheriting noisy geo/org
1180
+ # values from broad research context.
1181
+ if base_values:
1182
+ base_set = {str(v).strip().lower() for v in base_values if str(v).strip()}
1183
+ filtered_override_values = [
1184
+ str(v).strip()
1185
+ for v in override_values
1186
+ if str(v).strip() and str(v).strip().lower() in base_set
1187
+ ]
1188
+ if filtered_override_values:
1189
+ override_values = filtered_override_values
1190
+ else:
1191
+ override_values = []
1192
+
1193
+ # Keep deterministic order and avoid duplicates.
1194
+ seen = set()
1195
+ merged = []
1196
+ for value in override_values + base_values:
1197
+ v = str(value).strip()
1198
+ if not v:
1199
+ continue
1200
+ if v.lower() in seen:
1201
+ continue
1202
+ seen.add(v.lower())
1203
+ merged.append(v)
1204
+ return merged
1205
+
1206
+ def _choose_story_value(self, values: list[str], row_index: int) -> str:
1207
+ if not values:
1208
+ return ""
1209
+ if len(values) == 1:
1210
+ return values[0]
1211
+ idx = (row_index + self._story_rng.randint(0, max(len(values) - 1, 0))) % len(values)
1212
+ return values[idx]
1213
+
1214
+ def _apply_storyspec_time_series(
1215
+ self,
1216
+ table: Table,
1217
+ rows: list[dict[str, Any]],
1218
+ table_semantics: dict[str, SemanticType],
1219
+ ) -> None:
1220
+ """Apply smooth, deterministic trend profile plus bounded outliers."""
1221
+ if not rows or not table.is_fact_table or not self.story_spec:
1222
+ return
1223
+
1224
+ date_cols = [
1225
+ c for c, s in table_semantics.items()
1226
+ if s in {SemanticType.DATE_EVENT, SemanticType.DATE_START, SemanticType.DATE_END}
1227
+ ]
1228
+ if not date_cols:
1229
+ return
1230
+
1231
+ date_col = None
1232
+ for candidate in date_cols:
1233
+ if any(self._to_python_date(r.get(candidate)) for r in rows):
1234
+ date_col = candidate
1235
+ break
1236
+ if not date_col:
1237
+ return
1238
+
1239
+ measure_cols = []
1240
+ for col_name, semantic in table_semantics.items():
1241
+ if semantic in {SemanticType.MONEY, SemanticType.QUANTITY, SemanticType.COUNT}:
1242
+ measure_cols.append(col_name)
1243
+
1244
+ if not measure_cols:
1245
+ for col_name in rows[0].keys():
1246
+ low = col_name.lower()
1247
+ if any(tok in low for tok in ("revenue", "sales", "amount", "quantity", "units", "count", "trx")):
1248
+ measure_cols.append(col_name)
1249
+ measure_cols = list(dict.fromkeys(measure_cols))
1250
+ if not measure_cols:
1251
+ return
1252
+
1253
+ indexed_dates = []
1254
+ for idx, row in enumerate(rows):
1255
+ dt = self._to_python_date(row.get(date_col))
1256
+ if dt:
1257
+ indexed_dates.append((idx, dt))
1258
+ if len(indexed_dates) < 4:
1259
+ return
1260
+
1261
+ indexed_dates.sort(key=lambda x: x[1])
1262
+ n = len(indexed_dates)
1263
+
1264
+ controls = self.story_spec
1265
+ noise_band = max(0.0, float(controls.trend_profile.noise_band_pct))
1266
+ seasonal_strength = max(0.0, float(controls.trend_profile.seasonal_strength))
1267
+
1268
+ if controls.trend_profile.style == "smooth":
1269
+ slope = 0.18
1270
+ elif controls.trend_profile.style == "flat":
1271
+ slope = 0.06
1272
+ else:
1273
+ slope = 0.24
1274
+
1275
+ signals: dict[int, float] = {}
1276
+ for rank, (idx, dt) in enumerate(indexed_dates):
1277
+ progress = rank / max(n - 1, 1)
1278
+ trend_factor = 1.0 + slope * progress
1279
+ seasonal_factor = 1.0 + seasonal_strength * math.sin((2.0 * math.pi * (dt.month - 1)) / 12.0)
1280
+ noise_factor = 1.0 + self._story_rng.uniform(-noise_band, noise_band)
1281
+ signal = max(0.2, trend_factor * seasonal_factor * noise_factor)
1282
+ signals[idx] = signal
1283
+
1284
+ # Controlled outlier injection (sparse and explainable).
1285
+ budget = controls.outlier_budget
1286
+ max_points = max(0, int(round(n * float(budget.max_points_pct))))
1287
+ outlier_events = min(max_points, int(budget.max_events))
1288
+
1289
+ if outlier_events > 0:
1290
+ candidate_pos = list(range(max(2, int(n * 0.1)), max(3, int(n * 0.9))))
1291
+ self._story_rng.shuffle(candidate_pos)
1292
+ for pos in candidate_pos[:outlier_events]:
1293
+ idx, dt = indexed_dates[pos]
1294
+ multiplier = self._story_rng.uniform(float(budget.event_multiplier_min), float(budget.event_multiplier_max))
1295
+ signals[idx] = max(0.2, signals[idx] * multiplier)
1296
+ self.story_events.append({
1297
+ "table": table.name,
1298
+ "date": dt.isoformat(),
1299
+ "row_index": idx,
1300
+ "kind": "intentional_outlier",
1301
+ "multiplier": round(multiplier, 4),
1302
+ })
1303
+
1304
+ for col_name in measure_cols:
1305
+ numeric_points = []
1306
+ for idx, _ in indexed_dates:
1307
+ value = rows[idx].get(col_name)
1308
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
1309
+ numeric_points.append(float(value))
1310
+ if not numeric_points:
1311
+ continue
1312
+
1313
+ numeric_points_sorted = sorted(numeric_points)
1314
+ baseline = numeric_points_sorted[len(numeric_points_sorted) // 2]
1315
+
1316
+ for idx, _ in indexed_dates:
1317
+ raw = rows[idx].get(col_name)
1318
+ if not isinstance(raw, (int, float)) or isinstance(raw, bool):
1319
+ continue
1320
+
1321
+ low_name = col_name.lower()
1322
+ signal = signals.get(idx, 1.0)
1323
+
1324
+ if any(tok in low_name for tok in ("quantity", "units", "count", "trx", "orders")):
1325
+ signal = 1.0 + (signal - 1.0) * 0.85
1326
+
1327
+ blended = (0.65 * float(raw)) + (0.35 * baseline)
1328
+ adjusted = blended * signal
1329
+
1330
+ if not self.story_spec.value_guardrails.get("allow_negative_revenue", False):
1331
+ if any(tok in low_name for tok in ("revenue", "sales", "amount", "gross", "net")):
1332
+ adjusted = max(0.0, adjusted)
1333
+
1334
+ rows[idx][col_name] = int(round(adjusted)) if isinstance(raw, int) else round(adjusted, 2)
1335
+
1336
+ def _build_date_dimension_profile(
1337
+ self,
1338
+ table: Table,
1339
+ table_semantics: dict[str, SemanticType],
1340
+ num_rows: int,
1341
+ ) -> dict[str, Any] | None:
1342
+ """
1343
+ Detect date-dimension tables and prepare deterministic calendar generation.
1344
+ """
1345
+ if num_rows <= 0:
1346
+ return None
1347
+
1348
+ table_name_upper = table.name.upper()
1349
+ has_full_date_col = any("FULL_DATE" == c.name.upper() or c.name.upper().endswith("_DATE") for c in table.columns)
1350
+ has_calendar_attrs = any(
1351
+ any(tok in c.name.upper() for tok in ("WEEKEND", "HOLIDAY", "SEASON", "EVENT"))
1352
+ for c in table.columns
1353
+ )
1354
+ is_date_dimension = ("DATE" in table_name_upper and has_full_date_col) or (has_full_date_col and has_calendar_attrs)
1355
+ if not is_date_dimension:
1356
+ return None
1357
+
1358
+ today = datetime.now().date()
1359
+ start_date = today - timedelta(days=max(num_rows - 1, 0))
1360
+ return {
1361
+ "start_date": start_date,
1362
+ "date_key_columns": [c.name for c in table.columns if c.name.upper().endswith("DATE_KEY")],
1363
+ "full_date_columns": [c.name for c in table.columns if "DATE" in c.name.upper()],
1364
+ "month_name_columns": [c.name for c in table.columns if "MONTH" in c.name.upper() and "NAME" in c.name.upper()],
1365
+ "quarter_name_columns": [c.name for c in table.columns if "QUARTER" in c.name.upper() and "NAME" in c.name.upper()],
1366
+ "year_columns": [
1367
+ c.name
1368
+ for c in table.columns
1369
+ if "YEAR" in c.name.upper() and "NAME" not in c.name.upper()
1370
+ ],
1371
+ "is_weekend_columns": [col_name for col_name in table_semantics.keys() if "WEEKEND" in col_name.upper()],
1372
+ "is_holiday_columns": [c.name for c in table.columns if "HOLIDAY" in c.name.upper() and c.name.upper().startswith("IS_")],
1373
+ "holiday_name_columns": [col_name for col_name, s in table_semantics.items() if s == SemanticType.HOLIDAY_NAME],
1374
+ "season_columns": [col_name for col_name, s in table_semantics.items() if s == SemanticType.SEASON],
1375
+ "event_columns": [col_name for col_name, s in table_semantics.items() if s == SemanticType.EVENT_NAME],
1376
+ }
1377
+
1378
+ def _apply_date_dimension_rules(self, row: dict[str, Any], row_index: int, profile: dict[str, Any]) -> None:
1379
+ """
1380
+ Force coherent calendar rows (full date, weekend/holiday/season/event).
1381
+ """
1382
+ base_date = profile["start_date"] + timedelta(days=row_index)
1383
+ month = base_date.month
1384
+ weekday = base_date.weekday() # 0=Mon, 6=Sun
1385
+ is_weekend = weekday >= 5
1386
+
1387
+ # Lightweight US-holiday approximation.
1388
+ fixed_holidays = {
1389
+ (1, 1): "New Year's Day",
1390
+ (7, 4): "Independence Day",
1391
+ (12, 25): "Christmas Day",
1392
+ }
1393
+ holiday_name = fixed_holidays.get((month, base_date.day), "None")
1394
+ is_holiday = holiday_name != "None"
1395
+
1396
+ def season_for_month(m: int) -> str:
1397
+ if m in (12, 1, 2):
1398
+ return "Winter"
1399
+ if m in (3, 4, 5):
1400
+ return "Spring"
1401
+ if m in (6, 7, 8):
1402
+ return "Summer"
1403
+ return "Fall"
1404
+
1405
+ season_name = season_for_month(month)
1406
+ if month == 11 and base_date.day >= 20:
1407
+ event_name = "Black Friday Campaign"
1408
+ elif month == 12:
1409
+ event_name = "Holiday Promotions"
1410
+ elif month in (8, 9):
1411
+ event_name = "Back to School"
1412
+ else:
1413
+ event_name = "Regular Trading Day"
1414
+
1415
+ for col in profile.get("full_date_columns", []):
1416
+ col_upper = col.upper()
1417
+ if col_upper.endswith("DATE_KEY"):
1418
+ row[col] = int(base_date.strftime("%Y%m%d"))
1419
+ elif col_upper.endswith("_DATE") or col_upper == "FULL_DATE":
1420
+ row[col] = base_date
1421
+ for col in profile.get("date_key_columns", []):
1422
+ row[col] = int(base_date.strftime("%Y%m%d"))
1423
+ for col in profile.get("month_name_columns", []):
1424
+ col_upper = col.upper()
1425
+ if "ABBR" in col_upper or "SHORT" in col_upper:
1426
+ row[col] = base_date.strftime("%b")
1427
+ else:
1428
+ row[col] = base_date.strftime("%B")
1429
+ for col in profile.get("quarter_name_columns", []):
1430
+ quarter = ((month - 1) // 3) + 1
1431
+ row[col] = f"Q{quarter}"
1432
+ for col in profile.get("year_columns", []):
1433
+ row[col] = base_date.year
1434
+ for col in profile.get("is_weekend_columns", []):
1435
+ row[col] = is_weekend
1436
+ for col in profile.get("is_holiday_columns", []):
1437
+ row[col] = is_holiday
1438
+ for col in profile.get("holiday_name_columns", []):
1439
+ row[col] = holiday_name
1440
+ for col in profile.get("season_columns", []):
1441
+ row[col] = season_name
1442
+ for col in profile.get("event_columns", []):
1443
+ row[col] = event_name
1444
+
1445
+ def _is_executable_generic_strategy(self, strategy: str) -> bool:
1446
+ """Return True if strategy can be executed by GenericSourcer directly."""
1447
+ if not strategy or not isinstance(strategy, str):
1448
+ return False
1449
+ strategy_type = strategy.split(":", 1)[0].strip().lower()
1450
+ return strategy_type in self.EXECUTABLE_GENERIC_STRATEGIES
1451
+
1452
+ def _to_python_date(self, value: Any) -> date | None:
1453
+ """Best-effort conversion of generated date values to date objects."""
1454
+ if isinstance(value, datetime):
1455
+ return value.date()
1456
+ if isinstance(value, date):
1457
+ return value
1458
+ if isinstance(value, str):
1459
+ try:
1460
+ return datetime.strptime(value[:10], "%Y-%m-%d").date()
1461
+ except ValueError:
1462
+ return None
1463
+ return None
1464
+
1465
+ @staticmethod
1466
+ def _years_ago(base_date: date, years: int) -> date:
1467
+ """Return a date shifted back by full years (leap-safe)."""
1468
+ try:
1469
+ return base_date.replace(year=base_date.year - years)
1470
+ except ValueError:
1471
+ return base_date.replace(month=2, day=28, year=base_date.year - years)
1472
+
1473
+ def _apply_story_seasonality_scale(self, base_multiplier: float) -> float:
1474
+ if not self.story_spec:
1475
+ return base_multiplier
1476
+ strength = max(0.0, min(0.35, float(self.story_spec.trend_profile.seasonal_strength)))
1477
+ # 0.10 is the legacy baseline amplitude used by earlier heuristics.
1478
+ scale = strength / 0.10 if 0.10 > 0 else 1.0
1479
+ return 1.0 + ((base_multiplier - 1.0) * scale)
1480
+
1481
+ def _resolve_seasonality_multiplier(self, event_date: date) -> float:
1482
+ """Use-case aware seasonality to keep trend lines jumpy but believable."""
1483
+ month = event_date.month
1484
+ use_case_lower = (self.use_case or "").lower()
1485
+
1486
+ if any(k in use_case_lower for k in ("retail", "sales", "ecommerce", "merch")):
1487
+ by_month = {
1488
+ 1: 0.86, 2: 0.90, 3: 0.95, 4: 0.98, 5: 1.00, 6: 1.03,
1489
+ 7: 1.05, 8: 1.07, 9: 1.02, 10: 1.08, 11: 1.28, 12: 1.40,
1490
+ }
1491
+ return self._apply_story_seasonality_scale(by_month.get(month, 1.0))
1492
+
1493
+ if "marketing" in use_case_lower:
1494
+ by_month = {
1495
+ 1: 0.92, 2: 0.95, 3: 0.98, 4: 1.00, 5: 1.04, 6: 1.07,
1496
+ 7: 1.03, 8: 1.01, 9: 1.05, 10: 1.12, 11: 1.26, 12: 1.34,
1497
+ }
1498
+ return self._apply_story_seasonality_scale(by_month.get(month, 1.0))
1499
+
1500
+ if "supply" in use_case_lower or "inventory" in use_case_lower:
1501
+ by_month = {
1502
+ 1: 0.96, 2: 0.97, 3: 0.99, 4: 1.00, 5: 1.01, 6: 1.03,
1503
+ 7: 1.04, 8: 1.02, 9: 1.01, 10: 1.06, 11: 1.14, 12: 1.18,
1504
+ }
1505
+ return self._apply_story_seasonality_scale(by_month.get(month, 1.0))
1506
+
1507
+ # Conservative default with mild Q4 lift.
1508
+ return self._apply_story_seasonality_scale({10: 1.05, 11: 1.10, 12: 1.14}.get(month, 1.0))
1509
+
1510
+ def _apply_row_business_rules(self, table: Table, row: dict[str, Any]) -> None:
1511
+ """
1512
+ Apply lightweight realism constraints and correlations.
1513
+ Keeps data plausible without overfitting to one schema.
1514
+ """
1515
+ if not row:
1516
+ return
1517
+
1518
+ today = datetime.now().date()
1519
+
1520
+ # 1) Birthdate guardrails: enforce realistic adult ages (18-95).
1521
+ for key, value in list(row.items()):
1522
+ col_lower = key.lower()
1523
+ is_birth_col = (
1524
+ ("birth" in col_lower and "date" in col_lower) or
1525
+ col_lower == "dob" or
1526
+ col_lower.endswith("_dob")
1527
+ )
1528
+ if not is_birth_col or value is None:
1529
+ continue
1530
+
1531
+ dob = self._to_python_date(value)
1532
+ if not dob:
1533
+ continue
1534
+
1535
+ min_dob = self._years_ago(today, 95)
1536
+ max_dob = self._years_ago(today, 18)
1537
+ fixed = min(max(dob, min_dob), max_dob)
1538
+ row[key] = fixed
1539
+
1540
+ # 2) Keep non-birth dates from drifting into the future.
1541
+ for key, value in list(row.items()):
1542
+ col_lower = key.lower()
1543
+ if "date" not in col_lower or "birth" in col_lower or col_lower == "dob" or col_lower.endswith("_dob"):
1544
+ continue
1545
+ dt_val = self._to_python_date(value)
1546
+ if dt_val and dt_val > today:
1547
+ row[key] = today
1548
+
1549
+ # 2b) Enforce simple start/end date ordering and LOS coherence.
1550
+ start_key = next((k for k in row if any(tok in k.lower() for tok in ("start_date", "admission_date", "created_date", "begin_date"))), None)
1551
+ end_key = next((k for k in row if any(tok in k.lower() for tok in ("end_date", "discharge_date", "closed_date", "resolved_date"))), None)
1552
+ los_key = next((k for k in row if any(tok in k.lower() for tok in ("length_of_stay", "los_days", "stay_days"))), None)
1553
+ if start_key and end_key:
1554
+ start_dt = self._to_python_date(row.get(start_key))
1555
+ end_dt = self._to_python_date(row.get(end_key))
1556
+ if start_dt and end_dt and end_dt < start_dt:
1557
+ if los_key and isinstance(row.get(los_key), (int, float)):
1558
+ days = max(0, int(round(float(row[los_key]))))
1559
+ row[end_key] = start_dt + timedelta(days=days)
1560
+ else:
1561
+ row[end_key] = start_dt
1562
+ if start_key and los_key and end_key:
1563
+ start_dt = self._to_python_date(row.get(start_key))
1564
+ if start_dt and isinstance(row.get(los_key), (int, float)):
1565
+ days = max(0, int(round(float(row[los_key]))))
1566
+ row[end_key] = start_dt + timedelta(days=days)
1567
+
1568
+ # 3) Seasonal adjustments for fact-table measures.
1569
+ event_date = None
1570
+ for key, value in row.items():
1571
+ col_lower = key.lower()
1572
+ if "date" in col_lower and "birth" not in col_lower and col_lower != "dob" and not col_lower.endswith("_dob"):
1573
+ event_date = self._to_python_date(value)
1574
+ if event_date:
1575
+ break
1576
+
1577
+ if table.is_fact_table and event_date:
1578
+ seasonal = self._resolve_seasonality_multiplier(event_date) * self._story_rng.uniform(0.92, 1.08)
1579
+ for key, value in list(row.items()):
1580
+ if isinstance(value, bool) or not isinstance(value, (int, float)):
1581
+ continue
1582
+ col_lower = key.lower()
1583
+ if any(tok in col_lower for tok in ("revenue", "sales", "amount", "total", "gross", "net")):
1584
+ adjusted = max(0.0, float(value) * seasonal)
1585
+ row[key] = int(round(adjusted)) if isinstance(value, int) else round(adjusted, 2)
1586
+ elif any(tok in col_lower for tok in ("quantity", "qty", "units", "count")):
1587
+ adjusted = max(1.0, float(value) * seasonal * self._story_rng.uniform(0.95, 1.05))
1588
+ row[key] = int(round(adjusted))
1589
+
1590
+ # 3b) Normalize state/region coherence within row when both exist.
1591
+ state_key = next((k for k in row if "state" in k.lower() or "province" in k.lower()), None)
1592
+ region_key = next((k for k in row if "region" in k.lower() or "territory" in k.lower()), None)
1593
+ if state_key and region_key and row.get(state_key):
1594
+ mapped = map_state_to_region(str(row.get(state_key)))
1595
+ if mapped:
1596
+ row[region_key] = mapped
1597
+
1598
+ # 4) Enforce simple arithmetic consistency when possible.
1599
+ quantity_key = next((k for k in row if any(tok in k.lower() for tok in ("quantity", "qty", "units"))), None)
1600
+ price_key = next((k for k in row if "unit_price" in k.lower()), None)
1601
+ if price_key is None:
1602
+ price_key = next((k for k in row if k.lower().endswith("price")), None)
1603
+ revenue_key = next(
1604
+ (
1605
+ k for k in row
1606
+ if any(tok in k.lower() for tok in ("total_revenue", "sales_amount", "gross_revenue", "net_revenue", "total_amount", "revenue"))
1607
+ ),
1608
+ None
1609
+ )
1610
+ discount_key = next((k for k in row if any(tok in k.lower() for tok in ("discount_pct", "discount_rate"))), None)
1611
+
1612
+ if quantity_key and price_key and revenue_key:
1613
+ qty = row.get(quantity_key)
1614
+ price = row.get(price_key)
1615
+ revenue_val = row.get(revenue_key)
1616
+ if isinstance(qty, (int, float)) and isinstance(price, (int, float)) and isinstance(revenue_val, (int, float)):
1617
+ discount = 0.0
1618
+ if discount_key and isinstance(row.get(discount_key), (int, float)):
1619
+ raw_discount = float(row[discount_key])
1620
+ discount = raw_discount if raw_discount <= 1 else raw_discount / 100.0
1621
+ discount = min(max(discount, 0.0), 0.8)
1622
+
1623
+ computed = max(0.0, float(qty) * float(price) * (1.0 - discount))
1624
+ computed *= self._story_rng.uniform(0.97, 1.03) # mild variance
1625
+ row[revenue_key] = int(round(computed)) if isinstance(revenue_val, int) else round(computed, 2)
1626
+
1627
+ # 5) Interest/rate sanity.
1628
+ for key, value in list(row.items()):
1629
+ col_lower = key.lower()
1630
+ if not isinstance(value, (int, float)) or isinstance(value, bool):
1631
+ continue
1632
+ if "interest_rate" in col_lower or ("interest" in col_lower and "rate" in col_lower):
1633
+ row[key] = round(min(max(float(value), 0.0), 1.0), 4)
1634
+ elif "percent" in col_lower or col_lower.endswith("_pct"):
1635
+ pct = float(value)
1636
+ if 0 <= pct <= 1:
1637
+ pct = pct * 100.0
1638
+ row[key] = round(min(max(pct, 0.0), 100.0), 2)
1639
 
1640
  def _infer_strategy(self, column: Column) -> str:
1641
  col_lower = column.name.lower()
1642
  data_type = column.data_type.upper() if column.data_type else ''
1643
+ domain_use_case = self._domain_use_case()
1644
+ semantic = infer_semantic_type(column.name, column.data_type, use_case=domain_use_case)
1645
+
1646
+ semantic_strategy = self.generic_sourcer.get_strategy_for_semantic(semantic, domain_use_case)
1647
+ if semantic_strategy:
1648
+ return semantic_strategy
1649
 
1650
  # Check if data type is numeric
1651
  is_numeric = any(t in data_type for t in ('INT', 'NUMBER', 'NUMERIC', 'DECIMAL', 'BIGINT', 'SMALLINT'))
 
1656
 
1657
  # Date/timestamp columns - use current date for end range
1658
  # Calculate dynamic date ranges relative to today
 
1659
  today_dt = datetime.now()
1660
  today = today_dt.strftime('%Y-%m-%d')
1661
  three_years_ago = (today_dt - timedelta(days=3*365)).strftime('%Y-%m-%d')
1662
  two_years_ago = (today_dt - timedelta(days=2*365)).strftime('%Y-%m-%d')
1663
  five_years_ago = (today_dt - timedelta(days=5*365)).strftime('%Y-%m-%d')
1664
 
1665
+ # Birth dates must be handled before generic date matching.
1666
+ is_birth_col = (
1667
+ ('birth' in col_lower and 'date' in col_lower) or
1668
+ col_lower == 'dob' or
1669
+ col_lower.endswith('_dob')
1670
+ )
1671
+ if is_birth_col:
1672
+ oldest = self._years_ago(today_dt.date(), 95).strftime('%Y-%m-%d')
1673
+ youngest = self._years_ago(today_dt.date(), 18).strftime('%Y-%m-%d')
1674
+ return f"date_between:{oldest},{youngest}"
1675
+
1676
  if 'date' in col_lower or 'DATE' in data_type or 'TIMESTAMP' in data_type:
1677
  if 'created' in col_lower:
1678
  return f"date_between:{three_years_ago},{today}"
legitdata_project/legitdata/quality/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quality validation and repair modules."""
2
+
3
+ from .quality_spec import (
4
+ QualityIssue,
5
+ QualityReport,
6
+ QualitySpec,
7
+ QualityThresholds,
8
+ default_quality_spec,
9
+ )
10
+ from .validator import DataQualityValidator
11
+ from .repair import DataQualityRepairEngine
12
+
13
+ __all__ = [
14
+ "QualityIssue",
15
+ "QualityReport",
16
+ "QualitySpec",
17
+ "QualityThresholds",
18
+ "default_quality_spec",
19
+ "DataQualityValidator",
20
+ "DataQualityRepairEngine",
21
+ ]
22
+
legitdata_project/legitdata/quality/quality_spec.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quality thresholds and report models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime, timezone
7
+ from typing import Any
8
+
9
+
10
+ @dataclass
11
+ class QualityThresholds:
12
+ max_categorical_junk: int = 0
13
+ max_fk_orphans: int = 0
14
+ max_temporal_violations: int = 0
15
+ max_numeric_violations: int = 0
16
+ min_semantic_pass_ratio: float = 0.95
17
+
18
+ # Story-quality thresholds
19
+ max_volatility_breaches: int = 0
20
+ min_smoothness_score: float = 0.70
21
+ min_outlier_explainability: float = 0.70
22
+ min_kpi_consistency: float = 0.80
23
+
24
+
25
+ @dataclass
26
+ class QualitySpec:
27
+ thresholds: QualityThresholds = field(default_factory=QualityThresholds)
28
+ max_repair_attempts: int = 2
29
+ sample_limit_per_issue: int = 12
30
+
31
+
32
+ @dataclass
33
+ class QualityIssue:
34
+ issue_type: str
35
+ table: str
36
+ column: str
37
+ severity: str
38
+ count: int
39
+ message: str
40
+ samples: list[Any] = field(default_factory=list)
41
+
42
+
43
+ @dataclass
44
+ class QualityReport:
45
+ passed: bool
46
+ generated_at: str
47
+ summary: dict[str, Any]
48
+ issues: list[QualityIssue] = field(default_factory=list)
49
+ repair_actions: list[str] = field(default_factory=list)
50
+
51
+ def to_dict(self) -> dict[str, Any]:
52
+ return {
53
+ "passed": self.passed,
54
+ "generated_at": self.generated_at,
55
+ "summary": self.summary,
56
+ "issues": [
57
+ {
58
+ "issue_type": issue.issue_type,
59
+ "table": issue.table,
60
+ "column": issue.column,
61
+ "severity": issue.severity,
62
+ "count": issue.count,
63
+ "message": issue.message,
64
+ "samples": issue.samples,
65
+ }
66
+ for issue in self.issues
67
+ ],
68
+ "repair_actions": self.repair_actions,
69
+ }
70
+
71
+ def to_markdown(self) -> str:
72
+ lines = [
73
+ "# Data Quality Report",
74
+ "",
75
+ f"- Generated: {self.generated_at}",
76
+ f"- Passed: {'YES' if self.passed else 'NO'}",
77
+ f"- Semantic pass ratio: {self.summary.get('semantic_pass_ratio', 0):.3f}",
78
+ f"- Categorical junk: {self.summary.get('categorical_junk_count', 0)}",
79
+ f"- FK orphans: {self.summary.get('fk_orphan_count', 0)}",
80
+ f"- Temporal violations: {self.summary.get('temporal_violations', 0)}",
81
+ f"- Numeric violations: {self.summary.get('numeric_violations', 0)}",
82
+ f"- Volatility breaches: {self.summary.get('volatility_breaches', 0)}",
83
+ f"- Smoothness score: {self.summary.get('smoothness_score', 0):.3f}",
84
+ f"- Outlier explainability: {self.summary.get('outlier_explainability', 0):.3f}",
85
+ f"- KPI consistency: {self.summary.get('kpi_consistency', 0):.3f}",
86
+ "",
87
+ ]
88
+ if self.repair_actions:
89
+ lines.append("## Repair Actions")
90
+ lines.extend([f"- {a}" for a in self.repair_actions])
91
+ lines.append("")
92
+ if self.issues:
93
+ lines.append("## Issues")
94
+ for issue in self.issues:
95
+ lines.append(
96
+ f"- [{issue.severity}] {issue.issue_type} {issue.table}.{issue.column} "
97
+ f"count={issue.count} - {issue.message}"
98
+ )
99
+ lines.append("")
100
+ return "\n".join(lines)
101
+
102
+
103
+
104
+ def default_quality_spec(story_controls: dict[str, Any] | None = None) -> QualitySpec:
105
+ spec = QualitySpec()
106
+ controls = story_controls or {}
107
+
108
+ quality_targets = controls.get("quality_targets", {}) if isinstance(controls, dict) else {}
109
+ min_sem = quality_targets.get("min_semantic_pass_ratio")
110
+ if min_sem is not None:
111
+ try:
112
+ spec.thresholds.min_semantic_pass_ratio = float(min_sem)
113
+ except (TypeError, ValueError):
114
+ pass
115
+
116
+ guardrails = controls.get("value_guardrails", {}) if isinstance(controls, dict) else {}
117
+ max_mom = guardrails.get("max_mom_change_pct")
118
+ if max_mom is not None:
119
+ # Allow a small number of breaches for sparse datasets.
120
+ try:
121
+ max_mom_f = float(max_mom)
122
+ if max_mom_f <= 20:
123
+ spec.thresholds.max_volatility_breaches = 0
124
+ elif max_mom_f <= 35:
125
+ spec.thresholds.max_volatility_breaches = 2
126
+ else:
127
+ spec.thresholds.max_volatility_breaches = 4
128
+ except (TypeError, ValueError):
129
+ pass
130
+
131
+ outlier_budget = controls.get("outlier_budget", {}) if isinstance(controls, dict) else {}
132
+ max_events = outlier_budget.get("max_events")
133
+ if max_events is not None:
134
+ try:
135
+ max_events_i = int(max_events)
136
+ if max_events_i <= 1:
137
+ spec.thresholds.min_outlier_explainability = 0.85
138
+ elif max_events_i <= 3:
139
+ spec.thresholds.min_outlier_explainability = 0.75
140
+ else:
141
+ spec.thresholds.min_outlier_explainability = 0.65
142
+ except (TypeError, ValueError):
143
+ pass
144
+
145
+ return spec
146
+
147
+
148
+ def now_utc_iso() -> str:
149
+ return datetime.now(timezone.utc).isoformat()
legitdata_project/legitdata/quality/repair.py ADDED
@@ -0,0 +1,612 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Targeted repair passes for generated data."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from datetime import date, datetime, timedelta
7
+ from typing import Any
8
+
9
+ from ..ddl.models import Schema
10
+ from ..domain import SemanticType, get_domain_values, infer_semantic_type, map_state_to_region
11
+
12
+
13
+ def _to_datetime(value: Any) -> datetime | None:
14
+ if isinstance(value, datetime):
15
+ return value
16
+ if isinstance(value, date):
17
+ # Normalize plain dates so temporal repair can handle generator date objects.
18
+ return datetime.combine(value, datetime.min.time())
19
+ if isinstance(value, str):
20
+ try:
21
+ return datetime.strptime(value[:19], "%Y-%m-%d %H:%M:%S")
22
+ except ValueError:
23
+ try:
24
+ return datetime.strptime(value[:10], "%Y-%m-%d")
25
+ except ValueError:
26
+ return None
27
+ return None
28
+
29
+
30
+ def _semantic_map(schema: Schema, use_case: str | None) -> dict[str, dict[str, SemanticType]]:
31
+ mapping: dict[str, dict[str, SemanticType]] = {}
32
+ for table in schema.tables:
33
+ table_map: dict[str, SemanticType] = {}
34
+ for col in table.columns:
35
+ table_map[col.name] = infer_semantic_type(col.name, col.data_type, table.name, use_case)
36
+ mapping[table.name] = table_map
37
+ return mapping
38
+
39
+
40
+ class DataQualityRepairEngine:
41
+ """Applies bounded, deterministic repair operations."""
42
+
43
+ def repair_dataset(
44
+ self,
45
+ schema: Schema,
46
+ generated_data: dict[str, list[dict]],
47
+ use_case: str | None = None,
48
+ story_controls: dict[str, Any] | None = None,
49
+ ) -> tuple[dict[str, list[dict]], list[str]]:
50
+ semantic_map = _semantic_map(schema, use_case)
51
+ actions: list[str] = []
52
+
53
+ actions.extend(self._repair_categorical(schema, generated_data, semantic_map, use_case))
54
+ actions.extend(self._repair_numeric(schema, generated_data, semantic_map))
55
+ actions.extend(self._repair_temporal(schema, generated_data, semantic_map))
56
+ actions.extend(self._repair_geography(schema, generated_data, semantic_map))
57
+ actions.extend(self._repair_fk_orphans(schema, generated_data))
58
+ actions.extend(self._repair_series_volatility(schema, generated_data, semantic_map, story_controls or {}))
59
+ actions.extend(self._repair_kpi_consistency(schema, generated_data, semantic_map))
60
+ actions.extend(self._repair_private_equity_financials(schema, generated_data))
61
+ actions.extend(self._repair_saas_finance_contract(generated_data))
62
+
63
+ return generated_data, actions
64
+
65
+ def _repair_saas_finance_contract(
66
+ self,
67
+ generated_data: dict[str, list[dict]],
68
+ ) -> list[str]:
69
+ actions: list[str] = []
70
+ fact_rows = generated_data.get("SAAS_CUSTOMER_MONTHLY", [])
71
+ spend_rows = generated_data.get("SALES_MARKETING_SPEND_MONTHLY", [])
72
+ if not fact_rows and not spend_rows:
73
+ return actions
74
+
75
+ fact_changes = 0
76
+ for row in fact_rows:
77
+ starting_arr = float(row.get("STARTING_ARR_USD") or 0.0)
78
+ new_logo_arr = float(row.get("NEW_LOGO_ARR_USD") or 0.0)
79
+ expansion_arr = float(row.get("EXPANSION_ARR_USD") or 0.0)
80
+ contraction_arr = float(row.get("CONTRACTION_ARR_USD") or 0.0)
81
+ churned_arr = float(row.get("CHURNED_ARR_USD") or 0.0)
82
+ ending_arr = round(max(starting_arr + new_logo_arr + expansion_arr - contraction_arr - churned_arr, 0.0), 2)
83
+ mrr_usd = round(ending_arr / 12.0, 2)
84
+ gm_pct = round(max(45.0, min(92.0, float(row.get("GROSS_MARGIN_PCT") or 78.0))), 2)
85
+ gross_margin_usd = round(mrr_usd * (gm_pct / 100.0), 2)
86
+ support_cost_usd = round(max(mrr_usd - gross_margin_usd, mrr_usd * 0.06), 2)
87
+ acquired_flag = bool(row.get("CUSTOMER_ACQUIRED_FLAG")) and new_logo_arr > 0.0
88
+ churned_flag = bool(row.get("CUSTOMER_CHURNED_FLAG")) and ending_arr == 0.0 and churned_arr > 0.0
89
+
90
+ updates = {
91
+ "ENDING_ARR_USD": ending_arr,
92
+ "MRR_USD": mrr_usd,
93
+ "GROSS_MARGIN_PCT": gm_pct,
94
+ "GROSS_MARGIN_USD": gross_margin_usd,
95
+ "SUPPORT_COST_USD": support_cost_usd,
96
+ "CUSTOMER_ACQUIRED_FLAG": acquired_flag,
97
+ "CUSTOMER_CHURNED_FLAG": churned_flag,
98
+ }
99
+ for key, new_value in updates.items():
100
+ if row.get(key) != new_value:
101
+ row[key] = new_value
102
+ fact_changes += 1
103
+
104
+ spend_changes = 0
105
+ for row in spend_rows:
106
+ sales = round(float(row.get("SALES_SPEND_USD") or 0.0), 2)
107
+ marketing = round(float(row.get("MARKETING_SPEND_USD") or 0.0), 2)
108
+ total = round(sales + marketing, 2)
109
+ if row.get("TOTAL_S_AND_M_SPEND_USD") != total:
110
+ row["TOTAL_S_AND_M_SPEND_USD"] = total
111
+ spend_changes += 1
112
+ new_customers = int(row.get("NEW_CUSTOMERS_ACQUIRED") or 0)
113
+ if new_customers <= 0:
114
+ new_logo_arr = round(max(float(row.get("NEW_LOGO_ARR_USD") or 0.0), 0.0), 2)
115
+ if row.get("NEW_LOGO_ARR_USD") != new_logo_arr:
116
+ row["NEW_LOGO_ARR_USD"] = new_logo_arr
117
+ spend_changes += 1
118
+
119
+ if fact_changes:
120
+ actions.append(f"SAAS_CUSTOMER_MONTHLY: recalculated {fact_changes} SaaS finance metric values")
121
+ if spend_changes:
122
+ actions.append(f"SALES_MARKETING_SPEND_MONTHLY: repaired {spend_changes} spend identity values")
123
+ return actions
124
+
125
+ def _repair_categorical(
126
+ self,
127
+ schema: Schema,
128
+ generated_data: dict[str, list[dict]],
129
+ semantic_map: dict[str, dict[str, SemanticType]],
130
+ use_case: str | None,
131
+ ) -> list[str]:
132
+ actions: list[str] = []
133
+ for table in schema.tables:
134
+ rows = generated_data.get(table.name, [])
135
+ if not rows:
136
+ continue
137
+ for col_name, semantic in semantic_map.get(table.name, {}).items():
138
+ if semantic not in {
139
+ SemanticType.REGION,
140
+ SemanticType.CATEGORY,
141
+ SemanticType.SECTOR_NAME,
142
+ SemanticType.SECTOR_CATEGORY,
143
+ SemanticType.SEGMENT,
144
+ SemanticType.TIER,
145
+ SemanticType.STATUS,
146
+ SemanticType.FUND_STRATEGY,
147
+ SemanticType.INVESTOR_TYPE,
148
+ SemanticType.INVESTMENT_STAGE,
149
+ SemanticType.COVENANT_STATUS,
150
+ SemanticType.DEBT_PERFORMANCE_STATUS,
151
+ SemanticType.CHANNEL,
152
+ SemanticType.COUNTRY,
153
+ SemanticType.STATE,
154
+ SemanticType.CITY,
155
+ SemanticType.POSTAL_CODE,
156
+ SemanticType.PRODUCT_NAME,
157
+ SemanticType.BRAND_NAME,
158
+ SemanticType.DEPARTMENT_NAME,
159
+ SemanticType.BRANCH_NAME,
160
+ SemanticType.ORG_NAME,
161
+ SemanticType.PERSON_NAME,
162
+ }:
163
+ continue
164
+ domain_values = get_domain_values(semantic, use_case)
165
+ if not domain_values:
166
+ continue
167
+ strict_domain_semantics = {
168
+ SemanticType.CATEGORY,
169
+ SemanticType.SECTOR_NAME,
170
+ SemanticType.SECTOR_CATEGORY,
171
+ SemanticType.SEGMENT,
172
+ SemanticType.CHANNEL,
173
+ SemanticType.FUND_STRATEGY,
174
+ SemanticType.INVESTOR_TYPE,
175
+ SemanticType.INVESTMENT_STAGE,
176
+ SemanticType.COVENANT_STATUS,
177
+ SemanticType.DEBT_PERFORMANCE_STATUS,
178
+ SemanticType.PRODUCT_NAME,
179
+ SemanticType.BRAND_NAME,
180
+ SemanticType.DEPARTMENT_NAME,
181
+ SemanticType.BRANCH_NAME,
182
+ SemanticType.ORG_NAME,
183
+ }
184
+ allowed = {d.lower() for d in domain_values}
185
+ changed = 0
186
+ for idx, row in enumerate(rows):
187
+ current = row.get(col_name)
188
+ if current is None:
189
+ row[col_name] = domain_values[idx % len(domain_values)]
190
+ changed += 1
191
+ continue
192
+ sval = str(current).strip().lower()
193
+ needs_repair = (
194
+ (any(ch.isdigit() for ch in sval) and sval not in allowed)
195
+ or (semantic in strict_domain_semantics and sval not in allowed)
196
+ )
197
+ if needs_repair:
198
+ row[col_name] = domain_values[idx % len(domain_values)]
199
+ changed += 1
200
+ if changed:
201
+ actions.append(f"{table.name}.{col_name}: normalized {changed} categorical values")
202
+ return actions
203
+
204
+ def _repair_numeric(
205
+ self,
206
+ schema: Schema,
207
+ generated_data: dict[str, list[dict]],
208
+ semantic_map: dict[str, dict[str, SemanticType]],
209
+ ) -> list[str]:
210
+ actions: list[str] = []
211
+ for table in schema.tables:
212
+ rows = generated_data.get(table.name, [])
213
+ if not rows:
214
+ continue
215
+ for col_name, semantic in semantic_map.get(table.name, {}).items():
216
+ changed = 0
217
+ for row in rows:
218
+ value = row.get(col_name)
219
+ if isinstance(value, bool) or not isinstance(value, (int, float)):
220
+ continue
221
+ v = float(value)
222
+ fixed = value
223
+ if semantic == SemanticType.INTEREST_RATE:
224
+ fixed = min(max(v, 0.0), 1.0)
225
+ elif semantic == SemanticType.RETURN_RATE:
226
+ fixed = min(max(v, -0.25), 0.75)
227
+ elif semantic == SemanticType.RETURN_MULTIPLE:
228
+ fixed = min(max(v, 0.0), 10.0)
229
+ elif semantic == SemanticType.BASIS_POINTS:
230
+ fixed = min(max(v, -1000.0), 5000.0)
231
+ elif semantic == SemanticType.LEVERAGE_RATIO:
232
+ fixed = min(max(v, 0.0), 20.0)
233
+ elif semantic == SemanticType.PERCENT:
234
+ if 0 <= v <= 1:
235
+ fixed = v * 100.0
236
+ fixed = min(max(float(fixed), 0.0), 100.0)
237
+ elif semantic == SemanticType.DURATION_DAYS:
238
+ fixed = min(max(v, 0.0), 3650.0)
239
+ elif semantic == SemanticType.QUANTITY:
240
+ fixed = min(max(v, 0.0), 1_000_000.0)
241
+ elif semantic == SemanticType.COUNT:
242
+ fixed = min(max(v, 0.0), 100_000_000.0)
243
+ elif semantic == SemanticType.SCORE:
244
+ fixed = min(max(v, 0.0), 1000.0)
245
+ elif semantic == SemanticType.MONEY:
246
+ fixed = min(max(v, -100_000_000.0), 100_000_000.0)
247
+
248
+ if fixed != value:
249
+ row[col_name] = int(round(fixed)) if isinstance(value, int) else round(float(fixed), 4)
250
+ changed += 1
251
+ if changed:
252
+ actions.append(f"{table.name}.{col_name}: clamped {changed} numeric values")
253
+ return actions
254
+
255
+ def _repair_temporal(
256
+ self,
257
+ schema: Schema,
258
+ generated_data: dict[str, list[dict]],
259
+ semantic_map: dict[str, dict[str, SemanticType]],
260
+ ) -> list[str]:
261
+ actions: list[str] = []
262
+ today = datetime.now()
263
+ for table in schema.tables:
264
+ rows = generated_data.get(table.name, [])
265
+ if not rows:
266
+ continue
267
+ table_sem = semantic_map.get(table.name, {})
268
+ date_starts = [c for c, s in table_sem.items() if s == SemanticType.DATE_START]
269
+ date_ends = [c for c, s in table_sem.items() if s == SemanticType.DATE_END]
270
+ date_births = [c for c, s in table_sem.items() if s == SemanticType.DATE_BIRTH]
271
+ event_dates = [c for c, s in table_sem.items() if s == SemanticType.DATE_EVENT]
272
+ changed = 0
273
+ for row in rows:
274
+ for col in date_births:
275
+ dt = _to_datetime(row.get(col))
276
+ if not dt:
277
+ continue
278
+ age = today.year - dt.year
279
+ if dt > today or age < 18:
280
+ row[col] = (today - timedelta(days=random.randint(18 * 365, 70 * 365))).date()
281
+ changed += 1
282
+ for col in event_dates:
283
+ dt = _to_datetime(row.get(col))
284
+ if dt and dt > today:
285
+ row[col] = today.date()
286
+ changed += 1
287
+ for start_col in date_starts:
288
+ start_dt = _to_datetime(row.get(start_col))
289
+ if not start_dt:
290
+ continue
291
+ if start_dt > today:
292
+ row[start_col] = today.date()
293
+ start_dt = _to_datetime(row.get(start_col)) or start_dt
294
+ changed += 1
295
+ for end_col in date_ends:
296
+ end_dt = _to_datetime(row.get(end_col))
297
+ if not end_dt:
298
+ continue
299
+ if end_dt > today:
300
+ row[end_col] = today.date()
301
+ end_dt = _to_datetime(row.get(end_col)) or end_dt
302
+ changed += 1
303
+ if end_dt < start_dt:
304
+ fixed_end = start_dt + timedelta(days=random.randint(0, 14))
305
+ if fixed_end > today:
306
+ fixed_end = today
307
+ row[end_col] = fixed_end.date()
308
+ changed += 1
309
+ if changed:
310
+ actions.append(f"{table.name}: repaired {changed} temporal violations")
311
+ return actions
312
+
313
+ def _repair_geography(
314
+ self,
315
+ schema: Schema,
316
+ generated_data: dict[str, list[dict]],
317
+ semantic_map: dict[str, dict[str, SemanticType]],
318
+ ) -> list[str]:
319
+ actions: list[str] = []
320
+ for table in schema.tables:
321
+ rows = generated_data.get(table.name, [])
322
+ if not rows:
323
+ continue
324
+ table_sem = semantic_map.get(table.name, {})
325
+ state_cols = [c for c, s in table_sem.items() if s == SemanticType.STATE]
326
+ region_cols = [c for c, s in table_sem.items() if s == SemanticType.REGION]
327
+ if not state_cols or not region_cols:
328
+ continue
329
+ changed = 0
330
+ for row in rows:
331
+ state_val = row.get(state_cols[0])
332
+ if not state_val:
333
+ continue
334
+ expected_region = map_state_to_region(str(state_val))
335
+ if not expected_region:
336
+ continue
337
+ for region_col in region_cols:
338
+ if str(row.get(region_col, "")).strip().lower() != expected_region.lower():
339
+ row[region_col] = expected_region
340
+ changed += 1
341
+ if changed:
342
+ actions.append(f"{table.name}: fixed {changed} state/region mismatches")
343
+ return actions
344
+
345
+ def _repair_fk_orphans(
346
+ self,
347
+ schema: Schema,
348
+ generated_data: dict[str, list[dict]],
349
+ ) -> list[str]:
350
+ actions: list[str] = []
351
+ parent_values: dict[tuple[str, str], list[Any]] = {}
352
+ for table in schema.tables:
353
+ rows = generated_data.get(table.name, [])
354
+ for col in table.columns:
355
+ vals = [row.get(col.name) for row in rows if row.get(col.name) is not None]
356
+ parent_values[(table.name.upper(), col.name.upper())] = vals
357
+
358
+ for table in schema.tables:
359
+ rows = generated_data.get(table.name, [])
360
+ if not rows:
361
+ continue
362
+ for fk in table.foreign_keys:
363
+ allowed = parent_values.get((fk.references_table.upper(), fk.references_column.upper()), [])
364
+ if not allowed:
365
+ continue
366
+ changed = 0
367
+ for row in rows:
368
+ value = row.get(fk.column_name)
369
+ if value is not None and value not in allowed:
370
+ row[fk.column_name] = random.choice(allowed)
371
+ changed += 1
372
+ if changed:
373
+ actions.append(
374
+ f"{table.name}.{fk.column_name}: remapped {changed} orphan FK values "
375
+ f"to {fk.references_table}.{fk.references_column}"
376
+ )
377
+ return actions
378
+
379
+ def _repair_series_volatility(
380
+ self,
381
+ schema: Schema,
382
+ generated_data: dict[str, list[dict]],
383
+ semantic_map: dict[str, dict[str, SemanticType]],
384
+ story_controls: dict[str, Any],
385
+ ) -> list[str]:
386
+ actions: list[str] = []
387
+ guardrails = story_controls.get("value_guardrails", {}) if isinstance(story_controls, dict) else {}
388
+ max_mom = float(guardrails.get("max_mom_change_pct", 35.0)) / 100.0
389
+ # Smoothness scoring requires median change well below the raw guardrail.
390
+ # Use a tighter internal repair band so repaired series can pass both
391
+ # volatility and smoothness checks.
392
+ target_mom = max(0.05, max_mom * 0.35)
393
+
394
+ for table in schema.tables:
395
+ rows = generated_data.get(table.name, [])
396
+ if not rows:
397
+ continue
398
+
399
+ table_sem = semantic_map.get(table.name, {})
400
+ date_cols = [c for c, s in table_sem.items() if s in {SemanticType.DATE_EVENT, SemanticType.DATE_START, SemanticType.DATE_END}]
401
+ measure_cols = [c for c, s in table_sem.items() if s in {SemanticType.MONEY, SemanticType.QUANTITY, SemanticType.COUNT}]
402
+ if not date_cols or not measure_cols:
403
+ continue
404
+
405
+ date_col = date_cols[0]
406
+ dated_rows = []
407
+ for row in rows:
408
+ dval = _to_datetime(row.get(date_col))
409
+ if dval:
410
+ dated_rows.append((dval, row))
411
+ dated_rows.sort(key=lambda x: x[0])
412
+ if len(dated_rows) < 4:
413
+ continue
414
+
415
+ # Work at month granularity so repairs target trend volatility
416
+ # rather than row-level transaction noise.
417
+ month_rows: dict[str, list[dict[str, Any]]] = {}
418
+ for dval, row in dated_rows:
419
+ month_key = f"{dval.year:04d}-{dval.month:02d}"
420
+ month_rows.setdefault(month_key, []).append(row)
421
+
422
+ changed = 0
423
+ for col in measure_cols:
424
+ month_means: list[tuple[str, float]] = []
425
+ for month_key in sorted(month_rows.keys()):
426
+ vals = [
427
+ float(r[col])
428
+ for r in month_rows[month_key]
429
+ if isinstance(r.get(col), (int, float)) and not isinstance(r.get(col), bool)
430
+ ]
431
+ if vals:
432
+ month_means.append((month_key, sum(vals) / len(vals)))
433
+ if len(month_means) < 2:
434
+ continue
435
+
436
+ prev_target = None
437
+ for month_key, month_mean in month_means:
438
+ target = month_mean
439
+ if prev_target is not None and prev_target > 0:
440
+ low = max(0.0, prev_target * (1.0 - target_mom))
441
+ high = prev_target * (1.0 + target_mom)
442
+ if target < low:
443
+ target = low
444
+ elif target > high:
445
+ target = high
446
+
447
+ if month_mean > 0 and abs(target - month_mean) > 1e-9:
448
+ scale = target / month_mean
449
+ for row in month_rows.get(month_key, []):
450
+ value = row.get(col)
451
+ if not isinstance(value, (int, float)) or isinstance(value, bool):
452
+ continue
453
+ new_value = float(value) * scale
454
+ row[col] = int(round(new_value)) if isinstance(value, int) else round(new_value, 2)
455
+ changed += 1
456
+ prev_target = target
457
+
458
+ if changed:
459
+ actions.append(f"{table.name}: smoothed {changed} volatility points")
460
+
461
+ return actions
462
+
463
+ def _repair_kpi_consistency(
464
+ self,
465
+ schema: Schema,
466
+ generated_data: dict[str, list[dict]],
467
+ semantic_map: dict[str, dict[str, SemanticType]],
468
+ ) -> list[str]:
469
+ actions: list[str] = []
470
+ for table in schema.tables:
471
+ rows = generated_data.get(table.name, [])
472
+ if not rows:
473
+ continue
474
+
475
+ cols = list(semantic_map.get(table.name, {}).keys())
476
+ qty_key = next((k for k in cols if any(t in k.lower() for t in ("quantity", "qty", "units"))), None)
477
+ price_key = next((k for k in cols if "price" in k.lower()), None)
478
+ revenue_key = next((k for k in cols if any(t in k.lower() for t in ("revenue", "sales_amount", "total_amount", "gross", "net"))), None)
479
+ if not qty_key or not price_key or not revenue_key:
480
+ continue
481
+
482
+ changed = 0
483
+ for row in rows:
484
+ qty = row.get(qty_key)
485
+ price = row.get(price_key)
486
+ revenue = row.get(revenue_key)
487
+ if not isinstance(qty, (int, float)) or not isinstance(price, (int, float)) or not isinstance(revenue, (int, float)):
488
+ continue
489
+ expected = float(qty) * float(price)
490
+ if expected <= 0:
491
+ continue
492
+ rel_err = abs(float(revenue) - expected) / max(expected, 1.0)
493
+ if rel_err > 0.25:
494
+ row[revenue_key] = int(round(expected)) if isinstance(revenue, int) else round(expected, 2)
495
+ changed += 1
496
+
497
+ if changed:
498
+ actions.append(f"{table.name}: repaired {changed} KPI consistency rows")
499
+
500
+ return actions
501
+
502
+ def _repair_private_equity_financials(
503
+ self,
504
+ schema: Schema,
505
+ generated_data: dict[str, list[dict]],
506
+ ) -> list[str]:
507
+ actions: list[str] = []
508
+ for table in schema.tables:
509
+ rows = generated_data.get(table.name, [])
510
+ if not rows:
511
+ continue
512
+
513
+ col_map = {col.name.lower(): col.name for col in table.columns}
514
+ total_value_col = col_map.get("total_value_usd") or col_map.get("total_value")
515
+ reported_value_col = col_map.get("reported_value_usd") or col_map.get("reported_value")
516
+ distributions_col = col_map.get("distributions_usd") or col_map.get("distributions")
517
+ gross_irr_col = col_map.get("gross_irr")
518
+ net_irr_col = col_map.get("net_irr")
519
+ gross_without_subline_col = col_map.get("gross_irr_without_sub_line")
520
+ irr_impact_bps_col = col_map.get("irr_sub_line_impact_bps")
521
+ total_return_multiple_col = col_map.get("total_return_multiple")
522
+ dpi_multiple_col = col_map.get("dpi_multiple")
523
+ rvpi_multiple_col = col_map.get("rvpi_multiple")
524
+ revenue_col = col_map.get("revenue_usd") or col_map.get("revenue")
525
+ ebitda_col = col_map.get("ebitda_usd") or col_map.get("ebitda")
526
+ ebitda_margin_col = col_map.get("ebitda_margin_pct")
527
+ net_debt_col = col_map.get("net_debt_usd") or col_map.get("net_debt")
528
+ debt_to_ebitda_col = col_map.get("debt_to_ebitda_ratio")
529
+
530
+ changed = 0
531
+ for row in rows:
532
+ if (
533
+ total_value_col
534
+ and reported_value_col
535
+ and distributions_col
536
+ and isinstance(row.get(reported_value_col), (int, float))
537
+ and isinstance(row.get(distributions_col), (int, float))
538
+ ):
539
+ computed_total = float(row[reported_value_col]) + float(row[distributions_col])
540
+ current_total = row.get(total_value_col)
541
+ if not isinstance(current_total, (int, float)) or abs(float(current_total) - computed_total) > 0.01:
542
+ row[total_value_col] = round(computed_total, 2)
543
+ changed += 1
544
+
545
+ if (
546
+ gross_irr_col
547
+ and net_irr_col
548
+ and isinstance(row.get(gross_irr_col), (int, float))
549
+ and isinstance(row.get(net_irr_col), (int, float))
550
+ and float(row[net_irr_col]) > float(row[gross_irr_col])
551
+ ):
552
+ row[net_irr_col] = round(float(row[gross_irr_col]) - 0.01, 4)
553
+ changed += 1
554
+
555
+ if (
556
+ gross_irr_col
557
+ and gross_without_subline_col
558
+ and irr_impact_bps_col
559
+ and isinstance(row.get(gross_irr_col), (int, float))
560
+ and isinstance(row.get(gross_without_subline_col), (int, float))
561
+ ):
562
+ impact_bps = round((float(row[gross_irr_col]) - float(row[gross_without_subline_col])) * 10000, 0)
563
+ current_bps = row.get(irr_impact_bps_col)
564
+ if not isinstance(current_bps, (int, float)) or abs(float(current_bps) - impact_bps) > 1:
565
+ row[irr_impact_bps_col] = int(impact_bps)
566
+ changed += 1
567
+
568
+ if (
569
+ total_return_multiple_col
570
+ and dpi_multiple_col
571
+ and rvpi_multiple_col
572
+ and isinstance(row.get(dpi_multiple_col), (int, float))
573
+ and isinstance(row.get(rvpi_multiple_col), (int, float))
574
+ ):
575
+ tvpi = float(row[dpi_multiple_col]) + float(row[rvpi_multiple_col])
576
+ current_tvpi = row.get(total_return_multiple_col)
577
+ if not isinstance(current_tvpi, (int, float)) or abs(float(current_tvpi) - tvpi) > 0.01:
578
+ row[total_return_multiple_col] = round(tvpi, 3)
579
+ changed += 1
580
+
581
+ if (
582
+ revenue_col
583
+ and ebitda_col
584
+ and ebitda_margin_col
585
+ and isinstance(row.get(revenue_col), (int, float))
586
+ and isinstance(row.get(ebitda_col), (int, float))
587
+ and float(row[revenue_col]) > 0
588
+ ):
589
+ margin_pct = (float(row[ebitda_col]) / float(row[revenue_col])) * 100.0
590
+ current_margin = row.get(ebitda_margin_col)
591
+ if not isinstance(current_margin, (int, float)) or abs(float(current_margin) - margin_pct) > 0.25:
592
+ row[ebitda_margin_col] = round(margin_pct, 2)
593
+ changed += 1
594
+
595
+ if (
596
+ net_debt_col
597
+ and ebitda_col
598
+ and debt_to_ebitda_col
599
+ and isinstance(row.get(net_debt_col), (int, float))
600
+ and isinstance(row.get(ebitda_col), (int, float))
601
+ and float(row[ebitda_col]) > 0
602
+ ):
603
+ leverage = float(row[net_debt_col]) / float(row[ebitda_col])
604
+ current_leverage = row.get(debt_to_ebitda_col)
605
+ if not isinstance(current_leverage, (int, float)) or abs(float(current_leverage) - leverage) > 0.05:
606
+ row[debt_to_ebitda_col] = round(leverage, 3)
607
+ changed += 1
608
+
609
+ if changed:
610
+ actions.append(f"{table.name}: repaired {changed} private-equity financial identity values")
611
+
612
+ return actions
legitdata_project/legitdata/quality/validator.py ADDED
@@ -0,0 +1,935 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Post-generation validator for demo data quality."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import re
7
+ from collections import defaultdict
8
+ from datetime import date, datetime
9
+ from typing import Any
10
+
11
+ from ..ddl.models import Schema
12
+ from ..domain import (
13
+ SemanticType,
14
+ get_domain_values,
15
+ infer_semantic_type,
16
+ is_business_categorical,
17
+ map_state_to_region,
18
+ )
19
+ from .quality_spec import QualityIssue, QualityReport, QualitySpec, now_utc_iso
20
+
21
+
22
+ JUNK_PATTERN = re.compile(
23
+ r"^(?:region|category|segment|tier|status)\s*\d+$|^[A-Za-z]+\s+\d+$",
24
+ re.IGNORECASE,
25
+ )
26
+ PERSON_NAME_PATTERN = re.compile(r"^[A-Za-z][A-Za-z'\-]*\s+[A-Za-z][A-Za-z'\-]*(?:\s+[A-Za-z][A-Za-z'\-]*)*$")
27
+
28
+
29
+ def _to_date(value: Any) -> date | None:
30
+ if isinstance(value, datetime):
31
+ return value.date()
32
+ if isinstance(value, date):
33
+ return value
34
+ if isinstance(value, str):
35
+ try:
36
+ return datetime.strptime(value[:10], "%Y-%m-%d").date()
37
+ except ValueError:
38
+ return None
39
+ return None
40
+
41
+
42
+ def _is_numeric(value: Any) -> bool:
43
+ if isinstance(value, bool):
44
+ return False
45
+ return isinstance(value, (int, float))
46
+
47
+
48
+ def _numeric_violation(semantic: SemanticType, value: Any) -> bool:
49
+ if not _is_numeric(value):
50
+ return False
51
+ v = float(value)
52
+ if math.isnan(v) or math.isinf(v):
53
+ return True
54
+ if semantic == SemanticType.INTEREST_RATE:
55
+ return v < 0 or v > 1.0
56
+ if semantic == SemanticType.RETURN_RATE:
57
+ return v < -0.25 or v > 0.75
58
+ if semantic == SemanticType.RETURN_MULTIPLE:
59
+ return v < 0 or v > 10.0
60
+ if semantic == SemanticType.BASIS_POINTS:
61
+ return v < -1000 or v > 5000
62
+ if semantic == SemanticType.LEVERAGE_RATIO:
63
+ return v < 0 or v > 20.0
64
+ if semantic == SemanticType.PERCENT:
65
+ if 0 <= v <= 1:
66
+ return False
67
+ return not (0 <= v <= 100)
68
+ if semantic == SemanticType.QUANTITY:
69
+ return v < 0 or v > 1_000_000
70
+ if semantic == SemanticType.COUNT:
71
+ return v < 0 or v > 100_000_000
72
+ if semantic == SemanticType.SCORE:
73
+ return v < 0 or v > 1000
74
+ if semantic == SemanticType.DURATION_DAYS:
75
+ return v < 0 or v > 3650
76
+ if semantic == SemanticType.MONEY:
77
+ return abs(v) > 100_000_000
78
+ return False
79
+
80
+
81
+ def _find_pct_changes(values: list[float]) -> list[float]:
82
+ pct_changes: list[float] = []
83
+ for i in range(1, len(values)):
84
+ prev = values[i - 1]
85
+ cur = values[i]
86
+ if prev == 0:
87
+ continue
88
+ pct_changes.append((cur - prev) / abs(prev))
89
+ return pct_changes
90
+
91
+
92
+ def _month_bucket(dval: date) -> str:
93
+ return f"{dval.year:04d}-{dval.month:02d}"
94
+
95
+
96
+ def _series_by_month(dated_rows: list[tuple[date, dict[str, Any]]], col_name: str) -> list[float]:
97
+ buckets: dict[str, list[float]] = {}
98
+ for dval, row in dated_rows:
99
+ value = row.get(col_name)
100
+ if not _is_numeric(value):
101
+ continue
102
+ buckets.setdefault(_month_bucket(dval), []).append(float(value))
103
+ if not buckets:
104
+ return []
105
+ series: list[float] = []
106
+ for month in sorted(buckets.keys()):
107
+ vals = buckets[month]
108
+ if vals:
109
+ series.append(sum(vals) / len(vals))
110
+ return series
111
+
112
+
113
+ class DataQualityValidator:
114
+ """Validates generated row data before warehouse write."""
115
+
116
+ def __init__(self, quality_spec: QualitySpec):
117
+ self.spec = quality_spec
118
+
119
+ def _semantic_map(self, schema: Schema, use_case: str | None) -> dict[str, dict[str, SemanticType]]:
120
+ mapping: dict[str, dict[str, SemanticType]] = {}
121
+ for table in schema.tables:
122
+ table_map: dict[str, SemanticType] = {}
123
+ for col in table.columns:
124
+ table_map[col.name] = infer_semantic_type(col.name, col.data_type, table.name, use_case)
125
+ mapping[table.name] = table_map
126
+ return mapping
127
+
128
+ def _validate_saas_finance_contract(
129
+ self,
130
+ generated_data: dict[str, list[dict]],
131
+ ) -> tuple[list[QualityIssue], int, int, int, dict[str, Any]]:
132
+ """Validate recurring-revenue accounting identities for the gold-path schema."""
133
+ fact_rows = generated_data.get("SAAS_CUSTOMER_MONTHLY", [])
134
+ spend_rows = generated_data.get("SALES_MARKETING_SPEND_MONTHLY", [])
135
+ if not fact_rows:
136
+ return [], 0, 0, 0, {}
137
+
138
+ issues: list[QualityIssue] = []
139
+ kpi_checks = 0
140
+ kpi_passes = 0
141
+ temporal_violations = 0
142
+
143
+ identity_failures = 0
144
+ continuity_failures = 0
145
+ retention_failures = 0
146
+ spend_failures = 0
147
+
148
+ customer_months: dict[Any, list[int]] = defaultdict(list)
149
+ retention_buckets: dict[tuple[int, str], dict[str, float]] = defaultdict(
150
+ lambda: {
151
+ "starting": 0.0,
152
+ "expansion": 0.0,
153
+ "contraction": 0.0,
154
+ "churned": 0.0,
155
+ }
156
+ )
157
+
158
+ for row in fact_rows:
159
+ month_key = int(row.get("MONTH_KEY") or 0)
160
+ customer_key = row.get("CUSTOMER_KEY")
161
+ segment = None
162
+ starting_arr = float(row.get("STARTING_ARR_USD") or 0.0)
163
+ new_logo_arr = float(row.get("NEW_LOGO_ARR_USD") or 0.0)
164
+ expansion_arr = float(row.get("EXPANSION_ARR_USD") or 0.0)
165
+ contraction_arr = float(row.get("CONTRACTION_ARR_USD") or 0.0)
166
+ churned_arr = float(row.get("CHURNED_ARR_USD") or 0.0)
167
+ ending_arr = float(row.get("ENDING_ARR_USD") or 0.0)
168
+ mrr_usd = float(row.get("MRR_USD") or 0.0)
169
+ gm_pct = float(row.get("GROSS_MARGIN_PCT") or 0.0)
170
+ gm_usd = float(row.get("GROSS_MARGIN_USD") or 0.0)
171
+ acquired_flag = bool(row.get("CUSTOMER_ACQUIRED_FLAG"))
172
+ churned_flag = bool(row.get("CUSTOMER_CHURNED_FLAG"))
173
+
174
+ kpi_checks += 1
175
+ expected_ending = starting_arr + new_logo_arr + expansion_arr - contraction_arr - churned_arr
176
+ if abs(ending_arr - expected_ending) <= 1.0:
177
+ kpi_passes += 1
178
+ else:
179
+ identity_failures += 1
180
+
181
+ kpi_checks += 1
182
+ if abs((mrr_usd * 12.0) - ending_arr) <= 12.0:
183
+ kpi_passes += 1
184
+ else:
185
+ identity_failures += 1
186
+
187
+ recognized_revenue = ending_arr / 12.0
188
+ kpi_checks += 1
189
+ if recognized_revenue <= 0 or abs(gm_usd - (recognized_revenue * (gm_pct / 100.0))) <= max(50.0, recognized_revenue * 0.15):
190
+ kpi_passes += 1
191
+ else:
192
+ identity_failures += 1
193
+
194
+ kpi_checks += 1
195
+ if (acquired_flag and new_logo_arr > 0) or (not acquired_flag):
196
+ kpi_passes += 1
197
+ else:
198
+ identity_failures += 1
199
+
200
+ kpi_checks += 1
201
+ if (churned_flag and ending_arr == 0.0) or (not churned_flag):
202
+ kpi_passes += 1
203
+ else:
204
+ identity_failures += 1
205
+
206
+ if customer_key is not None and month_key:
207
+ customer_months[customer_key].append(month_key)
208
+
209
+ customer_segments = {
210
+ row.get("CUSTOMER_KEY"): str(row.get("SEGMENT") or "")
211
+ for row in generated_data.get("CUSTOMERS", [])
212
+ }
213
+ for row in fact_rows:
214
+ month_key = int(row.get("MONTH_KEY") or 0)
215
+ customer_key = row.get("CUSTOMER_KEY")
216
+ segment = customer_segments.get(customer_key, "")
217
+ bucket = retention_buckets[(month_key, segment)]
218
+ bucket["starting"] += float(row.get("STARTING_ARR_USD") or 0.0)
219
+ bucket["expansion"] += float(row.get("EXPANSION_ARR_USD") or 0.0)
220
+ bucket["contraction"] += float(row.get("CONTRACTION_ARR_USD") or 0.0)
221
+ bucket["churned"] += float(row.get("CHURNED_ARR_USD") or 0.0)
222
+
223
+ for month_keys in customer_months.values():
224
+ ordered = sorted(set(month_keys))
225
+ for prev, cur in zip(ordered, ordered[1:]):
226
+ prev_year, prev_month = divmod(prev // 100, 100)
227
+ cur_year, cur_month = divmod(cur // 100, 100)
228
+ month_gap = (cur_year - prev_year) * 12 + (cur_month - prev_month)
229
+ if month_gap > 1:
230
+ continuity_failures += 1
231
+ temporal_violations += 1
232
+
233
+ for (month_key, segment), bucket in retention_buckets.items():
234
+ starting = bucket["starting"]
235
+ if starting <= 0:
236
+ continue
237
+ nrr = (starting + bucket["expansion"] - bucket["contraction"] - bucket["churned"]) / starting
238
+ grr = (starting - bucket["contraction"] - bucket["churned"]) / starting
239
+ kpi_checks += 2
240
+ if 0.70 <= nrr <= 1.30:
241
+ kpi_passes += 1
242
+ else:
243
+ retention_failures += 1
244
+ if 0.60 <= grr <= 1.05:
245
+ kpi_passes += 1
246
+ else:
247
+ retention_failures += 1
248
+
249
+ for row in spend_rows:
250
+ sales = float(row.get("SALES_SPEND_USD") or 0.0)
251
+ marketing = float(row.get("MARKETING_SPEND_USD") or 0.0)
252
+ total = float(row.get("TOTAL_S_AND_M_SPEND_USD") or 0.0)
253
+ new_customers = int(row.get("NEW_CUSTOMERS_ACQUIRED") or 0)
254
+ new_logo_arr = float(row.get("NEW_LOGO_ARR_USD") or 0.0)
255
+ nrr_pct = float(row.get("NRR_PCT") or 0.0)
256
+ grr_pct = float(row.get("GRR_PCT") or 0.0)
257
+ cac_usd = float(row.get("CAC_USD") or 0.0)
258
+
259
+ kpi_checks += 1
260
+ if abs(total - (sales + marketing)) <= 1.0:
261
+ kpi_passes += 1
262
+ else:
263
+ spend_failures += 1
264
+
265
+ kpi_checks += 1
266
+ if (new_customers == 0 and new_logo_arr >= 0.0) or (new_customers > 0 and new_logo_arr > 0.0):
267
+ kpi_passes += 1
268
+ else:
269
+ spend_failures += 1
270
+
271
+ kpi_checks += 2
272
+ if 70.0 <= nrr_pct <= 130.0:
273
+ kpi_passes += 1
274
+ else:
275
+ spend_failures += 1
276
+ if 60.0 <= grr_pct <= 105.0:
277
+ kpi_passes += 1
278
+ else:
279
+ spend_failures += 1
280
+
281
+ kpi_checks += 1
282
+ if (new_customers == 0 and cac_usd == 0.0) or (new_customers > 0 and cac_usd > 0.0):
283
+ kpi_passes += 1
284
+ else:
285
+ spend_failures += 1
286
+
287
+ if identity_failures:
288
+ issues.append(
289
+ QualityIssue(
290
+ issue_type="saas_finance_identity",
291
+ table="SAAS_CUSTOMER_MONTHLY",
292
+ column="ENDING_ARR_USD",
293
+ severity="high",
294
+ count=identity_failures,
295
+ message="Recurring-revenue identities do not reconcile",
296
+ samples=[],
297
+ )
298
+ )
299
+ if continuity_failures:
300
+ issues.append(
301
+ QualityIssue(
302
+ issue_type="saas_finance_continuity",
303
+ table="SAAS_CUSTOMER_MONTHLY",
304
+ column="MONTH_KEY",
305
+ severity="high",
306
+ count=continuity_failures,
307
+ message="Detected missing customer-month continuity in recurring-revenue history",
308
+ samples=[],
309
+ )
310
+ )
311
+ if retention_failures:
312
+ issues.append(
313
+ QualityIssue(
314
+ issue_type="saas_finance_retention",
315
+ table="SAAS_CUSTOMER_MONTHLY",
316
+ column="NRR/GRR",
317
+ severity="high",
318
+ count=retention_failures,
319
+ message="NRR or GRR fell outside the allowed range for one or more segment-months",
320
+ samples=[],
321
+ )
322
+ )
323
+ if spend_failures:
324
+ issues.append(
325
+ QualityIssue(
326
+ issue_type="saas_finance_spend",
327
+ table="SALES_MARKETING_SPEND_MONTHLY",
328
+ column="TOTAL_S_AND_M_SPEND_USD",
329
+ severity="high",
330
+ count=spend_failures,
331
+ message="Sales and marketing spend rows failed CAC/spend integrity checks",
332
+ samples=[],
333
+ )
334
+ )
335
+
336
+ summary = {
337
+ "saas_finance_identity_failures": identity_failures,
338
+ "saas_finance_continuity_failures": continuity_failures,
339
+ "saas_finance_retention_failures": retention_failures,
340
+ "saas_finance_spend_failures": spend_failures,
341
+ }
342
+ return issues, kpi_checks, kpi_passes, temporal_violations, summary
343
+
344
+ def validate(
345
+ self,
346
+ schema: Schema,
347
+ generated_data: dict[str, list[dict]],
348
+ use_case: str | None = None,
349
+ story_controls: dict[str, Any] | None = None,
350
+ ) -> QualityReport:
351
+ semantic_map = self._semantic_map(schema, use_case)
352
+ issues: list[QualityIssue] = []
353
+
354
+ categorical_junk_count = 0
355
+ fk_orphan_count = 0
356
+ temporal_violations = 0
357
+ numeric_violations = 0
358
+ semantic_checks = 0
359
+ semantic_passes = 0
360
+
361
+ # Story-quality metrics
362
+ volatility_breaches = 0
363
+ smoothness_scores: list[float] = []
364
+ outlier_candidates = 0
365
+ outlier_points_total = 0
366
+ kpi_checks = 0
367
+ kpi_passes = 0
368
+
369
+ controls = story_controls or {}
370
+ guardrails = controls.get("value_guardrails", {}) if isinstance(controls, dict) else {}
371
+ max_mom_change_pct = float(guardrails.get("max_mom_change_pct", 35))
372
+ outlier_budget = controls.get("outlier_budget", {}) if isinstance(controls, dict) else {}
373
+ max_outlier_points_pct = float(outlier_budget.get("max_points_pct", 0.02))
374
+ max_outlier_events = int(outlier_budget.get("max_events", 2))
375
+
376
+ today = datetime.now().date()
377
+
378
+ # Build parent value sets for FK integrity.
379
+ parent_values: dict[tuple[str, str], set[Any]] = {}
380
+ for table in schema.tables:
381
+ rows = generated_data.get(table.name, [])
382
+ for col in table.columns:
383
+ values = {row.get(col.name) for row in rows if row.get(col.name) is not None}
384
+ parent_values[(table.name.upper(), col.name.upper())] = values
385
+
386
+ for table in schema.tables:
387
+ rows = generated_data.get(table.name, [])
388
+ table_semantics = semantic_map.get(table.name, {})
389
+
390
+ # Column-level checks.
391
+ for col in table.columns:
392
+ semantic = table_semantics.get(col.name, SemanticType.UNKNOWN)
393
+ values = [row.get(col.name) for row in rows if row.get(col.name) is not None]
394
+ if not values:
395
+ continue
396
+
397
+ if is_business_categorical(semantic):
398
+ semantic_checks += len(values)
399
+ junk_values = []
400
+ allowed = {v.lower() for v in get_domain_values(semantic, use_case)}
401
+ strict_domain_semantics = {
402
+ SemanticType.CATEGORY,
403
+ SemanticType.SECTOR_NAME,
404
+ SemanticType.SECTOR_CATEGORY,
405
+ SemanticType.SEGMENT,
406
+ SemanticType.CHANNEL,
407
+ SemanticType.FUND_STRATEGY,
408
+ SemanticType.INVESTOR_TYPE,
409
+ SemanticType.INVESTMENT_STAGE,
410
+ SemanticType.COVENANT_STATUS,
411
+ SemanticType.DEBT_PERFORMANCE_STATUS,
412
+ SemanticType.PRODUCT_NAME,
413
+ SemanticType.BRAND_NAME,
414
+ SemanticType.DEPARTMENT_NAME,
415
+ SemanticType.BRANCH_NAME,
416
+ SemanticType.ORG_NAME,
417
+ }
418
+ for value in values:
419
+ sval = str(value).strip()
420
+ is_junk = (
421
+ (JUNK_PATTERN.match(sval) and sval.lower() not in allowed)
422
+ or (semantic in strict_domain_semantics and allowed and sval.lower() not in allowed)
423
+ )
424
+ if is_junk:
425
+ junk_values.append(sval)
426
+ else:
427
+ semantic_passes += 1
428
+ if junk_values:
429
+ categorical_junk_count += len(junk_values)
430
+ issues.append(
431
+ QualityIssue(
432
+ issue_type="categorical_junk",
433
+ table=table.name,
434
+ column=col.name,
435
+ severity="high",
436
+ count=len(junk_values),
437
+ message="Detected synthetic categorical values",
438
+ samples=junk_values[: self.spec.sample_limit_per_issue],
439
+ )
440
+ )
441
+
442
+ if semantic == SemanticType.PERSON_NAME:
443
+ semantic_checks += len(values)
444
+ bad_person_names = []
445
+ for value in values:
446
+ sval = str(value).strip()
447
+ if PERSON_NAME_PATTERN.match(sval):
448
+ semantic_passes += 1
449
+ else:
450
+ bad_person_names.append(sval)
451
+ if bad_person_names:
452
+ categorical_junk_count += len(bad_person_names)
453
+ issues.append(
454
+ QualityIssue(
455
+ issue_type="categorical_junk",
456
+ table=table.name,
457
+ column=col.name,
458
+ severity="high",
459
+ count=len(bad_person_names),
460
+ message="Detected unrealistic person-name values",
461
+ samples=bad_person_names[: self.spec.sample_limit_per_issue],
462
+ )
463
+ )
464
+
465
+ if semantic in {
466
+ SemanticType.MONEY,
467
+ SemanticType.QUANTITY,
468
+ SemanticType.COUNT,
469
+ SemanticType.PERCENT,
470
+ SemanticType.INTEREST_RATE,
471
+ SemanticType.RETURN_RATE,
472
+ SemanticType.RETURN_MULTIPLE,
473
+ SemanticType.BASIS_POINTS,
474
+ SemanticType.LEVERAGE_RATIO,
475
+ SemanticType.SCORE,
476
+ SemanticType.DURATION_DAYS,
477
+ }:
478
+ semantic_checks += len(values)
479
+ bad = [v for v in values if _numeric_violation(semantic, v)]
480
+ numeric_violations += len(bad)
481
+ semantic_passes += len(values) - len(bad)
482
+ if bad:
483
+ issues.append(
484
+ QualityIssue(
485
+ issue_type="numeric_range",
486
+ table=table.name,
487
+ column=col.name,
488
+ severity="high",
489
+ count=len(bad),
490
+ message=f"Values outside plausible range for {semantic.value}",
491
+ samples=bad[: self.spec.sample_limit_per_issue],
492
+ )
493
+ )
494
+
495
+ # Row-level temporal and geo checks.
496
+ for row in rows:
497
+ start_dates = []
498
+ end_dates = []
499
+ state_value = None
500
+ region_value = None
501
+ reference_date = None
502
+ month_name_values: list[str] = []
503
+ quarter_name_values: list[str] = []
504
+ year_values: list[int] = []
505
+
506
+ for col in table.columns:
507
+ value = row.get(col.name)
508
+ if value is None:
509
+ continue
510
+ col_lower = col.name.lower()
511
+ semantic = table_semantics.get(col.name, SemanticType.UNKNOWN)
512
+ if semantic == SemanticType.DATE_START:
513
+ dt = _to_date(value)
514
+ if dt:
515
+ start_dates.append((col.name, dt))
516
+ if reference_date is None:
517
+ reference_date = dt
518
+ if dt > today:
519
+ temporal_violations += 1
520
+ elif semantic == SemanticType.DATE_END:
521
+ dt = _to_date(value)
522
+ if dt:
523
+ end_dates.append((col.name, dt))
524
+ if reference_date is None:
525
+ reference_date = dt
526
+ if dt > today:
527
+ temporal_violations += 1
528
+ elif semantic == SemanticType.DATE_EVENT:
529
+ dt = _to_date(value)
530
+ if dt:
531
+ if reference_date is None:
532
+ reference_date = dt
533
+ if dt > today:
534
+ temporal_violations += 1
535
+ elif semantic == SemanticType.DATE_BIRTH:
536
+ dt = _to_date(value)
537
+ if dt and (dt > today or (today.year - dt.year) < 18):
538
+ temporal_violations += 1
539
+ elif semantic == SemanticType.STATE:
540
+ state_value = value
541
+ elif semantic == SemanticType.REGION:
542
+ region_value = value
543
+
544
+ if "month" in col_lower and "name" in col_lower:
545
+ month_name_values.append(str(value).strip())
546
+ elif "quarter" in col_lower and "name" in col_lower:
547
+ quarter_name_values.append(str(value).strip())
548
+ elif "year" in col_lower and "name" not in col_lower and _is_numeric(value):
549
+ year_values.append(int(float(value)))
550
+
551
+ for _, start_dt in start_dates:
552
+ for _, end_dt in end_dates:
553
+ if end_dt < start_dt:
554
+ temporal_violations += 1
555
+
556
+ if reference_date:
557
+ expected_month_full = reference_date.strftime("%B").lower()
558
+ expected_month_abbr = reference_date.strftime("%b").lower()
559
+ expected_quarter = f"q{((reference_date.month - 1) // 3) + 1}"
560
+ expected_year = reference_date.year
561
+
562
+ for month_value in month_name_values:
563
+ normalized = month_value.lower()
564
+ if not (normalized.startswith(expected_month_full) or normalized.startswith(expected_month_abbr)):
565
+ temporal_violations += 1
566
+
567
+ for quarter_value in quarter_name_values:
568
+ normalized = quarter_value.lower().replace(" ", "")
569
+ if not (normalized.startswith(expected_quarter) or normalized == expected_quarter.replace("q", "")):
570
+ temporal_violations += 1
571
+
572
+ for year_value in year_values:
573
+ if year_value != expected_year:
574
+ temporal_violations += 1
575
+
576
+ if state_value and region_value:
577
+ expected_region = map_state_to_region(str(state_value))
578
+ if expected_region and str(region_value).strip().lower() != expected_region.lower():
579
+ issues.append(
580
+ QualityIssue(
581
+ issue_type="geo_mismatch",
582
+ table=table.name,
583
+ column="region",
584
+ severity="medium",
585
+ count=1,
586
+ message=f"State/region mismatch ({state_value} -> {region_value})",
587
+ samples=[{"state": state_value, "region": region_value, "expected": expected_region}],
588
+ )
589
+ )
590
+
591
+ # FK integrity checks.
592
+ for fk in table.foreign_keys:
593
+ parent_key = (fk.references_table.upper(), fk.references_column.upper())
594
+ allowed = parent_values.get(parent_key, set())
595
+ if not allowed:
596
+ continue
597
+ missing = 0
598
+ for row in rows:
599
+ value = row.get(fk.column_name)
600
+ if value is None:
601
+ continue
602
+ if value not in allowed:
603
+ missing += 1
604
+ if missing:
605
+ fk_orphan_count += missing
606
+ issues.append(
607
+ QualityIssue(
608
+ issue_type="fk_orphan",
609
+ table=table.name,
610
+ column=fk.column_name,
611
+ severity="critical",
612
+ count=missing,
613
+ message=f"Foreign key not found in {fk.references_table}.{fk.references_column}",
614
+ samples=[],
615
+ )
616
+ )
617
+
618
+ # Story metrics: smoothness + volatility + KPI consistency.
619
+ date_cols = [c for c, s in table_semantics.items() if s in {SemanticType.DATE_EVENT, SemanticType.DATE_START, SemanticType.DATE_END}]
620
+ measure_cols = [c for c, s in table_semantics.items() if s in {SemanticType.MONEY, SemanticType.QUANTITY, SemanticType.COUNT}]
621
+
622
+ if date_cols and measure_cols and rows:
623
+ date_col = date_cols[0]
624
+ dated_rows = []
625
+ for row in rows:
626
+ dval = _to_date(row.get(date_col))
627
+ if dval:
628
+ dated_rows.append((dval, row))
629
+ dated_rows.sort(key=lambda x: x[0])
630
+
631
+ if len(dated_rows) >= 4:
632
+ for col_name in measure_cols:
633
+ # Evaluate volatility/smoothness on monthly aggregates to avoid
634
+ # false positives from row-level transaction jitter.
635
+ series = _series_by_month(dated_rows, col_name)
636
+ if len(series) < 4:
637
+ continue
638
+ pct_changes = _find_pct_changes(series)
639
+ if not pct_changes:
640
+ continue
641
+
642
+ outlier_points_total += len(pct_changes)
643
+ abs_changes = [abs(v) for v in pct_changes]
644
+
645
+ breaches = sum(1 for v in abs_changes if (v * 100.0) > max_mom_change_pct)
646
+ volatility_breaches += breaches
647
+
648
+ median_abs = sorted(abs_changes)[len(abs_changes) // 2]
649
+ max_band = max(max_mom_change_pct / 100.0, 0.01)
650
+ smoothness = max(0.0, 1.0 - (median_abs / (max_band * 1.4)))
651
+ smoothness_scores.append(smoothness)
652
+
653
+ outlier_candidates += sum(1 for v in abs_changes if v > 0.80)
654
+
655
+ qty_key = next((k for k in table_semantics.keys() if any(t in k.lower() for t in ("quantity", "qty", "units"))), None)
656
+ price_key = next((k for k in table_semantics.keys() if "price" in k.lower()), None)
657
+ revenue_key = next((k for k in table_semantics.keys() if any(t in k.lower() for t in ("revenue", "sales_amount", "total_amount", "gross", "net"))), None)
658
+ if qty_key and price_key and revenue_key and rows:
659
+ bad_kpi = 0
660
+ for row in rows:
661
+ qty = row.get(qty_key)
662
+ price = row.get(price_key)
663
+ revenue = row.get(revenue_key)
664
+ if not (_is_numeric(qty) and _is_numeric(price) and _is_numeric(revenue)):
665
+ continue
666
+ expected = float(qty) * float(price)
667
+ if expected <= 0:
668
+ continue
669
+ rel_err = abs(float(revenue) - expected) / max(expected, 1.0)
670
+ kpi_checks += 1
671
+ if rel_err <= 0.25:
672
+ kpi_passes += 1
673
+ else:
674
+ bad_kpi += 1
675
+ if bad_kpi:
676
+ issues.append(
677
+ QualityIssue(
678
+ issue_type="kpi_consistency",
679
+ table=table.name,
680
+ column=revenue_key,
681
+ severity="medium",
682
+ count=bad_kpi,
683
+ message="Revenue deviates from quantity x price beyond tolerance",
684
+ samples=[],
685
+ )
686
+ )
687
+
688
+ col_map = {col.name.lower(): col.name for col in table.columns}
689
+
690
+ total_value_key = col_map.get("total_value_usd") or col_map.get("total_value")
691
+ reported_value_key = col_map.get("reported_value_usd") or col_map.get("reported_value")
692
+ distributions_key = col_map.get("distributions_usd") or col_map.get("distributions")
693
+ gross_irr_key = col_map.get("gross_irr")
694
+ net_irr_key = col_map.get("net_irr")
695
+ gross_without_subline_key = col_map.get("gross_irr_without_sub_line")
696
+ irr_impact_bps_key = col_map.get("irr_sub_line_impact_bps")
697
+ total_return_multiple_key = col_map.get("total_return_multiple")
698
+ dpi_multiple_key = col_map.get("dpi_multiple")
699
+ rvpi_multiple_key = col_map.get("rvpi_multiple")
700
+ revenue_metric_key = col_map.get("revenue_usd") or col_map.get("revenue")
701
+ ebitda_key = col_map.get("ebitda_usd") or col_map.get("ebitda")
702
+ ebitda_margin_key = col_map.get("ebitda_margin_pct")
703
+ net_debt_key = col_map.get("net_debt_usd") or col_map.get("net_debt")
704
+ debt_to_ebitda_key = col_map.get("debt_to_ebitda_ratio")
705
+
706
+ pe_identity_failures: dict[str, int] = {
707
+ "total_value_identity": 0,
708
+ "net_vs_gross_irr": 0,
709
+ "irr_bps_identity": 0,
710
+ "tvpi_identity": 0,
711
+ "ebitda_margin_identity": 0,
712
+ "debt_to_ebitda_identity": 0,
713
+ }
714
+
715
+ for row in rows:
716
+ if (
717
+ total_value_key
718
+ and reported_value_key
719
+ and distributions_key
720
+ and _is_numeric(row.get(total_value_key))
721
+ and _is_numeric(row.get(reported_value_key))
722
+ and _is_numeric(row.get(distributions_key))
723
+ ):
724
+ expected_total = float(row[reported_value_key]) + float(row[distributions_key])
725
+ rel_err = abs(float(row[total_value_key]) - expected_total) / max(abs(expected_total), 1.0)
726
+ kpi_checks += 1
727
+ if rel_err <= 0.02:
728
+ kpi_passes += 1
729
+ else:
730
+ pe_identity_failures["total_value_identity"] += 1
731
+
732
+ if gross_irr_key and net_irr_key and _is_numeric(row.get(gross_irr_key)) and _is_numeric(row.get(net_irr_key)):
733
+ kpi_checks += 1
734
+ if float(row[net_irr_key]) <= float(row[gross_irr_key]):
735
+ kpi_passes += 1
736
+ else:
737
+ pe_identity_failures["net_vs_gross_irr"] += 1
738
+
739
+ if (
740
+ gross_irr_key
741
+ and gross_without_subline_key
742
+ and irr_impact_bps_key
743
+ and _is_numeric(row.get(gross_irr_key))
744
+ and _is_numeric(row.get(gross_without_subline_key))
745
+ and _is_numeric(row.get(irr_impact_bps_key))
746
+ ):
747
+ expected_bps = (float(row[gross_irr_key]) - float(row[gross_without_subline_key])) * 10000
748
+ kpi_checks += 1
749
+ if abs(float(row[irr_impact_bps_key]) - expected_bps) <= 5.0:
750
+ kpi_passes += 1
751
+ else:
752
+ pe_identity_failures["irr_bps_identity"] += 1
753
+
754
+ if (
755
+ total_return_multiple_key
756
+ and dpi_multiple_key
757
+ and rvpi_multiple_key
758
+ and _is_numeric(row.get(total_return_multiple_key))
759
+ and _is_numeric(row.get(dpi_multiple_key))
760
+ and _is_numeric(row.get(rvpi_multiple_key))
761
+ ):
762
+ expected_tvpi = float(row[dpi_multiple_key]) + float(row[rvpi_multiple_key])
763
+ kpi_checks += 1
764
+ if abs(float(row[total_return_multiple_key]) - expected_tvpi) <= 0.05:
765
+ kpi_passes += 1
766
+ else:
767
+ pe_identity_failures["tvpi_identity"] += 1
768
+
769
+ if (
770
+ revenue_metric_key
771
+ and ebitda_key
772
+ and ebitda_margin_key
773
+ and _is_numeric(row.get(revenue_metric_key))
774
+ and _is_numeric(row.get(ebitda_key))
775
+ and _is_numeric(row.get(ebitda_margin_key))
776
+ and float(row[revenue_metric_key]) > 0
777
+ ):
778
+ expected_margin = (float(row[ebitda_key]) / float(row[revenue_metric_key])) * 100.0
779
+ kpi_checks += 1
780
+ if abs(float(row[ebitda_margin_key]) - expected_margin) <= 0.5:
781
+ kpi_passes += 1
782
+ else:
783
+ pe_identity_failures["ebitda_margin_identity"] += 1
784
+
785
+ if (
786
+ net_debt_key
787
+ and ebitda_key
788
+ and debt_to_ebitda_key
789
+ and _is_numeric(row.get(net_debt_key))
790
+ and _is_numeric(row.get(ebitda_key))
791
+ and _is_numeric(row.get(debt_to_ebitda_key))
792
+ and float(row[ebitda_key]) > 0
793
+ ):
794
+ expected_ratio = float(row[net_debt_key]) / float(row[ebitda_key])
795
+ kpi_checks += 1
796
+ if abs(float(row[debt_to_ebitda_key]) - expected_ratio) <= 0.1:
797
+ kpi_passes += 1
798
+ else:
799
+ pe_identity_failures["debt_to_ebitda_identity"] += 1
800
+
801
+ pe_issue_messages = {
802
+ "total_value_identity": "Total value deviates from reported value plus distributions",
803
+ "net_vs_gross_irr": "Net IRR exceeds gross IRR",
804
+ "irr_bps_identity": "IRR basis-point impact does not match gross vs unsubsidized IRR delta",
805
+ "tvpi_identity": "Total return multiple deviates from DPI plus RVPI",
806
+ "ebitda_margin_identity": "EBITDA margin deviates from EBITDA divided by revenue",
807
+ "debt_to_ebitda_identity": "Debt-to-EBITDA ratio deviates from net debt divided by EBITDA",
808
+ }
809
+ pe_issue_columns = {
810
+ "total_value_identity": total_value_key or "*",
811
+ "net_vs_gross_irr": net_irr_key or "*",
812
+ "irr_bps_identity": irr_impact_bps_key or "*",
813
+ "tvpi_identity": total_return_multiple_key or "*",
814
+ "ebitda_margin_identity": ebitda_margin_key or "*",
815
+ "debt_to_ebitda_identity": debt_to_ebitda_key or "*",
816
+ }
817
+ for issue_name, failure_count in pe_identity_failures.items():
818
+ if failure_count:
819
+ issues.append(
820
+ QualityIssue(
821
+ issue_type=issue_name,
822
+ table=table.name,
823
+ column=pe_issue_columns[issue_name],
824
+ severity="high",
825
+ count=failure_count,
826
+ message=pe_issue_messages[issue_name],
827
+ samples=[],
828
+ )
829
+ )
830
+
831
+ semantic_pass_ratio = (semantic_passes / semantic_checks) if semantic_checks else 1.0
832
+
833
+ finance_summary: dict[str, Any] = {}
834
+ if "SAAS_CUSTOMER_MONTHLY" in generated_data:
835
+ finance_issues, extra_kpi_checks, extra_kpi_passes, finance_temporal_violations, finance_summary = (
836
+ self._validate_saas_finance_contract(generated_data)
837
+ )
838
+ issues.extend(finance_issues)
839
+ kpi_checks += extra_kpi_checks
840
+ kpi_passes += extra_kpi_passes
841
+ temporal_violations += finance_temporal_violations
842
+
843
+ smoothness_score = (sum(smoothness_scores) / len(smoothness_scores)) if smoothness_scores else 1.0
844
+ allowed_outliers = int(outlier_points_total * max_outlier_points_pct) + max_outlier_events
845
+ if outlier_candidates <= allowed_outliers:
846
+ outlier_explainability = 1.0
847
+ else:
848
+ outlier_explainability = allowed_outliers / max(outlier_candidates, 1)
849
+ kpi_consistency = (kpi_passes / kpi_checks) if kpi_checks else 1.0
850
+
851
+ thresholds = self.spec.thresholds
852
+
853
+ if volatility_breaches > thresholds.max_volatility_breaches:
854
+ issues.append(
855
+ QualityIssue(
856
+ issue_type="volatility",
857
+ table="*",
858
+ column="*",
859
+ severity="high",
860
+ count=volatility_breaches,
861
+ message=f"Detected {volatility_breaches} period-over-period volatility breaches",
862
+ samples=[],
863
+ )
864
+ )
865
+ if smoothness_score < thresholds.min_smoothness_score:
866
+ issues.append(
867
+ QualityIssue(
868
+ issue_type="smoothness",
869
+ table="*",
870
+ column="*",
871
+ severity="high",
872
+ count=1,
873
+ message=f"Smoothness score {smoothness_score:.3f} below threshold",
874
+ samples=[],
875
+ )
876
+ )
877
+ if outlier_explainability < thresholds.min_outlier_explainability:
878
+ issues.append(
879
+ QualityIssue(
880
+ issue_type="outlier_explainability",
881
+ table="*",
882
+ column="*",
883
+ severity="medium",
884
+ count=1,
885
+ message=f"Outlier explainability {outlier_explainability:.3f} below threshold",
886
+ samples=[],
887
+ )
888
+ )
889
+ if kpi_consistency < thresholds.min_kpi_consistency:
890
+ issues.append(
891
+ QualityIssue(
892
+ issue_type="kpi_consistency",
893
+ table="*",
894
+ column="*",
895
+ severity="medium",
896
+ count=1,
897
+ message=f"KPI consistency {kpi_consistency:.3f} below threshold",
898
+ samples=[],
899
+ )
900
+ )
901
+
902
+ passed = (
903
+ categorical_junk_count <= thresholds.max_categorical_junk
904
+ and fk_orphan_count <= thresholds.max_fk_orphans
905
+ and temporal_violations <= thresholds.max_temporal_violations
906
+ and numeric_violations <= thresholds.max_numeric_violations
907
+ and semantic_pass_ratio >= thresholds.min_semantic_pass_ratio
908
+ and volatility_breaches <= thresholds.max_volatility_breaches
909
+ and smoothness_score >= thresholds.min_smoothness_score
910
+ and outlier_explainability >= thresholds.min_outlier_explainability
911
+ and kpi_consistency >= thresholds.min_kpi_consistency
912
+ )
913
+
914
+ summary = {
915
+ "categorical_junk_count": categorical_junk_count,
916
+ "fk_orphan_count": fk_orphan_count,
917
+ "temporal_violations": temporal_violations,
918
+ "numeric_violations": numeric_violations,
919
+ "semantic_checks": semantic_checks,
920
+ "semantic_passes": semantic_passes,
921
+ "semantic_pass_ratio": semantic_pass_ratio,
922
+ "volatility_breaches": volatility_breaches,
923
+ "smoothness_score": smoothness_score,
924
+ "outlier_candidates": outlier_candidates,
925
+ "outlier_explainability": outlier_explainability,
926
+ "kpi_checks": kpi_checks,
927
+ "kpi_consistency": kpi_consistency,
928
+ }
929
+ summary.update(finance_summary)
930
+ return QualityReport(
931
+ passed=passed,
932
+ generated_at=now_utc_iso(),
933
+ summary=summary,
934
+ issues=issues,
935
+ )
legitdata_project/legitdata/relationships/fk_manager.py CHANGED
@@ -62,10 +62,9 @@ class FKManager:
62
  table_key = references_table.upper()
63
  col_key = references_column.upper()
64
 
65
- if table_key not in self._pk_values:
66
- return None
67
-
68
- if col_key not in self._pk_values[table_key]:
69
  # Referenced column is not the PK - look it up from registered rows
70
  rows = self._table_rows.get(table_key, [])
71
  if rows:
@@ -81,12 +80,12 @@ class FKManager:
81
 
82
  # Fallback to PK only if it's the same type situation
83
  # (This is the old behavior - may return wrong type)
84
- if self._pk_values[table_key]:
85
- col_key = list(self._pk_values[table_key].keys())[0]
86
  else:
87
  return None
88
 
89
- values = self._pk_values[table_key][col_key]
90
  if not values:
91
  return None
92
 
 
62
  table_key = references_table.upper()
63
  col_key = references_column.upper()
64
 
65
+ table_pk_map = self._pk_values.get(table_key, {})
66
+
67
+ if col_key not in table_pk_map:
 
68
  # Referenced column is not the PK - look it up from registered rows
69
  rows = self._table_rows.get(table_key, [])
70
  if rows:
 
80
 
81
  # Fallback to PK only if it's the same type situation
82
  # (This is the old behavior - may return wrong type)
83
+ if table_pk_map:
84
+ col_key = list(table_pk_map.keys())[0]
85
  else:
86
  return None
87
 
88
+ values = table_pk_map.get(col_key, [])
89
  if not values:
90
  return None
91
 
legitdata_project/legitdata/sourcer/ai_generator.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  import re
5
  from typing import Optional
6
  from ..analyzer.context_builder import CompanyContext
 
7
 
8
 
9
  class AIGenerator:
@@ -167,37 +168,19 @@ Return ONLY a JSON array of strings. Example: ["value1", "value2", "value3"]"""
167
 
168
  def _generate_fallback(self, column_name: str, num_values: int) -> list[str]:
169
  """Generate fallback values when AI is unavailable."""
170
- col_lower = column_name.lower()
171
-
172
- # Common fallback values based on column type
173
- if 'region' in col_lower:
174
- base = ["North", "South", "East", "West", "Central", "Northeast", "Southeast", "Northwest", "Southwest", "Midwest"]
175
- elif 'country' in col_lower:
176
- base = ["United States", "Canada", "United Kingdom", "Germany", "France", "Japan", "Australia", "Brazil", "India", "Mexico"]
177
- elif 'status' in col_lower:
178
- base = ["Active", "Inactive", "Pending", "Completed", "Cancelled"]
179
- elif 'tier' in col_lower:
180
- base = ["Premium", "Standard", "Basic", "Enterprise", "Starter"]
181
- elif 'type' in col_lower:
182
- base = ["Type A", "Type B", "Type C", "Standard", "Special"]
183
- elif 'segment' in col_lower:
184
- base = ["Enterprise", "SMB", "Consumer", "Government", "Education"]
185
- elif 'category' in col_lower:
186
- base = ["Category 1", "Category 2", "Category 3", "Category 4", "Category 5"]
187
- elif 'group' in col_lower:
188
- base = ["Group A", "Group B", "Group C", "Group D", "Group E"]
189
- else:
190
- base = [f"{column_name}_{i}" for i in range(1, num_values + 1)]
191
-
192
- # Extend if needed
193
- result = base.copy()
194
- while len(result) < num_values:
195
- for i, v in enumerate(base):
196
- result.append(f"{v} {len(result) // len(base) + 1}")
197
- if len(result) >= num_values:
198
- break
199
-
200
- return result[:num_values]
201
 
202
  def clear_cache(self):
203
  """Clear the generation cache."""
 
4
  import re
5
  from typing import Optional
6
  from ..analyzer.context_builder import CompanyContext
7
+ from ..domain import get_domain_values, infer_semantic_type
8
 
9
 
10
  class AIGenerator:
 
168
 
169
  def _generate_fallback(self, column_name: str, num_values: int) -> list[str]:
170
  """Generate fallback values when AI is unavailable."""
171
+ semantic = infer_semantic_type(column_name)
172
+ base = get_domain_values(semantic)
173
+ if not base:
174
+ col_lower = column_name.lower()
175
+ if "type" in col_lower:
176
+ base = ["Standard", "Premium", "Basic"]
177
+ elif "group" in col_lower:
178
+ base = ["Group A", "Group B", "Group C"]
179
+ else:
180
+ base = [f"{column_name}_value"]
181
+
182
+ # Cycle deterministically without numeric suffix pollution.
183
+ return [base[i % len(base)] for i in range(num_values)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  def clear_cache(self):
186
  """Clear the generation cache."""
legitdata_project/legitdata/sourcer/generic.py CHANGED
@@ -4,15 +4,27 @@ FIXED VERSION - Proper fallbacks, no garbage concatenated words.
4
  """
5
 
6
  import random
 
 
7
  import uuid
8
  from datetime import datetime, timedelta
9
  from decimal import Decimal
10
  from typing import Any, Optional
11
  from faker import Faker
 
12
 
13
  fake = Faker()
14
 
15
 
 
 
 
 
 
 
 
 
 
16
  class GenericSourcer:
17
  """Generates generic/synthetic data based on column strategies."""
18
 
@@ -182,7 +194,7 @@ class GenericSourcer:
182
 
183
  return f"{prefix}_{self._key_counters[prefix]:05d}"
184
 
185
- def generate_value(self, strategy: str) -> Any:
186
  """
187
  Generate a value based on the strategy string.
188
 
@@ -215,31 +227,95 @@ class GenericSourcer:
215
  return self._generate(strategy_type, params)
216
  except Exception as e:
217
  print(f"Warning: Strategy '{strategy}' failed: {e}, using safe default")
218
- return self._gen_safe_default()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
  def _generate(self, strategy_type: str, params: list) -> Any:
221
  """Route to specific generator based on strategy type."""
222
 
223
  # Random integers
224
  if strategy_type == 'random_int':
225
- min_val = int(params[0]) if params else 1
226
- max_val = int(params[1]) if len(params) > 1 else 1000
227
  return random.randint(min_val, max_val)
228
 
229
  # Random decimals
230
  elif strategy_type == 'random_decimal':
231
- min_val = float(params[0]) if params else 0.0
232
- max_val = float(params[1]) if len(params) > 1 else 100.0
233
- precision = int(params[2]) if len(params) > 2 else 2
234
  value = random.uniform(min_val, max_val)
235
  return Decimal(str(round(value, precision)))
236
 
237
  # Random floats
238
  elif strategy_type == 'random_float':
239
- min_val = float(params[0]) if params else 0.0
240
- max_val = float(params[1]) if len(params) > 1 else 100.0
241
  return round(random.uniform(min_val, max_val), 2)
242
 
 
 
 
 
 
 
243
  # Choice from list
244
  elif strategy_type == 'choice':
245
  if params:
@@ -251,7 +327,7 @@ class GenericSourcer:
251
  # Format: "weighted_choice:val1,weight1,val2,weight2,..."
252
  if len(params) >= 2:
253
  values = params[0::2]
254
- weights = [float(w) for w in params[1::2]]
255
  return random.choices(values, weights=weights)[0]
256
  return params[0] if params else "Unknown"
257
 
@@ -261,14 +337,14 @@ class GenericSourcer:
261
 
262
  # Boolean
263
  elif strategy_type == 'boolean':
264
- true_pct = float(params[0]) if params else 0.5
265
  return random.random() < true_pct
266
 
267
  # Dates
268
  elif strategy_type == 'date_between':
269
  if params and len(params) >= 2:
270
- start = datetime.strptime(params[0], '%Y-%m-%d')
271
- end = datetime.strptime(params[1], '%Y-%m-%d')
272
  else:
273
  end = datetime.now()
274
  start = end - timedelta(days=365*2)
@@ -280,7 +356,7 @@ class GenericSourcer:
280
  # Sequential (for IDs)
281
  elif strategy_type == 'sequential':
282
  # This should be handled by the generator with state
283
- start = int(params[0]) if params else 1
284
  return start
285
 
286
  # Faker-based generators
@@ -353,13 +429,39 @@ class GenericSourcer:
353
  else:
354
  return self._gen_safe_default()
355
 
356
- def _gen_safe_default(self) -> str:
357
  """
358
  Generate a safe default value - NOT garbage concatenated words.
359
  Returns a simple, realistic-looking value.
360
  """
 
 
 
 
 
 
 
 
 
 
 
361
  # Return a simple word, not concatenated garbage
362
  return self.fake.word().capitalize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  def new_row(self):
365
  """
@@ -395,8 +497,25 @@ class GenericSourcer:
395
  Infer the best strategy for a column based on its name and type.
396
  This is called when no explicit strategy is provided.
397
  """
398
- col_lower = col_name.lower().replace('_', '').replace('-', '')
 
399
  col_type_upper = col_type.upper() if col_type else 'VARCHAR'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  # === ID COLUMNS ===
402
  if col_lower.endswith('id') or col_lower.endswith('key'):
@@ -415,6 +534,12 @@ class GenericSourcer:
415
  return 'random_int:100000,9999999'
416
 
417
  # === PRICE/MONEY COLUMNS ===
 
 
 
 
 
 
418
  if any(x in col_lower for x in ['price', 'cost', 'amount', 'total', 'subtotal']):
419
  return 'random_decimal:1.00,500.00'
420
 
@@ -530,12 +655,30 @@ class GenericSourcer:
530
  return 'lookup:quarter'
531
 
532
  # === DATE COLUMNS ===
533
- today = datetime.now().strftime('%Y-%m-%d')
534
- if any(x in col_lower for x in ['date', 'datetime', 'timestamp', 'createdat', 'updatedat']):
535
- return f'date_between:2024-01-01,{today}'
536
-
537
- if 'birthdate' in col_lower or 'dob' in col_lower:
538
- return 'date_between:1960-01-01,2005-12-31'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
 
540
  # === BOOLEAN COLUMNS - Check FIRST before contact columns! ===
541
  # is_mobile should be boolean, not phone number
@@ -552,8 +695,10 @@ class GenericSourcer:
552
  return 'boolean:0.5'
553
 
554
  # === NAME COLUMNS ===
555
- if col_lower in ('name', 'fullname', 'customername', 'clientname'):
556
  return 'name'
 
 
557
 
558
  if col_lower in ('firstname', 'fname', 'givenname'):
559
  return 'first_name'
@@ -565,10 +710,12 @@ class GenericSourcer:
565
  # col_lower has underscores stripped, so check for 'name' suffix
566
  if col_lower.endswith('name'):
567
  # Exclude non-person names
568
- non_person = ('productname', 'companyname', 'brandname', 'storename',
 
569
  'warehousename', 'centername', 'campaignname', 'tablename',
570
  'columnname', 'schemaname', 'filename', 'hostname',
571
- 'categoryname', 'holidayname')
 
572
  if col_lower not in non_person:
573
  return 'name'
574
 
@@ -600,7 +747,20 @@ class GenericSourcer:
600
  return 'company'
601
 
602
  # === URL COLUMNS ===
603
- if 'url' in col_lower or 'website' in col_lower or 'link' in col_lower:
 
 
 
 
 
 
 
 
 
 
 
 
 
604
  return 'url'
605
 
606
  # === DESCRIPTION/TEXT COLUMNS ===
@@ -625,8 +785,7 @@ class GenericSourcer:
625
  return 'boolean:0.5'
626
 
627
  if 'DATE' in col_type_upper or 'TIME' in col_type_upper:
628
- today = datetime.now().strftime('%Y-%m-%d')
629
- return f'date_between:2024-01-01,{today}'
630
 
631
  # === FINAL FALLBACK ===
632
  # Return a single word, not garbage
 
4
  """
5
 
6
  import random
7
+ import re
8
+ import string
9
  import uuid
10
  from datetime import datetime, timedelta
11
  from decimal import Decimal
12
  from typing import Any, Optional
13
  from faker import Faker
14
+ from ..domain import SemanticType, get_domain_values, infer_semantic_type
15
 
16
  fake = Faker()
17
 
18
 
19
+ def _years_ago(dt: datetime, years: int) -> datetime:
20
+ """Return datetime shifted back by full years (leap-safe)."""
21
+ try:
22
+ return dt.replace(year=dt.year - years)
23
+ except ValueError:
24
+ # Handle Feb 29 on non-leap years.
25
+ return dt.replace(month=2, day=28, year=dt.year - years)
26
+
27
+
28
  class GenericSourcer:
29
  """Generates generic/synthetic data based on column strategies."""
30
 
 
194
 
195
  return f"{prefix}_{self._key_counters[prefix]:05d}"
196
 
197
+ def generate_value(self, strategy: str, expected_type: Optional[str] = None) -> Any:
198
  """
199
  Generate a value based on the strategy string.
200
 
 
227
  return self._generate(strategy_type, params)
228
  except Exception as e:
229
  print(f"Warning: Strategy '{strategy}' failed: {e}, using safe default")
230
+ return self._gen_safe_default(expected_type)
231
+
232
+ def get_strategy_for_semantic(self, semantic_type: SemanticType, use_case: str | None = None) -> Optional[str]:
233
+ """Return deterministic strategy for a semantic type when possible."""
234
+ if semantic_type in {
235
+ SemanticType.REGION,
236
+ SemanticType.CATEGORY,
237
+ SemanticType.SECTOR_NAME,
238
+ SemanticType.SECTOR_CATEGORY,
239
+ SemanticType.SEGMENT,
240
+ SemanticType.TIER,
241
+ SemanticType.STATUS,
242
+ SemanticType.FUND_STRATEGY,
243
+ SemanticType.INVESTOR_TYPE,
244
+ SemanticType.INVESTMENT_STAGE,
245
+ SemanticType.COVENANT_STATUS,
246
+ SemanticType.DEBT_PERFORMANCE_STATUS,
247
+ SemanticType.CHANNEL,
248
+ SemanticType.COUNTRY,
249
+ SemanticType.STATE,
250
+ SemanticType.CITY,
251
+ SemanticType.POSTAL_CODE,
252
+ SemanticType.SEASON,
253
+ SemanticType.HOLIDAY_NAME,
254
+ SemanticType.EVENT_NAME,
255
+ SemanticType.PRODUCT_NAME,
256
+ SemanticType.BRAND_NAME,
257
+ SemanticType.BRANCH_NAME,
258
+ SemanticType.ORG_NAME,
259
+ SemanticType.DEPARTMENT_NAME,
260
+ SemanticType.COLOR_DESCRIPTION,
261
+ SemanticType.PACKAGING_SIZE,
262
+ }:
263
+ values = get_domain_values(semantic_type, use_case)
264
+ if values:
265
+ return f"choice:{','.join(values)}"
266
+ if semantic_type == SemanticType.PERSON_NAME:
267
+ return "name"
268
+ if semantic_type == SemanticType.INTEREST_RATE:
269
+ return "random_decimal:0.01,0.35,4"
270
+ if semantic_type == SemanticType.RETURN_RATE:
271
+ return "random_decimal:-0.10,0.35,4"
272
+ if semantic_type == SemanticType.RETURN_MULTIPLE:
273
+ return "random_decimal:0.50,4.50,3"
274
+ if semantic_type == SemanticType.BASIS_POINTS:
275
+ return "random_int:25,250"
276
+ if semantic_type == SemanticType.LEVERAGE_RATIO:
277
+ return "random_decimal:0.50,8.00,3"
278
+ if semantic_type == SemanticType.PERCENT:
279
+ return "random_decimal:0.00,100.00,2"
280
+ if semantic_type == SemanticType.DURATION_DAYS:
281
+ return "random_int:0,365"
282
+ if semantic_type == SemanticType.QUANTITY:
283
+ return "random_int:1,50"
284
+ if semantic_type == SemanticType.COUNT:
285
+ return "random_int:0,50000"
286
+ if semantic_type == SemanticType.SCORE:
287
+ return "random_decimal:1.0,10.0,1"
288
+ return None
289
 
290
  def _generate(self, strategy_type: str, params: list) -> Any:
291
  """Route to specific generator based on strategy type."""
292
 
293
  # Random integers
294
  if strategy_type == 'random_int':
295
+ min_val = self._extract_int(params[0], 1) if params else 1
296
+ max_val = self._extract_int(params[1], 1000) if len(params) > 1 else 1000
297
  return random.randint(min_val, max_val)
298
 
299
  # Random decimals
300
  elif strategy_type == 'random_decimal':
301
+ min_val = self._extract_float(params[0], 0.0) if params else 0.0
302
+ max_val = self._extract_float(params[1], 100.0) if len(params) > 1 else 100.0
303
+ precision = self._extract_int(params[2], 2) if len(params) > 2 else 2
304
  value = random.uniform(min_val, max_val)
305
  return Decimal(str(round(value, precision)))
306
 
307
  # Random floats
308
  elif strategy_type == 'random_float':
309
+ min_val = self._extract_float(params[0], 0.0) if params else 0.0
310
+ max_val = self._extract_float(params[1], 100.0) if len(params) > 1 else 100.0
311
  return round(random.uniform(min_val, max_val), 2)
312
 
313
+ # Random strings (alphanumeric IDs/codes)
314
+ elif strategy_type == 'random_string':
315
+ length = self._extract_int(params[0], 10) if params else 10
316
+ alphabet = string.ascii_uppercase + string.digits
317
+ return ''.join(random.choices(alphabet, k=max(1, length)))
318
+
319
  # Choice from list
320
  elif strategy_type == 'choice':
321
  if params:
 
327
  # Format: "weighted_choice:val1,weight1,val2,weight2,..."
328
  if len(params) >= 2:
329
  values = params[0::2]
330
+ weights = [self._extract_float(w, 1.0) for w in params[1::2]]
331
  return random.choices(values, weights=weights)[0]
332
  return params[0] if params else "Unknown"
333
 
 
337
 
338
  # Boolean
339
  elif strategy_type == 'boolean':
340
+ true_pct = self._extract_float(params[0], 0.5) if params else 0.5
341
  return random.random() < true_pct
342
 
343
  # Dates
344
  elif strategy_type == 'date_between':
345
  if params and len(params) >= 2:
346
+ start = datetime.strptime(self._extract_date(params[0], '2024-01-01'), '%Y-%m-%d')
347
+ end = datetime.strptime(self._extract_date(params[1], datetime.now().strftime('%Y-%m-%d')), '%Y-%m-%d')
348
  else:
349
  end = datetime.now()
350
  start = end - timedelta(days=365*2)
 
356
  # Sequential (for IDs)
357
  elif strategy_type == 'sequential':
358
  # This should be handled by the generator with state
359
+ start = self._extract_int(params[0], 1) if params else 1
360
  return start
361
 
362
  # Faker-based generators
 
429
  else:
430
  return self._gen_safe_default()
431
 
432
+ def _gen_safe_default(self, expected_type: Optional[str] = None) -> Any:
433
  """
434
  Generate a safe default value - NOT garbage concatenated words.
435
  Returns a simple, realistic-looking value.
436
  """
437
+ if expected_type:
438
+ t = expected_type.upper()
439
+ if any(x in t for x in ('INT', 'BIGINT', 'SMALLINT')):
440
+ return 0
441
+ if any(x in t for x in ('NUMBER', 'NUMERIC', 'DECIMAL', 'FLOAT', 'DOUBLE', 'REAL')):
442
+ return Decimal("0.00")
443
+ if 'BOOL' in t or 'BOOLEAN' in t or 'BIT' in t:
444
+ return False
445
+ if 'DATE' in t:
446
+ return datetime.now()
447
+
448
  # Return a simple word, not concatenated garbage
449
  return self.fake.word().capitalize()
450
+
451
+ def _extract_float(self, raw: str, default: float) -> float:
452
+ cleaned = (raw or "").split("(", 1)[0].strip()
453
+ match = re.search(r"[-+]?\d*\.?\d+", cleaned)
454
+ return float(match.group(0)) if match else default
455
+
456
+ def _extract_int(self, raw: str, default: int) -> int:
457
+ cleaned = (raw or "").split("(", 1)[0].strip()
458
+ match = re.search(r"[-+]?\d+", cleaned)
459
+ return int(match.group(0)) if match else default
460
+
461
+ def _extract_date(self, raw: str, default: str) -> str:
462
+ cleaned = (raw or "").split("(", 1)[0].strip()
463
+ match = re.search(r"\d{4}-\d{2}-\d{2}", cleaned)
464
+ return match.group(0) if match else default
465
 
466
  def new_row(self):
467
  """
 
497
  Infer the best strategy for a column based on its name and type.
498
  This is called when no explicit strategy is provided.
499
  """
500
+ col_raw_lower = col_name.lower()
501
+ col_lower = col_raw_lower.replace('_', '').replace('-', '')
502
  col_type_upper = col_type.upper() if col_type else 'VARCHAR'
503
+
504
+ # Prefer semantic strategy first for quality-critical columns.
505
+ semantic = infer_semantic_type(col_name, col_type)
506
+ semantic_strategy = self.get_strategy_for_semantic(semantic)
507
+ if semantic_strategy:
508
+ return semantic_strategy
509
+
510
+ # Identifier-like short codes (e.g., TAX_ID_LAST4) should remain code-like.
511
+ if (
512
+ 'taxid' in col_lower
513
+ or 'last4' in col_lower
514
+ or col_lower.endswith('idlast4')
515
+ ):
516
+ if 'VARCHAR' in col_type_upper or 'CHAR' in col_type_upper or 'TEXT' in col_type_upper:
517
+ return 'random_string:4'
518
+ return 'random_int:1000,9999'
519
 
520
  # === ID COLUMNS ===
521
  if col_lower.endswith('id') or col_lower.endswith('key'):
 
534
  return 'random_int:100000,9999999'
535
 
536
  # === PRICE/MONEY COLUMNS ===
537
+ if 'discount' in col_lower:
538
+ return 'random_decimal:0.00,200.00'
539
+
540
+ if any(x in col_lower for x in ['tax', 'shipping']):
541
+ return 'random_decimal:1.00,250.00'
542
+
543
  if any(x in col_lower for x in ['price', 'cost', 'amount', 'total', 'subtotal']):
544
  return 'random_decimal:1.00,500.00'
545
 
 
655
  return 'lookup:quarter'
656
 
657
  # === DATE COLUMNS ===
658
+ today = datetime.now()
659
+ today_str = today.strftime('%Y-%m-%d')
660
+
661
+ # Birth dates MUST be handled before generic date matching to avoid toddler ages
662
+ is_birth_col = (
663
+ 'birthdate' in col_lower or
664
+ 'dateofbirth' in col_lower or
665
+ col_lower.endswith('dob') or
666
+ col_lower == 'dob'
667
+ )
668
+ if is_birth_col:
669
+ oldest = _years_ago(today, 95).strftime('%Y-%m-%d')
670
+ youngest = _years_ago(today, 18).strftime('%Y-%m-%d')
671
+ return f'date_between:{oldest},{youngest}'
672
+
673
+ financial_to_date_name = (
674
+ ('todate' in col_lower or 'to_date' in col_lower)
675
+ and any(
676
+ x in col_lower
677
+ for x in ['amount', 'balance', 'cost', 'price', 'revenue', 'sales', 'fee', 'tax', 'total', 'recovery', 'payment']
678
+ )
679
+ )
680
+ if any(x in col_lower for x in ['date', 'datetime', 'timestamp', 'createdat', 'updatedat']) and not financial_to_date_name:
681
+ return f'date_between:2024-01-01,{today_str}'
682
 
683
  # === BOOLEAN COLUMNS - Check FIRST before contact columns! ===
684
  # is_mobile should be boolean, not phone number
 
695
  return 'boolean:0.5'
696
 
697
  # === NAME COLUMNS ===
698
+ if col_lower in ('name', 'fullname'):
699
  return 'name'
700
+ if col_lower in ('customername', 'accountname', 'clientname', 'companyname', 'organizationname'):
701
+ return 'company'
702
 
703
  if col_lower in ('firstname', 'fname', 'givenname'):
704
  return 'first_name'
 
710
  # col_lower has underscores stripped, so check for 'name' suffix
711
  if col_lower.endswith('name'):
712
  # Exclude non-person names
713
+ non_person = ('productname', 'companyname', 'accountname', 'customername',
714
+ 'clientname', 'organizationname', 'brandname', 'storename',
715
  'warehousename', 'centername', 'campaignname', 'tablename',
716
  'columnname', 'schemaname', 'filename', 'hostname',
717
+ 'categoryname', 'holidayname', 'branchname', 'departmentname',
718
+ 'regionname', 'cityname')
719
  if col_lower not in non_person:
720
  return 'name'
721
 
 
747
  return 'company'
748
 
749
  # === URL COLUMNS ===
750
+ has_url_token = (
751
+ col_raw_lower == 'url'
752
+ or col_raw_lower.startswith('url_')
753
+ or col_raw_lower.endswith('_url')
754
+ or '_url_' in col_raw_lower
755
+ or 'website' in col_raw_lower
756
+ )
757
+ has_link_token = (
758
+ col_raw_lower == 'link'
759
+ or col_raw_lower.startswith('link_')
760
+ or col_raw_lower.endswith('_link')
761
+ or '_link_' in col_raw_lower
762
+ )
763
+ if has_url_token or has_link_token:
764
  return 'url'
765
 
766
  # === DESCRIPTION/TEXT COLUMNS ===
 
785
  return 'boolean:0.5'
786
 
787
  if 'DATE' in col_type_upper or 'TIME' in col_type_upper:
788
+ return f'date_between:2024-01-01,{today_str}'
 
789
 
790
  # === FINAL FALLBACK ===
791
  # Return a single word, not garbage
legitdata_project/legitdata/storyspec.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Story-first generation contract for deterministic, realistic demo data."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ from dataclasses import asdict, dataclass, field
7
+ from typing import Any
8
+
9
+
10
+ @dataclass
11
+ class TrendProfile:
12
+ """Controls for synthetic trend shape and reproducibility."""
13
+
14
+ style: str = "smooth"
15
+ noise_band_pct: float = 0.04
16
+ seasonal_strength: float = 0.10
17
+
18
+
19
+ @dataclass
20
+ class OutlierBudget:
21
+ """Controls for intentional outlier placement."""
22
+
23
+ max_events: int = 2
24
+ max_points_pct: float = 0.02
25
+ event_multiplier_min: float = 1.8
26
+ event_multiplier_max: float = 3.2
27
+
28
+
29
+ @dataclass
30
+ class StorySpec:
31
+ """Canonical story specification used during generation."""
32
+
33
+ company_url: str
34
+ company_name: str
35
+ use_case: str
36
+ vertical: str
37
+ function: str
38
+ seed: int
39
+ realism_mode: str = "strict"
40
+ trend_profile: TrendProfile = field(default_factory=TrendProfile)
41
+ outlier_budget: OutlierBudget = field(default_factory=OutlierBudget)
42
+ kpis: list[str] = field(default_factory=list)
43
+ allowed_dimensions: list[str] = field(default_factory=list)
44
+ allowed_measures: list[str] = field(default_factory=list)
45
+ quality_targets: dict[str, Any] = field(default_factory=dict)
46
+ value_guardrails: dict[str, Any] = field(default_factory=dict)
47
+ domain_overrides: dict[str, list[str]] = field(default_factory=dict)
48
+ recommended_companies: list[str] = field(default_factory=list)
49
+
50
+ def to_dict(self) -> dict[str, Any]:
51
+ return asdict(self)
52
+
53
+
54
+ def _stable_seed(company_url: str, use_case: str) -> int:
55
+ text = f"{(company_url or '').strip().lower()}|{(use_case or '').strip().lower()}"
56
+ digest = hashlib.sha256(text.encode("utf-8")).hexdigest()[:8]
57
+ return int(digest, 16)
58
+
59
+
60
+ def _safe_float(value: Any, default: float) -> float:
61
+ try:
62
+ return float(value)
63
+ except (TypeError, ValueError):
64
+ return default
65
+
66
+
67
+ def _safe_int(value: Any, default: int) -> int:
68
+ try:
69
+ return int(value)
70
+ except (TypeError, ValueError):
71
+ return default
72
+
73
+
74
+ def _build_domain_overrides(company_name: str, use_case: str) -> dict[str, list[str]]:
75
+ company_clean = (company_name or "Demo Company").strip().title()
76
+ text = f"{company_name} {use_case}".lower()
77
+ overrides: dict[str, list[str]] = {}
78
+
79
+ if "amazon" in text:
80
+ overrides["product_name"] = [
81
+ "Echo Dot (5th Gen)",
82
+ "Kindle Paperwhite",
83
+ "Fire TV Stick 4K",
84
+ "Amazon Basics USB-C Cable",
85
+ "Ring Video Doorbell",
86
+ "Blink Outdoor Camera",
87
+ "iRobot Roomba Combo",
88
+ "Instant Pot Duo",
89
+ "LEGO Star Wars Set",
90
+ "Ninja Air Fryer",
91
+ ]
92
+ overrides["category"] = [
93
+ "Electronics",
94
+ "Home & Kitchen",
95
+ "Books",
96
+ "Fashion",
97
+ "Grocery",
98
+ "Sports & Outdoors",
99
+ "Health & Personal Care",
100
+ ]
101
+ overrides["channel"] = ["Prime", "Marketplace", "Subscribe & Save", "Amazon Fresh", "Retail"]
102
+
103
+ if any(tok in text for tok in ("prescription", "formulary", "pharma", "drug", "payer")):
104
+ overrides["category"] = [
105
+ "Cardiometabolic",
106
+ "Oncology",
107
+ "Women's Health",
108
+ "Immunology",
109
+ "Rare Disease",
110
+ ]
111
+ overrides["segment"] = [
112
+ "High Prescriber",
113
+ "Growth Prescriber",
114
+ "Low Prescriber",
115
+ "New Writer",
116
+ "Existing Writer",
117
+ ]
118
+ overrides["channel"] = [
119
+ "Field Sales",
120
+ "Medical Affairs",
121
+ "Payer Team",
122
+ "Digital HCP",
123
+ "Patient Support",
124
+ ]
125
+ overrides["org_name"] = [
126
+ "AOK Bayern",
127
+ "Techniker Krankenkasse",
128
+ "Barmer",
129
+ "DAK-Gesundheit",
130
+ "IKK Classic",
131
+ "Bayer Field Team South",
132
+ "Bayer Field Team North",
133
+ ]
134
+ overrides["product_name"] = ["Drug X 5mg", "Drug X 10mg", "Drug X Starter Pack", "Drug X Maintenance"]
135
+ if "germany" in text:
136
+ overrides["country"] = ["Germany"]
137
+ overrides["region"] = ["Bavaria", "North Rhine-Westphalia", "Berlin", "Hesse", "Saxony"]
138
+ overrides["city"] = ["Berlin", "Munich", "Hamburg", "Cologne", "Frankfurt", "Stuttgart"]
139
+
140
+ if not overrides.get("org_name"):
141
+ overrides["org_name"] = [
142
+ f"{company_clean} North Division",
143
+ f"{company_clean} South Division",
144
+ f"{company_clean} Enterprise Team",
145
+ f"{company_clean} Commercial Operations",
146
+ ]
147
+
148
+ return overrides
149
+
150
+
151
+ def build_story_spec(
152
+ company_url: str,
153
+ use_case: str,
154
+ company_name: str | None = None,
155
+ use_case_config: dict[str, Any] | None = None,
156
+ ) -> StorySpec:
157
+ """Build StorySpec from matrix controls and runtime context."""
158
+ config = use_case_config or {}
159
+ controls = config.get("story_controls", {}) if isinstance(config, dict) else {}
160
+ outlier_cfg = controls.get("outlier_budget", {})
161
+ quality_targets = controls.get("quality_targets", {}) or {}
162
+ value_guardrails = controls.get("value_guardrails", {}) or {}
163
+
164
+ seed = _stable_seed(company_url, use_case)
165
+ trend_profile = TrendProfile(
166
+ style=str(controls.get("trend_style", "smooth")),
167
+ noise_band_pct=_safe_float(controls.get("trend_noise_band_pct"), 0.04),
168
+ seasonal_strength=_safe_float(controls.get("seasonal_strength"), 0.10),
169
+ )
170
+ outlier_budget = OutlierBudget(
171
+ max_events=max(0, _safe_int(outlier_cfg.get("max_events"), 2)),
172
+ max_points_pct=max(0.0, _safe_float(outlier_cfg.get("max_points_pct"), 0.02)),
173
+ event_multiplier_min=max(1.0, _safe_float(outlier_cfg.get("event_multiplier_min"), 1.8)),
174
+ event_multiplier_max=max(1.0, _safe_float(outlier_cfg.get("event_multiplier_max"), 3.2)),
175
+ )
176
+
177
+ if outlier_budget.event_multiplier_max < outlier_budget.event_multiplier_min:
178
+ outlier_budget.event_multiplier_max = outlier_budget.event_multiplier_min
179
+
180
+ name = (company_name or company_url or "Demo Company").split(".")[0].title()
181
+ vertical = str(config.get("vertical") or "Generic")
182
+ function = str(config.get("function") or "Generic")
183
+
184
+ return StorySpec(
185
+ company_url=company_url,
186
+ company_name=name,
187
+ use_case=use_case,
188
+ vertical=vertical,
189
+ function=function,
190
+ seed=seed,
191
+ realism_mode=str(controls.get("realism_mode", "strict")),
192
+ trend_profile=trend_profile,
193
+ outlier_budget=outlier_budget,
194
+ kpis=list(config.get("kpis", []) or []),
195
+ allowed_dimensions=list(config.get("allowed_story_dimensions", []) or []),
196
+ allowed_measures=list(config.get("allowed_story_measures", []) or []),
197
+ quality_targets=quality_targets,
198
+ value_guardrails=value_guardrails,
199
+ domain_overrides=_build_domain_overrides(name, use_case),
200
+ recommended_companies=list(config.get("recommended_companies", []) or []),
201
+ )
202
+
203
+
204
+ def build_story_bundle(spec: StorySpec) -> dict[str, Any]:
205
+ """Build compact payload used by downstream ThoughtSpot handoff."""
206
+ return {
207
+ "vertical": spec.vertical,
208
+ "function": spec.function,
209
+ "kpis": spec.kpis[:8],
210
+ "allowed_dimensions": spec.allowed_dimensions[:10],
211
+ "allowed_measures": spec.allowed_measures[:10],
212
+ "trend_style": spec.trend_profile.style,
213
+ "trend_noise_band_pct": spec.trend_profile.noise_band_pct,
214
+ "seasonal_strength": spec.trend_profile.seasonal_strength,
215
+ "outlier_budget": {
216
+ "max_events": spec.outlier_budget.max_events,
217
+ "max_points_pct": spec.outlier_budget.max_points_pct,
218
+ },
219
+ "quality_targets": spec.quality_targets,
220
+ }
legitdata_project/pyproject.toml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "legitdata"
7
+ version = "0.1.0"
8
+ description = "Generate realistic synthetic data for analytics warehouses"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ {name = "Your Name", email = "your.email@example.com"}
14
+ ]
15
+ classifiers = [
16
+ "Development Status :: 3 - Alpha",
17
+ "Intended Audience :: Developers",
18
+ "License :: OSI Approved :: MIT License",
19
+ "Programming Language :: Python :: 3",
20
+ "Programming Language :: Python :: 3.10",
21
+ "Programming Language :: Python :: 3.11",
22
+ "Programming Language :: Python :: 3.12",
23
+ "Topic :: Database",
24
+ "Topic :: Software Development :: Testing",
25
+ ]
26
+ keywords = ["synthetic data", "data generation", "analytics", "snowflake", "testing"]
27
+
28
+ dependencies = []
29
+
30
+ [project.optional-dependencies]
31
+ snowflake = ["snowflake-connector-python>=3.0.0"]
32
+ ai = ["anthropic>=0.18.0"]
33
+ all = [
34
+ "snowflake-connector-python>=3.0.0",
35
+ "anthropic>=0.18.0",
36
+ ]
37
+ dev = [
38
+ "pytest>=7.0.0",
39
+ "pytest-cov>=4.0.0",
40
+ "black>=23.0.0",
41
+ "ruff>=0.1.0",
42
+ ]
43
+
44
+ [project.scripts]
45
+ legitdata = "legitdata.cli:main"
46
+
47
+ [project.urls]
48
+ Homepage = "https://github.com/yourname/legitdata"
49
+ Documentation = "https://github.com/yourname/legitdata#readme"
50
+ Repository = "https://github.com/yourname/legitdata"
51
+ Issues = "https://github.com/yourname/legitdata/issues"
52
+
53
+ [tool.setuptools.packages.find]
54
+ include = ["legitdata*"]
55
+
56
+ [tool.black]
57
+ line-length = 100
58
+ target-version = ['py310']
59
+
60
+ [tool.ruff]
61
+ line-length = 100
62
+ select = ["E", "F", "W", "I", "N"]
63
+ ignore = ["E501"]
64
+
65
+ [tool.pytest.ini_options]
66
+ testpaths = ["tests"]
67
+ python_files = ["test_*.py"]
legitdata_project/test_legitdata.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Test script for legitdata functionality."""
3
+
4
+ import sys
5
+ sys.path.insert(0, '/home/claude')
6
+
7
+ from legitdata import (
8
+ LegitGenerator,
9
+ parse_ddl_file,
10
+ SIZE_PRESETS,
11
+ USE_CASES
12
+ )
13
+ from legitdata.analyzer import ContextBuilder, ColumnClassifier, CompanyContext
14
+ from legitdata.sourcer import GenericSourcer
15
+ from legitdata.relationships import FKManager
16
+
17
+
18
+ def test_ddl_parsing():
19
+ """Test DDL parsing."""
20
+ print("=" * 60)
21
+ print("TEST: DDL Parsing")
22
+ print("=" * 60)
23
+
24
+ schema = parse_ddl_file('legitdata/sample_schema.sql')
25
+
26
+ print(f"\nParsed schema: {schema}")
27
+ print(f" Dimension tables: {len(schema.dimension_tables)}")
28
+ print(f" Fact tables: {len(schema.fact_tables)}")
29
+
30
+ print("\nDependency order:")
31
+ for table in schema.get_dependency_order():
32
+ print(f" {table}")
33
+ for fk in table.foreign_keys:
34
+ print(f" - {fk}")
35
+
36
+ # Verify SALES_TRANSACTIONS has FKs
37
+ sales = schema.get_table('SALES_TRANSACTIONS')
38
+ assert sales is not None, "SALES_TRANSACTIONS table not found"
39
+ assert sales.is_fact_table, "SALES_TRANSACTIONS should be a fact table"
40
+ assert len(sales.foreign_keys) == 4, f"Expected 4 FKs, got {len(sales.foreign_keys)}"
41
+
42
+ print("\n✓ DDL parsing tests passed!")
43
+ return schema
44
+
45
+
46
+ def test_context_builder():
47
+ """Test context building (without actual web scraping)."""
48
+ print("\n" + "=" * 60)
49
+ print("TEST: Context Builder")
50
+ print("=" * 60)
51
+
52
+ builder = ContextBuilder()
53
+
54
+ # Test minimal context creation
55
+ context = builder._create_minimal_context(
56
+ url="https://amazon.com",
57
+ use_case="Retail Analytics"
58
+ )
59
+
60
+ print(f"\nMinimal context:")
61
+ print(f" Company: {context.company_name}")
62
+ print(f" Industry: {context.industry}")
63
+ print(f" Use case prompt:\n{context.to_prompt()}")
64
+
65
+ assert context.company_name, "Company name should not be empty"
66
+ assert context.industry == "Retail", "Industry should be Retail for Retail Analytics"
67
+
68
+ print("\n✓ Context builder tests passed!")
69
+ return context
70
+
71
+
72
+ def test_column_classifier(schema, context):
73
+ """Test column classification."""
74
+ print("\n" + "=" * 60)
75
+ print("TEST: Column Classifier")
76
+ print("=" * 60)
77
+
78
+ classifier = ColumnClassifier()
79
+
80
+ # Test heuristic classification
81
+ classifications = classifier._classify_with_heuristics(
82
+ schema, context, "Retail Analytics"
83
+ )
84
+
85
+ print("\nClassifications:")
86
+ for table_name, columns in classifications.items():
87
+ print(f"\n{table_name}:")
88
+ for col_name, col_info in columns.items():
89
+ print(f" {col_name}: {col_info['classification']} -> {col_info['strategy'][:50]}...")
90
+
91
+ # Verify some expected classifications (case-insensitive)
92
+ products = classifications.get('PRODUCTS', {})
93
+ assert products.get('ProductName', {}).get('classification', '').upper() == 'SEARCH_REAL', \
94
+ "ProductName should be SEARCH_REAL"
95
+ assert products.get('ProductID', {}).get('classification', '').upper() == 'GENERIC', \
96
+ "ProductID should be GENERIC"
97
+
98
+ print("\n✓ Column classifier tests passed!")
99
+ return classifications
100
+
101
+
102
+ def test_generic_sourcer():
103
+ """Test generic data generation."""
104
+ print("\n" + "=" * 60)
105
+ print("TEST: Generic Sourcer")
106
+ print("=" * 60)
107
+
108
+ sourcer = GenericSourcer()
109
+
110
+ # Test various strategies
111
+ tests = [
112
+ ("random_int:1,100", "integer"),
113
+ ("random_decimal:10.00,500.00", "decimal"),
114
+ ("date_between:2023-01-01,2024-12-31", "date"),
115
+ ("boolean:0.7", "boolean"),
116
+ ("uuid", "uuid"),
117
+ ("choice:Gold|Silver|Bronze", "choice"),
118
+ ]
119
+
120
+ print("\nGenerated values:")
121
+ for strategy, desc in tests:
122
+ value = sourcer.generate_value(strategy)
123
+ print(f" {desc} ({strategy}): {value}")
124
+
125
+ # Test key generation
126
+ key = sourcer.generate_key("CUS")
127
+ print(f" business key: {key}")
128
+
129
+ # Test amount generation
130
+ amount = sourcer.generate_amount(10, 1000)
131
+ print(f" amount: {amount}")
132
+
133
+ print("\n✓ Generic sourcer tests passed!")
134
+
135
+
136
+ def test_fk_manager():
137
+ """Test FK relationship manager."""
138
+ print("\n" + "=" * 60)
139
+ print("TEST: FK Manager")
140
+ print("=" * 60)
141
+
142
+ manager = FKManager()
143
+
144
+ # Register some PK values
145
+ customer_ids = [1, 2, 3, 4, 5]
146
+ product_ids = [10, 20, 30, 40, 50]
147
+
148
+ manager.register_pk_values("CUSTOMERS", "CustomerID", customer_ids)
149
+ manager.register_pk_values("PRODUCTS", "ProductID", product_ids)
150
+
151
+ print("\nRegistered tables:", manager.get_available_tables())
152
+
153
+ # Get FK values with different distributions
154
+ print("\nUniform distribution (10 values):")
155
+ uniform_values = manager.get_fk_values("CUSTOMERS", "CustomerID", 10, "uniform")
156
+ print(f" {uniform_values}")
157
+
158
+ print("\nPareto distribution (20 values):")
159
+ pareto_values = manager.get_fk_values("CUSTOMERS", "CustomerID", 20, "pareto")
160
+ print(f" {pareto_values}")
161
+
162
+ # Count distribution
163
+ from collections import Counter
164
+ counts = Counter(pareto_values)
165
+ print(f" Distribution: {dict(counts)}")
166
+
167
+ print("\n✓ FK manager tests passed!")
168
+
169
+
170
+ def test_full_preview():
171
+ """Test full preview generation."""
172
+ print("\n" + "=" * 60)
173
+ print("TEST: Full Preview (Dry Run)")
174
+ print("=" * 60)
175
+
176
+ # Create generator in dry run mode
177
+ gen = LegitGenerator(
178
+ url="https://amazon.com",
179
+ use_case="Retail Analytics",
180
+ connection_string="",
181
+ dry_run=True,
182
+ cache_enabled=False
183
+ )
184
+
185
+ # Load DDL
186
+ gen.load_ddl('legitdata/sample_schema.sql')
187
+
188
+ # Preview (this exercises the full pipeline except DB writes)
189
+ print("\nGenerating preview...")
190
+ preview = gen.preview(num_rows=3)
191
+
192
+ print("\nPreview results:")
193
+ for table_name, rows in preview.items():
194
+ print(f"\n{table_name} ({len(rows)} rows):")
195
+ if rows:
196
+ # Show first row
197
+ row = rows[0]
198
+ for key, value in row.items():
199
+ if value is not None:
200
+ print(f" {key}: {value}")
201
+
202
+ # Verify FK relationships in fact table
203
+ sales_rows = preview.get('SALES_TRANSACTIONS', [])
204
+ if sales_rows:
205
+ row = sales_rows[0]
206
+ print(f"\nFK integrity check:")
207
+ print(f" CustomerID: {row.get('CustomerID')}")
208
+ print(f" ProductID: {row.get('ProductID')}")
209
+ print(f" SellerID: {row.get('SellerID')}")
210
+ print(f" FulfillmentCenterID: {row.get('FulfillmentCenterID')}")
211
+
212
+ print("\n✓ Full preview tests passed!")
213
+
214
+
215
+ def test_size_presets():
216
+ """Test size presets."""
217
+ print("\n" + "=" * 60)
218
+ print("TEST: Size Presets")
219
+ print("=" * 60)
220
+
221
+ print("\nAvailable presets:")
222
+ for name, preset in SIZE_PRESETS.items():
223
+ print(f" {name}: {preset.fact_rows} facts, {preset.dim_rows} dims - {preset.description}")
224
+
225
+ print(f"\nAvailable use cases:")
226
+ for uc in USE_CASES:
227
+ print(f" - {uc}")
228
+
229
+ print("\n✓ Size presets test passed!")
230
+
231
+
232
+ def main():
233
+ """Run all tests."""
234
+ print("\n" + "=" * 60)
235
+ print("LEGITDATA TEST SUITE")
236
+ print("=" * 60)
237
+
238
+ try:
239
+ # Run tests
240
+ schema = test_ddl_parsing()
241
+ context = test_context_builder()
242
+ test_column_classifier(schema, context)
243
+ test_generic_sourcer()
244
+ test_fk_manager()
245
+ test_size_presets()
246
+ test_full_preview()
247
+
248
+ print("\n" + "=" * 60)
249
+ print("ALL TESTS PASSED!")
250
+ print("=" * 60)
251
+ return 0
252
+
253
+ except Exception as e:
254
+ print(f"\n✗ TEST FAILED: {e}")
255
+ import traceback
256
+ traceback.print_exc()
257
+ return 1
258
+
259
+
260
+ if __name__ == '__main__':
261
+ sys.exit(main())
legitdata_project/test_with_ai.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Test legitdata with AI generation - larger dataset."""
3
+
4
+ from dotenv import load_dotenv
5
+ load_dotenv()
6
+
7
+ from anthropic import Anthropic
8
+ from legitdata import LegitGenerator
9
+
10
+ client = Anthropic()
11
+
12
+ def web_search(query: str) -> list[dict]:
13
+ """Use Claude to search the web."""
14
+ response = client.messages.create(
15
+ model="claude-sonnet-4-20250514",
16
+ max_tokens=1024,
17
+ tools=[{
18
+ "type": "web_search_20250305",
19
+ "name": "web_search"
20
+ }],
21
+ messages=[{"role": "user", "content": f"Search for: {query}. Return a list of results."}]
22
+ )
23
+
24
+ results = []
25
+ for block in response.content:
26
+ if hasattr(block, 'text'):
27
+ results.append({"title": block.text[:100], "snippet": block.text})
28
+
29
+ return results
30
+
31
+
32
+ gen = LegitGenerator(
33
+ url="https://amazon.com",
34
+ use_case="Retail Analytics",
35
+ connection_string="",
36
+ anthropic_client=client,
37
+ web_search_fn=web_search,
38
+ dry_run=True
39
+ )
40
+
41
+ gen.load_ddl('legitdata_project/legitdata/sample_schema.sql')
42
+
43
+ # Use "small" preset: 100 fact rows, 20 dim rows
44
+ preview = gen.preview(num_rows=20)
45
+
46
+ print("\n" + "="*60)
47
+ print("DATASET SUMMARY")
48
+ print("="*60)
49
+
50
+ for table, rows in preview.items():
51
+ print(f"\n{table}: {len(rows)} rows")
52
+
53
+ # Show first 3 rows
54
+ for i, row in enumerate(rows[:3]):
55
+ print(f" Row {i+1}: {row}")
56
+
57
+ if len(rows) > 3:
58
+ print(f" ... and {len(rows) - 3} more rows")
59
+
60
+ # Show FK distribution in fact table
61
+ print("\n" + "="*60)
62
+ print("FK DISTRIBUTION IN SALES_TRANSACTIONS")
63
+ print("="*60)
64
+
65
+ from collections import Counter
66
+ sales = preview.get('SALES_TRANSACTIONS', [])
67
+ if sales:
68
+ customer_dist = Counter(r['CustomerID'] for r in sales)
69
+ product_dist = Counter(r['ProductID'] for r in sales)
70
+
71
+ print(f"\nCustomerID distribution: {dict(customer_dist)}")
72
+ print(f"ProductID distribution: {dict(product_dist)}")
liveboard_creator.py CHANGED
@@ -158,13 +158,17 @@ def _get_answer_direct(question: str, model_id: str) -> dict:
158
 
159
  try:
160
  resp = session.post(url, json=payload, timeout=60)
161
-
162
  if resp.status_code == 200:
163
  data = resp.json()
164
- # Format response similar to MCP getAnswer
165
- # IMPORTANT: MCP's createLiveboard requires 'question' field!
 
 
 
 
166
  return {
167
- 'question': question, # MCP createLiveboard requires this!
168
  'session_identifier': data.get('session_identifier'),
169
  'tokens': data.get('tokens'),
170
  'display_tokens': data.get('display_tokens'),
@@ -173,10 +177,10 @@ def _get_answer_direct(question: str, model_id: str) -> dict:
173
  'message_type': data.get('message_type', 'TSAnswer')
174
  }
175
  else:
176
- error_msg = resp.text[:200] if resp.text else f"Status {resp.status_code}"
177
- print(f" ⚠️ Direct API answer failed: {error_msg}")
178
  return None
179
-
180
  except Exception as e:
181
  print(f" ⚠️ Direct API exception: {str(e)}")
182
  return None
@@ -378,6 +382,7 @@ def clean_viz_title(question: str) -> str:
378
  (r'^What is the ', ''),
379
  (r'^What are the ', ''),
380
  (r'^Show me the ', ''),
 
381
  (r'^Show the ', ''),
382
  (r'^Show ', ''),
383
  (r'^Create a detailed table showing ', ''),
@@ -519,11 +524,18 @@ CRITICAL RULES:
519
  4. If query mentions "location", find the EXACT column name (e.g., "Storename", "locaLocationid")
520
  5. ThoughtSpot auto-aggregates - DO NOT use sum(), count(), avg() functions
521
  6. Use [column name] syntax with exact column name: [Productname] not [product]
 
 
 
 
 
522
 
523
  Examples:
524
  - "Show products where stock > 900" → "[Productname] [Stocklevel] [Stocklevel] > 900"
525
  - "Show customers where lifetime value > 50000" → "[Name] [Lifetimevalue] [Lifetimevalue] > 50000"
526
  - "Show sales by region" → "[Salesamount] [Region]"
 
 
527
 
528
  Convert the natural query above to ThoughtSpot search syntax using EXACT column names.
529
  Return ONLY the search query string, nothing else."""
@@ -1093,6 +1105,18 @@ Examples:
1093
  )
1094
  print(f" Search: {search_query}")
1095
 
 
 
 
 
 
 
 
 
 
 
 
 
1096
  chart_type = self.query_translator.infer_chart_type(
1097
  search_query,
1098
  self.model_columns,
@@ -1111,17 +1135,8 @@ Examples:
1111
  viz_counter += 1
1112
 
1113
  except Exception as e:
1114
- print(f" ⚠️ Error: {e}")
1115
- # Fallback visualization
1116
- viz_configs.append({
1117
- 'id': f'Viz_{viz_counter}',
1118
- 'name': outlier['title'],
1119
- 'description': outlier['insight'],
1120
- 'chart_type': 'TABLE',
1121
- 'search_query_direct': '[data]', # Minimal fallback
1122
- 'outlier_context': outlier
1123
- })
1124
- viz_counter += 1
1125
 
1126
  print(f" Generated {len(viz_configs)} visualizations total")
1127
  return viz_configs
@@ -2714,14 +2729,12 @@ def _convert_outlier_to_mcp_question(outlier: Dict) -> str:
2714
  # Clean up the query - remove quotes, extra formatting
2715
  question = show_me.replace('"', '').replace("'", '').strip()
2716
 
2717
- # If it doesn't start with question words, make it a question
2718
- if not any(question.lower().startswith(q) for q in ['what', 'which', 'how', 'show', 'who', 'when', 'where']):
2719
- question = f"Show me {question}"
2720
-
2721
  return question
2722
-
2723
- # Fallback: create question from title
2724
- return f"Show me {outlier.get('title', 'data')}"
2725
 
2726
 
2727
  def _create_kpi_question_from_outlier(outlier: Dict) -> Optional[str]:
@@ -2820,45 +2833,24 @@ For example, if measures are [IMPRESSIONS, CLICKS, CONVERSIONS], use those - NOT
2820
  Company: {company_data.get('name', 'Unknown Company')}
2821
  Use Case: {use_case}
2822
  {column_context}
2823
- Create questions that will produce high-quality visualizations. Each question MUST:
2824
- - Use the ACTUAL column names from the data model above (NOT generic terms like "revenue" or "products")
2825
- - Be specific and actionable for business users
2826
- - Use "top N" with specific numbers (top 5, top 10, top 15)
2827
- - Include time periods (last 12 months, this year, last quarter, past 18 months)
2828
- - DO NOT include chart type in the question (no "bar chart", "line chart", etc.)
2829
-
2830
- REQUIRED FORMAT - Questions should be concise and business-focused. NO chart type hints!
2831
-
2832
- IMPORTANT: The FIRST TWO questions MUST be KPIs in this EXACT format:
2833
- - Question 1: "[MEASURE] weekly" (exactly 2-3 words, creates KPI with WoW sparkline)
2834
- - Question 2: "[MEASURE] monthly" (exactly 2-3 words, creates KPI with MoM sparkline)
2835
-
2836
- EXAMPLES OF CORRECT KPI QUESTIONS:
2837
- - "spend weekly" ✅
2838
- - "impressions monthly" ✅
2839
- - "conversions weekly" ✅
2840
- - "total revenue monthly" ✅
2841
-
2842
- REQUIRED QUESTIONS (adapt to actual columns above):
2843
- 1. KPI: "[PRIMARY_MEASURE] weekly" (2-3 words only!)
2844
- 2. KPI: "[SECONDARY_MEASURE] monthly" (2-3 words only!)
2845
- 3. "Top 10 [DIMENSION] by [MEASURE] last 12 months"
2846
- 4. "[MEASURE] by [DIMENSION] last 12 months"
2847
- 5. "[MEASURE] by month last 18 months"
2848
- 6. "[MEASURE] by [DIMENSION1] vs [DIMENSION2] last 12 months"
2849
- 7. "Top 15 [ANOTHER_DIMENSION] by [MEASURE]"
2850
- 8. "[MEASURE] by [DIMENSION] last 2 years"
2851
-
2852
- TIME FILTER RULES (IMPORTANT):
2853
- - Use "last 12 months" or "last 18 months" instead of "this year" (more data coverage)
2854
- - Use "last 12 months" instead of "last 6 months" (avoids empty results)
2855
  - Use "last 2 years" for broad comparisons
2856
- - NEVER use "this year" or "this quarter" - too restrictive with generated data
2857
 
2858
- CRITICAL:
2859
- - Use the EXACT column names from the data model
2860
- - DO NOT add "bar chart", "line chart", "stacked" to questions - ThoughtSpot picks chart type automatically
2861
- - Include variety: rankings, breakdowns, trends, comparisons
2862
 
2863
  Return ONLY a JSON object with this exact structure (no other text):
2864
  {{
@@ -2947,18 +2939,18 @@ Return ONLY a JSON object with this exact structure (no other text):
2947
  m2 = measures[1].lower().replace('_', ' ') if len(measures) > 1 else m1
2948
  d1 = dimensions[0].lower().replace('_', ' ') if dimensions else 'category'
2949
  return [
2950
- f"{m1} weekly", # KPI with sparkline
2951
- f"{m2} monthly", # KPI with sparkline
2952
  f"top 10 {d1} by {m1}",
2953
  f"{m1} by {d1}",
2954
  f"{m1} by month last 12 months",
2955
  f"{m2} by {d1}"
2956
  ][:num_questions]
2957
-
2958
- # Ultimate fallback - simple KPI format
2959
  return [
2960
- "revenue weekly", # Simple KPI format for sparkline
2961
- "sales monthly", # Simple KPI format for sparkline
2962
  "top 10 products by revenue",
2963
  "revenue by category",
2964
  "revenue by month last 12 months",
@@ -3262,10 +3254,14 @@ def create_liveboard_from_model_mcp(
3262
  print(f"🎨 Creating liveboard: {final_liveboard_name}")
3263
  print(f"📊 Preparing to send {len(answers)} answers to createLiveboard")
3264
 
3265
- # Debug: Log what we're sending
3266
- print(f"🔍 Answer structure sample:")
3267
- if answers:
3268
- print(f" Keys in first answer: {list(answers[0].keys())}")
 
 
 
 
3269
 
3270
  try:
3271
  print(f"📦 [createLiveboard] Sending {len(answers)} answers...", flush=True)
@@ -3279,8 +3275,8 @@ def create_liveboard_from_model_mcp(
3279
  })
3280
 
3281
  result_text = liveboard_result.content[0].text
3282
- print(f" [createLiveboard] Response: {result_text[:200]}", flush=True)
3283
-
3284
  if not result_text or result_text.strip() == '':
3285
  raise ValueError("createLiveboard returned empty response")
3286
 
@@ -3688,6 +3684,12 @@ def enhance_mcp_liveboard(
3688
 
3689
  visualizations = liveboard_tml.get('liveboard', {}).get('visualizations', [])
3690
  print(f" Found {len(visualizations)} visualizations", flush=True)
 
 
 
 
 
 
3691
 
3692
  # Step 2: Classify visualizations by type and purpose
3693
  kpi_vizs = [] # KPI charts (big numbers, sparklines)
@@ -3735,7 +3737,9 @@ def enhance_mcp_liveboard(
3735
  print(f" Classification: {len(kpi_vizs)} KPIs, {len(trend_vizs)} trends, {len(bar_vizs)} bars, {len(table_vizs)} tables, {len(note_vizs)} notes", flush=True)
3736
 
3737
  # Step 2.5: Remove note tiles (MCP requires them but we don't want them in final liveboard)
3738
- if note_vizs:
 
 
3739
  print(f" Removing {len(note_vizs)} note tile(s)...", flush=True)
3740
  original_count = len(visualizations)
3741
  liveboard_tml['liveboard']['visualizations'] = [
@@ -3744,6 +3748,8 @@ def enhance_mcp_liveboard(
3744
  visualizations = liveboard_tml['liveboard']['visualizations']
3745
  print(f" [OK] Removed note tiles ({original_count} -> {len(visualizations)} visualizations)", flush=True)
3746
  enhancements_applied.append(f"Removed {len(note_vizs)} note tile(s)")
 
 
3747
 
3748
  # Step 3: Add Groups - simplified: just KPI section, rest ungrouped
3749
  if add_groups:
 
158
 
159
  try:
160
  resp = session.post(url, json=payload, timeout=60)
161
+
162
  if resp.status_code == 200:
163
  data = resp.json()
164
+ print(f" [direct API] HTTP 200 keys: {list(data.keys())}", flush=True)
165
+ print(f" [direct API] session_identifier: {data.get('session_identifier')}", flush=True)
166
+ print(f" [direct API] visualization_type: {data.get('visualization_type')}", flush=True)
167
+ print(f" [direct API] message_type: {data.get('message_type')}", flush=True)
168
+ has_tokens = bool(data.get('tokens'))
169
+ print(f" [direct API] tokens present: {has_tokens}", flush=True)
170
  return {
171
+ 'question': question,
172
  'session_identifier': data.get('session_identifier'),
173
  'tokens': data.get('tokens'),
174
  'display_tokens': data.get('display_tokens'),
 
177
  'message_type': data.get('message_type', 'TSAnswer')
178
  }
179
  else:
180
+ error_msg = resp.text[:500] if resp.text else f"Status {resp.status_code}"
181
+ print(f" ⚠️ Direct API answer FAILED — HTTP {resp.status_code}: {error_msg}", flush=True)
182
  return None
183
+
184
  except Exception as e:
185
  print(f" ⚠️ Direct API exception: {str(e)}")
186
  return None
 
382
  (r'^What is the ', ''),
383
  (r'^What are the ', ''),
384
  (r'^Show me the ', ''),
385
+ (r'^Show me ', ''),
386
  (r'^Show the ', ''),
387
  (r'^Show ', ''),
388
  (r'^Create a detailed table showing ', ''),
 
524
  4. If query mentions "location", find the EXACT column name (e.g., "Storename", "locaLocationid")
525
  5. ThoughtSpot auto-aggregates - DO NOT use sum(), count(), avg() functions
526
  6. Use [column name] syntax with exact column name: [Productname] not [product]
527
+ 7. TIME GRANULARITY: NEVER use bare "weekly", "monthly", "daily" as standalone tokens.
528
+ Always attach granularity to the date column: [Date Column].weekly, [Date Column].monthly
529
+ Example: "sales weekly" → "[Sales Amount] [Order Date].weekly"
530
+ Example: "revenue by month" → "[Revenue] [Transaction Date].monthly"
531
+ 8. If no date column is needed, just use the measure and dimension columns.
532
 
533
  Examples:
534
  - "Show products where stock > 900" → "[Productname] [Stocklevel] [Stocklevel] > 900"
535
  - "Show customers where lifetime value > 50000" → "[Name] [Lifetimevalue] [Lifetimevalue] > 50000"
536
  - "Show sales by region" → "[Salesamount] [Region]"
537
+ - "Show ASP weekly trend" → "[Asp Amount] [Order Date].weekly"
538
+ - "Revenue by month last 12 months" → "[Revenue] [Date].monthly [Date].'last 12 months'"
539
 
540
  Convert the natural query above to ThoughtSpot search syntax using EXACT column names.
541
  Return ONLY the search query string, nothing else."""
 
1105
  )
1106
  print(f" Search: {search_query}")
1107
 
1108
+ # Validate all [Column] tokens exist in the model — skip if any are missing
1109
+ import re as _re
1110
+ col_names_lower = {col.get('name', '').lower() for col in self.model_columns if col.get('name')}
1111
+ # Extract tokens (strip granularity suffix like .weekly)
1112
+ unknown_tokens = [
1113
+ t for t in _re.findall(r'\[([^\]]+)\]', search_query)
1114
+ if t.split('.')[0].strip().lower() not in col_names_lower
1115
+ ]
1116
+ if unknown_tokens:
1117
+ print(f" ⚠️ Skipping '{outlier['title']}' — columns not in model: {unknown_tokens}")
1118
+ continue
1119
+
1120
  chart_type = self.query_translator.infer_chart_type(
1121
  search_query,
1122
  self.model_columns,
 
1135
  viz_counter += 1
1136
 
1137
  except Exception as e:
1138
+ print(f" ⚠️ Skipping '{outlier.get('title', 'unknown')}' due to error: {e}")
1139
+ # Don't add a broken viz — skip and continue
 
 
 
 
 
 
 
 
 
1140
 
1141
  print(f" Generated {len(viz_configs)} visualizations total")
1142
  return viz_configs
 
2729
  # Clean up the query - remove quotes, extra formatting
2730
  question = show_me.replace('"', '').replace("'", '').strip()
2731
 
2732
+ # Capitalize first letter if needed
2733
+ question = question[0].upper() + question[1:] if question else question
 
 
2734
  return question
2735
+
2736
+ # Fallback: use title directly
2737
+ return outlier.get('title', 'data')
2738
 
2739
 
2740
  def _create_kpi_question_from_outlier(outlier: Dict) -> Optional[str]:
 
2833
  Company: {company_data.get('name', 'Unknown Company')}
2834
  Use Case: {use_case}
2835
  {column_context}
2836
+ The liveboard already has KPI cards showing single metrics. These fill questions should surface INTERESTING BUSINESS INSIGHTS that a senior executive would find compelling — comparisons, rankings, and breakdowns that tell a story.
2837
+
2838
+ WHAT MAKES A GREAT FILL QUESTION:
2839
+ - Reveals which dimension is driving performance: "Top 10 [DIMENSION] by [MEASURE] last 12 months"
2840
+ - Compares performance across a meaningful grouping: "[MEASURE] by [DIMENSION] last 12 months"
2841
+ - Shows how a breakdown has shifted: "[MEASURE] by [DIMENSION1] and [DIMENSION2] last 18 months"
2842
+ - Surfaces a ranking with business stakes: "Top 5 [DIMENSION] by [MEASURE] vs prior year"
2843
+
2844
+ DO NOT generate simple single-metric questions like "[MEASURE] by month" — those become redundant KPI cards and the liveboard already has enough of those. Every question here should involve at least one dimension to compare against.
2845
+
2846
+ TIME FILTER RULES:
2847
+ - Use "last 12 months" or "last 18 months" never "this year" or "this quarter" (too restrictive with generated data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2848
  - Use "last 2 years" for broad comparisons
 
2849
 
2850
+ CRITICAL:
2851
+ - Use the EXACT column names from the data model above
2852
+ - Do NOT mention chart types ThoughtSpot picks automatically
2853
+ - Each question should make a business leader say "that's interesting" — not just confirm a number
2854
 
2855
  Return ONLY a JSON object with this exact structure (no other text):
2856
  {{
 
2939
  m2 = measures[1].lower().replace('_', ' ') if len(measures) > 1 else m1
2940
  d1 = dimensions[0].lower().replace('_', ' ') if dimensions else 'category'
2941
  return [
2942
+ f"{m1} by week",
2943
+ f"{m2} by month",
2944
  f"top 10 {d1} by {m1}",
2945
  f"{m1} by {d1}",
2946
  f"{m1} by month last 12 months",
2947
  f"{m2} by {d1}"
2948
  ][:num_questions]
2949
+
2950
+ # Ultimate fallback
2951
  return [
2952
+ "revenue by week",
2953
+ "sales by month",
2954
  "top 10 products by revenue",
2955
  "revenue by category",
2956
  "revenue by month last 12 months",
 
3254
  print(f"🎨 Creating liveboard: {final_liveboard_name}")
3255
  print(f"📊 Preparing to send {len(answers)} answers to createLiveboard")
3256
 
3257
+ # Log what we're sending to createLiveboard
3258
+ print(f"📋 [createLiveboard] Answer summary ({len(answers)} total):", flush=True)
3259
+ for idx, ans in enumerate(answers):
3260
+ sid = ans.get('session_identifier', 'MISSING')
3261
+ viz = ans.get('visualization_type', 'MISSING')
3262
+ q = ans.get('question', 'MISSING')[:60]
3263
+ has_tok = bool(ans.get('tokens'))
3264
+ print(f" [{idx+1}] session={sid} viz={viz} tokens={has_tok} q='{q}'", flush=True)
3265
 
3266
  try:
3267
  print(f"📦 [createLiveboard] Sending {len(answers)} answers...", flush=True)
 
3275
  })
3276
 
3277
  result_text = liveboard_result.content[0].text
3278
+ print(f" [createLiveboard] FULL response: {result_text}", flush=True)
3279
+
3280
  if not result_text or result_text.strip() == '':
3281
  raise ValueError("createLiveboard returned empty response")
3282
 
 
3684
 
3685
  visualizations = liveboard_tml.get('liveboard', {}).get('visualizations', [])
3686
  print(f" Found {len(visualizations)} visualizations", flush=True)
3687
+ for _v in visualizations:
3688
+ _vid = _v.get('id', '?')
3689
+ _is_note = 'note_tile' in _v
3690
+ _chart_type = _v.get('answer', {}).get('chart', {}).get('type', 'N/A') if not _is_note else 'note'
3691
+ _name = _v.get('answer', {}).get('name', '') if not _is_note else '(note tile)'
3692
+ print(f" viz {_vid}: type={_chart_type} note={_is_note} name='{_name}'", flush=True)
3693
 
3694
  # Step 2: Classify visualizations by type and purpose
3695
  kpi_vizs = [] # KPI charts (big numbers, sparklines)
 
3737
  print(f" Classification: {len(kpi_vizs)} KPIs, {len(trend_vizs)} trends, {len(bar_vizs)} bars, {len(table_vizs)} tables, {len(note_vizs)} notes", flush=True)
3738
 
3739
  # Step 2.5: Remove note tiles (MCP requires them but we don't want them in final liveboard)
3740
+ # Safety: only remove note tiles if there are other vizzes to show — never leave liveboard empty
3741
+ non_note_count = len(visualizations) - len(note_vizs)
3742
+ if note_vizs and non_note_count > 0:
3743
  print(f" Removing {len(note_vizs)} note tile(s)...", flush=True)
3744
  original_count = len(visualizations)
3745
  liveboard_tml['liveboard']['visualizations'] = [
 
3748
  visualizations = liveboard_tml['liveboard']['visualizations']
3749
  print(f" [OK] Removed note tiles ({original_count} -> {len(visualizations)} visualizations)", flush=True)
3750
  enhancements_applied.append(f"Removed {len(note_vizs)} note tile(s)")
3751
+ elif note_vizs and non_note_count == 0:
3752
+ print(f" ⚠️ Only note tiles found ({len(note_vizs)}) — MCP may have returned no chart answers. Keeping note tiles to preserve liveboard content.", flush=True)
3753
 
3754
  # Step 3: Add Groups - simplified: just KPI section, rest ungrouped
3755
  if add_groups:
outlier_system.py DELETED
@@ -1,769 +0,0 @@
1
- """
2
- Outlier System - Unified SQL generation for data pattern injection.
3
-
4
- This system handles two use cases with the SAME core mechanism:
5
- 1. Auto-Injection: AI identifies patterns during research → SQL UPDATEs after population
6
- 2. Chat Adjustment: User requests changes via chat → SQL UPDATEs on demand
7
-
8
- The core insight: Both are SQL UPDATE statements that modify data to create patterns.
9
-
10
- Usage:
11
- from outlier_system import OutlierGenerator, apply_outliers
12
-
13
- # Generate SQL from story descriptions
14
- generator = OutlierGenerator(schema_info)
15
- sql_updates = generator.generate_outlier_sql(
16
- "Premium subscriptions spike 3x in Q4",
17
- target_column="subscription_tier",
18
- target_value="premium"
19
- )
20
-
21
- # Apply multiple outliers after population
22
- apply_outliers(snowflake_conn, outliers, schema_name)
23
- """
24
-
25
- import re
26
- import os
27
- from typing import Dict, List, Optional, Tuple
28
- from dataclasses import dataclass, field
29
- from datetime import datetime
30
-
31
-
32
- # ============================================================================
33
- # Phase 1: New Structured Outlier System (February 2026 Sprint)
34
- # ============================================================================
35
-
36
- @dataclass
37
- class OutlierPattern:
38
- """
39
- Defines a single outlier pattern that serves three purposes:
40
- 1. Liveboard visualizations
41
- 2. Spotter questions
42
- 3. Demo talking points
43
- """
44
- # Identity
45
- name: str # "ASP Decline"
46
- category: str # "pricing", "volume", "inventory"
47
-
48
- # For LIVEBOARD (visualization)
49
- viz_type: str # "KPI", "COLUMN", "LINE"
50
- viz_question: str # "ASP weekly"
51
- viz_talking_point: str # "ASP dropped 12% — excessive discounting"
52
-
53
- # For SPOTTER (ad-hoc questions)
54
- spotter_questions: List[str] = field(default_factory=list)
55
- spotter_followups: List[str] = field(default_factory=list)
56
-
57
- # For DATA INJECTION (SQL generation)
58
- sql_template: str = "" # "UPDATE {fact_table} SET {column} = ..."
59
- affected_columns: List[str] = field(default_factory=list)
60
- magnitude: str = "" # "15% below normal"
61
- target_filter: str = "" # "WHERE REGION = 'West'"
62
-
63
- # For DEMO NOTES
64
- demo_setup: str = "" # "Start by showing overall sales are UP"
65
- demo_payoff: str = "" # "Then reveal ASP is DOWN — 'at what cost?'"
66
-
67
-
68
- @dataclass
69
- class OutlierConfig:
70
- """
71
- Configuration for outliers per use case.
72
- Combines required patterns, optional patterns, and AI generation guidance.
73
- """
74
- required: List[OutlierPattern] = field(default_factory=list) # Always include
75
- optional: List[OutlierPattern] = field(default_factory=list) # AI picks 1-2
76
- allow_ai_generated: bool = True # AI can create 1 custom
77
- ai_guidance: str = "" # Hint for AI generation
78
-
79
-
80
- OUTLIER_CONFIGS = {
81
- ("Retail", "Sales"): OutlierConfig(
82
- required=[
83
- OutlierPattern(
84
- name="ASP Decline",
85
- category="pricing",
86
- viz_type="KPI",
87
- viz_question="ASP weekly",
88
- viz_talking_point="ASP dropped 12% even though revenue is up — we're discounting too heavily",
89
- spotter_questions=[
90
- "Why did ASP drop last month?",
91
- "Which products have the biggest discount?",
92
- "Show me ASP by region",
93
- ],
94
- spotter_followups=[
95
- "Compare to same period last year",
96
- "Which stores are discounting most?",
97
- ],
98
- sql_template="UPDATE {fact_table} SET UNIT_PRICE = UNIT_PRICE * 0.85 WHERE REGION = 'West' AND {date_column} > '{recent_date}'",
99
- affected_columns=["UNIT_PRICE", "DISCOUNT_PCT"],
100
- magnitude="15% below normal",
101
- target_filter="WHERE REGION = 'West'",
102
- demo_setup="Start by showing overall sales are UP — everything looks good",
103
- demo_payoff="Then reveal ASP is DOWN — 'but at what cost?' moment",
104
- ),
105
- OutlierPattern(
106
- name="Regional Variance",
107
- category="geographic",
108
- viz_type="COLUMN",
109
- viz_question="Dollar Sales by Region",
110
- viz_talking_point="West region outperforming by 40% — what are they doing differently?",
111
- spotter_questions=[
112
- "Which region has the highest sales?",
113
- "Compare West to East performance",
114
- ],
115
- spotter_followups=[
116
- "What products are driving West?",
117
- "Show me the trend for West region",
118
- ],
119
- sql_template="UPDATE {fact_table} SET QUANTITY = QUANTITY * 1.4 WHERE REGION = 'West'",
120
- affected_columns=["QUANTITY", "REVENUE"],
121
- magnitude="40% above other regions",
122
- target_filter="WHERE REGION = 'West'",
123
- demo_setup="Show overall sales by region",
124
- demo_payoff="West is crushing it — drill in to find out why",
125
- ),
126
- ],
127
- optional=[
128
- OutlierPattern(
129
- name="Seasonal Spike",
130
- category="temporal",
131
- viz_type="LINE",
132
- viz_question="Dollar Sales trend by month",
133
- viz_talking_point="Holiday surge 3x normal — were we prepared?",
134
- spotter_questions=["Show me sales trend for Q4", "When was our peak sales day?"],
135
- spotter_followups=[],
136
- sql_template="UPDATE {fact_table} SET QUANTITY = QUANTITY * 3 WHERE MONTH IN (11, 12)",
137
- affected_columns=["QUANTITY", "REVENUE"],
138
- magnitude="3x normal",
139
- target_filter="WHERE MONTH IN (11, 12)",
140
- demo_setup="",
141
- demo_payoff="",
142
- ),
143
- OutlierPattern(
144
- name="Category Surge",
145
- category="product",
146
- viz_type="COLUMN",
147
- viz_question="Dollar Sales by Category",
148
- viz_talking_point="Electronics up 60% YoY while Apparel flat",
149
- spotter_questions=["Which category grew fastest?", "Compare Electronics to Apparel"],
150
- spotter_followups=[],
151
- sql_template="",
152
- affected_columns=[],
153
- magnitude="60% YoY",
154
- target_filter="",
155
- demo_setup="",
156
- demo_payoff="",
157
- ),
158
- ],
159
- allow_ai_generated=True,
160
- ai_guidance="If company has sustainability initiatives, create outlier around eco-friendly product sales",
161
- ),
162
-
163
- ("Banking", "Marketing"): OutlierConfig(
164
- required=[
165
- OutlierPattern(
166
- name="Funnel Drop-off",
167
- category="conversion",
168
- viz_type="COLUMN",
169
- viz_question="Conversion rate by funnel stage",
170
- viz_talking_point="70% drop-off at application page — UX issue?",
171
- spotter_questions=[
172
- "Where is our biggest funnel drop-off?",
173
- "What's our application completion rate?",
174
- ],
175
- spotter_followups=[],
176
- sql_template="",
177
- affected_columns=[],
178
- magnitude="70% drop-off",
179
- target_filter="",
180
- demo_setup="Show the full funnel from impression to approval",
181
- demo_payoff="The application page is killing conversions",
182
- ),
183
- ],
184
- optional=[
185
- OutlierPattern(
186
- name="Channel Performance",
187
- category="channel",
188
- viz_type="COLUMN",
189
- viz_question="CTR by channel",
190
- viz_talking_point="Mobile CTR 2x desktop — shift budget?",
191
- spotter_questions=["Which channel has the best CTR?"],
192
- spotter_followups=[],
193
- sql_template="",
194
- affected_columns=[],
195
- magnitude="2x desktop",
196
- target_filter="",
197
- demo_setup="",
198
- demo_payoff="",
199
- ),
200
- ],
201
- allow_ai_generated=True,
202
- ai_guidance="Consider seasonal patterns in loan applications",
203
- ),
204
- }
205
-
206
-
207
- def get_outliers_for_use_case(vertical: str, function: str) -> OutlierConfig:
208
- """Get outlier configuration for a use case, with fallback to empty config."""
209
- return OUTLIER_CONFIGS.get(
210
- (vertical, function),
211
- OutlierConfig(
212
- required=[],
213
- optional=[],
214
- allow_ai_generated=True,
215
- ai_guidance=f"Generate outliers appropriate for {vertical} {function}"
216
- )
217
- )
218
-
219
-
220
- # ============================================================================
221
- # Legacy Outlier System (existing code below)
222
- # ============================================================================
223
-
224
- @dataclass
225
- class LegacyOutlierPattern:
226
- """Represents a data pattern to inject (legacy structure)."""
227
- title: str
228
- description: str
229
- sql_update: str
230
- affected_table: str
231
- affected_column: str
232
- target_value: Optional[str] = None
233
- conditions: Optional[str] = None
234
- impact_description: Optional[str] = None
235
- spotter_question: Optional[str] = None
236
- talking_point: Optional[str] = None
237
-
238
-
239
- class OutlierGenerator:
240
- """
241
- Generate SQL UPDATE statements from text descriptions.
242
-
243
- The generator understands common patterns like:
244
- - Time-based: "spike in Q4", "drop in summer"
245
- - Value-based: "premium users 3x higher", "decrease electronics by 10%"
246
- - Conditional: "when category is X, increase Y"
247
- """
248
-
249
- def __init__(self, schema_info: Dict = None):
250
- """
251
- Initialize with schema information.
252
-
253
- Args:
254
- schema_info: Dict with table and column information
255
- {
256
- 'tables': {'SUBSCRIPTIONS': {'columns': [...]}, ...},
257
- 'fact_tables': ['SUBSCRIPTIONS', ...],
258
- 'dimension_tables': ['USERS', ...]
259
- }
260
- """
261
- self.schema_info = schema_info or {}
262
- self.tables = schema_info.get('tables', {}) if schema_info else {}
263
-
264
- def generate_outlier_sql(
265
- self,
266
- pattern_description: str,
267
- target_table: str = None,
268
- target_column: str = None,
269
- multiplier: float = None,
270
- absolute_value: float = None,
271
- conditions: Dict = None
272
- ) -> OutlierPattern:
273
- """
274
- Generate SQL UPDATE from a pattern description.
275
-
276
- Args:
277
- pattern_description: Natural language description
278
- "Premium subscriptions spike 3x in Q4"
279
- "Decrease electronics revenue by 10%"
280
- target_table: Table to update (optional, inferred from description)
281
- target_column: Column to update (optional, inferred)
282
- multiplier: Factor to multiply by (e.g., 3.0 for "3x")
283
- absolute_value: Exact value to set
284
- conditions: WHERE clause conditions as dict
285
- {'month': ['Oct', 'Nov', 'Dec'], 'category': 'Electronics'}
286
-
287
- Returns:
288
- OutlierPattern with SQL and documentation
289
- """
290
- # Parse the description to extract intent
291
- parsed = self._parse_pattern_description(pattern_description)
292
-
293
- # Use provided values or parsed ones
294
- target_table = target_table or parsed.get('table')
295
- target_column = target_column or parsed.get('column')
296
- multiplier = multiplier or parsed.get('multiplier')
297
- absolute_value = absolute_value if absolute_value is not None else parsed.get('absolute_value')
298
- conditions = conditions or parsed.get('conditions', {})
299
-
300
- # Build the SQL UPDATE
301
- sql = self._build_update_sql(
302
- target_table, target_column, multiplier, absolute_value, conditions
303
- )
304
-
305
- # Generate Spotter question
306
- spotter_q = self._generate_spotter_question(
307
- target_table, target_column, conditions, pattern_description
308
- )
309
-
310
- return LegacyOutlierPattern(
311
- title=parsed.get('title', pattern_description[:50]),
312
- description=pattern_description,
313
- sql_update=sql,
314
- affected_table=target_table,
315
- affected_column=target_column,
316
- target_value=str(absolute_value) if absolute_value else f"{multiplier}x",
317
- conditions=str(conditions) if conditions else None,
318
- impact_description=parsed.get('impact'),
319
- spotter_question=spotter_q,
320
- talking_point=self._generate_talking_point(pattern_description)
321
- )
322
-
323
- def _parse_pattern_description(self, description: str) -> Dict:
324
- """
325
- Parse natural language pattern description.
326
-
327
- Examples:
328
- "Premium subscriptions spike 3x in Q4"
329
- → {'multiplier': 3.0, 'conditions': {'quarter': 'Q4'}, 'column': 'subscription_tier'}
330
-
331
- "Decrease electronics by 10%"
332
- → {'multiplier': 0.9, 'conditions': {'category': 'electronics'}}
333
- """
334
- result = {
335
- 'title': description[:50] + ('...' if len(description) > 50 else ''),
336
- 'conditions': {}
337
- }
338
-
339
- desc_lower = description.lower()
340
-
341
- # Parse multiplier patterns
342
- # "3x", "spike 3x", "increase 3x"
343
- multiplier_match = re.search(r'(\d+(?:\.\d+)?)\s*x', desc_lower)
344
- if multiplier_match:
345
- result['multiplier'] = float(multiplier_match.group(1))
346
-
347
- # "increase by 20%", "decrease by 10%", "decrease revenue by 10%"
348
- # More flexible pattern that allows words between action and amount
349
- percent_match = re.search(r'(increase|decrease|drop|spike|rise)s?\s+.*?(?:by\s+)?(\d+(?:\.\d+)?)\s*%', desc_lower)
350
- if percent_match:
351
- action = percent_match.group(1)
352
- pct = float(percent_match.group(2))
353
- if action in ['decrease', 'drop']:
354
- result['multiplier'] = 1 - (pct / 100)
355
- else:
356
- result['multiplier'] = 1 + (pct / 100)
357
-
358
- # "set to 50B", "make 40B"
359
- absolute_match = re.search(r'(?:set|make|to)\s+(\d+(?:\.\d+)?)\s*(B|M|K)?', description, re.IGNORECASE)
360
- if absolute_match:
361
- value = float(absolute_match.group(1))
362
- suffix = absolute_match.group(2)
363
- if suffix:
364
- suffix = suffix.upper()
365
- if suffix == 'B':
366
- value *= 1_000_000_000
367
- elif suffix == 'M':
368
- value *= 1_000_000
369
- elif suffix == 'K':
370
- value *= 1_000
371
- result['absolute_value'] = value
372
-
373
- # Parse time conditions
374
- # Q1, Q2, Q3, Q4
375
- quarter_match = re.search(r'Q([1-4])', description, re.IGNORECASE)
376
- if quarter_match:
377
- q = int(quarter_match.group(1))
378
- quarter_months = {
379
- 1: ['Jan', 'Feb', 'Mar'],
380
- 2: ['Apr', 'May', 'Jun'],
381
- 3: ['Jul', 'Aug', 'Sep'],
382
- 4: ['Oct', 'Nov', 'Dec']
383
- }
384
- result['conditions']['month'] = quarter_months[q]
385
-
386
- # Specific months
387
- months = ['january', 'february', 'march', 'april', 'may', 'june',
388
- 'july', 'august', 'september', 'october', 'november', 'december']
389
- for i, month in enumerate(months):
390
- if month in desc_lower:
391
- result['conditions']['month'] = month.capitalize()[:3]
392
-
393
- # "in summer", "in winter"
394
- if 'summer' in desc_lower:
395
- result['conditions']['month'] = ['Jun', 'Jul', 'Aug']
396
- elif 'winter' in desc_lower:
397
- result['conditions']['month'] = ['Dec', 'Jan', 'Feb']
398
- elif 'spring' in desc_lower:
399
- result['conditions']['month'] = ['Mar', 'Apr', 'May']
400
- elif 'fall' in desc_lower or 'autumn' in desc_lower:
401
- result['conditions']['month'] = ['Sep', 'Oct', 'Nov']
402
-
403
- # Parse category/product conditions
404
- # "electronics", "premium", "gold tier"
405
- tier_patterns = ['premium', 'gold', 'silver', 'bronze', 'free', 'basic', 'pro', 'enterprise']
406
- for tier in tier_patterns:
407
- if tier in desc_lower:
408
- result['conditions']['tier'] = tier.capitalize()
409
- result['target_value'] = tier.capitalize()
410
-
411
- # Try to infer table/column from description
412
- if 'subscription' in desc_lower:
413
- result['table'] = 'SUBSCRIPTIONS'
414
- result['column'] = 'subscription_tier' if 'tier' in desc_lower or 'premium' in desc_lower else 'revenue'
415
- elif 'revenue' in desc_lower or 'sales' in desc_lower:
416
- result['table'] = 'SALES'
417
- result['column'] = 'revenue' if 'revenue' in desc_lower else 'total_amount'
418
- elif 'churn' in desc_lower:
419
- result['table'] = 'SUBSCRIPTIONS'
420
- result['column'] = 'churned'
421
-
422
- return result
423
-
424
- def _build_update_sql(
425
- self,
426
- table: str,
427
- column: str,
428
- multiplier: float = None,
429
- absolute_value: float = None,
430
- conditions: Dict = None
431
- ) -> str:
432
- """Build SQL UPDATE statement."""
433
- if not table or not column:
434
- return "-- Unable to generate SQL: missing table or column"
435
-
436
- # Build SET clause
437
- if absolute_value is not None:
438
- set_clause = f"{column} = {absolute_value}"
439
- elif multiplier is not None:
440
- set_clause = f"{column} = {column} * {multiplier}"
441
- else:
442
- return "-- Unable to generate SQL: no value modification specified"
443
-
444
- # Build WHERE clause
445
- where_parts = []
446
- if conditions:
447
- for col, val in conditions.items():
448
- if isinstance(val, list):
449
- # IN clause for multiple values
450
- values_str = ', '.join(f"'{v}'" for v in val)
451
- where_parts.append(f"{col} IN ({values_str})")
452
- else:
453
- where_parts.append(f"{col} = '{val}'")
454
-
455
- where_clause = ' AND '.join(where_parts) if where_parts else '1=1'
456
-
457
- # Add randomization for realistic distribution (not all rows affected)
458
- if multiplier and multiplier != 1.0:
459
- # Apply to ~30% of matching rows for variation
460
- where_clause += ' AND RANDOM() < 0.3'
461
-
462
- sql = f"""UPDATE {table}
463
- SET {set_clause}
464
- WHERE {where_clause};"""
465
-
466
- return sql
467
-
468
- def _generate_spotter_question(
469
- self,
470
- table: str,
471
- column: str,
472
- conditions: Dict,
473
- description: str
474
- ) -> str:
475
- """Generate a Spotter question that reveals this outlier."""
476
- # Base question on the pattern
477
- if conditions and 'month' in conditions:
478
- months = conditions['month']
479
- if isinstance(months, list):
480
- time_phrase = f"by month"
481
- else:
482
- time_phrase = f"in {months}"
483
- else:
484
- time_phrase = "over time"
485
-
486
- if 'revenue' in column.lower() or 'sales' in column.lower():
487
- return f"What is the total {column} {time_phrase}?"
488
- elif 'tier' in column.lower() or 'subscription' in column.lower():
489
- return f"What is the breakdown of subscription tiers {time_phrase}?"
490
- elif 'churn' in column.lower():
491
- return f"What is the churn rate {time_phrase}?"
492
- else:
493
- return f"Show me {column} {time_phrase}"
494
-
495
- def _generate_talking_point(self, description: str) -> str:
496
- """Generate a sales-ready talking point."""
497
- return f"Notice how ThoughtSpot instantly surfaces this pattern: {description}. This is the kind of insight that would take hours to find manually."
498
-
499
-
500
- def parse_chat_adjustment(message: str) -> Dict:
501
- """
502
- Parse a chat adjustment request into structured form.
503
-
504
- Examples:
505
- "decrease electronics by 10%" →
506
- {'action': 'decrease', 'entity': 'electronics', 'amount': 0.1, 'type': 'percent'}
507
-
508
- "make laptop 50B" →
509
- {'action': 'set', 'entity': 'laptop', 'amount': 50000000000, 'type': 'absolute'}
510
-
511
- "viz 3, increase premium by 20%" →
512
- {'action': 'increase', 'entity': 'premium', 'amount': 0.2, 'type': 'percent', 'viz': 3}
513
- """
514
- result = {'raw_message': message}
515
- msg_lower = message.lower()
516
-
517
- # Extract viz reference
518
- viz_match = re.search(r'viz\s*(\d+)', msg_lower)
519
- if viz_match:
520
- result['viz'] = int(viz_match.group(1))
521
-
522
- # Extract action
523
- if any(word in msg_lower for word in ['decrease', 'reduce', 'drop', 'lower']):
524
- result['action'] = 'decrease'
525
- elif any(word in msg_lower for word in ['increase', 'raise', 'boost', 'spike']):
526
- result['action'] = 'increase'
527
- elif any(word in msg_lower for word in ['set', 'make', 'change to']):
528
- result['action'] = 'set'
529
-
530
- # Extract amount - percentage
531
- pct_match = re.search(r'(\d+(?:\.\d+)?)\s*%', message)
532
- if pct_match:
533
- result['amount'] = float(pct_match.group(1)) / 100
534
- result['type'] = 'percent'
535
-
536
- # Extract amount - absolute with suffix
537
- abs_match = re.search(r'(\d+(?:\.\d+)?)\s*(B|M|K)(?:\s|$)', message, re.IGNORECASE)
538
- if abs_match:
539
- value = float(abs_match.group(1))
540
- suffix = abs_match.group(2).upper()
541
- if suffix == 'B':
542
- value *= 1_000_000_000
543
- elif suffix == 'M':
544
- value *= 1_000_000
545
- elif suffix == 'K':
546
- value *= 1_000
547
- result['amount'] = value
548
- result['type'] = 'absolute'
549
-
550
- # Extract entity (product name, category, etc.)
551
- # Remove action words and amounts to find entity
552
- entity_text = msg_lower
553
- for word in ['decrease', 'increase', 'reduce', 'raise', 'boost', 'spike',
554
- 'set', 'make', 'change', 'to', 'by', 'viz', '%']:
555
- entity_text = entity_text.replace(word, ' ')
556
- # Remove numbers
557
- entity_text = re.sub(r'\d+(?:\.\d+)?(?:B|M|K)?', '', entity_text, flags=re.IGNORECASE)
558
- # Clean up
559
- entity_text = ' '.join(entity_text.split()).strip()
560
- if entity_text:
561
- result['entity'] = entity_text
562
-
563
- return result
564
-
565
-
566
- def generate_adjustment_sql(
567
- adjustment: Dict,
568
- schema_info: Dict,
569
- schema_name: str
570
- ) -> Tuple[str, str]:
571
- """
572
- Generate SQL UPDATE for a chat adjustment.
573
-
574
- Args:
575
- adjustment: Parsed adjustment dict from parse_chat_adjustment()
576
- schema_info: Schema information with tables/columns
577
- schema_name: Snowflake schema name
578
-
579
- Returns:
580
- Tuple of (sql_update, explanation)
581
- """
582
- entity = adjustment.get('entity', '')
583
- action = adjustment.get('action', 'set')
584
- amount = adjustment.get('amount', 0)
585
- amount_type = adjustment.get('type', 'percent')
586
-
587
- # Determine which table/column to update based on entity
588
- # For now, assume it's a product/category affecting revenue
589
- # TODO: Make this smarter with schema analysis
590
-
591
- if amount_type == 'percent':
592
- if action == 'decrease':
593
- multiplier = 1 - amount
594
- else:
595
- multiplier = 1 + amount
596
- set_clause = f"total_revenue = total_revenue * {multiplier}"
597
- else:
598
- set_clause = f"total_revenue = {amount}"
599
-
600
- # Build WHERE clause to find the entity
601
- where_clause = f"LOWER(product_name) LIKE '%{entity.lower()}%'"
602
-
603
- sql = f"""UPDATE {schema_name}.SALES
604
- SET {set_clause}
605
- WHERE product_id IN (
606
- SELECT product_id FROM {schema_name}.PRODUCTS
607
- WHERE {where_clause}
608
- );"""
609
-
610
- explanation = f"{'Decreased' if action == 'decrease' else 'Increased'} {entity} by {amount * 100 if amount_type == 'percent' else amount}"
611
-
612
- return sql, explanation
613
-
614
-
615
- def apply_outliers(
616
- snowflake_conn,
617
- outliers: List[LegacyOutlierPattern],
618
- schema_name: str,
619
- dry_run: bool = False
620
- ) -> List[Dict]:
621
- """
622
- Apply multiple outlier patterns to the database.
623
-
624
- Args:
625
- snowflake_conn: Snowflake connection
626
- outliers: List of OutlierPattern objects
627
- schema_name: Schema to apply updates to
628
- dry_run: If True, only print SQL without executing
629
-
630
- Returns:
631
- List of results with success/failure for each outlier
632
- """
633
- results = []
634
-
635
- cursor = snowflake_conn.cursor()
636
-
637
- try:
638
- cursor.execute(f"USE SCHEMA {schema_name}")
639
-
640
- for outlier in outliers:
641
- result = {
642
- 'title': outlier.title,
643
- 'sql': outlier.sql_update,
644
- 'success': False,
645
- 'rows_affected': 0,
646
- 'error': None
647
- }
648
-
649
- print(f"📊 Applying outlier: {outlier.title}", flush=True)
650
- print(f" SQL: {outlier.sql_update[:100]}...", flush=True)
651
-
652
- if dry_run:
653
- print(f" [DRY RUN - not executed]", flush=True)
654
- result['success'] = True
655
- else:
656
- try:
657
- cursor.execute(outlier.sql_update)
658
- result['rows_affected'] = cursor.rowcount
659
- result['success'] = True
660
- print(f" ✅ Affected {result['rows_affected']} rows", flush=True)
661
- except Exception as e:
662
- result['error'] = str(e)
663
- print(f" ❌ Error: {e}", flush=True)
664
-
665
- results.append(result)
666
-
667
- if not dry_run:
668
- snowflake_conn.commit()
669
- print(f"✅ All outliers applied and committed", flush=True)
670
-
671
- finally:
672
- cursor.close()
673
-
674
- return results
675
-
676
-
677
- def generate_demo_pack(
678
- outliers: List[LegacyOutlierPattern],
679
- company_name: str,
680
- use_case: str
681
- ) -> str:
682
- """
683
- Generate a Demo Pack markdown document from outliers.
684
-
685
- Args:
686
- outliers: List of OutlierPattern objects (applied or to be applied)
687
- company_name: Company name for the demo
688
- use_case: Use case description
689
-
690
- Returns:
691
- Markdown string with demo documentation
692
- """
693
- today = datetime.now().strftime("%Y-%m-%d")
694
-
695
- md = f"""# {company_name} Demo Pack
696
- ## {use_case}
697
-
698
- *Generated: {today}*
699
-
700
- ---
701
-
702
- ## Overview
703
-
704
- This demo showcases ThoughtSpot's AI-powered analytics for {company_name}'s {use_case} use case. The data has been enhanced with realistic patterns that highlight ThoughtSpot's ability to surface business-critical insights.
705
-
706
- ---
707
-
708
- ## Key Insights (Outliers)
709
-
710
- """
711
-
712
- for i, outlier in enumerate(outliers, 1):
713
- md += f"""### {i}. {outlier.title}
714
-
715
- **Pattern:** {outlier.description}
716
-
717
- **Spotter Question:** _{outlier.spotter_question or 'N/A'}_
718
-
719
- **Talking Point:** {outlier.talking_point or 'N/A'}
720
-
721
- **Technical Details:**
722
- - Table: `{outlier.affected_table}`
723
- - Column: `{outlier.affected_column}`
724
- - Conditions: {outlier.conditions or 'None'}
725
-
726
- ---
727
-
728
- """
729
-
730
- md += """## Demo Flow
731
-
732
- 1. **Start with overview liveboard** - Show the big picture
733
- 2. **Ask Spotter questions** - Use the questions above to surface each insight
734
- 3. **Drill into details** - Let ThoughtSpot guide exploration
735
- 4. **Highlight AI capabilities** - Show how patterns were auto-detected
736
-
737
- ---
738
-
739
- ## Notes
740
-
741
- - Data patterns are realistic but not actual customer data
742
- - Refresh liveboard after any data adjustments
743
- - Use Monitor feature to show proactive alerts
744
-
745
- """
746
-
747
- return md
748
-
749
-
750
- # CLI interface for testing
751
- if __name__ == "__main__":
752
- print("Outlier System - Testing")
753
- print("=" * 40)
754
-
755
- generator = OutlierGenerator()
756
-
757
- # Test cases
758
- test_patterns = [
759
- "Premium subscriptions spike 3x in Q4",
760
- "Decrease electronics revenue by 10%",
761
- "Make laptop revenue 50B",
762
- "Churn increases 20% in summer months"
763
- ]
764
-
765
- for pattern in test_patterns:
766
- print(f"\nPattern: {pattern}")
767
- result = generator.generate_outlier_sql(pattern)
768
- print(f"SQL:\n{result.sql_update}")
769
- print(f"Spotter Q: {result.spotter_question}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompt_logger.py CHANGED
@@ -11,7 +11,7 @@ Usage:
11
  # Option 1: Wrap litellm.completion() calls
12
  response = logged_completion(
13
  stage="ddl",
14
- model="claude-sonnet-4.5",
15
  messages=[{"role": "user", "content": "..."}],
16
  max_tokens=4000
17
  )
@@ -20,7 +20,7 @@ Usage:
20
  logger = get_prompt_logger()
21
  logger.log_prompt(
22
  stage="research_company",
23
- model="anthropic/claude-sonnet-4.5",
24
  messages=[...],
25
  response_text="...",
26
  tokens_in=500,
 
11
  # Option 1: Wrap litellm.completion() calls
12
  response = logged_completion(
13
  stage="ddl",
14
+ model="gpt-4o",
15
  messages=[{"role": "user", "content": "..."}],
16
  max_tokens=4000
17
  )
 
20
  logger = get_prompt_logger()
21
  logger.log_prompt(
22
  stage="research_company",
23
+ model="openai/gpt-4o",
24
  messages=[...],
25
  response_text="...",
26
  tokens_in=500,
prompts.py CHANGED
@@ -438,19 +438,18 @@ def build_prompt(
438
  Complete prompt string ready for LLM
439
  """
440
  from demo_personas import get_use_case_config
441
- from outlier_system import get_outliers_for_use_case
442
-
443
  # Get merged configuration
444
  config = get_use_case_config(vertical, function)
445
- outliers = get_outliers_for_use_case(vertical, function)
446
-
447
  # Build sections
448
  sections = []
449
-
450
  # Section A: Company Context
451
  sections.append(f"""## COMPANY CONTEXT
452
  {company_context}""")
453
-
454
  # Section B: Use Case Framework
455
  persona = config.get("target_persona", "Business Leader")
456
  problem = config.get("business_problem", "Need for faster, data-driven decisions")
@@ -460,7 +459,7 @@ def build_prompt(
460
  - **Business Problem:** {problem}
461
  - **Industry Terms:** {', '.join(config.get('industry_terms', []))}
462
  - **Typical Entities:** {', '.join(config.get('entities', []))}""")
463
-
464
  # Section C: Required KPIs and Visualizations
465
  kpi_text = "\n".join([f"- {kpi}: {config['kpi_definitions'].get(kpi, '')}" for kpi in config.get('kpis', [])])
466
  sections.append(f"""## REQUIRED KPIs
@@ -468,17 +467,18 @@ def build_prompt(
468
 
469
  ## REQUIRED VISUALIZATIONS
470
  {', '.join(config.get('viz_types', []))}""")
471
-
472
- # Section D: Outlier Patterns
473
- if outliers.required:
474
- outlier_text = "\n".join([f"- **{o.name}:** {o.viz_talking_point}" for o in outliers.required])
 
475
  sections.append(f"""## DATA STORIES TO CREATE
476
  {outlier_text}""")
477
-
478
  # Section E: Spotter Questions
479
  spotter_qs = []
480
- for o in outliers.required:
481
- spotter_qs.extend(o.spotter_questions[:2]) # Top 2 from each required outlier
482
  if spotter_qs:
483
  sections.append(f"""## SPOTTER QUESTIONS TO ENABLE
484
  {chr(10).join(['- ' + q for q in spotter_qs[:6]])}""")
 
438
  Complete prompt string ready for LLM
439
  """
440
  from demo_personas import get_use_case_config
441
+
 
442
  # Get merged configuration
443
  config = get_use_case_config(vertical, function)
444
+ lq = config.get("liveboard_questions", [])
445
+
446
  # Build sections
447
  sections = []
448
+
449
  # Section A: Company Context
450
  sections.append(f"""## COMPANY CONTEXT
451
  {company_context}""")
452
+
453
  # Section B: Use Case Framework
454
  persona = config.get("target_persona", "Business Leader")
455
  problem = config.get("business_problem", "Need for faster, data-driven decisions")
 
459
  - **Business Problem:** {problem}
460
  - **Industry Terms:** {', '.join(config.get('industry_terms', []))}
461
  - **Typical Entities:** {', '.join(config.get('entities', []))}""")
462
+
463
  # Section C: Required KPIs and Visualizations
464
  kpi_text = "\n".join([f"- {kpi}: {config['kpi_definitions'].get(kpi, '')}" for kpi in config.get('kpis', [])])
465
  sections.append(f"""## REQUIRED KPIs
 
467
 
468
  ## REQUIRED VISUALIZATIONS
469
  {', '.join(config.get('viz_types', []))}""")
470
+
471
+ # Section D: Data stories from liveboard questions
472
+ required_qs = [q for q in lq if q.get("required")]
473
+ if required_qs:
474
+ outlier_text = "\n".join([f"- **{q['title']}:** {q.get('insight', '')}" for q in required_qs])
475
  sections.append(f"""## DATA STORIES TO CREATE
476
  {outlier_text}""")
477
+
478
  # Section E: Spotter Questions
479
  spotter_qs = []
480
+ for q in required_qs:
481
+ spotter_qs.extend(q.get('spotter_qs', [])[:2])
482
  if spotter_qs:
483
  sections.append(f"""## SPOTTER QUESTIONS TO ENABLE
484
  {chr(10).join(['- ' + q for q in spotter_qs[:6]])}""")
session_logger.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Session Logger — Structured pipeline event logger.
3
+
4
+ Writes events to both a local JSONL file and the Supabase `session_logs` table.
5
+ Falls back gracefully to file-only logging if Supabase is unavailable.
6
+
7
+ ===========================================================================
8
+ Supabase table setup (run once in Supabase SQL editor):
9
+ ===========================================================================
10
+
11
+ CREATE TABLE session_logs (
12
+ id BIGSERIAL PRIMARY KEY,
13
+ session_id TEXT NOT NULL,
14
+ user_email TEXT,
15
+ ts TIMESTAMPTZ DEFAULT NOW(),
16
+ stage TEXT,
17
+ event TEXT NOT NULL,
18
+ duration_ms INTEGER,
19
+ error TEXT,
20
+ meta JSONB
21
+ );
22
+ CREATE INDEX ON session_logs (session_id);
23
+ CREATE INDEX ON session_logs (user_email);
24
+ CREATE INDEX ON session_logs (ts DESC);
25
+
26
+ ===========================================================================
27
+
28
+ Usage:
29
+ from session_logger import init_session_logger, get_session_logger
30
+
31
+ logger = init_session_logger("20260319_143000", user_email="alice@example.com")
32
+
33
+ t = logger.log_start("research")
34
+ # ... do work ...
35
+ logger.log_end("research", t, model="gpt-4o", rows=42)
36
+
37
+ logger.log("deploy", "Liveboard created", liveboard_id="abc-123")
38
+ """
39
+
40
+ import json
41
+ import sys
42
+ import time
43
+ from datetime import datetime, timezone
44
+ from pathlib import Path
45
+ from typing import Any, Optional
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # SessionLogger class
50
+ # ---------------------------------------------------------------------------
51
+
52
+ class SessionLogger:
53
+ """Structured event logger that writes to file and Supabase session_logs."""
54
+
55
+ TABLE = "session_logs"
56
+
57
+ def __init__(self, session_id: str, user_email: str = None):
58
+ """
59
+ Initialize the session logger.
60
+
61
+ Args:
62
+ session_id: Unique ID for this build session (e.g. datetime string).
63
+ user_email: Email of the user running this session.
64
+ """
65
+ self.session_id = session_id
66
+ self.user_email = user_email
67
+
68
+ # File log path: logs/sessions/{session_id}.log (one JSON line per event)
69
+ script_dir = Path(__file__).parent
70
+ log_dir = script_dir / "logs" / "sessions"
71
+ try:
72
+ log_dir.mkdir(parents=True, exist_ok=True)
73
+ except Exception as e:
74
+ print(f"[SessionLogger] Could not create log directory {log_dir}: {e}", file=sys.stderr)
75
+ self._log_file = log_dir / f"{session_id}.log"
76
+
77
+ # Try to initialise Supabase — never raise
78
+ self._supabase_ok = False
79
+ self._client = None
80
+ self._init_supabase()
81
+
82
+ def _init_supabase(self):
83
+ """Attempt to connect to Supabase. Sets self._supabase_ok and self._client."""
84
+ try:
85
+ # Lazy import to avoid circular imports
86
+ from supabase_client import SupabaseSettings
87
+ ss = SupabaseSettings()
88
+ if ss.is_enabled():
89
+ self._client = ss.client
90
+ self._supabase_ok = True
91
+ except Exception as e:
92
+ print(f"[SessionLogger] Supabase unavailable, falling back to file-only logging: {e}",
93
+ file=sys.stderr)
94
+
95
+ # ------------------------------------------------------------------
96
+ # Core log method
97
+ # ------------------------------------------------------------------
98
+
99
+ def log(self, stage: str, event: str, duration_ms: int = None,
100
+ error: str = None, **meta):
101
+ """
102
+ Log one pipeline event.
103
+
104
+ Args:
105
+ stage: Pipeline stage name (e.g. 'research', 'deploy').
106
+ event: Short description of what happened.
107
+ duration_ms: Optional elapsed time in milliseconds.
108
+ error: Optional error message if the event represents a failure.
109
+ **meta: Arbitrary key/value pairs stored in the meta JSONB column.
110
+ """
111
+ ts = datetime.now(timezone.utc).isoformat()
112
+ record = {
113
+ "session_id": self.session_id,
114
+ "user_email": self.user_email,
115
+ "ts": ts,
116
+ "stage": stage,
117
+ "event": event,
118
+ "duration_ms": duration_ms,
119
+ "error": error,
120
+ "meta": meta if meta else None,
121
+ }
122
+
123
+ # Always write to file
124
+ self._write_file(record)
125
+
126
+ # Write to Supabase if available
127
+ if self._supabase_ok:
128
+ self._write_supabase(record)
129
+
130
+ def _write_file(self, record: dict):
131
+ """Append one JSON line to the session log file. Never raises."""
132
+ try:
133
+ with open(self._log_file, "a", encoding="utf-8") as fh:
134
+ fh.write(json.dumps(record, default=str) + "\n")
135
+ except Exception as e:
136
+ print(f"[SessionLogger] File write failed: {e}", file=sys.stderr)
137
+
138
+ def _write_supabase(self, record: dict):
139
+ """Insert one row into session_logs. Never raises."""
140
+ try:
141
+ # Build the insert payload, omitting None values for cleanliness
142
+ payload = {k: v for k, v in record.items() if v is not None}
143
+ self._client.table(self.TABLE).insert(payload).execute()
144
+ except Exception as e:
145
+ # Supabase write failure is non-fatal — demote to stderr
146
+ print(f"[SessionLogger] Supabase write failed: {e}", file=sys.stderr)
147
+ # Mark Supabase as unavailable so we stop trying for this session
148
+ self._supabase_ok = False
149
+
150
+ # ------------------------------------------------------------------
151
+ # Convenience helpers
152
+ # ------------------------------------------------------------------
153
+
154
+ def log_start(self, stage: str) -> float:
155
+ """
156
+ Log that a pipeline stage has started.
157
+
158
+ Returns:
159
+ Monotonic start time (pass to log_end).
160
+ """
161
+ self.log(stage, f"{stage} started")
162
+ return time.monotonic()
163
+
164
+ def log_end(self, stage: str, start_time: float, error: str = None, **meta):
165
+ """
166
+ Log that a pipeline stage has ended, computing duration from start_time.
167
+
168
+ Args:
169
+ stage: Pipeline stage name (must match the one passed to log_start).
170
+ start_time: Value returned by the corresponding log_start call.
171
+ error: Optional error message if the stage failed.
172
+ **meta: Arbitrary key/value pairs stored in the meta column.
173
+ """
174
+ elapsed_ms = int((time.monotonic() - start_time) * 1000)
175
+ event = f"{stage} failed" if error else f"{stage} completed"
176
+ self.log(stage, event, duration_ms=elapsed_ms, error=error, **meta)
177
+
178
+
179
+ # ---------------------------------------------------------------------------
180
+ # Module-level singleton helpers
181
+ # ---------------------------------------------------------------------------
182
+
183
+ _current_logger: Optional[SessionLogger] = None
184
+
185
+
186
+ def get_session_logger() -> Optional[SessionLogger]:
187
+ """Return the active SessionLogger, or None if not yet initialised."""
188
+ return _current_logger
189
+
190
+
191
+ def init_session_logger(session_id: str, user_email: str = None) -> SessionLogger:
192
+ """
193
+ Create (or replace) the module-level SessionLogger.
194
+
195
+ Args:
196
+ session_id: Unique ID for this build session.
197
+ user_email: Email of the user running this session.
198
+
199
+ Returns:
200
+ The newly created SessionLogger instance.
201
+ """
202
+ global _current_logger
203
+ _current_logger = SessionLogger(session_id, user_email)
204
+ return _current_logger
smart_data_adjuster.py CHANGED
@@ -1,78 +1,139 @@
1
  """
2
- Smart Conversational Data Adjuster
3
 
4
- Understands liveboard context and asks smart qualifying questions.
5
- Bundles confirmations into one step when confident.
 
 
 
 
 
 
 
6
  """
7
 
8
  from typing import Dict, List, Optional, Tuple
9
  from snowflake_auth import get_snowflake_connection
10
  from thoughtspot_deployer import ThoughtSpotDeployer
11
  import json
12
- from llm_config import build_openai_chat_token_kwargs, is_openai_model_name, resolve_model_name
13
- from llm_client_factory import create_openai_client
14
 
15
 
16
  class SmartDataAdjuster:
17
- """Smart adjuster with liveboard context and conversational flow"""
18
-
19
- def __init__(self, database: str, schema: str, liveboard_guid: str, llm_model: str = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  self.database = database
21
  self.schema = schema
22
  self.liveboard_guid = liveboard_guid
23
- self.conn = None
24
- self.ts_client = None
25
-
26
- # LLM setup - require explicit model from settings/controller
27
  self.llm_model = (llm_model or "").strip()
28
  if not self.llm_model:
29
  raise ValueError("SmartDataAdjuster requires llm_model from settings.")
30
- self._llm_client = None
31
-
32
- # Context about the liveboard
33
- self.liveboard_name = None
34
- self.visualizations = [] # List of viz metadata
35
-
36
- def _call_llm(self, prompt: str) -> str:
37
- """Call the configured LLM (OpenAI GPT only)."""
38
- target_model = resolve_model_name(self.llm_model)
39
- if not is_openai_model_name(target_model):
40
- raise ValueError(
41
- f"SmartDataAdjuster only supports OpenAI GPT/Codex models. Received '{self.llm_model}'."
42
- )
43
 
44
- client = create_openai_client()
 
45
 
46
- request_kwargs = {
47
- "model": target_model,
48
- "messages": [{"role": "user", "content": prompt}],
49
- "temperature": 0,
50
- }
51
- request_kwargs.update(build_openai_chat_token_kwargs(target_model, 2000))
 
 
 
 
 
 
52
 
53
- response = client.chat.completions.create(**request_kwargs)
54
- return response.choices[0].message.content
55
-
56
  def connect(self):
57
- """Connect to Snowflake and ThoughtSpot"""
58
- # Snowflake
59
  self.conn = get_snowflake_connection()
60
  cursor = self.conn.cursor()
61
  cursor.execute(f"USE DATABASE {self.database}")
62
  cursor.execute(f'USE SCHEMA "{self.schema}"')
63
-
64
- # ThoughtSpot
65
- self.ts_client = ThoughtSpotDeployer()
 
 
66
  self.ts_client.authenticate()
67
-
68
- print(f"✅ Connected to {self.database}.{self.schema}")
69
- print(f"✅ Connected to ThoughtSpot")
70
-
71
- def load_liveboard_context(self):
72
- """Load liveboard metadata and visualization details"""
73
- print(f"\n📊 Loading liveboard context...")
74
-
75
- # Get liveboard metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  response = self.ts_client.session.post(
77
  f"{self.ts_client.base_url}/api/rest/2.0/metadata/search",
78
  json={
@@ -80,544 +141,490 @@ class SmartDataAdjuster:
80
  "include_visualization_headers": True
81
  }
82
  )
83
-
84
  if response.status_code != 200:
85
- print(f"❌ Failed to load liveboard")
86
  return False
87
-
88
  data = response.json()[0]
89
  self.liveboard_name = data.get('metadata_name', 'Unknown Liveboard')
90
-
91
- viz_headers = data.get('visualization_headers', [])
92
-
93
- print(f" Liveboard: {self.liveboard_name}")
94
- print(f" Visualizations: {len(viz_headers)}")
95
-
96
- # Extract viz details
97
- for viz in viz_headers:
98
  name = viz.get('name', '')
99
- viz_id = viz.get('id')
100
-
101
- # Skip note tiles
102
  if 'note-tile' in name.lower():
103
  continue
104
-
105
- # Parse the name to extract columns used
106
- # Names like "top 10 product_name by total revenue"
107
- viz_info = {
108
- 'id': viz_id,
109
- 'name': name,
110
- 'columns': self._extract_columns_from_name(name)
111
- }
112
-
113
- self.visualizations.append(viz_info)
114
- print(f" - {name}")
115
-
116
- return True
117
-
118
- def _extract_columns_from_name(self, name: str) -> List[str]:
119
- """Extract column names from visualization name"""
120
- # Simple heuristic: look for column-like words
121
- # e.g., "top 10 product_name by total revenue" → [product_name, total_revenue]
122
-
123
- columns = []
124
- name_lower = name.lower()
125
-
126
- # Common column patterns
127
- if 'product_name' in name_lower:
128
- columns.append('PRODUCT_NAME')
129
- if 'total revenue' in name_lower or 'total_revenue' in name_lower:
130
- columns.append('TOTAL_AMOUNT')
131
- if 'quantity' in name_lower:
132
- columns.append('QUANTITY_SOLD')
133
- if 'profit margin' in name_lower or 'profit_margin' in name_lower:
134
- columns.append('PROFIT_MARGIN')
135
- if 'customer_segment' in name_lower:
136
- columns.append('CUSTOMER_SEGMENT')
137
- if 'category' in name_lower:
138
- columns.append('CATEGORY')
139
- if 'seller' in name_lower:
140
- columns.append('SELLER_NAME')
141
-
142
- return columns
143
-
144
- def _simple_parse(self, message: str) -> Optional[Dict]:
145
- """Simple regex-based parser for common patterns like 'decrease phone case by 10%' or 'decrease seller acme by 10%'"""
146
- import re
147
-
148
- print(f"🔍 DEBUG _simple_parse: message='{message}'")
149
- msg_lower = message.lower()
150
-
151
- # Detect if user specified a viz number
152
- viz_match = re.search(r'(?:viz|visualization)\s+(\d+)', msg_lower)
153
- viz_number = int(viz_match.group(1)) if viz_match else None
154
-
155
- # Detect entity type (product or seller)
156
- # Check explicit "seller" keyword, or infer from viz number
157
- is_seller = 'seller' in msg_lower
158
-
159
- # If viz number is specified, check if it's a seller viz
160
- if viz_number and not is_seller and len(self.visualizations) >= viz_number:
161
- viz = self.visualizations[viz_number - 1]
162
- if 'seller' in viz['name'].lower():
163
- is_seller = True
164
-
165
- entity_type = 'seller' if is_seller else 'product'
166
-
167
- # Extract entity name - try quotes first, then words after action verbs
168
- entity_match = re.search(r'"([^"]+)"', message)
169
- if not entity_match:
170
- # Try to find entity name after action words, but stop before numbers
171
- # Include "seller" keyword if present
172
- if is_seller:
173
- # Match: "decrease seller home depot by 20%"
174
- action_pattern = r'(?:decrease|increase|make|set|adjust)\s+(?:the\s+)?(?:profit\s+margin\s+for\s+)?seller\s+([a-z\s]+?)(?:\s+\d|\s+by|\s+to|\s*$)'
175
- else:
176
- # Match: "decrease bluetooth speaker by 10%" OR "decrease the revenue for bluetooth speaker by 10%"
177
- action_pattern = r'(?:decrease|increase|make|set|adjust)\s+(?:the\s+)?(?:revenue\s+for\s+|profit\s+margin\s+for\s+)?([a-z\s]+?)(?:\s+\d|\s+by|\s+to|\s*$)'
178
- entity_match = re.search(action_pattern, msg_lower, re.I)
179
-
180
- if not entity_match:
181
  return None
182
-
183
- entity = entity_match.group(1).strip()
184
-
185
- # Find percentage or absolute value
186
- is_percentage = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  percentage = None
188
  target_value = None
189
-
190
- # Look for percentage like "by 10%" or "10%"
191
- pct_match = re.search(r'by\s+(\d+\.?\d*)%|(\d+\.?\d*)%', msg_lower)
192
  if pct_match:
193
- is_percentage = True
194
- percentage = float(pct_match.group(1) or pct_match.group(2))
195
- # Check if it's decrease or increase
196
- if 'decrease' in msg_lower or 'reduce' in msg_lower or 'lower' in msg_lower:
197
- percentage = -percentage
 
 
 
 
 
 
198
  else:
199
- # Look for absolute value like "50B", "50 billion", "1000000"
200
- val_match = re.search(r'(\d+\.?\d*)\s*([bBmMkK]|billion|million|thousand)?', message)
201
- if val_match:
202
- num = float(val_match.group(1))
203
- unit = (val_match.group(2) or '').lower()
204
- if unit in ['b', 'billion']:
205
- target_value = num * 1_000_000_000
206
- elif unit in ['m', 'million']:
207
- target_value = num * 1_000_000
208
- elif unit in ['k', 'thousand']:
209
- target_value = num * 1_000
210
- else:
211
- target_value = num
212
-
213
- if not is_percentage and not target_value:
214
  return None
215
-
216
- # Find appropriate viz and determine metric column
217
- # If user specified viz number, use it; otherwise search for matching viz
218
- if viz_number:
219
- viz_num = viz_number
220
- elif is_seller:
221
- # Look for seller-related viz
222
- viz_num = 1 # Default
223
- for i, viz in enumerate(self.visualizations, 1):
224
- if 'seller' in viz['name'].lower():
225
- viz_num = i
226
- break
227
- else:
228
- # Look for product-related viz
229
- viz_num = 1 # Default
230
- for i, viz in enumerate(self.visualizations, 1):
231
- if 'product' in viz['name'].lower():
232
- viz_num = i
233
- break
234
-
235
- # Determine metric based on entity type
236
- if is_seller:
237
- metric_column = 'PROFIT_MARGIN' # Sellers typically use profit margin
238
  else:
239
- metric_column = 'TOTAL_AMOUNT' # Products typically use revenue (column is TOTAL_AMOUNT)
240
-
241
- result = {
242
- 'viz_number': viz_num,
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  'entity_value': entity,
244
  'entity_type': entity_type,
245
- 'metric_column': metric_column,
246
- 'target_value': target_value,
247
  'is_percentage': is_percentage,
248
  'percentage': percentage,
 
249
  'confidence': 'medium',
250
- 'reasoning': f'Simple {entity_type} parse'
251
  }
252
-
253
- print(f"🔍 DEBUG _simple_parse result: entity='{entity}', percentage={percentage}, metric={metric_column}, viz_num={viz_num}")
254
-
255
- if viz_num <= len(self.visualizations):
256
- result['viz'] = self.visualizations[viz_num - 1]
257
-
258
- return result
259
-
260
  def match_request_to_viz(self, user_request: str) -> Optional[Dict]:
261
  """
262
- Use AI to match user request to specific visualization
263
-
264
- Returns:
265
- {
266
- 'viz': {...},
267
- 'confidence': 'high'|'medium'|'low',
268
- 'entity_value': '1080p Webcam',
269
- 'metric_column': 'TOTAL_AMOUNT',
270
- 'target_value': 50000000000
271
- }
272
  """
273
- # Try simple parse first (faster, no AI needed)
274
- simple_result = self._simple_parse(user_request)
275
- if simple_result:
276
- print(f" ⚡ Quick parse: '{simple_result['entity_value']}' {simple_result.get('percentage', simple_result.get('target_value'))}")
277
- return simple_result
278
-
279
- viz_list = "\n".join([
280
- f"{i+1}. {v['name']} (columns: {', '.join(v['columns'])})"
281
- for i, v in enumerate(self.visualizations)
282
- ])
283
-
284
- prompt = f"""User is looking at a ThoughtSpot liveboard and wants to adjust data.
285
-
286
- User request: "{user_request}"
287
-
288
- Available visualizations on the liveboard:
289
- {viz_list}
290
-
291
- Analyze the request and determine:
292
- 1. Which visualization (by number) is the user referring to?
293
- 2. What entity/product are they talking about? (e.g., "1080p Webcam")
294
- 3. What metric should be adjusted? (TOTAL_AMOUNT, QUANTITY_SOLD, PROFIT_MARGIN)
295
- 4. What's the target value?
296
- - If absolute value (e.g., "40B", "100M"): convert to number (40B = 40000000000)
297
- - If percentage increase (e.g., "increase by 20%"): set is_percentage=true and percentage=20
298
- 5. How confident are you? (high/medium/low)
299
-
300
- Return JSON:
301
- {{
302
- "viz_number": 1,
303
- "entity_value": "1080p Webcam",
304
- "metric_column": "TOTAL_AMOUNT",
305
- "target_value": 50000000000,
306
- "is_percentage": false,
307
- "percentage": null,
308
- "confidence": "high",
309
- "reasoning": "User mentioned product and the top 10 products viz uses PRODUCT_NAME and TOTAL_AMOUNT"
310
- }}
311
-
312
- OR for percentage increase:
313
  {{
314
- "viz_number": 1,
315
- "entity_value": "1080p Webcam",
316
- "metric_column": "TOTAL_AMOUNT",
317
- "target_value": null,
318
- "is_percentage": true,
319
- "percentage": 20,
320
- "confidence": "high",
321
- "reasoning": "User wants to increase revenue by 20%"
322
- }}
323
-
324
- CRITICAL: target_value and percentage must be numbers, never strings.
325
- If unsure about ANY field, set confidence to "low" or "medium".
326
- """
327
-
328
- content = self._call_llm(prompt)
329
- if content.startswith('```'):
330
- lines = content.split('\n')
331
- content = '\n'.join(lines[1:-1])
332
-
333
- try:
334
- result = json.loads(content)
335
-
336
- # Add the actual viz object
337
- viz_num = result.get('viz_number', 1)
338
- if 1 <= viz_num <= len(self.visualizations):
339
- result['viz'] = self.visualizations[viz_num - 1]
340
-
341
- return result
342
- except:
343
- return None
344
-
345
- def _find_closest_entity(self, entity_value: str, entity_type: str = 'product') -> Optional[str]:
346
- """Find the closest matching entity name (product or seller) in the database"""
347
- cursor = self.conn.cursor()
348
-
349
- # Get all entity names based on type
350
- if entity_type == 'seller':
351
- cursor.execute(f"""
352
- SELECT DISTINCT SELLER_NAME
353
- FROM {self.database}."{self.schema}".SELLERS
354
- """)
355
- else: # product
356
- cursor.execute(f"""
357
- SELECT DISTINCT PRODUCT_NAME
358
- FROM {self.database}."{self.schema}".PRODUCTS
359
- """)
360
-
361
- entities = [row[0] for row in cursor.fetchall()]
362
-
363
- # Normalize: lowercase and remove spaces for comparison
364
- def normalize(s):
365
- return s.lower().replace(' ', '').replace('-', '').replace('_', '')
366
-
367
- entity_normalized = normalize(entity_value)
368
-
369
- # First try exact case-insensitive match
370
- entity_lower = entity_value.lower()
371
- for entity in entities:
372
- if entity.lower() == entity_lower:
373
- return entity
374
-
375
- # Try normalized match (ignoring spaces/dashes)
376
- for entity in entities:
377
- if normalize(entity) == entity_normalized:
378
- return entity
379
-
380
- # Try partial match (contains)
381
- for entity in entities:
382
- if entity_lower in entity.lower() or entity.lower() in entity_lower:
383
- return entity
384
-
385
- # Try normalized partial match
386
- for entity in entities:
387
- if entity_normalized in normalize(entity) or normalize(entity) in entity_normalized:
388
- return entity
389
-
390
- return None
391
-
392
- def _find_closest_product(self, entity_value: str) -> Optional[str]:
393
- """Backward compatibility wrapper"""
394
- return self._find_closest_entity(entity_value, 'product')
395
-
396
- def get_current_value(self, entity_value: str, metric_column: str, entity_type: str = 'product') -> float:
397
- """Query current value from Snowflake"""
398
  cursor = self.conn.cursor()
399
-
400
- # Find closest matching entity
401
- matched_entity = self._find_closest_entity(entity_value, entity_type)
402
-
403
- if not matched_entity:
404
- print(f"⚠️ Could not find {entity_type} matching '{entity_value}'")
405
- return 0
406
-
407
- if matched_entity.lower() != entity_value.lower():
408
- print(f" 📝 Using closest match: '{matched_entity}'")
409
-
410
- # Build query based on entity type
411
- if entity_type == 'seller':
412
  query = f"""
413
- SELECT AVG(st.{metric_column})
414
- FROM {self.database}."{self.schema}".SALES_TRANSACTIONS st
415
- JOIN {self.database}."{self.schema}".SELLERS s ON st.SELLER_ID = s.SELLER_ID
416
- WHERE LOWER(s.SELLER_NAME) = LOWER('{matched_entity}')
 
417
  """
418
- else: # product
 
419
  query = f"""
420
- SELECT SUM(st.{metric_column})
421
- FROM {self.database}."{self.schema}".SALES_TRANSACTIONS st
422
- JOIN {self.database}."{self.schema}".PRODUCTS p ON st.PRODUCT_ID = p.PRODUCT_ID
423
- WHERE LOWER(p.PRODUCT_NAME) = LOWER('{matched_entity}')
424
  """
425
-
426
- cursor.execute(query)
427
- result = cursor.fetchone()
428
- return float(result[0]) if result and result[0] else 0
429
-
430
- def generate_strategy(self, entity_value: str, metric_column: str, current_value: float, target_value: float = None, percentage: float = None, entity_type: str = 'product') -> Dict:
431
- """Generate the best strategy (default to Strategy A for now)"""
432
-
433
- print(f"🔍 DEBUG generate_strategy: entity='{entity_value}', metric={metric_column}, percentage={percentage}, current={current_value}")
434
-
435
- # Find the actual entity name
436
- matched_entity = self._find_closest_entity(entity_value, entity_type)
437
- if not matched_entity:
438
- matched_entity = entity_value # Fallback
439
-
440
- print(f"🔍 DEBUG matched_entity: '{matched_entity}'")
441
-
442
- # Calculate multiplier
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
  if percentage is not None:
444
- # Percentage-based: "decrease by 10%" means multiply by 0.9
445
  multiplier = 1 + (percentage / 100)
446
- percentage_change = percentage
447
- target_value = current_value * multiplier
 
 
 
 
448
  else:
449
- # Absolute target value
450
- multiplier = target_value / current_value if current_value > 0 else 1
451
- percentage_change = (multiplier - 1) * 100
452
-
453
- # Build SQL based on entity type
454
- if entity_type == 'seller':
455
- sql = f"""UPDATE {self.database}."{self.schema}".SALES_TRANSACTIONS
 
 
 
456
  SET {metric_column} = {metric_column} * {multiplier:.6f}
457
- WHERE SELLER_ID IN (
458
- SELECT SELLER_ID FROM {self.database}."{self.schema}".SELLERS
459
- WHERE LOWER(SELLER_NAME) = LOWER('{matched_entity}')
 
460
  )"""
461
- else: # product
462
- sql = f"""UPDATE {self.database}."{self.schema}".SALES_TRANSACTIONS
463
  SET {metric_column} = {metric_column} * {multiplier:.6f}
464
- WHERE PRODUCT_ID IN (
465
- SELECT PRODUCT_ID FROM {self.database}."{self.schema}".PRODUCTS
466
- WHERE LOWER(PRODUCT_NAME) = LOWER('{matched_entity}')
467
- )"""
468
-
469
- print(f"🔍 DEBUG SQL generated:\n{sql}")
470
-
471
  return {
472
  'id': 'A',
473
- 'name': 'Distribute Across All Transactions',
474
- 'description': f"Multiply all transactions by {multiplier:.2f}x ({percentage_change:+.1f}%)",
475
  'sql': sql,
476
- 'matched_product': matched_entity, # Keep key name for compatibility
477
- 'target_value': target_value
478
  }
479
-
480
- def present_smart_confirmation(self, match: Dict, current_value: float, strategy: Dict) -> str:
481
- """Create a bundled confirmation prompt"""
482
-
483
- viz_name = match['viz']['name']
484
- entity = match['entity_value']
485
- matched_product = strategy.get('matched_product', entity)
486
- metric = match['metric_column']
487
- target = strategy.get('target_value', match.get('target_value')) # Use calculated target from strategy
488
- confidence = match['confidence']
489
-
490
- # Show if we fuzzy matched
491
- entity_display = entity
492
- if matched_product.lower() != entity.lower():
493
- entity_display = f"{entity} → '{matched_product}'"
494
-
495
- confirmation = f"""
496
- {'='*80}
497
- 📋 SMART CONFIRMATION
498
- {'='*80}
499
-
500
- Liveboard: {self.liveboard_name}
501
- Visualization: [{viz_name}]
502
-
503
- Adjustment:
504
- Entity: {entity_display}
505
- Metric: {metric}
506
- Current Value: ${current_value:,.0f}
507
- Target Value: ${target:,.0f}
508
- Change: ${target - current_value:+,.0f} ({(target/current_value - 1)*100:+.1f}%)
509
-
510
- Strategy: {strategy['name']}
511
- {strategy['description']}
512
-
513
- Confidence: {confidence.upper()}
514
- {match.get('reasoning', '')}
515
-
516
- SQL Preview:
517
- {strategy['sql'][:200]}...
518
 
519
- """
520
-
521
- if confidence == 'low':
522
- confirmation += "\n⚠️ Low confidence - please verify this is correct\n"
523
-
524
- confirmation += "\n" + "="*80 + "\n"
525
- return confirmation
526
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  def execute_sql(self, sql: str) -> Dict:
528
- """Execute the SQL update"""
529
- print(f"🔍 DEBUG execute_sql: About to execute SQL")
530
- print(f"SQL:\n{sql}")
531
  cursor = self.conn.cursor()
532
-
533
  try:
534
  cursor.execute(sql)
535
  rows_affected = cursor.rowcount
536
  self.conn.commit()
537
- print(f"✅ SQL executed successfully, rows affected: {rows_affected}")
538
-
539
- return {
540
- 'success': True,
541
- 'rows_affected': rows_affected
542
- }
543
  except Exception as e:
544
- self.conn.rollback()
545
- return {
546
- 'success': False,
547
- 'error': str(e)
548
- }
549
-
 
 
 
 
550
  def close(self):
551
- """Close connections"""
552
  if self.conn:
553
- self.conn.close()
554
-
555
-
556
- def test_smart_adjuster():
557
- """Test the smart adjuster"""
558
- from dotenv import load_dotenv
559
- load_dotenv()
560
-
561
- print("""
562
- ╔════════════════════════════════════════════════════════════╗
563
- ║ ║
564
- ║ Smart Data Adjuster Test ║
565
- ║ ║
566
- ╚═══════════════════════════════��════════════════════════════╝
567
- """)
568
-
569
- adjuster = SmartDataAdjuster(
570
- database=get_admin_setting('SNOWFLAKE_DATABASE'),
571
- schema="20251116_140933_AMAZO_SAL",
572
- liveboard_guid="9a30c9e4-efba-424a-8359-b16eb3a43ec3"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  )
574
-
575
- adjuster.connect()
576
- adjuster.load_liveboard_context()
577
-
578
- # Test request
579
- user_request = "make 1080p Webcam 50 billion"
580
- print(f"\n💬 User: \"{user_request}\"")
581
-
582
- # Match to viz
583
- print(f"\n🤔 Analyzing request...")
584
- match = adjuster.match_request_to_viz(user_request)
585
-
586
- if not match:
587
- print("❌ Could not understand request")
588
- return
589
-
590
- # Get current value
591
- current = adjuster.get_current_value(match['entity_value'], match['metric_column'])
592
-
593
- # Generate strategy (handle both absolute and percentage)
594
- strategy = adjuster.generate_strategy(
595
- match['entity_value'],
596
- match['metric_column'],
597
- current,
598
- target_value=match.get('target_value'),
599
- percentage=match.get('percentage')
 
 
 
 
 
 
 
 
 
 
 
 
600
  )
601
-
602
- # Present confirmation
603
- confirmation = adjuster.present_smart_confirmation(match, current, strategy)
604
- print(confirmation)
605
-
606
- # Ask for confirmation
607
- response = input("Run SQL? [yes/no]: ").strip().lower()
608
-
609
- if response in ['yes', 'y']:
610
- result = adjuster.execute_sql(strategy['sql'])
611
- if result['success']:
612
- print(f"\n✅ Success! Updated {result['rows_affected']} rows")
613
- else:
614
- print(f"\n❌ Failed: {result['error']}")
615
- else:
616
- print("\n❌ Cancelled")
617
-
618
- adjuster.close()
619
 
 
 
 
 
620
 
621
- if __name__ == "__main__":
622
- test_smart_adjuster()
 
 
 
 
 
 
 
623
 
 
 
 
 
 
 
 
 
1
  """
2
+ Smart Data Adjuster
3
 
4
+ Understands liveboard and schema context; handles conversational, multi-turn
5
+ data adjustment requests in natural language.
6
+
7
+ Connects to:
8
+ - ThoughtSpot (to load liveboard viz context)
9
+ - Snowflake (to query and update data)
10
+
11
+ Works with any configured LLM (Claude, GPT-4, etc.) via litellm.
12
+ Schema is discovered dynamically — no hardcoded table names.
13
  """
14
 
15
  from typing import Dict, List, Optional, Tuple
16
  from snowflake_auth import get_snowflake_connection
17
  from thoughtspot_deployer import ThoughtSpotDeployer
18
  import json
19
+ import re
20
+ from llm_config import resolve_model_name
21
 
22
 
23
  class SmartDataAdjuster:
24
+ """
25
+ Conversational data adjuster with liveboard context and schema discovery.
26
+
27
+ Usage:
28
+ adjuster = SmartDataAdjuster(database, schema, liveboard_guid, llm_model)
29
+ adjuster.connect()
30
+ adjuster.load_liveboard_context()
31
+
32
+ # Per user message:
33
+ result = adjuster.handle_message("make webcam revenue 40B")
34
+ # result: {'type': 'confirmation', 'text': '...', 'pending': {...}}
35
+ # or {'type': 'result', 'text': '...'}
36
+ # or {'type': 'error', 'text': '...'}
37
+ """
38
+
39
+ def __init__(self, database: str, schema: str, liveboard_guid: str,
40
+ llm_model: str = None, ts_url: str = None, ts_secret: str = None):
41
  self.database = database
42
  self.schema = schema
43
  self.liveboard_guid = liveboard_guid
44
+ self.ts_url = (ts_url or "").strip() or None
45
+ self.ts_secret = (ts_secret or "").strip() or None
46
+
 
47
  self.llm_model = (llm_model or "").strip()
48
  if not self.llm_model:
49
  raise ValueError("SmartDataAdjuster requires llm_model from settings.")
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ self.conn = None
52
+ self.ts_client = None
53
 
54
+ # Populated by load_liveboard_context()
55
+ self.liveboard_name: Optional[str] = None
56
+ self.visualizations: List[Dict] = []
57
+
58
+ # Populated by _discover_schema() in connect()
59
+ self.schema_tables: Dict[str, List[str]] = {} # table → [col, ...]
60
+ self.fact_tables: List[str] = [] # tables likely to be updated
61
+ self.dimension_tables: Dict[str, str] = {} # table → name_column
62
+
63
+ # ------------------------------------------------------------------
64
+ # Connection & schema discovery
65
+ # ------------------------------------------------------------------
66
 
 
 
 
67
  def connect(self):
68
+ """Connect to Snowflake and ThoughtSpot, then discover schema."""
 
69
  self.conn = get_snowflake_connection()
70
  cursor = self.conn.cursor()
71
  cursor.execute(f"USE DATABASE {self.database}")
72
  cursor.execute(f'USE SCHEMA "{self.schema}"')
73
+
74
+ self.ts_client = ThoughtSpotDeployer(
75
+ base_url=self.ts_url or None,
76
+ secret_key=self.ts_secret or None,
77
+ )
78
  self.ts_client.authenticate()
79
+
80
+ self._discover_schema()
81
+
82
+ def _discover_schema(self):
83
+ """Read actual table/column structure from INFORMATION_SCHEMA."""
84
+ cursor = self.conn.cursor()
85
+ cursor.execute(f"""
86
+ SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE
87
+ FROM {self.database}.INFORMATION_SCHEMA.COLUMNS
88
+ WHERE TABLE_SCHEMA = '{self.schema}'
89
+ ORDER BY TABLE_NAME, ORDINAL_POSITION
90
+ """)
91
+ raw: Dict[str, List[Dict]] = {}
92
+ for table, column, dtype in cursor.fetchall():
93
+ raw.setdefault(table, []).append({'name': column, 'type': dtype.upper()})
94
+
95
+ self.schema_tables = {t: [c['name'] for c in cols] for t, cols in raw.items()}
96
+
97
+ # Heuristic: dimension tables have a _NAME column; fact tables have date + numeric cols
98
+ for table, cols in raw.items():
99
+ col_names = [c['name'] for c in cols]
100
+ col_types = {c['name']: c['type'] for c in cols}
101
+
102
+ name_cols = [c for c in col_names if c.endswith('_NAME')]
103
+ num_cols = [c for c in col_names
104
+ if any(t in col_types.get(c, '') for t in ('NUMBER', 'FLOAT', 'INT', 'DECIMAL', 'NUMERIC'))]
105
+ date_cols = [c for c in col_names
106
+ if any(t in col_types.get(c, '') for t in ('DATE', 'TIME', 'TIMESTAMP'))]
107
+
108
+ if name_cols:
109
+ # Use the first _NAME column as the entity name column
110
+ self.dimension_tables[table] = name_cols[0]
111
+ if num_cols and date_cols:
112
+ self.fact_tables.append(table)
113
+
114
+ # If nothing looks like a fact table, fall back to largest table
115
+ if not self.fact_tables and self.schema_tables:
116
+ self.fact_tables = list(self.schema_tables.keys())
117
+
118
+ def _call_llm(self, prompt: str) -> str:
119
+ """Call the configured LLM via litellm (supports all providers)."""
120
+ from prompt_logger import logged_completion
121
+ model = resolve_model_name(self.llm_model)
122
+ response = logged_completion(
123
+ stage="data_adjuster",
124
+ model=model,
125
+ messages=[{"role": "user", "content": prompt}],
126
+ temperature=0,
127
+ max_tokens=1000,
128
+ )
129
+ return response.choices[0].message.content.strip()
130
+
131
+ # ------------------------------------------------------------------
132
+ # Liveboard context
133
+ # ------------------------------------------------------------------
134
+
135
+ def load_liveboard_context(self) -> bool:
136
+ """Load liveboard metadata and visualization list from ThoughtSpot."""
137
  response = self.ts_client.session.post(
138
  f"{self.ts_client.base_url}/api/rest/2.0/metadata/search",
139
  json={
 
141
  "include_visualization_headers": True
142
  }
143
  )
 
144
  if response.status_code != 200:
 
145
  return False
146
+
147
  data = response.json()[0]
148
  self.liveboard_name = data.get('metadata_name', 'Unknown Liveboard')
149
+
150
+ for viz in data.get('visualization_headers', []):
 
 
 
 
 
 
151
  name = viz.get('name', '')
 
 
 
152
  if 'note-tile' in name.lower():
153
  continue
154
+ self.visualizations.append({'id': viz.get('id'), 'name': name})
155
+
156
+ return bool(self.visualizations)
157
+
158
+ # ------------------------------------------------------------------
159
+ # Entity matching (schema-aware)
160
+ # ------------------------------------------------------------------
161
+
162
+ def _fuzzy_match(self, target: str, candidates: List[str]) -> Optional[str]:
163
+ """Return the best matching candidate for target, or None."""
164
+ def norm(s):
165
+ return s.lower().replace(' ', '').replace('-', '').replace('_', '')
166
+
167
+ t = norm(target)
168
+ t_lower = target.lower()
169
+
170
+ for c in candidates:
171
+ if c.lower() == t_lower:
172
+ return c
173
+ for c in candidates:
174
+ if norm(c) == t:
175
+ return c
176
+ for c in candidates:
177
+ if t_lower in c.lower() or c.lower() in t_lower:
178
+ return c
179
+ for c in candidates:
180
+ if t in norm(c) or norm(c) in t:
181
+ return c
182
+ return None
183
+
184
+ def _find_entity(self, entity_value: str, entity_type_hint: str = None) -> Tuple[Optional[str], Optional[str], Optional[str]]:
185
+ """
186
+ Find the closest matching entity name in any dimension table.
187
+
188
+ Returns: (matched_name, table_name, name_column) or (None, None, None)
189
+ """
190
+ cursor = self.conn.cursor()
191
+
192
+ # Sort dimension tables: prefer ones whose name matches entity_type_hint
193
+ tables_to_try = list(self.dimension_tables.items())
194
+ if entity_type_hint:
195
+ hint = entity_type_hint.lower()
196
+ tables_to_try.sort(key=lambda x: 0 if hint in x[0].lower() else 1)
197
+
198
+ for table, name_col in tables_to_try:
199
+ cursor.execute(f'SELECT DISTINCT {name_col} FROM {self.database}."{self.schema}".{table}')
200
+ candidates = [row[0] for row in cursor.fetchall() if row[0]]
201
+ match = self._fuzzy_match(entity_value, candidates)
202
+ if match:
203
+ return match, table, name_col
204
+
205
+ return None, None, None
206
+
207
+ def _find_fact_join(self, dim_table: str) -> Optional[Tuple[str, str, str]]:
208
+ """
209
+ Find a fact table that has an FK column referencing dim_table.
210
+ Returns: (fact_table, fact_fk_column, dim_pk_column) or None.
211
+ """
212
+ # Look for an ID column in dim_table
213
+ dim_cols = self.schema_tables.get(dim_table, [])
214
+ dim_id = next((c for c in dim_cols if c.endswith('_ID')), None)
215
+ if not dim_id:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  return None
217
+
218
+ # Look for that column in fact tables
219
+ for ft in self.fact_tables:
220
+ if ft == dim_table:
221
+ continue
222
+ ft_cols = self.schema_tables.get(ft, [])
223
+ if dim_id in ft_cols:
224
+ return ft, dim_id, dim_id
225
+
226
+ return None
227
+
228
+ # ------------------------------------------------------------------
229
+ # Request interpretation
230
+ # ------------------------------------------------------------------
231
+
232
+ def _parse_request_simple(self, message: str) -> Optional[Dict]:
233
+ """
234
+ Fast regex-based parser for common patterns:
235
+ "decrease webcam by 10%", "make laptop 50B", "increase revenue for acme by 20%"
236
+ Returns parsed dict or None if pattern not matched.
237
+ """
238
+ msg_lower = message.lower()
239
+
240
+ # Percentage match: "by 20%", "-10%"
241
+ pct_match = re.search(r'by\s+(-?\d+\.?\d*)%|(-?\d+\.?\d*)%', msg_lower)
242
+ # Absolute value: "50B", "50 billion", "1.5M", "1000000"
243
+ val_match = re.search(r'(\d+\.?\d*)\s*(b(?:illion)?|m(?:illion)?|k(?:thousand)?)\b', msg_lower)
244
+ bare_num = re.search(r'\b(\d{4,})\b', message) # bare large integer
245
+
246
+ is_percentage = bool(pct_match)
247
  percentage = None
248
  target_value = None
249
+
 
 
250
  if pct_match:
251
+ raw_pct = float(pct_match.group(1) or pct_match.group(2))
252
+ if any(w in msg_lower for w in ('decrease', 'reduce', 'lower', 'drop', 'cut')):
253
+ raw_pct = -abs(raw_pct)
254
+ percentage = raw_pct
255
+ elif val_match:
256
+ num = float(val_match.group(1))
257
+ unit = val_match.group(2)[0].lower()
258
+ multipliers = {'b': 1e9, 'm': 1e6, 'k': 1e3}
259
+ target_value = num * multipliers.get(unit, 1)
260
+ elif bare_num:
261
+ target_value = float(bare_num.group(1))
262
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  return None
264
+
265
+ # Extract entity name: quoted or after action verb
266
+ entity = None
267
+ quoted = re.search(r'"([^"]+)"', message)
268
+ if quoted:
269
+ entity = quoted.group(1).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  else:
271
+ # "make/set/increase/decrease/adjust <entity> [by/to]"
272
+ action_pat = r'(?:make|set|increase|decrease|reduce|boost|lower|adjust|change)\s+(?:the\s+)?(?:\w+\s+(?:for|of)\s+)?([a-z0-9][\w\s-]*?)(?:\s+(?:by|to|revenue|sales|at)\b|\s+\d|$)'
273
+ am = re.search(action_pat, msg_lower, re.I)
274
+ if am:
275
+ entity = am.group(1).strip()
276
+
277
+ if not entity:
278
+ return None
279
+
280
+ # Detect entity type from keywords
281
+ entity_type = None
282
+ for kw in ('seller', 'vendor', 'customer', 'product', 'item', 'region', 'store'):
283
+ if kw in msg_lower:
284
+ entity_type = kw
285
+ break
286
+
287
+ return {
288
  'entity_value': entity,
289
  'entity_type': entity_type,
 
 
290
  'is_percentage': is_percentage,
291
  'percentage': percentage,
292
+ 'target_value': target_value,
293
  'confidence': 'medium',
 
294
  }
295
+
 
 
 
 
 
 
 
296
  def match_request_to_viz(self, user_request: str) -> Optional[Dict]:
297
  """
298
+ Parse request and enrich with schema context.
299
+ Returns structured match dict or None.
 
 
 
 
 
 
 
 
300
  """
301
+ result = self._parse_request_simple(user_request)
302
+
303
+ if not result:
304
+ # Fall back to LLM for complex requests
305
+ schema_summary = "\n".join(
306
+ f" {t}: {', '.join(cols[:8])}" + (' ...' if len(cols) > 8 else '')
307
+ for t, cols in self.schema_tables.items()
308
+ )
309
+ viz_summary = "\n".join(f" {i+1}. {v['name']}" for i, v in enumerate(self.visualizations))
310
+ prompt = f"""Parse this data adjustment request.
311
+
312
+ Request: "{user_request}"
313
+
314
+ Snowflake schema tables:
315
+ {schema_summary}
316
+
317
+ Liveboard visualizations:
318
+ {viz_summary}
319
+
320
+ Return JSON with these fields (numbers only, not strings):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  {{
322
+ "entity_value": "the entity name to adjust (e.g. '1080p Webcam')",
323
+ "entity_type": "product|seller|customer|region|null",
324
+ "is_percentage": true|false,
325
+ "percentage": <number or null>,
326
+ "target_value": <number or null>,
327
+ "metric_hint": "keyword like 'revenue', 'sales', 'profit_margin', or null",
328
+ "confidence": "high|medium|low"
329
+ }}"""
330
+ try:
331
+ raw = self._call_llm(prompt)
332
+ if raw.startswith('```'):
333
+ raw = '\n'.join(raw.split('\n')[1:-1])
334
+ result = json.loads(raw)
335
+ except Exception:
336
+ return None
337
+
338
+ return result if result else None
339
+
340
+ # ------------------------------------------------------------------
341
+ # Value retrieval & SQL generation
342
+ # ------------------------------------------------------------------
343
+
344
+ def get_current_value(self, entity_value: str, metric_column: str,
345
+ entity_type: str = None) -> Tuple[float, Optional[str], Optional[str], Optional[str]]:
346
+ """
347
+ Query current aggregate value for entity from Snowflake.
348
+
349
+ Returns: (current_value, matched_entity, dim_table, fact_table)
350
+ """
351
+ matched, dim_table, name_col = self._find_entity(entity_value, entity_type)
352
+ if not matched:
353
+ return 0.0, None, None, None
354
+
355
+ join_info = self._find_fact_join(dim_table) if dim_table else None
356
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  cursor = self.conn.cursor()
358
+ if join_info:
359
+ fact_table, fk_col, dim_pk_col = join_info
360
+ dim_cols = self.schema_tables.get(dim_table, [])
361
+ dim_pk = next((c for c in dim_cols if c.endswith('_ID')), dim_pk_col)
 
 
 
 
 
 
 
 
 
362
  query = f"""
363
+ SELECT SUM(f.{metric_column})
364
+ FROM {self.database}."{self.schema}".{fact_table} f
365
+ JOIN {self.database}."{self.schema}".{dim_table} d
366
+ ON f.{fk_col} = d.{dim_pk}
367
+ WHERE LOWER(d.{name_col}) = LOWER('{matched}')
368
  """
369
+ else:
370
+ # entity is directly in the table with the metric
371
  query = f"""
372
+ SELECT SUM({metric_column})
373
+ FROM {self.database}."{self.schema}".{dim_table}
374
+ WHERE LOWER({name_col}) = LOWER('{matched}')
 
375
  """
376
+
377
+ try:
378
+ cursor.execute(query)
379
+ row = cursor.fetchone()
380
+ value = float(row[0]) if row and row[0] is not None else 0.0
381
+ fact_table_used = join_info[0] if join_info else dim_table
382
+ return value, matched, dim_table, fact_table_used
383
+ except Exception as e:
384
+ print(f"[SmartDataAdjuster] get_current_value query failed: {e}")
385
+ return 0.0, None, None, None
386
+
387
+ def _pick_metric_column(self, metric_hint: str = None) -> Optional[str]:
388
+ """Choose the best metric column from fact tables based on hint."""
389
+ # Build a list of all numeric-looking columns across fact tables
390
+ candidates = []
391
+ for ft in self.fact_tables:
392
+ for col in self.schema_tables.get(ft, []):
393
+ if any(kw in col.upper() for kw in ('AMOUNT', 'REVENUE', 'TOTAL', 'SALES', 'VALUE', 'MARGIN', 'PROFIT', 'COST', 'PRICE')):
394
+ candidates.append((ft, col))
395
+
396
+ if not candidates:
397
+ return None
398
+
399
+ if metric_hint:
400
+ hint_upper = metric_hint.upper()
401
+ for ft, col in candidates:
402
+ if hint_upper in col:
403
+ return col
404
+
405
+ # Default: prefer TOTAL_AMOUNT, REVENUE, then first available
406
+ for preferred in ('TOTAL_AMOUNT', 'TOTAL_REVENUE', 'REVENUE', 'AMOUNT'):
407
+ for ft, col in candidates:
408
+ if col == preferred:
409
+ return col
410
+
411
+ return candidates[0][1] if candidates else None
412
+
413
+ def generate_strategy(self, entity_value: str, metric_column: str,
414
+ current_value: float, target_value: float = None,
415
+ percentage: float = None, entity_type: str = None) -> Dict:
416
+ """Generate an UPDATE strategy based on the adjustment request."""
417
+ matched, dim_table, name_col = self._find_entity(entity_value, entity_type)
418
+ if not matched:
419
+ matched = entity_value
420
+
421
  if percentage is not None:
 
422
  multiplier = 1 + (percentage / 100)
423
+ pct_change = percentage
424
+ if target_value is None:
425
+ target_value = current_value * multiplier
426
+ elif target_value and current_value > 0:
427
+ multiplier = target_value / current_value
428
+ pct_change = (multiplier - 1) * 100
429
  else:
430
+ multiplier = 1.0
431
+ pct_change = 0.0
432
+
433
+ join_info = self._find_fact_join(dim_table) if dim_table else None
434
+
435
+ if join_info:
436
+ fact_table, fk_col, _ = join_info
437
+ dim_cols = self.schema_tables.get(dim_table, [])
438
+ dim_pk = next((c for c in dim_cols if c.endswith('_ID')), fk_col)
439
+ sql = f"""UPDATE {self.database}."{self.schema}".{fact_table}
440
  SET {metric_column} = {metric_column} * {multiplier:.6f}
441
+ WHERE {fk_col} IN (
442
+ SELECT {dim_pk}
443
+ FROM {self.database}."{self.schema}".{dim_table}
444
+ WHERE LOWER({name_col}) = LOWER('{matched}')
445
  )"""
446
+ elif dim_table:
447
+ sql = f"""UPDATE {self.database}."{self.schema}".{dim_table}
448
  SET {metric_column} = {metric_column} * {multiplier:.6f}
449
+ WHERE LOWER({name_col}) = LOWER('{matched}')"""
450
+ else:
451
+ sql = f"-- Could not determine table structure for '{entity_value}'"
452
+
 
 
 
453
  return {
454
  'id': 'A',
455
+ 'name': 'Scale All Transactions',
456
+ 'description': f"Multiply all rows for '{matched}' by {multiplier:.3f}x ({pct_change:+.1f}%)",
457
  'sql': sql,
458
+ 'matched_entity': matched,
459
+ 'target_value': target_value,
460
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
+ def present_smart_confirmation(self, match: Dict, current_value: float,
463
+ strategy: Dict, metric_column: str) -> str:
464
+ """Format a human-readable confirmation message."""
465
+ entity = match.get('entity_value', '?')
466
+ matched = strategy.get('matched_entity', entity)
467
+ target = strategy.get('target_value', 0) or 0
468
+
469
+ if matched.lower() != entity.lower():
470
+ entity_display = f"{entity} → **{matched}**"
471
+ else:
472
+ entity_display = f"**{matched}**"
473
+
474
+ change = target - current_value
475
+ pct = (change / current_value * 100) if current_value else 0
476
+
477
+ lines = [
478
+ f"**Liveboard:** {self.liveboard_name}",
479
+ f"**Entity:** {entity_display}",
480
+ f"**Metric:** `{metric_column}`",
481
+ f"**Current:** {current_value:,.0f}",
482
+ f"**Target:** {target:,.0f} ({change:+,.0f} / {pct:+.1f}%)",
483
+ f"**Strategy:** {strategy['description']}",
484
+ "",
485
+ f"```sql\n{strategy['sql']}\n```",
486
+ ]
487
+
488
+ if match.get('confidence') == 'low':
489
+ lines.append("\n⚠️ Low confidence — please verify before confirming.")
490
+
491
+ return "\n".join(lines)
492
+
493
+ # ------------------------------------------------------------------
494
+ # SQL execution
495
+ # ------------------------------------------------------------------
496
+
497
  def execute_sql(self, sql: str) -> Dict:
498
+ """Execute an UPDATE statement. Returns success/error dict."""
 
 
499
  cursor = self.conn.cursor()
 
500
  try:
501
  cursor.execute(sql)
502
  rows_affected = cursor.rowcount
503
  self.conn.commit()
504
+ return {'success': True, 'rows_affected': rows_affected}
 
 
 
 
 
505
  except Exception as e:
506
+ try:
507
+ self.conn.rollback()
508
+ except Exception:
509
+ pass
510
+ return {'success': False, 'error': str(e)}
511
+
512
+ # ------------------------------------------------------------------
513
+ # Teardown
514
+ # ------------------------------------------------------------------
515
+
516
  def close(self):
517
+ """Close Snowflake connection."""
518
  if self.conn:
519
+ try:
520
+ self.conn.close()
521
+ except Exception:
522
+ pass
523
+
524
+
525
+ # ---------------------------------------------------------------------------
526
+ # Liveboard-first context loader
527
+ # ---------------------------------------------------------------------------
528
+
529
+ def load_context_from_liveboard(liveboard_guid: str, ts_client) -> dict:
530
+ """
531
+ Resolve Snowflake database/schema from a liveboard GUID.
532
+
533
+ Flow:
534
+ liveboard TML (export_fqn=True)
535
+ model GUID from visualizations[n].answer.tables[0].fqn
536
+ → model TML
537
+ → database / schema from model.tables[0].table.{db, schema}
538
+
539
+ Args:
540
+ liveboard_guid: ThoughtSpot liveboard GUID
541
+ ts_client: Authenticated ThoughtSpotDeployer instance
542
+
543
+ Returns:
544
+ dict with keys: liveboard_name, model_guid, model_name, database, schema
545
+
546
+ Raises:
547
+ ValueError if any step fails to resolve.
548
+ """
549
+ import yaml
550
+
551
+ # Step 1: Export liveboard TML with FQNs
552
+ response = ts_client.session.post(
553
+ f"{ts_client.base_url}/api/rest/2.0/metadata/tml/export",
554
+ json={
555
+ "metadata": [{"identifier": liveboard_guid}],
556
+ "export_associated": False,
557
+ "export_fqn": True,
558
+ "format_type": "YAML",
559
+ }
560
  )
561
+ if response.status_code != 200:
562
+ raise ValueError(
563
+ f"Failed to export liveboard TML ({response.status_code}): {response.text[:300]}"
564
+ )
565
+
566
+ tml_data = response.json()
567
+ if not tml_data:
568
+ raise ValueError("Empty response from liveboard TML export")
569
+
570
+ lb_tml = yaml.safe_load(tml_data[0]['edoc'])
571
+ liveboard_name = lb_tml.get('liveboard', {}).get('name', 'Unknown Liveboard')
572
+
573
+ # Step 2: Find model GUID from first visualization with answer.tables[].fqn
574
+ model_guid = None
575
+ for viz in lb_tml.get('liveboard', {}).get('visualizations', []):
576
+ for t in viz.get('answer', {}).get('tables', []):
577
+ fqn = t.get('fqn')
578
+ if fqn:
579
+ model_guid = fqn
580
+ break
581
+ if model_guid:
582
+ break
583
+
584
+ if not model_guid:
585
+ raise ValueError(
586
+ "Could not find model GUID in liveboard TML — "
587
+ "make sure the liveboard has at least one answer-based visualization."
588
+ )
589
+
590
+ # Step 3: Export model TML to get database/schema
591
+ response = ts_client.session.post(
592
+ f"{ts_client.base_url}/api/rest/2.0/metadata/tml/export",
593
+ json={
594
+ "metadata": [{"identifier": model_guid, "type": "LOGICAL_TABLE"}],
595
+ "export_associated": False,
596
+ "export_fqn": True,
597
+ "format_type": "YAML",
598
+ }
599
  )
600
+ if response.status_code != 200:
601
+ raise ValueError(
602
+ f"Failed to export model TML ({response.status_code}): {response.text[:300]}"
603
+ )
604
+
605
+ tml_data = response.json()
606
+ model_tml = yaml.safe_load(tml_data[0]['edoc'])
607
+ model_name = model_tml.get('model', {}).get('name', 'Unknown Model')
 
 
 
 
 
 
 
 
 
 
608
 
609
+ # Step 4: Extract db/schema from first model table entry
610
+ tables = model_tml.get('model', {}).get('tables', [])
611
+ if not tables:
612
+ raise ValueError("No tables found in model TML")
613
 
614
+ first_table = tables[0].get('table', {})
615
+ database = first_table.get('db')
616
+ schema = first_table.get('schema')
617
+
618
+ if not database or not schema:
619
+ raise ValueError(
620
+ f"Could not resolve database/schema from model TML "
621
+ f"(db={database!r}, schema={schema!r})"
622
+ )
623
 
624
+ return {
625
+ 'liveboard_name': liveboard_name,
626
+ 'model_guid': model_guid,
627
+ 'model_name': model_name,
628
+ 'database': database,
629
+ 'schema': schema,
630
+ }
snowflake_auth.py CHANGED
@@ -6,14 +6,10 @@ Supports multiple private key formats for cloud deployment flexibility:
6
  - Base64-encoded PEM (single line, recommended for HF Spaces)
7
  - Newline-escaped PEM (\\n replaced with actual newlines)
8
  """
9
- import os
10
  import base64
11
- from dotenv import load_dotenv
12
  from cryptography.hazmat.primitives import serialization
13
  from cryptography.hazmat.primitives.serialization import load_pem_private_key
14
-
15
- # Load environment variables (no-op if not using .env file)
16
- load_dotenv()
17
 
18
 
19
  def _decode_private_key(raw_key: str) -> str:
@@ -79,35 +75,8 @@ def get_snowflake_connection_params():
79
  Returns:
80
  dict: Connection parameters for snowflake.connector.connect()
81
  """
82
- # Get private key from environment
83
- private_key_raw = os.getenv('SNOWFLAKE_KP_PK')
84
-
85
- # Fallback: try to read from .env file directly (for local development)
86
- if not private_key_raw:
87
- try:
88
- env_path = '.env'
89
- if os.path.exists(env_path):
90
- with open(env_path, 'r') as f:
91
- content = f.read()
92
- # Find the private key section manually
93
- start_marker = 'SNOWFLAKE_KP_PK=-----BEGIN'
94
-
95
- if start_marker in content:
96
- start_idx = content.find('-----BEGIN', content.find(start_marker))
97
- end_idx = content.find('-----END', start_idx)
98
- if end_idx != -1:
99
- end_idx = content.find('-----', end_idx + 7) + 5
100
- private_key_raw = content[start_idx:end_idx].strip()
101
- print("✅ Loaded private key from .env file")
102
- except Exception as e:
103
- print(f"⚠️ Could not read .env file: {e}")
104
-
105
- if not private_key_raw:
106
- raise ValueError(
107
- "SNOWFLAKE_KP_PK environment variable not set or could not be parsed.\n"
108
- "For HF Spaces: Add as a Secret in Settings > Repository Secrets\n"
109
- "Supported formats: Direct PEM, Base64-encoded PEM, or escaped newlines"
110
- )
111
 
112
  # Decode the private key from various formats
113
  private_key_pem = _decode_private_key(private_key_raw)
@@ -116,7 +85,7 @@ def get_snowflake_connection_params():
116
  password = None
117
  if 'ENCRYPTED' in private_key_pem:
118
  # Try to get password from environment if key is encrypted
119
- password = os.getenv('SNOWFLAKE_KP_PASSPHRASE')
120
  if password:
121
  password = password.encode()
122
  print("✅ Using passphrase for encrypted private key")
@@ -130,8 +99,8 @@ def get_snowflake_connection_params():
130
  except Exception as e:
131
  if 'ENCRYPTED' in private_key_pem and not password:
132
  raise ValueError(
133
- "Private key is encrypted but SNOWFLAKE_KP_PASSPHRASE not provided.\n"
134
- "Add SNOWFLAKE_KP_PASSPHRASE as a Secret in HF Spaces settings."
135
  )
136
  else:
137
  raise ValueError(
@@ -148,14 +117,15 @@ def get_snowflake_connection_params():
148
  )
149
 
150
  # Return connection parameters
 
151
  return {
152
- 'user': os.getenv('SNOWFLAKE_KP_USER'),
153
  'private_key': private_key_bytes,
154
- 'account': os.getenv('SNOWFLAKE_ACCOUNT'),
155
- 'role': os.getenv('SNOWFLAKE_ROLE'),
156
- 'warehouse': os.getenv('SNOWFLAKE_WAREHOUSE'),
157
- 'database': os.getenv('SNOWFLAKE_DATABASE'),
158
- 'schema': os.getenv('SNOWFLAKE_SCHEMA', 'PUBLIC'),
159
  }
160
 
161
  def get_snowflake_connection():
 
6
  - Base64-encoded PEM (single line, recommended for HF Spaces)
7
  - Newline-escaped PEM (\\n replaced with actual newlines)
8
  """
 
9
  import base64
 
10
  from cryptography.hazmat.primitives import serialization
11
  from cryptography.hazmat.primitives.serialization import load_pem_private_key
12
+ from supabase_client import get_admin_setting
 
 
13
 
14
 
15
  def _decode_private_key(raw_key: str) -> str:
 
75
  Returns:
76
  dict: Connection parameters for snowflake.connector.connect()
77
  """
78
+ # Source of truth: admin settings in Supabase
79
+ private_key_raw = get_admin_setting('SNOWFLAKE_KP_PK')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # Decode the private key from various formats
82
  private_key_pem = _decode_private_key(private_key_raw)
 
85
  password = None
86
  if 'ENCRYPTED' in private_key_pem:
87
  # Try to get password from environment if key is encrypted
88
+ password = get_admin_setting('SNOWFLAKE_KP_PASSPHRASE', required=False)
89
  if password:
90
  password = password.encode()
91
  print("✅ Using passphrase for encrypted private key")
 
99
  except Exception as e:
100
  if 'ENCRYPTED' in private_key_pem and not password:
101
  raise ValueError(
102
+ "Private key is encrypted but SNOWFLAKE_KP_PASSPHRASE is missing. "
103
+ "Set it in Admin Settings."
104
  )
105
  else:
106
  raise ValueError(
 
117
  )
118
 
119
  # Return connection parameters
120
+ schema_name = get_admin_setting('SNOWFLAKE_SCHEMA', required=False) or 'PUBLIC'
121
  return {
122
+ 'user': get_admin_setting('SNOWFLAKE_KP_USER'),
123
  'private_key': private_key_bytes,
124
+ 'account': get_admin_setting('SNOWFLAKE_ACCOUNT'),
125
+ 'role': get_admin_setting('SNOWFLAKE_ROLE', required=False),
126
+ 'warehouse': get_admin_setting('SNOWFLAKE_WAREHOUSE'),
127
+ 'database': get_admin_setting('SNOWFLAKE_DATABASE'),
128
+ 'schema': schema_name,
129
  }
130
 
131
  def get_snowflake_connection():
sprint_2026_01.md DELETED
@@ -1,520 +0,0 @@
1
- # Sprint: January 2026 - Making the Demo Better
2
-
3
- ## Core Mission
4
-
5
- **"The goal of the demo builder is to tell a better story based on different personas, such as a line of business leader or data analyst, and to use a framework around the demo to effectively tell that story so that we can sell scloud software."**
6
-
7
- This framework influences the objects, visualizations, and data generated - including **outliers** that drive compelling demo narratives.
8
-
9
- **Terminology Clarification:**
10
- - **Outliers** = Creating interesting data points during data generation (story points for demos)
11
- - **Outliers Adjustment** = Modifying those outliers to make them stand out more
12
-
13
- See `dev_notes/USE_CASE_FLOW.md` for full documentation of the use case framework.
14
-
15
- ---
16
-
17
- ## Sprint Objectives
18
-
19
- ### 1. Generic Use Case Handling ✅ DONE
20
-
21
- **Problem:** Use cases not in the predefined list (like "subscription" for tinder.com) don't work as well as desired.
22
-
23
- **Solution:** Conversational flow for generic use cases - asks user for additional context when use case doesn't match packaged options.
24
-
25
- **Files:** `chat_interface.py` - use case validation and initialization flow
26
-
27
- **Also fixed:**
28
- - Company extraction now handles "with the use" format
29
- - Research continues even if website is inaccessible
30
-
31
- ---
32
-
33
- ### 2. Unified Outlier System ✅ DONE
34
-
35
- **File:** `outlier_system.py`
36
-
37
- **Key Classes:**
38
- - `OutlierGenerator` - Generates SQL UPDATEs from text descriptions
39
- - `OutlierPattern` - Data class for outlier with SQL, documentation, Spotter question
40
- - `parse_chat_adjustment()` - Parse chat messages like "decrease electronics by 10%"
41
- - `apply_outliers()` - Execute multiple outlier SQL UPDATEs
42
- - `generate_demo_pack()` - Generate markdown demo documentation
43
-
44
- **Features:**
45
- - Parses natural language patterns ("Premium subscriptions spike 3x in Q4")
46
- - Extracts multipliers, percentages, absolute values
47
- - Understands time conditions (Q4, summer, specific months)
48
- - Generates Spotter questions automatically
49
- - Creates talking points for demos
50
-
51
- #### Two Entry Points, Same System
52
-
53
- | Entry Point | When | Who Decides | Example |
54
- |-------------|------|-------------|---------|
55
- | **Auto-Injection** | After population | AI (from research) | "Create compelling story patterns" |
56
- | **Chat Adjustment** | After deployment | User via chat | "decrease electronics by 10%" |
57
-
58
- ---
59
-
60
- ### 3. Demo Pack Generation ✅ DONE
61
-
62
- **Function:** `generate_demo_pack()` in `outlier_system.py`
63
-
64
- **Output:** Markdown document with:
65
- - Outlier stories that were injected
66
- - Suggested Spotter questions for each story
67
- - Talking points
68
- - Technical details (table, column, conditions)
69
- - Demo flow guide
70
-
71
- ---
72
-
73
- ### 4. Golden Demo with Hybrid MCP/TML Approach ✅ DONE
74
-
75
- **Goal:** Create the "perfect demo" liveboard using best of both methods.
76
-
77
- **Implementation Complete (Jan 18, 2026):**
78
-
79
- **Three-Method System:**
80
- - `TML` - Pure template-based approach, full control
81
- - `MCP` - Pure AI-driven approach, fast but basic
82
- - `HYBRID` - MCP creates + TML post-processing (recommended)
83
-
84
- **Configuration:**
85
- - Settings UI: Admin tab → "Liveboard Creation Method" dropdown
86
- - Environment variable: `LIVEBOARD_METHOD=TML|MCP|HYBRID`
87
- - Default: HYBRID
88
-
89
- **Files Modified:**
90
- - `thoughtspot_deployer.py` - Extended deploy_all() with liveboard_method parameter
91
- - `demo_prep.py` - Added liveboard_method_dropdown to Settings UI
92
- - `liveboard_creator.py` - Added `enhance_mcp_liveboard()` function
93
- - `chat_interface.py` - Passes liveboard_method from settings
94
- - `CLAUDE.md` - Updated documentation
95
-
96
- **enhance_mcp_liveboard() Function:**
97
- 1. Exports MCP-created liveboard TML
98
- 2. Classifies visualizations by type (KPI, trend, categorical)
99
- 3. Adds Groups (tabs) for organization
100
- 4. Fixes KPI sparklines and comparisons
101
- 5. Applies brand colors (GBC_* for groups, TBC_* for tiles)
102
- 6. Re-imports enhanced TML
103
-
104
- ---
105
-
106
- ### 5. Justin Recommendations
107
-
108
- **a) snake_case Column Names in ThoughtSpot** ✅ DONE
109
- - Model column names now use snake_case (e.g., `shipping_mode`, `days_to_ship`)
110
- - Function: `_to_snake_case()` in `thoughtspot_deployer.py`
111
- - Updated `_resolve_column_name_conflict()` to use snake_case
112
-
113
- **b) ThoughtSpot Date Handling**
114
- - Proper date type configuration for time-series analysis
115
- - Calendar types, fiscal calendars, etc.
116
- - Affects: Model TML column properties
117
-
118
- ---
119
-
120
- ### 6. Remove Chart Type Hints from Viz Titles ✅ DONE
121
-
122
- **Problem:** Viz titles include "(line chart)", "(bar chart)" etc. which look unprofessional.
123
-
124
- **Solution:** Added `_strip_chart_type_hints()` function in `liveboard_creator.py` that strips these hints before sending to MCP.
125
-
126
- ---
127
-
128
- ### 7. Better Messaging / Progress Indication ✅ DONE
129
-
130
- **Problem:** After "Population Complete", UI appears frozen while processing continues in background. User thinks app is stuck.
131
-
132
- **Solution:** Added `flush=True` to critical print statements in:
133
- - `legitdata_bridge.py` - Population steps
134
- - `liveboard_creator.py` - MCP workflow
135
- - `chat_interface.py` - Deployment messages
136
-
137
- ---
138
-
139
- ## Tasks
140
-
141
- ### Done
142
- - [x] Documentation reorganization ✅
143
- - [x] LegitData integration ✅
144
- - [x] Chat interface working end-to-end ✅
145
- - [x] MCP liveboard creation working ✅
146
- - [x] Generic use case handling ✅
147
- - [x] Fix LegitData `generate_key` error ✅
148
- - [x] Full Tinder demo run successful! ✅
149
-
150
- ### In Progress
151
- - [x] Debug MCP viz failures (2 vizs failed on Tinder demo) ✅ DONE
152
- - [x] **MCP on SEBE cluster** ✅ DONE - Bearer auth implementation fixed this (Feb 2)
153
-
154
- ---
155
-
156
- ## Planned Features
157
-
158
- ### Existing Model Selection (Future)
159
-
160
- **Goal:** Allow users to select an existing ThoughtSpot model and create liveboard / adjust data without creating new tables.
161
-
162
- **Questions to resolve:**
163
- 1. **Model Selection** - Dropdown of models? Search by name? Filter by connection?
164
- 2. **Use Case Matching** - How to validate model matches the use case? AI analyze columns vs keywords?
165
- 3. **Database Access Check** - Verify connection is valid? Run test query? Check permissions?
166
- 4. **Data Adjustment** - Outlier injection into existing data? Create views?
167
- 5. **Entry Point** - "Create new demo" vs "Use existing model" option at start?
168
- 6. **Cluster Scope** - Only models from cluster in .env
169
-
170
- **Implementation ideas:**
171
- - Use `/api/rest/2.0/metadata/search` to list models user has access to
172
- - Export model TML to get column names for use case matching
173
- - Check model's connection status via API
174
-
175
- ---
176
-
177
- ### MCP Multi-Cluster Auth (Needed)
178
-
179
- **Problem:** MCP OAuth tokens are cached per-endpoint. When switching between SE and SEBE clusters, must manually clear `~/.mcp-auth/` and re-authenticate.
180
-
181
- **Goal:** Be smart about which cluster MCP is authenticated to and switch automatically based on .env settings.
182
-
183
- **Options explored:**
184
- 1. **Bearer auth with `@thoughtspot/mcp-server`** - Use stdio transport with `TS_INSTANCE` and `TS_AUTH_TOKEN` env vars. Generate token via trusted auth from `THOUGHTSPOT_SECRET_KEY`.
185
- 2. **Multiple cached tokens** - Store tokens by cluster URL, switch based on .env `THOUGHTSPOT_URL`
186
- 3. **Manual for now** - Clear `~/.mcp-auth/` when switching clusters
187
-
188
- **Current workaround:** `rm -rf ~/.mcp-auth/` then re-auth when switching clusters.
189
-
190
- **Jan 31 findings:**
191
- - MCP works fine on SE cluster with same code
192
- - MCP fails on SEBE with "No answer found" errors
193
- - Not a code issue - same DONT_INDEX settings work on SE
194
- - Could be SEBE cluster config, Sage not enabled, or service issue
195
-
196
- ### High Priority - Use Case Framework ✅ DONE
197
- - [x] **Complete 6 core use case configs** ✅ - Added all missing persona configs:
198
- - Customer Analytics ✅
199
- - Financial Analytics ✅
200
- - Marketing Analytics ✅
201
- - Retail Analytics ✅
202
- - [x] **Define outlier patterns for each use case** ✅ - 4-5 outlier patterns per use case
203
- - [x] **Use case aliases** ✅ - Flexible matching (e.g., "Sales" → "Sales Analytics")
204
- - [ ] **Generic use case improvements** - Better context gathering flow (future)
205
- - See `dev_notes/USE_CASE_FLOW.md` for documentation
206
-
207
- ### To Do
208
- - [ ] Auto-injection step (hook outlier_system into population flow) - **DEFER: unclear what this is, revisit later**
209
- - [ ] Chat adjustment using outlier system - **DEFER to next sprint**
210
- - [x] ThoughtSpot date handling (Justin rec) ✅ DONE
211
-
212
- ### UI Enhancements
213
- - [ ] 🧙 **Wizard Tab** - Old school step-by-step wizard (right of Chat, left of Settings)
214
- - [x] 📺 **Live Progress Tab** ✅ - Real-time deployment output in UI
215
- - [x] 📋 **Demo Pack Tab** ✅ - Auto-generated demo notes/talking points after deployment
216
- - [x] Removed Population Code tab ✅ (not needed with LegitData)
217
-
218
- ### Done
219
- - [x] Outlier SQL generation core ✅ - `outlier_system.py` created
220
- - [x] Demo Pack generation ✅ - `generate_demo_pack()` function
221
- - [x] snake_case naming for ThoughtSpot objects ✅ - `_to_snake_case()` in deployer
222
- - [x] Fix UI feedback buffering issue ✅ - Added flush=True to critical print statements
223
- - [x] Remove "(line chart)" / "(bar chart)" from viz titles ✅ - `_strip_chart_type_hints()`
224
- - [x] Fix date range mismatch - dates now generated up to current date ✅
225
- - [x] **MCP Bearer Auth** ✅ - Major weekend effort (Jan 31 - Feb 2). Bearer auth with trusted token ensures same org context. Works on both SE and SEBE clusters.
226
- - [x] Dynamic date generation ✅ - LegitData now generates dates relative to today
227
- - [x] Universal context prompt ✅ - All use cases (not just generic) now get context prompt
228
- - [x] ThoughtSpot date handling ✅
229
-
230
- ---
231
-
232
- ## Notes
233
-
234
- ### Jan 13, 2026 (Evening)
235
- **🎉 SUCCESSFUL FULL RUN: Tinder Subscription Demo**
236
- - Liveboard created: https://se-thoughtspot-cloud.thoughtspot.cloud/#/pinboard/e5c0eb54-ba91-4380-b7f4-353c134038a7
237
- - 5 of 6 vizs succeeded, 1 failed (KPI viz error)
238
-
239
- **Issues found:**
240
- 1. **MCP Viz "No Data"** - 2 vizs showing no data due to date range mismatch:
241
- - AI questions use "last 12 months" → filters to 2025-2026
242
- - Generated data has dates from 2022-2024
243
- - Fix: Either generate data with recent dates OR avoid relative time filters in questions
244
- 2. **Viz Title Cleanup** - Remove "(line chart)" / "(bar chart)" from visualization titles
245
- 3. **UI Feedback Buffering** - After "Population Complete", UI freezes while work continues in background. Terminal output is buffered, not real-time. User thinks app is stuck.
246
- 4. **Progress Indication** - Need better way to show user that processing is still happening
247
-
248
- **Fixes today:**
249
- - Generic use case handling complete
250
- - Company/use case extraction improved
251
- - Research continues without website access
252
- - LegitData generate_key method added
253
- - Key realization: Outliers + Outliers Adjustment = same system
254
- - Unified model: SQL UPDATEs for pattern injection
255
- - Two entry points: AI-driven (auto) + User-driven (chat)
256
-
257
- ### Jan 12, 2026
258
- - Returned to project after a month away
259
- - Reorganized documentation structure
260
-
261
- ---
262
-
263
- ## Backlog (Future Sprints)
264
-
265
- - ~~**Concurrent demo builds**~~ ✅ FIXED - ChatDemoInterface now stored in `gr.State()` for per-session isolation
266
- - Site creator
267
- - Bot creator
268
- - Approval gates in chat workflow
269
- - LegitData integration strategy (pip package, submodule, etc.)
270
- - Real-time terminal output (fix buffering)
271
-
272
- ---
273
-
274
- *Sprint Started: January 12, 2026*
275
-
276
- ---
277
-
278
- ## Interface Mode Refactor (Jan 26, 2026)
279
-
280
- ### Status: Planning Complete
281
-
282
- **Goal:** Revive `demo_prep.py` (form-based interface) and add ability to switch between interfaces.
283
-
284
- ### Findings
285
-
286
- #### 1. demo_prep.py Status: ✅ WORKS
287
- - Tested successfully on port 7870
288
- - Only issue was port conflict (7860 in use by chat interface)
289
- - All imports work, interface creates, Gradio launches
290
-
291
- #### 2. Code Duplication Analysis
292
-
293
- **Duplicated between demo_prep.py and chat_interface.py:**
294
-
295
- | Area | Notes |
296
- |------|-------|
297
- | Cache checking logic | ~60 lines each - nearly identical domain extraction, filename building, age check |
298
- | DDL generation prompt | ~80 lines each - same schema requirements, Snowflake syntax |
299
- | Research workflow | ~150 lines each - website extract → company analysis → industry research |
300
- | ThoughtSpot deployment | ~100 lines each - schema verification, deployer calls, progress callbacks |
301
-
302
- **Already Shared (chat_interface imports from demo_prep):**
303
- - `map_llm_display_to_provider()` - LLM provider mapping
304
- - `execute_population_script()` - Population execution
305
- - `generate_demo_base_name()` - Naming convention
306
-
307
- **Unique to demo_prep.py:**
308
- - `validate_ddl_syntax()` - DDL validation
309
- - `validate_python_syntax()` - Python code validation
310
- - `extract_outliers_from_population_script()` - Outlier extraction (see below)
311
- - `extract_python_code()` - Code block extraction
312
- - Expert/Fast mode toggle
313
-
314
- **Unique to chat_interface.py:**
315
- - `ChatDemoInterface` class - State management
316
- - Message parsing methods
317
- - Stage-based conversation flow
318
- - `SmartDataAdjuster` integration
319
-
320
- #### 3. Proposed Shared Class: `DemoWorkflowEngine`
321
-
322
- ```python
323
- class DemoWorkflowEngine:
324
- """Shared workflow logic for both form and chat interfaces"""
325
-
326
- def __init__(self, use_case: str, company_url: str, settings: dict = None)
327
-
328
- # Research
329
- def check_cache(self, company, use_case) -> tuple[bool, str, dict]
330
- def run_research(self, company, use_case, use_cache=True) -> Generator
331
-
332
- # DDL Generation
333
- def generate_ddl(self) -> Generator
334
- def validate_ddl(self, ddl: str) -> tuple[bool, str]
335
-
336
- # Population
337
- def generate_population(self) -> Generator
338
- def execute_population(self, schema_name) -> tuple[bool, str]
339
-
340
- # Deployment
341
- def deploy_to_snowflake(self, schema_name) -> Generator
342
- def deploy_to_thoughtspot(self, schema_name, ...) -> Generator
343
-
344
- # State
345
- @property
346
- def current_stage(self) -> str
347
- def advance_stage(self)
348
- ```
349
-
350
- #### 4. workflow_style Setting
351
-
352
- - **Setting name:** `workflow_style`
353
- - **Values:** `classic` (form) or `chat`
354
- - **Default:** `classic`
355
- - **Storage:** Supabase user settings
356
- - **Switch mechanism:** Requires app restart
357
-
358
- ---
359
-
360
- ## Outlier Feature Comparison (Phase 2)
361
-
362
- ### Two Different Approaches
363
-
364
- #### demo_prep.py - Pre-Planned Outlier Documentation
365
-
366
- **How it works:**
367
- 1. LLM generates structured comments in population script:
368
- ```python
369
- # DEMO_OUTLIER: High-Value Customers at Risk
370
- # INSIGHT: Top 5 customers (>$50K LTV) showing declining satisfaction
371
- # SHOW_ME: "Show customers where lifetime_value > 50000 and satisfaction < 3"
372
- # IMPACT: $250K annual revenue at risk
373
- # TALKING_POINT: "Notice how ThoughtSpot surfaces your most valuable at-risk accounts"
374
- ```
375
- 2. `extract_outliers_from_population_script()` parses these comments
376
- 3. Generates "Demo Notes & Presentation Guide" with talking points
377
-
378
- **Strength:** Creates pre-planned cheat sheet for demos
379
-
380
- #### chat_interface.py - Interactive Post-Deployment Adjustment
381
-
382
- **How it works:**
383
- 1. Uses `SmartDataAdjuster` class after deployment
384
- 2. User types natural language commands:
385
- - "make 1080p webcam 40B"
386
- - "increase smart watch by 20%"
387
- 3. Executes SQL UPDATEs directly to Snowflake
388
- 4. User refreshes liveboard to see changes
389
-
390
- **Strength:** Real-time customization before demos
391
-
392
- ### Key Differences
393
-
394
- | Aspect | demo_prep.py | chat_interface.py |
395
- |--------|-------------|-------------------|
396
- | When | During data generation | After deployment |
397
- | What | Documents planned outliers | Modifies existing data |
398
- | Output | Presentation guide/cheat sheet | SQL updates |
399
- | Persistence | Baked into generated data | Changes live data |
400
-
401
- ### Phase 2 Recommendation
402
-
403
- 1. **Keep both approaches** - they serve different purposes
404
- 2. **demo_prep.py outliers** should be migrated to shared engine (generates cheat sheet)
405
- 3. **SmartDataAdjuster** already separate - can be used by both interfaces
406
- 4. **Combine them:** Show pre-planned outliers AND allow interactive adjustment
407
- 5. **chat_interface.py currently doesn't parse structured comments** - opportunity to add
408
-
409
- **The demo_prep.py approach is arguably MORE valuable** because:
410
- - Creates documentation that survives after the session
411
- - Gives salespeople a script to follow
412
- - Has richer metadata (viz types, KPIs, talking points)
413
-
414
- ---
415
-
416
- ## Updated Tasks
417
-
418
- ### Interface Refactor - Phase 1
419
- - [x] Test demo_prep.py - confirm it works ✅
420
- - [x] Map duplicated code sections ✅
421
- - [x] Design shared class API ✅
422
- - [ ] Create `demo_workflow_engine.py` with shared logic
423
- - [ ] Refactor demo_prep.py to use shared engine
424
- - [ ] Refactor chat_interface.py to use shared engine
425
- - [ ] Add `workflow_style` setting to Supabase
426
- - [ ] Create unified launcher
427
-
428
- ### Outlier System - Phase 2
429
- - [ ] Port `extract_outliers_from_population_script()` to shared engine
430
- - [ ] Add outlier documentation parsing to chat_interface.py
431
- - [ ] Combine pre-planned outliers with interactive adjustment
432
- - [ ] Unified demo pack generation
433
-
434
- ---
435
-
436
- *Updated: January 26, 2026*
437
-
438
- ---
439
-
440
- ## Bugs Found (Jan 26, 2026)
441
-
442
- ### BUG: Liveboard link not clickable in chat output
443
- **Priority:** High (dev101)
444
- **Location:** `chat_interface.py` - deployment success message
445
- **Issue:** Liveboard URL renders as plain text, not a clickable markdown link
446
- **Expected:** `[Liveboard Name](https://url)` format that's clickable
447
- **Fix:** Check markdown link formatting in response string
448
-
449
- ### BUG: TML Liveboard Quality - Comscore Demo
450
- **Priority:** High
451
- **Method:** TML (not MCP/HYBRID)
452
- **Company:** Comscore
453
- **Issues observed:**
454
- 1. **No KPIs** - Missing headline KPI tiles with big numbers
455
- 2. **Identical-looking vizs** - "Total Sales Performance" and "Sales Trend by Channel" both render as identical blue line charts
456
- 3. **"Sales Trend by Channel"** - Shows single line, should show multiple lines (one per channel) or stacked area
457
- 4. **Third viz failed** - Error icon, viz didn't render
458
- 5. **No groups/tabs** - Flat layout instead of organized groups
459
- 6. **No brand colors** - Everything default blue
460
- 7. **Only 5 vizs** - Should have more variety
461
-
462
- **Root cause investigation needed:**
463
- - `liveboard_creator.py` - `create_visualization_tml()` not generating KPIs?
464
- - `_generate_smart_questions_with_ai()` generating boring/duplicate questions?
465
- - Group creation logic missing or not running?
466
- - Visualization type selection too conservative?
467
-
468
- **Reference:** Golden demo at `dev_notes/liveboard_demogold2/`
469
-
470
- ---
471
-
472
- ## Fixes Applied (Jan 26, 2026 - Evening)
473
-
474
- ### FIX: Liveboard link not clickable
475
- **File:** `chat_interface.py`
476
- **Issue:** Bold formatting inside link brackets `[**name**](url)` breaking Gradio markdown
477
- **Fix:** Changed to `[name](url)` format (4 places)
478
-
479
- ### FIX: TML Liveboard Quality - No Groups/Tabs/Styling
480
- **File:** `liveboard_creator.py`
481
- **Issue:** `USE_ADVANCED_TML` defaulted to `'false'`, disabling groups/tabs/styling
482
- **Fix:** Changed default to `'true'` - now groups and brand colors enabled by default
483
-
484
- ### FIX: TML Liveboard Quality - No KPIs, Boring Charts
485
- **File:** `liveboard_creator.py`
486
- **Issue:** AI-generated visualizations not enforced to have variety
487
- **Fix:** Added `_enforce_visualization_variety()` method that:
488
- - Ensures at least 2 KPIs are always created
489
- - Converts LINE charts without dimensions to KPIs if needed
490
- - Prevents more than 2 of the same chart type (diversifies to BAR, COLUMN, etc.)
491
- - Called after AI generates visualizations, before returning
492
-
493
- **Test needed:** Run a new demo to verify improvements
494
-
495
- ---
496
-
497
- ## Carry Forward to February Sprint
498
-
499
- ### Unsatisfied / Needs More Work
500
- - **Unified Outlier System** - Core done but needs refinement, not satisfied with output quality
501
- - **Demo Pack Generation** - Very unsatisfied, needs significant improvement
502
- - **Chart Titles** - Still not happy with viz titles, needs better naming
503
-
504
- ### Needs Testing/Verification
505
- - **Existing Model Selection** - May be done but needs confirmation testing
506
- - **Universal Context Prompt** - Double-test this feature
507
- - **Self-join skip in models** - Verify this is working correctly
508
-
509
- ### Deferred Features
510
- - **Chat adjustment using outlier system** - Never got to this
511
- - **Wizard Tab UI** - Not started
512
- - **Interface Mode Refactor** - `DemoWorkflowEngine` shared class concept
513
-
514
- ### Unclear / Revisit Later
515
- - **Auto-injection step** - Unclear what this was supposed to be
516
-
517
- ---
518
-
519
- *Sprint Closed: February 2, 2026*
520
- t
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sprint_2026_03.md ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Sprint: March 2026
2
+
3
+ *Started: March 16, 2026*
4
+ *Planning doc: `dev_notes/plan_march_2026.md`*
5
+
6
+ ---
7
+
8
+ ## Context
9
+
10
+ App is live. Small group launch early this week, broader rollout (4–5 people) within a couple weeks.
11
+ This sprint covers hardening, settings, and new capabilities before that happens.
12
+
13
+ ---
14
+
15
+ ## Sprint Objectives
16
+
17
+ ### Before Small Group Launch (This Week)
18
+
19
+ - [x] **TS Environment Dropdown** ✅ — ENV-based dropdown on front page; URL→key map hard-coded in app
20
+ - `TS_ENV_N_LABEL/URL` pattern in `.env`; `get_ts_environments()` reads all to build dropdown
21
+ - Dropdown in right panel alongside AI Model + Liveboard Name
22
+ - `update_ts_env()` resolves URL + auth key (via `os.getenv(key_name)`) into controller settings
23
+ - Bug fixed: was storing ENV var name instead of actual secret value
24
+ - Controller created on first message also receives current dropdown env selection
25
+ - [x] **Front Page Redesign** ✅ — Right panel (where Stage + AI Model currently live):
26
+ - Add **TS Environment dropdown** here alongside AI Model
27
+ - Add **Liveboard Name** field here
28
+ - **Remove the Stage textbox** — replace with proper progress indicator (see below)
29
+ - Company/use case stays chat-driven only
30
+ - Remove `default_company_url` setting (replaced by chat-driven flow)
31
+ - [x] **Progress Meter Fix** ✅ — current stage textbox is not great UX
32
+ - Add **Init** as the first stage in the progress sequence
33
+ - Replace the stage textbox with a visual progress indicator (step 1–N style)
34
+ - Stages: Init → Research → DDL → Data → Model → Liveboard (→ Data Adjuster when in that phase)
35
+ - [x] **Chat Flow UX Improvements** ✅
36
+ - In-chat help text at session start: brief instructions message
37
+ - Clearer prompting back to user when use case is ambiguous
38
+ - Consider `?` tooltip near chat input
39
+ - [x] **Error Handling Review** ✅
40
+ - **Liveboard partial success** ✅: Snowflake + model OK but liveboard fails → ⚠️ message with Spotter Viz Story tab pointer + 'retry liveboard' prompt
41
+ - **TML import errors**: parse TS `error_list`, show which viz failed specifically
42
+ - **MCP failures**: show which step failed, whether partial work should be kept
43
+ - **Top-level wrapper** ✅: `process_chat_message` wrapped in try/except → yield friendly error + log full traceback
44
+ - **Data Adjuster errors**: SQL execution failures shown clearly with context, not swallowed
45
+ - **Snowflake connection errors**: distinguish auth failure vs. query failure vs. timeout
46
+ - [x] **Supabase Session Logging** ✅ — `session_logger.py`
47
+ - `SessionLogger` class: writes to `logs/sessions/{id}.log` + Supabase `session_logs` table
48
+ - Falls back silently to file-only if Supabase unavailable
49
+ - Initialized at first message in `process_chat_message`; logs research + deploy stage start/end
50
+ - `init_session_logger()` / `get_session_logger()` module-level helpers
51
+ - Table DDL in file docstring (run once in Supabase SQL editor)
52
+ - [x] **Admin Log Viewer** ✅ — added to Admin Settings tab; email filter + row limit + Refresh button; queries `session_logs` via Supabase
53
+ - [x] **Data Adjuster Cleanup & Controller Integration** ✅
54
+ - Existing files: `data_adjuster.py`, `smart_data_adjuster.py`, `conversational_data_adjuster.py`, `chat_data_adjuster.py`
55
+ - Currently wired post-liveboard but state lives on `self._adjuster` / `self._pending_adjustment` (instance vars — wrong)
56
+ - **Controller owns the adjuster phase**: adjuster state moves into the chat controller phase flow, not instance vars
57
+ - **Multi-turn**: controller stays in the adjuster phase across multiple messages; user can ask multiple questions and make multiple adjustments in one session
58
+ - **Smart adjustments**: LLM understands the request, maps to the right table/column, proposes the SQL, confirms with user, executes
59
+ - Consolidate: decide which of the 4 files survives (likely `smart_data_adjuster.py` as the engine, rest retired or merged)
60
+ - [x] **State Isolation Audit** ✅ — audit complete; two HIGH issues identified
61
+ - **HIGH**: `session_logger.py` module-level `_current_logger` singleton — concurrent sessions overwrite each other's logger
62
+ - **HIGH**: `prompt_logger.py` module-level `_prompt_logger` singleton — all users' LLM prompts mix in same in-memory list
63
+ - **MEDIUM**: `inject_admin_settings_to_env()` writes to process-global `os.environ` — concurrent deployments could use wrong Snowflake account
64
+ - **MEDIUM**: Admin settings cache has no TTL — external Supabase edits invisible until restart
65
+ - Main `ChatDemoInterface` state IS isolated via `gr.State` — the pipeline itself is safe
66
+ - Fix (loggers + os.environ) tracked under "Before Broader Rollout → State Isolation Fix" below
67
+
68
+ ### Before Broader Rollout
69
+ - [x] **Settings Audit & Cleanup** ✅
70
+ - `ts_instance_url` removed from SETTINGS_SCHEMA + hidden in Settings UI (replaced by env dropdown)
71
+ - `default_company_url` removed from SETTINGS_SCHEMA + hidden in Settings UI (chat-driven now)
72
+ - AI Model selection already on front page ✅
73
+ - [ ] **demo_prep.py Refresh** — scope too large for this sprint, moved to Phase 2
74
+ - Audit done: ~2-3 day job (Spotter Viz tab, outlier integration, logged_completion, class refactor)
75
+ - [x] **Session Persistence Verification** ✅ — verified working
76
+ - `ts_username` in SETTINGS_SCHEMA → pre-fills on load via `load_settings_on_startup`
77
+ - `liveboard_name`, `default_use_case`, `default_llm` all pre-fill on startup
78
+ - company no longer pre-filled (chat-driven) — working as intended
79
+ - [x] **State Isolation Fix** ✅ — HIGH risk items resolved
80
+ - Session logger: stored on `self._session_logger` (per controller instance, not module singleton)
81
+ - Prompt logger: `reset_prompt_logger(session_id)` called at session start → fresh instance per session
82
+ - Both loggers initialized on first message of each session
83
+ - `inject_admin_settings_to_env()` still in use (deferred — requires cdw_connector refactor)
84
+
85
+ ---
86
+
87
+ ## Phase 2 (Next Sprint or Later)
88
+
89
+ ### demo_prep.py Refresh (from March Sprint)
90
+
91
+ - [ ] **demo_prep.py Refresh** — sync with `chat_interface.py` improvements (~2-3 days)
92
+ - Add Spotter Viz Story tab + `_generate_spotter_viz_story()` (5h)
93
+ - Add Demo Pack tab with outlier-driven talking points + Spotter questions (4h)
94
+ - Replace all `researcher.make_request()` calls with `logged_completion()` wrapper (5h)
95
+ - Refactor to class-based pattern (like `ChatDemoInterface`) for state isolation (7h)
96
+ - Full outlier system integration (2.5h)
97
+ - Per-user session logging throughout (2h)
98
+
99
+ ---
100
+
101
+ ### Carry-forward from Sprint 2
102
+
103
+ - [ ] **Unified Outlier System** — core done, not satisfied with output quality; needs refinement
104
+ - [ ] **Demo Pack Generation** — very unsatisfied, needs significant improvement
105
+ - [ ] **Chart Titles** — not happy with viz titles/naming; needs better approach
106
+ - [ ] **Existing Model Selection + Self-Join Skip** — may be done; needs confirmation test + verify self-join skip is working correctly
107
+ - [ ] **Universal Context Prompt** — double-test this feature end-to-end
108
+ - [ ] **Chat Adjustment Using Outlier System** — never got to this
109
+ - [ ] **Interface Mode Refactor** (`DemoWorkflowEngine` shared class concept)
110
+ - [ ] **Wizard Tab UI** — not started
111
+ - [ ] **Tag Assignment to Models** — returns 404 (works for tables, not models); needs investigation
112
+ - [ ] **Spotter Viz Story Verification** — run end-to-end and verify story generation + blank viz (ASP, Total Sales Weekly) and brand colors rendering
113
+ - [ ] **Fix Research Cache Not Loading** — relative path issue; fix was ready, needs test
114
+ - [ ] **Fix DAYSONHAND Generation** — currently random; needs business logic (realistic 15–120 day distribution)
115
+ - [ ] **Verify KPIs in Liveboard** — requires live deployment test
116
+ - [ ] **Auto-injection step** — revisit what this was supposed to be
117
+ - [ ] **Dead code cleanup: model TML generators** — `thoughtspot_deployer.py` has 3 model TML functions; only `_create_model_with_constraints` is called by `deploy_all`; remove `create_actual_model_tml` and `create_model_tml`
118
+
119
+ ### From March Plan
120
+
121
+ - [x] **Data Adjuster — Liveboard-First Entry Point** ✅
122
+ - Paste any TS liveboard URL in the init stage → jumps straight to adjuster (skips build pipeline)
123
+ - `load_context_from_liveboard()` in `smart_data_adjuster.py`: liveboard TML (export_fqn) → model GUID → model TML → db/schema
124
+ - Detection in `chat_interface.py` init stage: regex on `pinboard/<guid>` pattern → auth TS client → load context → init SmartDataAdjuster → `outlier_adjustment` stage
125
+ - [x] **Sharing** ✅ — model + liveboard shared (can_edit / MODIFY) after every build
126
+ - `share_objects()` method in `thoughtspot_deployer.py`: POST `/api/rest/2.0/security/metadata/share`
127
+ - Detects `@` in value → USER type, otherwise → USER_GROUP
128
+ - `share_with` in regular Settings (per-user); `SHARE_WITH` in Admin Settings (system-wide default)
129
+ - Per-user setting takes priority; falls back to admin setting if empty
130
+ - Model shared after creation, liveboard shared after creation
131
+ - [x] **Sage Indexing Retry** ✅ — `_get_answer_direct` now retries once with 20s wait on 10004 "No answer found"; flag is module-level so the wait happens only once per build run, not per question
132
+ - [x] **Fallback TML: Skip Invalid Column Refs** ✅ — after `convert_natural_to_search`, validates `[Column]` tokens against model columns; skips viz (instead of failing the whole liveboard) if any token is missing
133
+ - [ ] **MCP 500 Retry Logic** — broader retry for other 5xx errors
134
+ - [ ] **Model Generator: Chasm Trap Fix** — when two fact tables share a dimension, model generator must:
135
+ - Include ALL FK joins from each fact table to shared dimensions (e.g. `PRIOR_AUTHORIZATIONS.DRUG_NDC → DRUGS`)
136
+ - Set `is_attribution_dimension: false` on shared dimension tables so TS doesn't fold fact tables together
137
+ - Without this: queries fan out through a shared date dimension → every group gets the same average
138
+ - Fixed manually for Abarca: added `PRIOR_AUTHORIZATIONS → DRUGS` join + `DRUGS.is_attribution_dimension=false`
139
+ - [ ] **Data Narrative Layer for Population** — LLM generates random/flat data because it doesn't know what story the KPIs should tell
140
+ - Root cause: population script gets DDL + company context but NOT the KPI formulas or desired metric distributions
141
+ - Fix: before data generation, build a "data narrative" spec from the vertical×function matrix: explicit per-column constraints ("IS_GENERIC: 93% for Medicaid rows, 80% for Commercial"), outlier targets, trend directions
142
+ - Pass this narrative spec as a required section in the population prompt
143
+ - Domain rules baked in: specialty/biologic drugs get low PA approval, GLP-1s face high scrutiny, Medicaid has highest GDR, etc.
144
+ - Goal: generated data tells the story on first run — no manual Snowflake fixups required
145
+ - [x] **Fix Domain-Specific NAME Column Generation** ✅ — `DRUG_NAME` was falling through to `fake.name()` (person name) because `'NAME' in col_name_upper` matched first, before the drug-specific check; fixed by adding DRUG/MEDICATION check at the top of the NAME block in `chat_interface.py`
146
+ - [x] **Abarca Demo Data — KPI Variation** ✅ — GDR and PA Approval Rate KPIs had flat sparklines
147
+ - Root cause: IS_GENERIC set by plan type only (uniform across months); PA rate set by therapeutic class only
148
+ - Fix: `scratch/fix_abarca_kpi_variation.py` — full per-month reset using plan-type + monthly adjustment
149
+ - GDR visible range: 84–93% with clear oscillations; PA visible range: 74–86%
150
+ - **⚠️ Re-run `fix_abarca_kpi_variation.py` if other data changes clobber monthly variation**
151
+ - PA by therapeutic class fix: `scratch/fix_abarca_pa_therapeutic_class.py` + `scratch/fix_abarca_pa_ts_table.py`
152
+ - Root cause of flat PA-by-class chart: no PRIOR_AUTHORIZATIONS→DRUGS join in TS model; added THERAPEUTIC_CLASS column directly to PRIOR_AUTHORIZATIONS table instead
153
+
154
+ ---
155
+
156
+ ## Phase 3 (Future)
157
+
158
+ - [ ] **OAuth/SSO Login** — swap Gradio auth for proper OAuth flow
159
+ - [ ] **Batch Runner Gradio Tab** — after CLI proves out, add Gradio tab for batch testing
160
+ - [ ] **Batch Runner: Full Pipeline Stages** — add population, deploy_snowflake, deploy_thoughtspot, liveboard stages
161
+ - [ ] **Request New Environment Form** — if/when needed
162
+ - [ ] **Liveboard Question Column Mapping** — `liveboard_questions[].viz_question` uses natural language that may not match actual DDL column names; after model is built, map questions to real column names before sending to MCP. Currently worked around with generic NL ("average selling price by week" vs "ASP weekly") but proper runtime column substitution would be more reliable.
163
+
164
+ ---
165
+
166
+ ## Cancelled / Resolved
167
+
168
+ - ~~**MCP Bearer Auth investigation**~~ — resolved; bearer auth working, no further action needed
169
+
170
+ ---
171
+
172
+ ## Done
173
+
174
+ ### Session: March 26, 2026 — Liveboard Name Fix + Settings Reorganization
175
+
176
+ - [x] **Liveboard Name Bug Fixed** ✅ — UI field value now takes priority over DB-loaded default
177
+ - `send_message` and `quick_action` accept `liveboard_name_ui` param
178
+ - `liveboard_name_input` added to `_send_inputs` and `_action_inputs`
179
+ - Applied to `controller.settings['liveboard_name']` on every message — always uses current UI value
180
+ - [x] **Settings UI Reorganized** ✅ — Split into "Default Settings" and "App Settings"
181
+ - **Default Settings**: AI Model, Default Use Case, Default Liveboard Name (3-up row)
182
+ - **App Settings**: Tag Name, Fact Table Size, Dim Table Size, Object Naming Prefix, Column Naming Style
183
+
184
+ ---
185
+
186
+ ### Session: March 26, 2026 — New Vision Merge + Pipeline Investigation
187
+
188
+ - [x] **Spotter enable fix verified** ✅ — `spotter_config` placement confirmed correct (nested inside `model.properties`, not sibling). Tested on model `f40ff5bd` via `scratch/test_spotter_enable.py` — Spotter answered (HTTP 200).
189
+
190
+ - [x] **Liveboard pipeline bugs documented** ✅ — Full trace written in `dev_notes/liveboard_flow_amazon_retail.md`
191
+ - **6 KPI root cause**: AI-generated fill questions (slots 5–8) are all time+metric → MCP creates them all as KPI. Fix: cap AI questions to max 2 single-metric.
192
+ - **"Show me" title bug**: `_convert_outlier_to_mcp_question` prepends "Show me"; `_clean_viz_title` strips "Show " leaving "me...". Fix: add `(r'^Show me ', '')` before `(r'^Show ', '')`.
193
+ - **Spotter Viz Story mismatch**: `_generate_spotter_viz_story` never sees actual viz names — generates independently. Fix: pass actual viz names post-build.
194
+ - OutlierPattern fields `sql_template`, `magnitude`, `affected_columns`, `target_filter`, `demo_setup`, `demo_payoff` are all dead — never read anywhere.
195
+
196
+ - [x] **DemoPrep_new_vision2 merge completed** ✅ — New data generation engine with real outlier injection merged into current codebase:
197
+ - `demo_personas.py` — replaced with new version: `DEFAULT_STORY_CONTROLS`, `story_controls` on every vertical/function, Finance/SaaS overrides, `ROUTED_USE_CASES`, merged `get_use_case_config()`
198
+ - `legitdata_project/legitdata/generator.py` — replaced with 1,600-line version: `_refresh_story_spec()`, `_apply_storyspec_time_series()` (actual outlier injection with deterministic seed + trend/seasonal signals), `_generate_saas_finance_gold()`
199
+ - `legitdata_project/legitdata/storyspec.py` — new file: `StorySpec`, `TrendProfile`, `OutlierBudget`, `ValueGuardrails` dataclasses
200
+ - `legitdata_project/legitdata/domain/` — new package: `SemanticType` enum + domain value libraries
201
+ - `legitdata_project/legitdata/quality/` — new package: quality rules, validator, repair
202
+ - 5 updated source files: `column_classifier.py`, `ai_generator.py`, `generic.py`, `fk_manager.py`, `parser.py`
203
+ - `legitdata_project/legitdata/__init__.py` — updated with StorySpec exports
204
+ - Verified end-to-end: seed=1780963166 (deterministic from `amazon.com+Retail Sales`), 1 outlier injected Sept 9 2024 at 2.4x multiplier in SALES_TRANSACTIONS
205
+
206
+ ---
207
+
208
+ ## Notes
209
+
210
+ ### Vertical × Function Matrix System
211
+
212
+ The matrix determines what gets built — KPIs, visualizations, outliers, target persona.
213
+ See `dev_notes/plan_march_2026.md` appendix for full documentation.
214
+
215
+ **Current coverage:**
216
+
217
+ | Vertical | Sales | Supply Chain | Marketing |
218
+ |----------|-------|-------------|-----------|
219
+ | Retail | ✅ Override | Base merge | Base merge |
220
+ | Banking | Base merge | Base merge | ✅ Override |
221
+ | Software | ✅ Override | Base merge | Base merge |
222
+ | Manufacturing | Base merge | Base merge | Base merge |
223
+ | *other* | Generic | Generic | Generic |
224
+
225
+ *Override = enriched with persona, extra KPIs, specific viz*
226
+ *Base merge = Vertical + Function combined, no special override*
227
+ *Generic = AI adapts from closest function match*
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # DemoPrep test package
tests/newvision_sample_runner.py ADDED
@@ -0,0 +1,1246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run the New Vision 4-case sample set without UI login flow.
3
+
4
+ Modes:
5
+ 1) Full chat pipeline mode (uses configured default_llm and required settings)
6
+ 2) Offline DDL mode (deterministic schema template, still validates settings up front)
7
+
8
+ Usage:
9
+ source ./demoprep/bin/activate
10
+ python tests/newvision_sample_runner.py
11
+ python tests/newvision_sample_runner.py --offline-ddl
12
+ python tests/newvision_sample_runner.py --skip-thoughtspot
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import sys
21
+ from datetime import datetime, timezone
22
+ from pathlib import Path
23
+ from typing import Any
24
+
25
+ import yaml
26
+
27
+ PROJECT_ROOT = Path(__file__).parent.parent
28
+ sys.path.insert(0, str(PROJECT_ROOT))
29
+
30
+ from dotenv import load_dotenv
31
+
32
+ load_dotenv(PROJECT_ROOT / ".env")
33
+ os.environ.setdefault("DEMOPREP_NO_AUTH", "true")
34
+
35
+ # Pull admin settings into environment when available.
36
+ try:
37
+ from supabase_client import inject_admin_settings_to_env
38
+
39
+ inject_admin_settings_to_env()
40
+ except Exception as exc: # noqa: BLE001
41
+ print(f"[newvision_runner] Admin setting injection unavailable: {exc}")
42
+
43
+ OFFLINE_DEMO_DDL = """
44
+ CREATE TABLE DIM_DATE (
45
+ DATE_KEY INT PRIMARY KEY,
46
+ ORDER_DATE DATE,
47
+ MONTH_NAME VARCHAR(30),
48
+ QUARTER_NAME VARCHAR(10),
49
+ YEAR_NUM INT,
50
+ IS_WEEKEND BOOLEAN
51
+ );
52
+
53
+ CREATE TABLE DIM_LOCATION (
54
+ LOCATION_KEY INT PRIMARY KEY,
55
+ COUNTRY VARCHAR(100),
56
+ REGION VARCHAR(100),
57
+ STATE VARCHAR(100),
58
+ CITY VARCHAR(100),
59
+ SALES_CHANNEL VARCHAR(100),
60
+ CUSTOMER_SEGMENT VARCHAR(100)
61
+ );
62
+
63
+ CREATE TABLE DIM_PRODUCT (
64
+ PRODUCT_KEY INT PRIMARY KEY,
65
+ PRODUCT_NAME VARCHAR(200),
66
+ BRAND_NAME VARCHAR(100),
67
+ CATEGORY VARCHAR(100),
68
+ SUB_CATEGORY VARCHAR(100),
69
+ PRODUCT_TIER VARCHAR(50),
70
+ UNIT_PRICE DECIMAL(12,2)
71
+ );
72
+
73
+ CREATE TABLE FACT_RETAIL_DAILY (
74
+ TRANSACTION_KEY INT PRIMARY KEY,
75
+ DATE_KEY INT,
76
+ LOCATION_KEY INT,
77
+ PRODUCT_KEY INT,
78
+ ORDER_DATE DATE,
79
+ ORDER_COUNT INT,
80
+ UNITS_SOLD INT,
81
+ UNIT_PRICE DECIMAL(12,2),
82
+ GROSS_REVENUE DECIMAL(14,2),
83
+ NET_REVENUE DECIMAL(14,2),
84
+ SALES_AMOUNT DECIMAL(14,2),
85
+ DISCOUNT_PCT DECIMAL(5,2),
86
+ INVENTORY_ON_HAND INT,
87
+ LOST_SALES_USD DECIMAL(14,2),
88
+ IS_OOS BOOLEAN,
89
+ FOREIGN KEY (DATE_KEY) REFERENCES DIM_DATE(DATE_KEY),
90
+ FOREIGN KEY (LOCATION_KEY) REFERENCES DIM_LOCATION(LOCATION_KEY),
91
+ FOREIGN KEY (PRODUCT_KEY) REFERENCES DIM_PRODUCT(PRODUCT_KEY)
92
+ );
93
+ """.strip()
94
+
95
+
96
+ def _now_utc_iso() -> str:
97
+ return datetime.now(timezone.utc).isoformat()
98
+
99
+
100
+ def _load_cases(cases_file: Path) -> list[dict[str, Any]]:
101
+ data = yaml.safe_load(cases_file.read_text(encoding="utf-8")) or {}
102
+ return list(data.get("test_cases", []))
103
+
104
+
105
+ def _parse_quality_report_path(message: str) -> str | None:
106
+ for line in (message or "").splitlines():
107
+ if "Report:" in line:
108
+ return line.split("Report:", 1)[1].strip()
109
+ if "See report:" in line:
110
+ return line.split("See report:", 1)[1].strip()
111
+ return None
112
+
113
+
114
+ def _load_quality_report(report_path: str | None) -> dict[str, Any]:
115
+ if not report_path:
116
+ return {}
117
+ path = Path(report_path)
118
+ json_path = path.with_suffix(".json")
119
+ if json_path.exists():
120
+ try:
121
+ return json.loads(json_path.read_text(encoding="utf-8"))
122
+ except Exception: # noqa: BLE001
123
+ return {}
124
+ return {}
125
+
126
+
127
+ def _build_quality_gate_stage(report_path: str | None) -> dict[str, Any]:
128
+ report = _load_quality_report(report_path)
129
+ summary = report.get("summary", {}) if isinstance(report, dict) else {}
130
+ passed = report.get("passed") if isinstance(report, dict) else None
131
+ return {
132
+ "ok": bool(passed) if passed is not None else False,
133
+ "report_path": report_path,
134
+ "passed": passed,
135
+ "summary": {
136
+ "semantic_pass_ratio": summary.get("semantic_pass_ratio"),
137
+ "categorical_junk_count": summary.get("categorical_junk_count"),
138
+ "fk_orphan_count": summary.get("fk_orphan_count"),
139
+ "temporal_violations": summary.get("temporal_violations"),
140
+ "numeric_violations": summary.get("numeric_violations"),
141
+ "volatility_breaches": summary.get("volatility_breaches"),
142
+ "smoothness_score": summary.get("smoothness_score"),
143
+ "outlier_explainability": summary.get("outlier_explainability"),
144
+ "kpi_consistency": summary.get("kpi_consistency"),
145
+ },
146
+ }
147
+
148
+
149
+ def _resolve_runtime_settings() -> tuple[str, str]:
150
+ user_email = (
151
+ os.getenv("USER_EMAIL")
152
+ or os.getenv("INITIAL_USER")
153
+ or os.getenv("THOUGHTSPOT_ADMIN_USER")
154
+ or "default@user.com"
155
+ ).strip()
156
+ default_llm = (os.getenv("DEFAULT_LLM") or os.getenv("OPENAI_MODEL") or "").strip()
157
+ if not default_llm:
158
+ raise ValueError("Missing required env var: DEFAULT_LLM or OPENAI_MODEL")
159
+ return user_email, default_llm
160
+
161
+
162
+ def _run_realism_sanity_checks(schema_name: str, case: dict[str, Any]) -> dict[str, Any]:
163
+ """Fast, opinionated sanity checks for demo realism.
164
+
165
+ These checks intentionally target user-visible demo breakages that can slip
166
+ through structural quality gates (e.g., null dimensions in top-N charts).
167
+ """
168
+ checks: list[dict[str, Any]] = []
169
+ failures: list[str] = []
170
+ if not schema_name:
171
+ return {"ok": False, "checks": checks, "failures": ["Missing schema name for sanity checks"]}
172
+
173
+ use_case = str(case.get("use_case", "") or "")
174
+ use_case_lower = (use_case or "").lower()
175
+ case_name = str(case.get("name", "") or "").lower()
176
+ is_legal = "legal" in use_case_lower
177
+ is_private_equity = any(
178
+ marker in use_case_lower
179
+ for marker in ("private equity", "lp reporting", "state street")
180
+ )
181
+ is_saas_finance = any(
182
+ marker in use_case_lower
183
+ for marker in ("saas finance", "unit economics", "financial analytics", "fp&a", "fpa")
184
+ )
185
+ if not is_legal and not is_private_equity and not is_saas_finance:
186
+ # Keep runtime fast by evaluating only scoped vertical checks.
187
+ return {"ok": True, "checks": checks, "failures": []}
188
+
189
+ from supabase_client import inject_admin_settings_to_env
190
+ from snowflake_auth import get_snowflake_connection
191
+
192
+ inject_admin_settings_to_env()
193
+ conn = None
194
+ cur = None
195
+ try:
196
+ db_name = (os.getenv("SNOWFLAKE_DATABASE") or "DEMOBUILD").strip()
197
+ safe_schema = schema_name.replace('"', "")
198
+ conn = get_snowflake_connection()
199
+ cur = conn.cursor()
200
+ cur.execute(f'USE DATABASE "{db_name}"')
201
+ cur.execute(f'USE SCHEMA "{safe_schema}"')
202
+
203
+ if is_legal:
204
+ cur.execute("SHOW TABLES")
205
+ legal_tables = {str(row[1]).upper() for row in cur.fetchall()}
206
+ has_split_legal = {"LEGAL_MATTERS", "OUTSIDE_COUNSEL_INVOICES", "ATTORNEYS", "MATTER_TYPES"}.issubset(legal_tables)
207
+ has_event_legal = "LEGAL_SPEND_EVENTS" in legal_tables
208
+
209
+ if has_split_legal:
210
+ # 1) Invoice -> matter -> attorney join coverage.
211
+ cur.execute(
212
+ """
213
+ SELECT
214
+ COUNT(*) AS total_rows,
215
+ COUNT_IF(a.ATTORNEY_NAME IS NULL) AS null_rows
216
+ FROM OUTSIDE_COUNSEL_INVOICES oci
217
+ LEFT JOIN LEGAL_MATTERS lm ON oci.MATTER_ID = lm.MATTER_ID
218
+ LEFT JOIN ATTORNEYS a ON lm.ASSIGNED_ATTORNEY_ID = a.ATTORNEY_ID
219
+ """
220
+ )
221
+ total_rows, null_rows = cur.fetchone()
222
+ null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0
223
+ checks.append(
224
+ {
225
+ "name": "legal_attorney_join_null_pct",
226
+ "value": round(null_pct, 2),
227
+ "threshold": "<= 5.0",
228
+ "ok": null_pct <= 5.0,
229
+ }
230
+ )
231
+ if null_pct > 5.0:
232
+ failures.append(
233
+ f"Attorney join null rate too high: {null_pct:.2f}% (expected <= 5%)"
234
+ )
235
+
236
+ # 1b) Invoice MATTER_ID linkage must be complete.
237
+ cur.execute(
238
+ """
239
+ SELECT
240
+ COUNT(*) AS total_rows,
241
+ COUNT_IF(MATTER_ID IS NULL) AS null_rows
242
+ FROM OUTSIDE_COUNSEL_INVOICES
243
+ """
244
+ )
245
+ total_rows, null_rows = cur.fetchone()
246
+ null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0
247
+ checks.append(
248
+ {
249
+ "name": "legal_invoice_matter_id_null_pct",
250
+ "value": round(null_pct, 2),
251
+ "threshold": "== 0.0",
252
+ "ok": null_pct == 0.0,
253
+ }
254
+ )
255
+ if null_pct != 0.0:
256
+ failures.append(
257
+ f"Invoice MATTER_ID null rate is {null_pct:.2f}% (expected 0%)"
258
+ )
259
+
260
+ # 2) Region cardinality should be compact for legal executive demos.
261
+ cur.execute("SELECT COUNT(DISTINCT REGION) FROM LEGAL_MATTERS WHERE REGION IS NOT NULL")
262
+ region_cardinality = int(cur.fetchone()[0] or 0)
263
+ checks.append(
264
+ {
265
+ "name": "legal_region_distinct_count",
266
+ "value": region_cardinality,
267
+ "threshold": "<= 6",
268
+ "ok": region_cardinality <= 6,
269
+ }
270
+ )
271
+ if region_cardinality > 6:
272
+ failures.append(
273
+ f"Region cardinality too high: {region_cardinality} distinct values (expected <= 6)"
274
+ )
275
+
276
+ # 3) Firm names should not contain obvious cross-vertical banking/org jargon.
277
+ cur.execute(
278
+ """
279
+ SELECT COUNT(*)
280
+ FROM OUTSIDE_COUNSEL
281
+ WHERE REGEXP_LIKE(
282
+ LOWER(FIRM_NAME),
283
+ 'retail banking|consumer lending|digital channels|enterprise operations|regional service'
284
+ )
285
+ """
286
+ )
287
+ bad_firm_count = int(cur.fetchone()[0] or 0)
288
+ checks.append(
289
+ {
290
+ "name": "legal_firm_name_cross_vertical_count",
291
+ "value": bad_firm_count,
292
+ "threshold": "== 0",
293
+ "ok": bad_firm_count == 0,
294
+ }
295
+ )
296
+ if bad_firm_count != 0:
297
+ failures.append(
298
+ f"Detected {bad_firm_count} cross-vertical/non-legal firm names"
299
+ )
300
+
301
+ # 4) Matter type taxonomy should remain concise and demo-friendly.
302
+ cur.execute(
303
+ """
304
+ SELECT COUNT(DISTINCT mt.MATTER_TYPE_NAME)
305
+ FROM LEGAL_MATTERS lm
306
+ LEFT JOIN MATTER_TYPES mt ON lm.MATTER_TYPE_ID = mt.MATTER_TYPE_ID
307
+ WHERE mt.MATTER_TYPE_NAME IS NOT NULL
308
+ """
309
+ )
310
+ matter_type_cardinality = int(cur.fetchone()[0] or 0)
311
+ checks.append(
312
+ {
313
+ "name": "legal_matter_type_distinct_count",
314
+ "value": matter_type_cardinality,
315
+ "threshold": "<= 15",
316
+ "ok": matter_type_cardinality <= 15,
317
+ }
318
+ )
319
+ if matter_type_cardinality > 15:
320
+ failures.append(
321
+ f"Matter type cardinality too high: {matter_type_cardinality} distinct values (expected <= 15)"
322
+ )
323
+ elif has_event_legal:
324
+ # 1) Attorney dimension join coverage (critical for "Top Attorney by Cost").
325
+ cur.execute(
326
+ """
327
+ SELECT
328
+ COUNT(*) AS total_rows,
329
+ COUNT_IF(a.ATTORNEY_NAME IS NULL) AS null_rows
330
+ FROM LEGAL_SPEND_EVENTS lse
331
+ LEFT JOIN ATTORNEYS a ON lse.ATTORNEY_ID = a.ATTORNEY_ID
332
+ """
333
+ )
334
+ total_rows, null_rows = cur.fetchone()
335
+ null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0
336
+ checks.append(
337
+ {
338
+ "name": "legal_attorney_join_null_pct",
339
+ "value": round(null_pct, 2),
340
+ "threshold": "<= 5.0",
341
+ "ok": null_pct <= 5.0,
342
+ }
343
+ )
344
+ if null_pct > 5.0:
345
+ failures.append(
346
+ f"Attorney join null rate too high: {null_pct:.2f}% (expected <= 5%)"
347
+ )
348
+
349
+ # 2) Region cardinality should be compact for legal executive demos.
350
+ cur.execute("SELECT COUNT(DISTINCT REGION) FROM LEGAL_SPEND_EVENTS WHERE REGION IS NOT NULL")
351
+ region_cardinality = int(cur.fetchone()[0] or 0)
352
+ checks.append(
353
+ {
354
+ "name": "legal_region_distinct_count",
355
+ "value": region_cardinality,
356
+ "threshold": "<= 6",
357
+ "ok": region_cardinality <= 6,
358
+ }
359
+ )
360
+ if region_cardinality > 6:
361
+ failures.append(
362
+ f"Region cardinality too high: {region_cardinality} distinct values (expected <= 6)"
363
+ )
364
+
365
+ # 3) Firm names should not contain obvious cross-vertical banking/org jargon.
366
+ cur.execute(
367
+ """
368
+ SELECT COUNT(*)
369
+ FROM OUTSIDE_COUNSEL_FIRMS
370
+ WHERE REGEXP_LIKE(
371
+ LOWER(FIRM_NAME),
372
+ 'retail banking|consumer lending|digital channels|enterprise operations|regional service'
373
+ )
374
+ """
375
+ )
376
+ bad_firm_count = int(cur.fetchone()[0] or 0)
377
+ checks.append(
378
+ {
379
+ "name": "legal_firm_name_cross_vertical_count",
380
+ "value": bad_firm_count,
381
+ "threshold": "== 0",
382
+ "ok": bad_firm_count == 0,
383
+ }
384
+ )
385
+ if bad_firm_count != 0:
386
+ failures.append(
387
+ f"Detected {bad_firm_count} cross-vertical/non-legal firm names"
388
+ )
389
+
390
+ # 4) Matter type taxonomy should remain concise and demo-friendly.
391
+ cur.execute(
392
+ """
393
+ SELECT COUNT(DISTINCT mt.MATTER_TYPE_NAME)
394
+ FROM LEGAL_SPEND_EVENTS lse
395
+ LEFT JOIN MATTER_TYPES mt ON lse.MATTER_TYPE_ID = mt.MATTER_TYPE_ID
396
+ WHERE mt.MATTER_TYPE_NAME IS NOT NULL
397
+ """
398
+ )
399
+ matter_type_cardinality = int(cur.fetchone()[0] or 0)
400
+ checks.append(
401
+ {
402
+ "name": "legal_matter_type_distinct_count",
403
+ "value": matter_type_cardinality,
404
+ "threshold": "<= 15",
405
+ "ok": matter_type_cardinality <= 15,
406
+ }
407
+ )
408
+ if matter_type_cardinality > 15:
409
+ failures.append(
410
+ f"Matter type cardinality too high: {matter_type_cardinality} distinct values (expected <= 15)"
411
+ )
412
+ else:
413
+ failures.append("Could not find supported legal schema shape for realism checks")
414
+
415
+ if is_private_equity:
416
+ # Guard against semantic leakage where sector/strategy dimensions are
417
+ # accidentally populated with company names.
418
+ cur.execute(
419
+ """
420
+ WITH dim_companies AS (
421
+ SELECT DISTINCT COMPANY_NAME
422
+ FROM PORTFOLIO_COMPANIES
423
+ WHERE COMPANY_NAME IS NOT NULL
424
+ ),
425
+ dim_sectors AS (
426
+ SELECT DISTINCT SECTOR_NAME
427
+ FROM SECTORS
428
+ WHERE SECTOR_NAME IS NOT NULL
429
+ ),
430
+ dim_strategies AS (
431
+ SELECT DISTINCT FUND_STRATEGY
432
+ FROM FUNDS
433
+ WHERE FUND_STRATEGY IS NOT NULL
434
+ )
435
+ SELECT
436
+ (SELECT COUNT(*) FROM dim_sectors),
437
+ (SELECT COUNT(*) FROM dim_strategies),
438
+ (SELECT COUNT(*) FROM dim_sectors s JOIN dim_companies c ON s.SECTOR_NAME = c.COMPANY_NAME),
439
+ (SELECT COUNT(*) FROM dim_strategies f JOIN dim_companies c ON f.FUND_STRATEGY = c.COMPANY_NAME)
440
+ """
441
+ )
442
+ sector_distinct, strategy_distinct, sector_overlap, strategy_overlap = cur.fetchone()
443
+ sector_distinct = int(sector_distinct or 0)
444
+ strategy_distinct = int(strategy_distinct or 0)
445
+ sector_overlap = int(sector_overlap or 0)
446
+ strategy_overlap = int(strategy_overlap or 0)
447
+
448
+ checks.append(
449
+ {
450
+ "name": "pe_sector_name_company_overlap_count",
451
+ "value": sector_overlap,
452
+ "threshold": "== 0",
453
+ "ok": sector_overlap == 0,
454
+ }
455
+ )
456
+ if sector_overlap != 0:
457
+ failures.append(
458
+ f"Sector names overlap company names ({sector_overlap} overlaps); likely mislabeled dimensions"
459
+ )
460
+
461
+ checks.append(
462
+ {
463
+ "name": "pe_fund_strategy_company_overlap_count",
464
+ "value": strategy_overlap,
465
+ "threshold": "== 0",
466
+ "ok": strategy_overlap == 0,
467
+ }
468
+ )
469
+ if strategy_overlap != 0:
470
+ failures.append(
471
+ f"Fund strategy values overlap company names ({strategy_overlap} overlaps); likely mislabeled dimensions"
472
+ )
473
+
474
+ checks.append(
475
+ {
476
+ "name": "pe_sector_distinct_count",
477
+ "value": sector_distinct,
478
+ "threshold": ">= 4 and <= 20",
479
+ "ok": 4 <= sector_distinct <= 20,
480
+ }
481
+ )
482
+ if not (4 <= sector_distinct <= 20):
483
+ failures.append(
484
+ f"Sector distinct count out of expected demo range: {sector_distinct} (expected 4-20)"
485
+ )
486
+
487
+ checks.append(
488
+ {
489
+ "name": "pe_fund_strategy_distinct_count",
490
+ "value": strategy_distinct,
491
+ "threshold": ">= 4 and <= 20",
492
+ "ok": 4 <= strategy_distinct <= 20,
493
+ }
494
+ )
495
+ if not (4 <= strategy_distinct <= 20):
496
+ failures.append(
497
+ f"Fund strategy distinct count out of expected demo range: {strategy_distinct} (expected 4-20)"
498
+ )
499
+
500
+ if case_name == "statestreet_private_equity_lp_reporting":
501
+ cur.execute(
502
+ """
503
+ SELECT
504
+ COUNT(*) AS total_rows,
505
+ COUNT_IF(ABS(TOTAL_VALUE_USD - (REPORTED_VALUE_USD + DISTRIBUTIONS_USD)) > 0.01) AS bad_rows
506
+ FROM PORTFOLIO_PERFORMANCE
507
+ """
508
+ )
509
+ total_rows, bad_rows = cur.fetchone()
510
+ total_rows = int(total_rows or 0)
511
+ bad_rows = int(bad_rows or 0)
512
+ identity_ok = total_rows > 0 and bad_rows == 0
513
+ checks.append(
514
+ {
515
+ "name": "pe_total_value_identity_bad_rows",
516
+ "value": bad_rows,
517
+ "threshold": "== 0",
518
+ "ok": identity_ok,
519
+ }
520
+ )
521
+ if not identity_ok:
522
+ failures.append(
523
+ f"Total value identity broken in {bad_rows} PE fact rows"
524
+ )
525
+
526
+ cur.execute(
527
+ """
528
+ SELECT
529
+ COUNT(*) AS total_rows,
530
+ COUNT_IF(IRR_SUB_LINE_IMPACT_BPS BETWEEN 80 AND 210) AS in_band_rows,
531
+ COUNT_IF(ABS(IRR_SUB_LINE_IMPACT_BPS - ((GROSS_IRR - GROSS_IRR_WITHOUT_SUB_LINE) * 10000)) <= 5) AS identity_rows
532
+ FROM PORTFOLIO_PERFORMANCE
533
+ """
534
+ )
535
+ total_rows, in_band_rows, identity_rows = cur.fetchone()
536
+ total_rows = int(total_rows or 0)
537
+ in_band_rows = int(in_band_rows or 0)
538
+ identity_rows = int(identity_rows or 0)
539
+ irr_band_ok = total_rows > 0 and in_band_rows == total_rows and identity_rows == total_rows
540
+ checks.append(
541
+ {
542
+ "name": "pe_subscription_line_impact_rows_valid",
543
+ "value": {"total": total_rows, "in_band": in_band_rows, "identity": identity_rows},
544
+ "threshold": "all rows in 80-210 bps band and identity holds",
545
+ "ok": irr_band_ok,
546
+ }
547
+ )
548
+ if not irr_band_ok:
549
+ failures.append("Subscription line impact rows do not consistently satisfy PE IRR delta rules")
550
+
551
+ cur.execute(
552
+ """
553
+ SELECT
554
+ COUNT(*) AS apex_rows,
555
+ MAX(pp.IRR_SUB_LINE_IMPACT_BPS) AS apex_max_bps,
556
+ (
557
+ SELECT MAX(IRR_SUB_LINE_IMPACT_BPS)
558
+ FROM PORTFOLIO_PERFORMANCE
559
+ ) AS overall_max_bps
560
+ FROM PORTFOLIO_PERFORMANCE pp
561
+ JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
562
+ WHERE LOWER(pc.COMPANY_NAME) = 'apex industrial solutions'
563
+ """
564
+ )
565
+ apex_rows, apex_max_bps, overall_max_bps = cur.fetchone()
566
+ apex_ok = int(apex_rows or 0) > 0 and apex_max_bps is not None and abs(float(apex_max_bps) - 210.0) <= 1.0 and overall_max_bps is not None and abs(float(overall_max_bps) - 210.0) <= 1.0
567
+ checks.append(
568
+ {
569
+ "name": "pe_apex_subscription_line_outlier",
570
+ "value": {"rows": int(apex_rows or 0), "apex_max_bps": apex_max_bps, "overall_max_bps": overall_max_bps},
571
+ "threshold": "Apex exists and max impact == 210 bps",
572
+ "ok": apex_ok,
573
+ }
574
+ )
575
+ if not apex_ok:
576
+ failures.append("Apex Industrial Solutions outlier is missing or not set to the expected 210 bps impact")
577
+
578
+ cur.execute(
579
+ """
580
+ WITH covenant_exceptions AS (
581
+ SELECT
582
+ LOWER(pc.COMPANY_NAME) AS company_name,
583
+ LOWER(pp.COVENANT_STATUS) AS covenant_status,
584
+ COUNT(*) AS row_count
585
+ FROM PORTFOLIO_PERFORMANCE pp
586
+ JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
587
+ WHERE LOWER(pp.COVENANT_STATUS) <> 'compliant'
588
+ GROUP BY 1, 2
589
+ )
590
+ SELECT
591
+ COUNT_IF(company_name = 'meridian specialty chemicals' AND covenant_status = 'waived') AS meridian_waived_groups,
592
+ COUNT_IF(company_name <> 'meridian specialty chemicals' OR covenant_status <> 'waived') AS invalid_groups
593
+ FROM covenant_exceptions
594
+ """
595
+ )
596
+ meridian_groups, invalid_groups = cur.fetchone()
597
+ meridian_ok = int(meridian_groups or 0) > 0 and int(invalid_groups or 0) == 0
598
+ checks.append(
599
+ {
600
+ "name": "pe_meridian_covenant_exception",
601
+ "value": {"meridian_groups": int(meridian_groups or 0), "invalid_groups": int(invalid_groups or 0)},
602
+ "threshold": "Meridian only, status waived",
603
+ "ok": meridian_ok,
604
+ }
605
+ )
606
+ if not meridian_ok:
607
+ failures.append("Meridian Specialty Chemicals is not the sole waived/non-compliant covenant exception")
608
+
609
+ cur.execute(
610
+ """
611
+ WITH sector_perf AS (
612
+ SELECT
613
+ s.SECTOR_NAME,
614
+ AVG(pp.ENTRY_EV_EBITDA_MULTIPLE) AS avg_entry_multiple,
615
+ AVG(pp.TOTAL_RETURN_MULTIPLE) AS avg_tvpi
616
+ FROM PORTFOLIO_PERFORMANCE pp
617
+ JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
618
+ JOIN SECTORS s ON pc.SECTOR_ID = s.SECTOR_ID
619
+ GROUP BY 1
620
+ )
621
+ SELECT
622
+ MAX(CASE WHEN LOWER(SECTOR_NAME) = 'technology' THEN avg_entry_multiple END) AS tech_entry,
623
+ MAX(CASE WHEN LOWER(SECTOR_NAME) = 'technology' THEN avg_tvpi END) AS tech_tvpi,
624
+ MAX(CASE WHEN LOWER(SECTOR_NAME) <> 'technology' THEN avg_entry_multiple END) AS other_entry_max,
625
+ MAX(CASE WHEN LOWER(SECTOR_NAME) <> 'technology' THEN avg_tvpi END) AS other_tvpi_max
626
+ FROM sector_perf
627
+ """
628
+ )
629
+ tech_entry, tech_tvpi, other_entry_max, other_tvpi_max = cur.fetchone()
630
+ tech_sector_ok = (
631
+ tech_entry is not None
632
+ and tech_tvpi is not None
633
+ and other_entry_max is not None
634
+ and other_tvpi_max is not None
635
+ and float(tech_entry) >= float(other_entry_max)
636
+ and float(tech_tvpi) >= float(other_tvpi_max)
637
+ )
638
+ checks.append(
639
+ {
640
+ "name": "pe_technology_sector_leads_multiples",
641
+ "value": {
642
+ "tech_entry": tech_entry,
643
+ "tech_tvpi": tech_tvpi,
644
+ "other_entry_max": other_entry_max,
645
+ "other_tvpi_max": other_tvpi_max,
646
+ },
647
+ "threshold": "Technology leads average entry and return multiples",
648
+ "ok": tech_sector_ok,
649
+ }
650
+ )
651
+ if not tech_sector_ok:
652
+ failures.append("Technology sector does not lead entry and return multiples as required by the State Street narrative")
653
+
654
+ cur.execute(
655
+ """
656
+ WITH vintage_rank AS (
657
+ SELECT
658
+ VINTAGE_YEAR,
659
+ SUM(REPORTED_VALUE_USD) AS total_reported_value,
660
+ DENSE_RANK() OVER (ORDER BY SUM(REPORTED_VALUE_USD) DESC) AS value_rank
661
+ FROM PORTFOLIO_PERFORMANCE
662
+ GROUP BY 1
663
+ )
664
+ SELECT LISTAGG(TO_VARCHAR(VINTAGE_YEAR), ',') WITHIN GROUP (ORDER BY VINTAGE_YEAR)
665
+ FROM vintage_rank
666
+ WHERE value_rank <= 2
667
+ """
668
+ )
669
+ top_vintages = cur.fetchone()[0] or ""
670
+ top_vintage_set = {part.strip() for part in str(top_vintages).split(",") if part.strip()}
671
+ vintage_ok = top_vintage_set == {"2021", "2022"}
672
+ checks.append(
673
+ {
674
+ "name": "pe_top_vintages_reported_value",
675
+ "value": sorted(top_vintage_set),
676
+ "threshold": "top 2 vintages are 2021 and 2022",
677
+ "ok": vintage_ok,
678
+ }
679
+ )
680
+ if not vintage_ok:
681
+ failures.append("2021 and 2022 are not the top reported-value vintages")
682
+
683
+ cur.execute(
684
+ """
685
+ WITH healthcare_quarters AS (
686
+ SELECT
687
+ DATE_TRUNC('quarter', pp.FULL_DATE) AS quarter_start,
688
+ AVG(pp.TOTAL_VALUE_USD) AS avg_total_value
689
+ FROM PORTFOLIO_PERFORMANCE pp
690
+ JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
691
+ JOIN SECTORS s ON pc.SECTOR_ID = s.SECTOR_ID
692
+ WHERE LOWER(s.SECTOR_NAME) = 'healthcare'
693
+ GROUP BY 1
694
+ )
695
+ SELECT
696
+ MAX(CASE WHEN quarter_start = DATE '2024-07-01' THEN avg_total_value END) AS q3_2024_value,
697
+ MAX(CASE WHEN quarter_start = DATE '2024-10-01' THEN avg_total_value END) AS q4_2024_value
698
+ FROM healthcare_quarters
699
+ """
700
+ )
701
+ q3_2024_value, q4_2024_value = cur.fetchone()
702
+ healthcare_dip_ok = (
703
+ q3_2024_value is not None
704
+ and q4_2024_value is not None
705
+ and float(q4_2024_value) < float(q3_2024_value)
706
+ )
707
+ checks.append(
708
+ {
709
+ "name": "pe_healthcare_q4_2024_dip",
710
+ "value": {"q3_2024": q3_2024_value, "q4_2024": q4_2024_value},
711
+ "threshold": "Q4 2024 healthcare total value lower than Q3 2024",
712
+ "ok": healthcare_dip_ok,
713
+ }
714
+ )
715
+ if not healthcare_dip_ok:
716
+ failures.append("Healthcare Q4 2024 performance dip is missing")
717
+
718
+ cur.execute(
719
+ """
720
+ WITH company_trends AS (
721
+ SELECT
722
+ pc.COMPANY_NAME,
723
+ FIRST_VALUE(pp.REVENUE_USD) OVER (PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE) AS first_revenue,
724
+ LAST_VALUE(pp.REVENUE_USD) OVER (
725
+ PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE
726
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
727
+ ) AS last_revenue,
728
+ FIRST_VALUE(pp.EBITDA_MARGIN_PCT) OVER (PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE) AS first_margin,
729
+ LAST_VALUE(pp.EBITDA_MARGIN_PCT) OVER (
730
+ PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE
731
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
732
+ ) AS last_margin
733
+ FROM PORTFOLIO_PERFORMANCE pp
734
+ JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
735
+ )
736
+ SELECT COUNT(DISTINCT COMPANY_NAME)
737
+ FROM company_trends
738
+ WHERE last_revenue > first_revenue AND last_margin < first_margin
739
+ """
740
+ )
741
+ trend_company_count = int(cur.fetchone()[0] or 0)
742
+ trend_ok = trend_company_count >= 1
743
+ checks.append(
744
+ {
745
+ "name": "pe_revenue_up_margin_down_company_count",
746
+ "value": trend_company_count,
747
+ "threshold": ">= 1",
748
+ "ok": trend_ok,
749
+ }
750
+ )
751
+ if not trend_ok:
752
+ failures.append("No portfolio company shows the required revenue-up / EBITDA-margin-down trend")
753
+
754
+ if is_saas_finance:
755
+ cur.execute("SELECT COUNT(DISTINCT MONTH_KEY) FROM DATES")
756
+ month_count = int(cur.fetchone()[0] or 0)
757
+ checks.append(
758
+ {
759
+ "name": "saas_month_count",
760
+ "value": month_count,
761
+ "threshold": ">= 24",
762
+ "ok": month_count >= 24,
763
+ }
764
+ )
765
+ if month_count < 24:
766
+ failures.append(f"SaaS finance month count too low: {month_count} (expected >= 24)")
767
+
768
+ cur.execute("SELECT COUNT(DISTINCT SEGMENT) FROM CUSTOMERS WHERE SEGMENT IS NOT NULL")
769
+ segment_count = int(cur.fetchone()[0] or 0)
770
+ checks.append(
771
+ {
772
+ "name": "saas_segment_distinct_count",
773
+ "value": segment_count,
774
+ "threshold": ">= 3",
775
+ "ok": segment_count >= 3,
776
+ }
777
+ )
778
+ if segment_count < 3:
779
+ failures.append(f"SaaS finance segment count too low: {segment_count} (expected >= 3)")
780
+
781
+ cur.execute("SELECT COUNT(DISTINCT REGION) FROM LOCATIONS WHERE REGION IS NOT NULL")
782
+ region_count = int(cur.fetchone()[0] or 0)
783
+ checks.append(
784
+ {
785
+ "name": "saas_region_distinct_count",
786
+ "value": region_count,
787
+ "threshold": ">= 3",
788
+ "ok": region_count >= 3,
789
+ }
790
+ )
791
+ if region_count < 3:
792
+ failures.append(f"SaaS finance region count too low: {region_count} (expected >= 3)")
793
+
794
+ cur.execute(
795
+ """
796
+ SELECT
797
+ COUNT(*) AS total_rows,
798
+ COUNT_IF(
799
+ ABS(
800
+ ENDING_ARR_USD - (
801
+ STARTING_ARR_USD + NEW_LOGO_ARR_USD + EXPANSION_ARR_USD
802
+ - CONTRACTION_ARR_USD - CHURNED_ARR_USD
803
+ )
804
+ ) > 1.0
805
+ ) AS bad_arr_rows,
806
+ COUNT_IF(ABS((MRR_USD * 12.0) - ENDING_ARR_USD) > 12.0) AS bad_mrr_rows
807
+ FROM SAAS_CUSTOMER_MONTHLY
808
+ """
809
+ )
810
+ total_rows, bad_arr_rows, bad_mrr_rows = cur.fetchone()
811
+ total_rows = int(total_rows or 0)
812
+ bad_arr_rows = int(bad_arr_rows or 0)
813
+ bad_mrr_rows = int(bad_mrr_rows or 0)
814
+ arr_identity_ok = total_rows > 0 and bad_arr_rows == 0 and bad_mrr_rows == 0
815
+ checks.append(
816
+ {
817
+ "name": "saas_arr_rollforward_bad_rows",
818
+ "value": {"total": total_rows, "bad_arr": bad_arr_rows, "bad_mrr": bad_mrr_rows},
819
+ "threshold": "all rows reconcile",
820
+ "ok": arr_identity_ok,
821
+ }
822
+ )
823
+ if not arr_identity_ok:
824
+ failures.append(
825
+ f"SaaS finance ARR identities broken (bad_arr={bad_arr_rows}, bad_mrr={bad_mrr_rows})"
826
+ )
827
+
828
+ cur.execute(
829
+ """
830
+ WITH month_counts AS (
831
+ SELECT CUSTOMER_KEY, COUNT(DISTINCT MONTH_KEY) AS active_months
832
+ FROM SAAS_CUSTOMER_MONTHLY
833
+ GROUP BY 1
834
+ )
835
+ SELECT AVG(active_months), MIN(active_months), MAX(active_months)
836
+ FROM month_counts
837
+ """
838
+ )
839
+ avg_active_months, min_active_months, max_active_months = cur.fetchone()
840
+ avg_active_months = float(avg_active_months or 0.0)
841
+ min_active_months = int(min_active_months or 0)
842
+ max_active_months = int(max_active_months or 0)
843
+ density_ok = avg_active_months >= 12.0 and max_active_months >= 20
844
+ checks.append(
845
+ {
846
+ "name": "saas_customer_month_density",
847
+ "value": {
848
+ "avg_active_months": round(avg_active_months, 2),
849
+ "min_active_months": min_active_months,
850
+ "max_active_months": max_active_months,
851
+ },
852
+ "threshold": "avg >= 12.0 and max >= 20",
853
+ "ok": density_ok,
854
+ }
855
+ )
856
+ if not density_ok:
857
+ failures.append(
858
+ f"SaaS finance customer-month density too sparse (avg={avg_active_months:.2f}, max={max_active_months})"
859
+ )
860
+
861
+ cur.execute(
862
+ """
863
+ SELECT
864
+ COUNT(*) AS total_rows,
865
+ COUNT_IF(ABS(TOTAL_S_AND_M_SPEND_USD - (SALES_SPEND_USD + MARKETING_SPEND_USD)) > 1.0) AS bad_rows
866
+ FROM SALES_MARKETING_SPEND_MONTHLY
867
+ """
868
+ )
869
+ spend_total_rows, bad_spend_rows = cur.fetchone()
870
+ spend_total_rows = int(spend_total_rows or 0)
871
+ bad_spend_rows = int(bad_spend_rows or 0)
872
+ spend_ok = spend_total_rows > 0 and bad_spend_rows == 0
873
+ checks.append(
874
+ {
875
+ "name": "saas_spend_identity_bad_rows",
876
+ "value": {"total": spend_total_rows, "bad_rows": bad_spend_rows},
877
+ "threshold": "== 0",
878
+ "ok": spend_ok,
879
+ }
880
+ )
881
+ if not spend_ok:
882
+ failures.append(f"SaaS finance spend identity broken in {bad_spend_rows} rows")
883
+
884
+ except Exception as exc: # noqa: BLE001
885
+ failures.append(f"Realism sanity checks failed to execute: {exc}")
886
+ finally:
887
+ try:
888
+ if cur is not None:
889
+ cur.close()
890
+ except Exception:
891
+ pass
892
+ try:
893
+ if conn is not None:
894
+ conn.close()
895
+ except Exception:
896
+ pass
897
+
898
+ return {"ok": len(failures) == 0, "checks": checks, "failures": failures}
899
+
900
+
901
+ def _run_case_chat(
902
+ case: dict[str, Any],
903
+ default_llm: str,
904
+ user_email: str,
905
+ skip_thoughtspot: bool = False,
906
+ ) -> dict[str, Any]:
907
+ from chat_interface import ChatDemoInterface
908
+ from demo_personas import get_use_case_config, parse_use_case
909
+
910
+ company = case["company"]
911
+ use_case = case["use_case"]
912
+ model = default_llm
913
+ context = case.get("context", "")
914
+
915
+ controller = ChatDemoInterface(user_email=user_email)
916
+ controller.settings["model"] = model
917
+ controller.vertical, controller.function = parse_use_case(use_case or "")
918
+ controller.use_case_config = get_use_case_config(
919
+ controller.vertical or "Generic",
920
+ controller.function or "Generic",
921
+ )
922
+
923
+ result: dict[str, Any] = {
924
+ "name": case.get("name") or f"{company}_{use_case}",
925
+ "company": company,
926
+ "use_case": use_case,
927
+ "mode": "chat",
928
+ "started_at": _now_utc_iso(),
929
+ "success": False,
930
+ "stages": {},
931
+ }
932
+
933
+ stage_start = datetime.now(timezone.utc)
934
+ last_research = None
935
+ for update in controller.run_research_streaming(company, use_case, generic_context=context):
936
+ last_research = update
937
+ result["stages"]["research"] = {
938
+ "ok": bool(controller.demo_builder and controller.demo_builder.company_analysis_results),
939
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
940
+ "preview": str(last_research)[:500] if last_research else "",
941
+ }
942
+
943
+ stage_start = datetime.now(timezone.utc)
944
+ ddl_text = (controller.demo_builder.schema_generation_results or "") if controller.demo_builder else ""
945
+ if not ddl_text:
946
+ _, ddl_text = controller.run_ddl_creation()
947
+ result["stages"]["ddl"] = {
948
+ "ok": bool(ddl_text and "CREATE TABLE" in ddl_text.upper()),
949
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
950
+ "ddl_length": len(ddl_text or ""),
951
+ }
952
+
953
+ if not result["stages"]["ddl"]["ok"]:
954
+ result["error"] = "DDL generation failed"
955
+ result["finished_at"] = _now_utc_iso()
956
+ return result
957
+
958
+ stage_start = datetime.now(timezone.utc)
959
+ deploy_error = None
960
+ try:
961
+ for _ in controller.run_deployment_streaming():
962
+ pass
963
+ except Exception as exc: # noqa: BLE001
964
+ deploy_error = str(exc)
965
+ deployed_schema = getattr(controller, "_deployed_schema_name", None)
966
+ schema_candidate = deployed_schema or getattr(controller, "_last_schema_name", None)
967
+ result["stages"]["deploy_snowflake"] = {
968
+ "ok": bool(deployed_schema) and deploy_error is None,
969
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
970
+ "schema": schema_candidate,
971
+ "error": deploy_error,
972
+ }
973
+
974
+ if schema_candidate:
975
+ quality_report_path = getattr(controller, "_last_population_quality_report_path", None)
976
+ result["stages"]["quality_gate"] = _build_quality_gate_stage(quality_report_path)
977
+ if not result["stages"]["quality_gate"]["ok"]:
978
+ result["error"] = f"Quality gate failed: {quality_report_path or 'missing quality report'}"
979
+ elif deploy_error and not result.get("error"):
980
+ result["error"] = deploy_error
981
+
982
+ stage_start = datetime.now(timezone.utc)
983
+ if result["stages"]["quality_gate"]["ok"]:
984
+ sanity = _run_realism_sanity_checks(schema_candidate, case)
985
+ result["stages"]["realism_sanity"] = {
986
+ "ok": bool(sanity.get("ok")),
987
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
988
+ "checks": sanity.get("checks", []),
989
+ "failures": sanity.get("failures", []),
990
+ }
991
+ if not result["stages"]["realism_sanity"]["ok"] and not result.get("error"):
992
+ result["error"] = "Realism sanity checks failed before ThoughtSpot deployment"
993
+
994
+ if not skip_thoughtspot and deployed_schema and not result.get("error"):
995
+ stage_start = datetime.now(timezone.utc)
996
+ ts_ok = True
997
+ ts_last = None
998
+ try:
999
+ for ts_update in controller._run_thoughtspot_deployment(deployed_schema, company, use_case):
1000
+ ts_last = ts_update
1001
+ except Exception as exc: # noqa: BLE001
1002
+ ts_ok = False
1003
+ ts_last = str(exc)
1004
+ # Some deployment paths return a structured failure payload rather than
1005
+ # raising; treat those as failures so pass/fail reporting is accurate.
1006
+ ts_preview_text = str(ts_last) if ts_last is not None else ""
1007
+ if ts_ok and (
1008
+ "THOUGHTSPOT DEPLOYMENT FAILED" in ts_preview_text.upper()
1009
+ or "MODEL VALIDATION FAILED" in ts_preview_text.upper()
1010
+ or "LIVEBOARD CREATION FAILED" in ts_preview_text.upper()
1011
+ ):
1012
+ ts_ok = False
1013
+ result["stages"]["deploy_thoughtspot"] = {
1014
+ "ok": ts_ok,
1015
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
1016
+ "preview": ts_preview_text[:1000],
1017
+ }
1018
+
1019
+ result["schema_name"] = schema_candidate
1020
+ result["success"] = all(stage.get("ok") for stage in result["stages"].values())
1021
+ result["finished_at"] = _now_utc_iso()
1022
+ return result
1023
+
1024
+
1025
+ def _run_case_offline(
1026
+ case: dict[str, Any],
1027
+ default_llm: str,
1028
+ user_email: str,
1029
+ skip_thoughtspot: bool = False,
1030
+ ) -> dict[str, Any]:
1031
+ from cdw_connector import SnowflakeDeployer
1032
+ from demo_prep import generate_demo_base_name
1033
+ from legitdata_bridge import populate_demo_data
1034
+ from thoughtspot_deployer import deploy_to_thoughtspot
1035
+
1036
+ company = case["company"]
1037
+ use_case = case["use_case"]
1038
+
1039
+ result: dict[str, Any] = {
1040
+ "name": case.get("name") or f"{company}_{use_case}",
1041
+ "company": company,
1042
+ "use_case": use_case,
1043
+ "mode": "offline_ddl",
1044
+ "started_at": _now_utc_iso(),
1045
+ "success": False,
1046
+ "stages": {},
1047
+ "ddl_template": "offline_star_schema_v1",
1048
+ }
1049
+
1050
+ deployer = SnowflakeDeployer()
1051
+
1052
+ # 1) Snowflake schema + DDL deploy
1053
+ stage_start = datetime.now(timezone.utc)
1054
+ ok, msg = deployer.connect()
1055
+ if not ok:
1056
+ result["stages"]["snowflake_connect"] = {
1057
+ "ok": False,
1058
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
1059
+ "message": msg,
1060
+ }
1061
+ result["error"] = msg
1062
+ result["finished_at"] = _now_utc_iso()
1063
+ return result
1064
+
1065
+ base_name = generate_demo_base_name("", company)
1066
+ ok, schema_name, ddl_msg = deployer.create_demo_schema_and_deploy(base_name, OFFLINE_DEMO_DDL)
1067
+ result["stages"]["snowflake_ddl"] = {
1068
+ "ok": ok,
1069
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
1070
+ "schema": schema_name,
1071
+ "message": ddl_msg,
1072
+ }
1073
+ if not ok or not schema_name:
1074
+ result["error"] = ddl_msg
1075
+ result["finished_at"] = _now_utc_iso()
1076
+ return result
1077
+
1078
+ # 2) Data population via LegitData
1079
+ stage_start = datetime.now(timezone.utc)
1080
+ pop_ok, pop_msg, pop_results = populate_demo_data(
1081
+ ddl_content=OFFLINE_DEMO_DDL,
1082
+ company_url=company,
1083
+ use_case=use_case,
1084
+ schema_name=schema_name,
1085
+ llm_model=default_llm,
1086
+ user_email=user_email,
1087
+ size="medium",
1088
+ )
1089
+ result["stages"]["populate_data"] = {
1090
+ "ok": pop_ok,
1091
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
1092
+ "rows": pop_results,
1093
+ "quality_report": _parse_quality_report_path(pop_msg),
1094
+ }
1095
+ if not pop_ok:
1096
+ result["error"] = pop_msg
1097
+ result["finished_at"] = _now_utc_iso()
1098
+ return result
1099
+
1100
+ quality_report_path = _parse_quality_report_path(pop_msg)
1101
+ result["stages"]["quality_gate"] = _build_quality_gate_stage(quality_report_path)
1102
+ if not result["stages"]["quality_gate"]["ok"]:
1103
+ result["error"] = f"Quality gate failed: {quality_report_path or 'missing quality report'}"
1104
+ result["schema_name"] = schema_name
1105
+ result["finished_at"] = _now_utc_iso()
1106
+ return result
1107
+
1108
+ stage_start = datetime.now(timezone.utc)
1109
+ sanity = _run_realism_sanity_checks(schema_name, case)
1110
+ result["stages"]["realism_sanity"] = {
1111
+ "ok": bool(sanity.get("ok")),
1112
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
1113
+ "checks": sanity.get("checks", []),
1114
+ "failures": sanity.get("failures", []),
1115
+ }
1116
+ if not result["stages"]["realism_sanity"]["ok"]:
1117
+ result["error"] = "Realism sanity checks failed before ThoughtSpot deployment"
1118
+ result["schema_name"] = schema_name
1119
+ result["finished_at"] = _now_utc_iso()
1120
+ return result
1121
+
1122
+ # 3) ThoughtSpot model + liveboard
1123
+ if not skip_thoughtspot:
1124
+ stage_start = datetime.now(timezone.utc)
1125
+ ts_result = deploy_to_thoughtspot(
1126
+ ddl=OFFLINE_DEMO_DDL,
1127
+ database=os.getenv("SNOWFLAKE_DATABASE", "DEMOBUILD"),
1128
+ schema=schema_name,
1129
+ base_name=base_name,
1130
+ connection_name=f"{base_name}_conn",
1131
+ company_name=company,
1132
+ use_case=use_case,
1133
+ llm_model=default_llm,
1134
+ )
1135
+ result["stages"]["deploy_thoughtspot"] = {
1136
+ "ok": bool(ts_result and not ts_result.get("errors")),
1137
+ "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
1138
+ "result": ts_result,
1139
+ }
1140
+
1141
+ result["schema_name"] = schema_name
1142
+ result["success"] = all(stage.get("ok") for stage in result["stages"].values())
1143
+ result["finished_at"] = _now_utc_iso()
1144
+ return result
1145
+
1146
+
1147
+ def main() -> None:
1148
+ parser = argparse.ArgumentParser(description="Run New Vision sample set")
1149
+ parser.add_argument(
1150
+ "--cases-file",
1151
+ default="tests/newvision_test_cases.yaml",
1152
+ help="Path to YAML test case file",
1153
+ )
1154
+ parser.add_argument(
1155
+ "--skip-thoughtspot",
1156
+ action="store_true",
1157
+ help="Run through data generation only and skip ThoughtSpot object creation",
1158
+ )
1159
+ parser.add_argument(
1160
+ "--offline-ddl",
1161
+ action="store_true",
1162
+ help="Force offline DDL mode (no LLM dependency)",
1163
+ )
1164
+ args = parser.parse_args()
1165
+ user_email, default_llm = _resolve_runtime_settings()
1166
+ from startup_validation import validate_required_pipeline_settings_or_raise
1167
+
1168
+ validate_required_pipeline_settings_or_raise(
1169
+ default_llm=default_llm,
1170
+ require_thoughtspot=not args.skip_thoughtspot,
1171
+ require_snowflake=True,
1172
+ )
1173
+
1174
+ cases_file = (PROJECT_ROOT / args.cases_file).resolve()
1175
+ cases = _load_cases(cases_file)
1176
+
1177
+ run_id = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
1178
+ out_dir = PROJECT_ROOT / "results" / "newvision_samples" / run_id
1179
+ out_dir.mkdir(parents=True, exist_ok=True)
1180
+
1181
+ use_offline = bool(args.offline_ddl)
1182
+ print(f"Mode: {'offline_ddl' if use_offline else 'chat'}", flush=True)
1183
+ print(f"default_llm: {default_llm}", flush=True)
1184
+
1185
+ results = []
1186
+ for idx, case in enumerate(cases, start=1):
1187
+ print(f"\n[{idx}/{len(cases)}] {case.get('name', case['company'])} -> {case['use_case']}", flush=True)
1188
+ try:
1189
+ if use_offline:
1190
+ case_result = _run_case_offline(
1191
+ case,
1192
+ default_llm=default_llm,
1193
+ user_email=user_email,
1194
+ skip_thoughtspot=args.skip_thoughtspot,
1195
+ )
1196
+ else:
1197
+ case_result = _run_case_chat(
1198
+ case,
1199
+ default_llm=default_llm,
1200
+ user_email=user_email,
1201
+ skip_thoughtspot=args.skip_thoughtspot,
1202
+ )
1203
+ except Exception as exc: # noqa: BLE001
1204
+ case_result = {
1205
+ "name": case.get("name") or f"{case['company']}_{case['use_case']}",
1206
+ "company": case["company"],
1207
+ "use_case": case["use_case"],
1208
+ "mode": "offline_ddl" if use_offline else "chat",
1209
+ "started_at": _now_utc_iso(),
1210
+ "finished_at": _now_utc_iso(),
1211
+ "success": False,
1212
+ "error": f"Runner exception: {exc}",
1213
+ "stages": {
1214
+ "runner_exception": {
1215
+ "ok": False,
1216
+ "message": str(exc),
1217
+ }
1218
+ },
1219
+ }
1220
+ results.append(case_result)
1221
+ (out_dir / f"{case_result['name']}.json").write_text(
1222
+ json.dumps(case_result, indent=2),
1223
+ encoding="utf-8",
1224
+ )
1225
+ print(f" success={case_result['success']} schema={case_result.get('schema_name')}", flush=True)
1226
+
1227
+ summary = {
1228
+ "run_id": run_id,
1229
+ "mode": "offline_ddl" if use_offline else "chat",
1230
+ "cases_file": str(cases_file),
1231
+ "total": len(results),
1232
+ "passed": sum(1 for r in results if r.get("success")),
1233
+ "failed": sum(1 for r in results if not r.get("success")),
1234
+ "results": results,
1235
+ }
1236
+ (out_dir / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
1237
+
1238
+ print("\nSaved sample artifacts:", out_dir)
1239
+ print(f"Passed: {summary['passed']} / {summary['total']}")
1240
+
1241
+ if summary["failed"]:
1242
+ raise SystemExit(1)
1243
+
1244
+
1245
+ if __name__ == "__main__":
1246
+ main()
tests/newvision_test_cases.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # New Vision 4-case review set (no-login runner)
2
+
3
+ test_cases:
4
+ - name: "amazon_defined"
5
+ company: "amazon.com"
6
+ use_case: "Sales Analytics"
7
+ model: "claude-sonnet-4.5"
8
+ tags: [defined, retail, baseline]
9
+
10
+ - name: "chase_defined"
11
+ company: "chase.com"
12
+ use_case: "Financial Analytics"
13
+ model: "claude-sonnet-4.5"
14
+ tags: [defined, banking]
15
+
16
+ - name: "nike_defined"
17
+ company: "nike.com"
18
+ use_case: "Retail Analytics"
19
+ model: "claude-sonnet-4.5"
20
+ tags: [defined, retail]
21
+
22
+ - name: "bayer_generic"
23
+ company: "bayer.com"
24
+ use_case: "Prescription Performance and Sales Impact"
25
+ model: "claude-sonnet-4.5"
26
+ tags: [generic, pharma]
27
+ context: |
28
+ Primary business driver:
29
+ Prescription growth and revenue performance of a branded pharmaceutical product.
30
+
31
+ Frequent changes and initiatives:
32
+ - Price adjustments and contract negotiations
33
+ - Reimbursement and formulary access changes
34
+ - Sales force targeting and promotional activity
35
+ - Launch or modification of patient access programs
36
+
37
+ Highly segmented by:
38
+ - Geography and sales territory
39
+ - Payer and formulary status
40
+ - Prescriber specialty and prescribing volume
41
+ - Customer segment (high vs low prescribers, new vs existing writers)
42
+
43
+ Business questions:
44
+ 1) How did new and total prescriptions for Drug X change in Germany after the price and formulary update, compared to the previous month?
45
+ 2) Break this down by payer status and prescriber segment.
46
+ 3) Did prescribing frequency or sales volume change among physicians who did not increase prescribing after the update?
tests/newvision_test_cases_2.yaml ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ test_cases:
2
+ - name: "thoughtspot_legal_ops_spend"
3
+ company: "thoughtspot.com"
4
+ use_case: "Legal Operations and Spend Management"
5
+ model: "claude-sonnet-4.5"
6
+ tags: [legal, operations, spend, generic]
7
+ context: |
8
+ Target persona: General Counsel / VP of Legal Operations.
9
+
10
+ Build a schema with exactly 2 fact tables and 4 dimension tables.
11
+
12
+ Fact table 1: LEGAL_MATTERS (workload/matter management)
13
+ Columns:
14
+ - matter_id
15
+ - matter_key (example: LGL-6691)
16
+ - summary
17
+ - request_type (Sales, NDA, Procurement, Alliance, Marketing, Other)
18
+ - priority (P1, P2, P3)
19
+ - status (New, In Progress, Done)
20
+ - assigned_attorney_id (FK)
21
+ - reporter_name
22
+ - created_date
23
+ - updated_date
24
+ - resolved_date
25
+ - matter_type_id (FK to MATTER_TYPES)
26
+ - region (US East, US West, EMEA, APAC)
27
+ - is_customer_deal (boolean)
28
+ - sfdc_opportunity_value
29
+
30
+ Fact table 2: OUTSIDE_COUNSEL_INVOICES (spend management)
31
+ Columns:
32
+ - invoice_id
33
+ - outside_counsel_id (FK)
34
+ - matter_id (FK to LEGAL_MATTERS)
35
+ - matter_number (business reference only; do not use as the relationship key)
36
+ - description
37
+ - services_amount
38
+ - costs_amount
39
+ - total_amount
40
+ - invoice_date
41
+ - invoice_month
42
+ - payment_status (Paid, Pending, Overdue)
43
+
44
+ IMPORTANT RELATIONSHIP RULE:
45
+ - MATTER_TYPE belongs to LEGAL_MATTERS through matter_type_id.
46
+ - OUTSIDE_COUNSEL_INVOICES must link to LEGAL_MATTERS through matter_id.
47
+ - Do not make invoices carry an independent matter_type truth.
48
+
49
+ Dimension table 1: ATTORNEYS
50
+ - 100 in-house attorneys
51
+ - attorney_id, attorney_name, title (Associate GC, Senior Counsel, Paralegal, VP Legal, GC),
52
+ department (Commercial, IP, Employment, Corporate, Compliance), hire_date, manager_name
53
+
54
+ Dimension table 2: OUTSIDE_COUNSEL
55
+ - 5 firms: Cooley LLP, Young Basile, Baker McKenzie, McDermott Will & Emery, Newman Du Wors
56
+ - firm_id, firm_name, specialty (IP, Corporate, Employment, Commercial, Regulatory),
57
+ billing_model (Hourly, Flat Fee, Blended), avg_hourly_rate
58
+
59
+ Dimension table 3: MATTER_TYPES
60
+ - NDA, Sales Agreement, Procurement, Alliance, Marketing, Patent Filing, Trademark Application,
61
+ Litigation, Employment, Corporate Governance
62
+
63
+ Dimension table 4: DATE_DIM
64
+ - Standard date dimension
65
+
66
+ Outlier patterns to include:
67
+ - November spend spike about 40% higher than other months, driven by 4 new patent filings and litigation surge
68
+ - Travis Guerre spends disproportionate time on NDA reviews versus peers
69
+ - One attorney should have significantly longer completion times on Sales matters
70
+ - Cooley LLP should be highest-spend firm, followed by Baker McKenzie
71
+ - Q4 should show higher volume across all request types
72
+
73
+ Data volume:
74
+ - About 2,000 legal matters over 24 months
75
+ - About 500 invoices across the 5 firms
76
+
77
+ Key Spotter questions to support:
78
+ - Why was spend so high in November?
79
+ - What were the top 3 outside counsels by spend?
80
+ - What does Travis spend most of his time on?
81
+ - Average time to close NDA requests
82
+ - Which attorneys have the longest completion times for Sales matters?
83
+
84
+ QUALITY GUARDRAILS (required):
85
+ - Keep month-over-month change for core spend and volume measures within +/-20% in normal months.
86
+ - Allow large changes only for named events (November spend spike and Q4 lift), with at most 2 outlier periods.
87
+ - Enforce arithmetic consistency: total_amount = services_amount + costs_amount.
88
+ - Keep payment_status and request_type values strictly within listed domains.
89
+ - Keep trend lines smooth and realistic, not random jagged spikes.
90
+
91
+ - name: "thoughtspot_saas_finance_defined"
92
+ company: "thoughtspot.com"
93
+ use_case: "Financial Analytics"
94
+ model: "claude-sonnet-4.5"
95
+ tags: [saas, finance, arr, defined]
96
+ context: |
97
+ Build a SaaS finance story for a CFO/VP Finance persona.
98
+
99
+ Focus metrics:
100
+ - ARR, MRR, NRR, Gross Revenue Retention, churn_rate
101
+ - expansion_arr, contraction_arr, new_logo_arr
102
+ - CAC, LTV, payback_months
103
+ - billings, collections, deferred_revenue
104
+
105
+ Data expectations:
106
+ - 24 months of monthly history
107
+ - 1,500-2,500 customer-account records
108
+ - Enterprise, Mid-Market, SMB segments
109
+ - Regions: North America, EMEA, APAC
110
+ - Required dimensions: DATES, CUSTOMERS, PRODUCTS, LOCATIONS
111
+ - Required facts: SAAS_CUSTOMER_MONTHLY, SALES_MARKETING_SPEND_MONTHLY
112
+ - SAAS_CUSTOMER_MONTHLY grain must be one row per customer per month
113
+ - SALES_MARKETING_SPEND_MONTHLY grain must be one row per month per segment per region
114
+ - SALES_MARKETING_SPEND_MONTHLY must carry NRR_PCT, GRR_PCT, NET_NEW_ARR_USD, and CAC_USD
115
+
116
+ Outlier patterns to include:
117
+ - Q4 ARR uplift from enterprise expansion
118
+ - July SMB churn spike and temporary NRR dip
119
+ - EMEA collections lag increases DSO
120
+ - One segment with unusually strong upsell conversion
121
+
122
+ Key Spotter questions:
123
+ - Why did ARR jump in Q4?
124
+ - Which segment is driving the most churn?
125
+ - What are the top drivers of NRR change month over month?
126
+ - Where are collections delays impacting cash flow?
127
+
128
+ QUALITY GUARDRAILS (required):
129
+ - Keep ARR/MRR/NRR time series smooth with normal month-over-month movement within +/-20%.
130
+ - Restrict major variance to named events only (Q4 ARR uplift and July SMB churn spike), max 2 outlier periods.
131
+ - Preserve metric relationships: ending_arr approximately prior_arr + expansion_arr + new_logo_arr - churned_arr - contraction_arr.
132
+ - Preserve metric relationships: mrr approximately ending_arr / 12.
133
+ - Preserve spend relationships: total_s_and_m_spend = sales_spend + marketing_spend.
134
+ - Preserve derived KPI relationships: nrr and grr should reconcile to the monthly ARR movement components for each segment/region.
135
+ - Keep categorical fields in-domain only (segment and region lists above).
136
+ - Avoid synthetic or nonsense category values.
137
+
138
+ - name: "datadog_saas_finance_generic"
139
+ company: "datadog.com"
140
+ use_case: "SaaS Finance and Unit Economics"
141
+ model: "claude-sonnet-4.5"
142
+ tags: [saas, finance, unit_economics, generic]
143
+ context: |
144
+ Create a generic SaaS finance and unit economics demo for FP&A and Finance Ops.
145
+
146
+ Include realistic SaaS business entities:
147
+ - Customers/accounts, subscriptions, invoices, payments, usage
148
+ - Product lines (core platform, security, observability, add-ons)
149
+ - Segments, geo regions, and contract terms
150
+ - Required dimensions: DATES, CUSTOMERS, PRODUCTS, LOCATIONS
151
+ - Required facts: SAAS_CUSTOMER_MONTHLY, SALES_MARKETING_SPEND_MONTHLY
152
+ - SAAS_CUSTOMER_MONTHLY grain must be one row per customer per month
153
+ - SALES_MARKETING_SPEND_MONTHLY grain must be one row per month per segment per region
154
+ - SALES_MARKETING_SPEND_MONTHLY must carry NRR_PCT, GRR_PCT, NET_NEW_ARR_USD, and CAC_USD
155
+
156
+ Core metrics:
157
+ - ARR, MRR, NRR, GRR
158
+ - churned_arr, expansion_arr, net_new_arr
159
+ - CAC, LTV, gross_margin, payback_period
160
+ - billed_vs_collected, deferred_revenue
161
+
162
+ Outlier patterns to include:
163
+ - November billings spike due to annual prepay renewals
164
+ - One product line has margin compression for two quarters
165
+ - APAC has stronger net new ARR growth than other regions
166
+ - SMB churn rises while enterprise expansion offsets total ARR
167
+
168
+ Key Spotter questions:
169
+ - Why were billings unusually high in November?
170
+ - Which product line has the weakest margin trend?
171
+ - What is driving NRR by segment?
172
+ - Which region contributes most to net new ARR?
173
+
174
+ QUALITY GUARDRAILS (required):
175
+ - Keep monthly finance series smooth with normal month-over-month movement within +/-20%.
176
+ - Allow major movement only for named events (November billings spike, two-quarter margin compression), max 2 outlier periods.
177
+ - Use strict arithmetic consistency where relevant (for example billed_vs_collected and deferred revenue behavior over time).
178
+ - Preserve metric relationships: ending_arr approximately prior_arr + expansion_arr + new_logo_arr - churned_arr - contraction_arr.
179
+ - Preserve metric relationships: mrr approximately ending_arr / 12.
180
+ - Preserve spend relationships: total_s_and_m_spend = sales_spend + marketing_spend.
181
+ - Preserve derived KPI relationships: nrr and grr should reconcile to the monthly ARR movement components for each segment/region.
182
+ - Use only realistic categorical values for product lines, segments, and regions.
183
+ - For organization entities, prefer ACCOUNT_NAME or COMPANY_NAME style values; avoid person-like names in org columns.
184
+
185
+ - name: "statestreet_private_equity_lp_reporting"
186
+ company: "statestreet.com"
187
+ use_case: "Private Equity Portfolio Analytics for LP Reporting"
188
+ model: "claude-sonnet-4.5"
189
+ tags: [private_equity, lp_reporting, portfolio_analytics, statestreet]
190
+ context: |
191
+ Schema Requirements - 1 fact table and 5 dimension tables:
192
+
193
+ FACT TABLE: PORTFOLIO_PERFORMANCE (quarterly reporting data per portfolio company per quarter)
194
+ - record_id
195
+ - fund_id (FK)
196
+ - company_id (FK)
197
+ - quarter_date
198
+ - vintage_year
199
+ - invested_capital
200
+ - reported_value (NAV)
201
+ - distributions
202
+ - total_value (reported_value + distributions)
203
+ - gross_irr
204
+ - net_irr
205
+ - gross_irr_without_sub_line
206
+ - irr_sub_line_impact_bps
207
+ - total_return_multiple (TVPI)
208
+ - dpi_multiple
209
+ - rvpi_multiple
210
+ - revenue
211
+ - ebitda
212
+ - net_debt
213
+ - entry_ev_ebitda_multiple
214
+ - current_ev_ebitda_multiple
215
+ - revenue_growth_pct
216
+ - ebitda_margin_pct
217
+ - debt_to_ebitda_ratio
218
+ - covenant_status (Compliant, Waived, Breached)
219
+ - debt_performance_status (Performing, Watch List, Non-Performing)
220
+
221
+ DIM TABLE 1: FUNDS
222
+ - fund_id
223
+ - fund_name ("Alpha Private Equity Fund III, L.P." and 2-3 others)
224
+ - fund_manager
225
+ - fund_size
226
+ - fund_currency
227
+ - fund_strategy (Buyout, Growth, Venture)
228
+ - inception_date
229
+
230
+ DIM TABLE 2: PORTFOLIO_COMPANIES
231
+ - company_id
232
+ - company_name (10 realistic PE-backed companies like "Apex Industrial Solutions", "Meridian Specialty Chemicals", "NovaTech Data Systems")
233
+ - sector (Industrials, Healthcare, Technology, Consumer, Financial Services)
234
+ - sub_sector
235
+ - headquarters_state
236
+ - investment_date
237
+ - investment_stage (Active, Realized, Written Off)
238
+ - board_seats_held
239
+
240
+ DIM TABLE 3: SECTORS
241
+ - sector_id
242
+ - sector_name
243
+ - sector_category (Cyclical, Defensive, Growth)
244
+
245
+ DIM TABLE 4: LP_INVESTORS
246
+ - investor_id
247
+ - investor_name (pension funds, endowments, sovereign wealth)
248
+ - investor_type (Pension Fund, Endowment, Family Office, Sovereign Wealth, Insurance)
249
+ - commitment_amount
250
+ - aum_total
251
+
252
+ DIM TABLE 5: DATE_DIM
253
+ - standard quarterly date dimension
254
+
255
+ Outlier patterns to bake in (critical for demo narrative):
256
+ - Subscription line impact: all companies should have gross_irr_without_sub_line that is 80-210 bps lower than gross_irr. Apex Industrial Solutions should have the highest impact at exactly 210 bps.
257
+ - Meridian Specialty Chemicals should have covenant_status = 'Waived' as the only company with a non-compliant status.
258
+ - Technology sector investments made at higher entry multiples (12-15x EV/EBITDA) should have the highest total return multiples (2.5-3.8x TVPI).
259
+ - 2021 and 2022 vintage years should hold the most reported value.
260
+ - Q4 2024 should show a performance dip in Healthcare sector.
261
+ - One company should show revenue growing but EBITDA margin declining.
262
+
263
+ Data volume:
264
+ - 10 portfolio companies x 12 quarters (3 years) = ~120 rows in fact table
265
+ - Keep it small and realistic (private equity reporting model)
266
+ - 5 LP investors, 3 funds
267
+
268
+ Target persona:
269
+ - Portfolio Manager at a $50B Pension Fund (LP client using the embedded experience)
270
+
271
+ Key Spotter questions the demo MUST support (in this order):
272
+ - show me my total invested and reported value by sector
273
+ - now by vintage year
274
+ - compare gross irr with and without subscription facility for my active investments
275
+ - which company has the highest irr impact from the sub line?
276
+ - show me the trend of revenue vs ebitda for Apex Industrial Solutions since investment
277
+ - what is the covenant status for my active investments with debt status performing?
278
+ - show me the relationship between entry ev/ebitda multiple and total return multiple by sector
tests/test_mcp_liveboard.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick test script for MCP liveboard creation
4
+ """
5
+ from liveboard_creator import create_liveboard_from_model_mcp
6
+ from thoughtspot_deployer import ThoughtSpotDeployer
7
+ from dotenv import load_dotenv
8
+ import os
9
+
10
+ load_dotenv()
11
+
12
+ # Create ThoughtSpot client
13
+ ts_client = ThoughtSpotDeployer(
14
+ base_url=os.getenv('THOUGHTSPOT_URL'),
15
+ username=os.getenv('THOUGHTSPOT_USERNAME'),
16
+ password=os.getenv('THOUGHTSPOT_PASSWORD')
17
+ )
18
+
19
+ # Test with the model you specified
20
+ model_name = "GEN_02031143_YWV_mdl"
21
+ model_id = "36de83be-e284-4d76-98cf-330902ba1973"
22
+
23
+ company_data = {
24
+ 'name': 'Test Company',
25
+ 'website': 'test.com'
26
+ }
27
+
28
+ print(f"Testing MCP liveboard creation with model: {model_name}")
29
+ print(f"Model ID: {model_id}")
30
+ print("-" * 60)
31
+
32
+ result = create_liveboard_from_model_mcp(
33
+ ts_client=ts_client,
34
+ model_id=model_id,
35
+ model_name=model_name,
36
+ company_data=company_data,
37
+ use_case="Retail Sales",
38
+ num_visualizations=3, # Just 3 to test quickly
39
+ liveboard_name="Test Liveboard MCP",
40
+ llm_model="claude-sonnet-4"
41
+ )
42
+
43
+ print("-" * 60)
44
+ print(f"Result: {result}")
45
+
46
+ if result.get('success'):
47
+ print(f"✅ SUCCESS!")
48
+ print(f" Liveboard: {result.get('liveboard_name')}")
49
+ print(f" GUID: {result.get('liveboard_guid')}")
50
+ print(f" URL: {result.get('liveboard_url')}")
51
+ else:
52
+ print(f"❌ FAILED: {result.get('error')}")
53
+
54
+
thoughtspot_deployer.py CHANGED
@@ -749,7 +749,10 @@ class ThoughtSpotDeployer:
749
  'columns': [],
750
  'properties': {
751
  'is_bypass_rls': False,
752
- 'join_progressive': True
 
 
 
753
  }
754
  }
755
  }
@@ -904,7 +907,10 @@ class ThoughtSpotDeployer:
904
  'columns': [],
905
  'properties': {
906
  'is_bypass_rls': False,
907
- 'join_progressive': True
 
 
 
908
  }
909
  }
910
  }
@@ -1456,7 +1462,10 @@ class ThoughtSpotDeployer:
1456
  'worksheet_columns': [], # Adding back - but with GUID references
1457
  'properties': {
1458
  'is_bypass_rls': False,
1459
- 'join_progressive': True
 
 
 
1460
  }
1461
  }
1462
  }
@@ -1731,6 +1740,49 @@ class ThoughtSpotDeployer:
1731
  print(f"[ThoughtSpot] ⚠️ Tag assignment error: {str(e)}", flush=True)
1732
  return False
1733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1734
  def _generate_demo_names(self, company_name: str = None, use_case: str = None):
1735
  """Generate standardized demo names using DM convention"""
1736
  from datetime import datetime
@@ -1764,10 +1816,10 @@ class ThoughtSpotDeployer:
1764
  }
1765
 
1766
  def deploy_all(self, ddl: str, database: str, schema: str, base_name: str,
1767
- connection_name: str = None, company_name: str = None,
1768
- use_case: str = None, liveboard_name: str = None,
1769
- llm_model: str = None, tag_name: str = None,
1770
- liveboard_method: str = None,
1771
  progress_callback=None) -> Dict:
1772
  """
1773
  Deploy complete data model to ThoughtSpot
@@ -2169,18 +2221,48 @@ class ThoughtSpotDeployer:
2169
  log_progress(f"Assigning tag '{tag_name}' to model...")
2170
  self.assign_tags_to_objects([model_guid], 'LOGICAL_TABLE', tag_name)
2171
 
2172
- # Step 3.5: Enable Spotter on the model via API
 
 
 
 
 
 
 
2173
  try:
2174
- enable_response = self.session.post(
2175
- f"{self.base_url}/api/rest/2.0/metadata/sage/enable",
2176
  json={
2177
- "metadata_identifiers": [model_guid]
 
 
2178
  }
2179
  )
2180
- if enable_response.status_code == 200:
2181
- log_progress(f"🤖 Spotter enabled")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2182
  except Exception as spotter_error:
2183
- pass # Not critical
2184
 
2185
  # Step 4: Auto-create Liveboard from model
2186
  lb_start = time.time()
@@ -2216,40 +2298,37 @@ class ThoughtSpotDeployer:
2216
  for col in columns_list:
2217
  model_columns.append(col)
2218
 
2219
- # Get outlier patterns from the vertical×function system
2220
  outlier_dicts = []
2221
  try:
2222
- from outlier_system import get_outliers_for_use_case
2223
- from demo_personas import parse_use_case
2224
  uc_vertical, uc_function = parse_use_case(use_case or '')
2225
- if uc_vertical or uc_function:
2226
- outlier_config = get_outliers_for_use_case(
2227
- uc_vertical or "Generic",
2228
- uc_function or "Generic"
2229
- )
2230
- # Convert OutlierPattern objects to dicts for MCP function
2231
- for op in outlier_config.required:
2232
- outlier_dicts.append({
2233
- 'title': op.name,
2234
- 'insight': op.viz_talking_point,
2235
- 'viz_type': op.viz_type,
2236
- 'show_me_query': op.viz_question,
2237
- 'kpi_companion': True,
2238
- 'spotter_questions': op.spotter_questions,
2239
- })
2240
- for op in outlier_config.optional[:2]: # Include up to 2 optional
2241
- outlier_dicts.append({
2242
- 'title': op.name,
2243
- 'insight': op.viz_talking_point,
2244
- 'viz_type': op.viz_type,
2245
- 'show_me_query': op.viz_question,
2246
- 'kpi_companion': False,
2247
- 'spotter_questions': op.spotter_questions,
2248
- })
2249
- if outlier_dicts:
2250
- log_progress(f" [MCP] Using {len(outlier_dicts)} outlier patterns from {uc_vertical}×{uc_function}")
2251
  except Exception as outlier_err:
2252
- log_progress(f" [MCP] Outlier loading skipped: {outlier_err}")
2253
 
2254
  # MCP creates liveboard
2255
  if method == 'HYBRID':
@@ -2265,7 +2344,7 @@ class ThoughtSpotDeployer:
2265
  model_name=model_name,
2266
  company_data=company_data,
2267
  use_case=use_case or 'General Analytics',
2268
- num_visualizations=8,
2269
  liveboard_name=liveboard_name,
2270
  outliers=outlier_dicts if outlier_dicts else None,
2271
  llm_model=llm_model,
@@ -2290,7 +2369,7 @@ class ThoughtSpotDeployer:
2290
  model_name=model_name,
2291
  company_data=company_data,
2292
  use_case=use_case or 'General Analytics',
2293
- num_visualizations=8,
2294
  liveboard_name=liveboard_name,
2295
  llm_model=llm_model,
2296
  outliers=outlier_dicts if outlier_dicts else None,
@@ -2343,6 +2422,12 @@ class ThoughtSpotDeployer:
2343
  if tag_name and liveboard_result.get('liveboard_guid'):
2344
  log_progress(f"Assigning tag '{tag_name}' to liveboard...")
2345
  self.assign_tags_to_objects([liveboard_result['liveboard_guid']], 'PINBOARD_ANSWER_BOOK', tag_name)
 
 
 
 
 
 
2346
  else:
2347
  error = f"Liveboard creation failed: {liveboard_result.get('error', 'Unknown error')}"
2348
  print(f"❌ DEBUG: Liveboard creation failed! Error: {error}")
 
749
  'columns': [],
750
  'properties': {
751
  'is_bypass_rls': False,
752
+ 'join_progressive': True,
753
+ 'spotter_config': {
754
+ 'is_spotter_enabled': True
755
+ }
756
  }
757
  }
758
  }
 
907
  'columns': [],
908
  'properties': {
909
  'is_bypass_rls': False,
910
+ 'join_progressive': True,
911
+ 'spotter_config': {
912
+ 'is_spotter_enabled': True
913
+ }
914
  }
915
  }
916
  }
 
1462
  'worksheet_columns': [], # Adding back - but with GUID references
1463
  'properties': {
1464
  'is_bypass_rls': False,
1465
+ 'join_progressive': True,
1466
+ 'spotter_config': {
1467
+ 'is_spotter_enabled': True
1468
+ }
1469
  }
1470
  }
1471
  }
 
1740
  print(f"[ThoughtSpot] ⚠️ Tag assignment error: {str(e)}", flush=True)
1741
  return False
1742
 
1743
+ def share_objects(self, object_guids: List[str], object_type: str, share_with: str) -> bool:
1744
+ """
1745
+ Share ThoughtSpot objects with a user or group (can_edit / MODIFY).
1746
+
1747
+ Args:
1748
+ object_guids: GUIDs to share
1749
+ object_type: 'LOGICAL_TABLE' for models/tables, 'LIVEBOARD' for liveboards
1750
+ share_with: user email (contains '@') or group name
1751
+ """
1752
+ if not share_with or not object_guids:
1753
+ return True
1754
+
1755
+ principal_type = "USER" if '@' in share_with else "USER_GROUP"
1756
+
1757
+ try:
1758
+ response = self.session.post(
1759
+ f"{self.base_url}/api/rest/2.0/security/metadata/share",
1760
+ json={
1761
+ "permissions": [
1762
+ {
1763
+ "principal": {
1764
+ "identifier": share_with,
1765
+ "type": principal_type
1766
+ },
1767
+ "share_mode": "MODIFY"
1768
+ }
1769
+ ],
1770
+ "metadata": [
1771
+ {"identifier": guid, "type": object_type}
1772
+ for guid in object_guids
1773
+ ]
1774
+ }
1775
+ )
1776
+ if response.status_code in [200, 204]:
1777
+ print(f"[ThoughtSpot] ✅ Shared {len(object_guids)} {object_type} with {principal_type} '{share_with}'", flush=True)
1778
+ return True
1779
+ else:
1780
+ print(f"[ThoughtSpot] ⚠️ Share failed: {response.status_code} - {response.text[:200]}", flush=True)
1781
+ return False
1782
+ except Exception as e:
1783
+ print(f"[ThoughtSpot] ⚠️ Share error: {str(e)}", flush=True)
1784
+ return False
1785
+
1786
  def _generate_demo_names(self, company_name: str = None, use_case: str = None):
1787
  """Generate standardized demo names using DM convention"""
1788
  from datetime import datetime
 
1816
  }
1817
 
1818
  def deploy_all(self, ddl: str, database: str, schema: str, base_name: str,
1819
+ connection_name: str = None, company_name: str = None,
1820
+ use_case: str = None, liveboard_name: str = None,
1821
+ llm_model: str = None, tag_name: str = None,
1822
+ liveboard_method: str = None, share_with: str = None,
1823
  progress_callback=None) -> Dict:
1824
  """
1825
  Deploy complete data model to ThoughtSpot
 
2221
  log_progress(f"Assigning tag '{tag_name}' to model...")
2222
  self.assign_tags_to_objects([model_guid], 'LOGICAL_TABLE', tag_name)
2223
 
2224
+ # Share model
2225
+ _effective_share = share_with or get_admin_setting('SHARE_WITH', required=False)
2226
+ if _effective_share:
2227
+ log_progress(f"Sharing model with '{_effective_share}'...")
2228
+ self.share_objects([model_guid], 'LOGICAL_TABLE', _effective_share)
2229
+
2230
+ # Step 3.5: Enable Spotter via TML update
2231
+ # create_new=True import ignores spotter_config — must export then re-import to set it
2232
  try:
2233
+ export_resp = self.session.post(
2234
+ f"{self.base_url}/api/rest/2.0/metadata/tml/export",
2235
  json={
2236
+ "metadata": [{"identifier": model_guid, "type": "LOGICAL_TABLE"}],
2237
+ "export_associated": False,
2238
+ "format_type": "YAML"
2239
  }
2240
  )
2241
+ if export_resp.status_code == 200:
2242
+ export_data = export_resp.json()
2243
+ if export_data and 'edoc' in export_data[0]:
2244
+ model_tml_dict = json.loads(export_data[0]['edoc'])
2245
+ # Set spotter_config inside properties (correct location per golden demo TML)
2246
+ model_tml_dict.setdefault('model', {}).setdefault('properties', {})['spotter_config'] = {'is_spotter_enabled': True}
2247
+ updated_tml = yaml.dump(model_tml_dict, allow_unicode=True, sort_keys=False)
2248
+ update_resp = self.session.post(
2249
+ f"{self.base_url}/api/rest/2.0/metadata/tml/import",
2250
+ json={
2251
+ "metadata_tmls": [updated_tml],
2252
+ "import_policy": "ALL_OR_NONE",
2253
+ "create_new": False
2254
+ }
2255
+ )
2256
+ if update_resp.status_code == 200:
2257
+ log_progress(f"🤖 Spotter enabled on model")
2258
+ else:
2259
+ log_progress(f"🤖 Spotter enable update failed: HTTP {update_resp.status_code} — {update_resp.text[:200]}")
2260
+ else:
2261
+ log_progress(f"🤖 Spotter enable: export returned no edoc")
2262
+ else:
2263
+ log_progress(f"🤖 Spotter enable: export failed HTTP {export_resp.status_code}")
2264
  except Exception as spotter_error:
2265
+ log_progress(f"🤖 Spotter enable exception: {spotter_error}")
2266
 
2267
  # Step 4: Auto-create Liveboard from model
2268
  lb_start = time.time()
 
2298
  for col in columns_list:
2299
  model_columns.append(col)
2300
 
2301
+ # Get liveboard questions from the vertical×function config
2302
  outlier_dicts = []
2303
  try:
2304
+ from demo_personas import parse_use_case, get_use_case_config
 
2305
  uc_vertical, uc_function = parse_use_case(use_case or '')
2306
+ uc_config = get_use_case_config(uc_vertical or "Generic", uc_function or "Generic")
2307
+ lq = uc_config.get("liveboard_questions", [])
2308
+ required_qs = [q for q in lq if q.get("required")]
2309
+ optional_qs = [q for q in lq if not q.get("required")]
2310
+ for q in required_qs:
2311
+ outlier_dicts.append({
2312
+ 'title': q['title'],
2313
+ 'insight': q.get('insight', ''),
2314
+ 'viz_type': q['viz_type'],
2315
+ 'show_me_query': q['viz_question'],
2316
+ 'kpi_companion': True,
2317
+ 'spotter_questions': q.get('spotter_qs', []),
2318
+ })
2319
+ for q in optional_qs[:2]:
2320
+ outlier_dicts.append({
2321
+ 'title': q['title'],
2322
+ 'insight': q.get('insight', ''),
2323
+ 'viz_type': q['viz_type'],
2324
+ 'show_me_query': q['viz_question'],
2325
+ 'kpi_companion': False,
2326
+ 'spotter_questions': q.get('spotter_qs', []),
2327
+ })
2328
+ if outlier_dicts:
2329
+ log_progress(f" [MCP] Using {len(outlier_dicts)} liveboard questions from {uc_vertical}×{uc_function}")
 
 
2330
  except Exception as outlier_err:
2331
+ log_progress(f" [MCP] Liveboard questions loading skipped: {outlier_err}")
2332
 
2333
  # MCP creates liveboard
2334
  if method == 'HYBRID':
 
2344
  model_name=model_name,
2345
  company_data=company_data,
2346
  use_case=use_case or 'General Analytics',
2347
+ num_visualizations=10,
2348
  liveboard_name=liveboard_name,
2349
  outliers=outlier_dicts if outlier_dicts else None,
2350
  llm_model=llm_model,
 
2369
  model_name=model_name,
2370
  company_data=company_data,
2371
  use_case=use_case or 'General Analytics',
2372
+ num_visualizations=10,
2373
  liveboard_name=liveboard_name,
2374
  llm_model=llm_model,
2375
  outliers=outlier_dicts if outlier_dicts else None,
 
2422
  if tag_name and liveboard_result.get('liveboard_guid'):
2423
  log_progress(f"Assigning tag '{tag_name}' to liveboard...")
2424
  self.assign_tags_to_objects([liveboard_result['liveboard_guid']], 'PINBOARD_ANSWER_BOOK', tag_name)
2425
+
2426
+ # Share liveboard
2427
+ _effective_share = share_with or get_admin_setting('SHARE_WITH', required=False)
2428
+ if _effective_share and liveboard_result.get('liveboard_guid'):
2429
+ log_progress(f"Sharing liveboard with '{_effective_share}'...")
2430
+ self.share_objects([liveboard_result['liveboard_guid']], 'LIVEBOARD', _effective_share)
2431
  else:
2432
  error = f"Liveboard creation failed: {liveboard_result.get('error', 'Unknown error')}"
2433
  print(f"❌ DEBUG: Liveboard creation failed! Error: {error}")