""" Smart Data Adjuster Understands liveboard and schema context; handles conversational, multi-turn data adjustment requests in natural language. Connects to: - ThoughtSpot (to load liveboard viz context) - Snowflake (to query and update data) Works with any configured LLM (Claude, GPT-4, etc.) via litellm. Schema is discovered dynamically — no hardcoded table names. """ from typing import Dict, List, Optional, Tuple from snowflake_auth import get_snowflake_connection from thoughtspot_deployer import ThoughtSpotDeployer import json import re from llm_config import resolve_model_name class SmartDataAdjuster: """ Conversational data adjuster with liveboard context and schema discovery. Usage: adjuster = SmartDataAdjuster(database, schema, liveboard_guid, llm_model) adjuster.connect() adjuster.load_liveboard_context() # Per user message: result = adjuster.handle_message("make webcam revenue 40B") # result: {'type': 'confirmation', 'text': '...', 'pending': {...}} # or {'type': 'result', 'text': '...'} # or {'type': 'error', 'text': '...'} """ def __init__(self, database: str, schema: str, liveboard_guid: str, llm_model: str = None, ts_url: str = None, ts_secret: str = None, prompt_logger=None): self.database = database self.schema = schema self.liveboard_guid = liveboard_guid self.ts_url = (ts_url or "").strip() or None self.ts_secret = (ts_secret or "").strip() or None self.llm_model = (llm_model or "").strip() if not self.llm_model: raise ValueError("SmartDataAdjuster requires llm_model from settings.") self._prompt_logger = prompt_logger self.conn = None self.ts_client = None # Populated by load_liveboard_context() self.liveboard_name: Optional[str] = None self.visualizations: List[Dict] = [] # Populated by _discover_schema() in connect() self.schema_tables: Dict[str, List[str]] = {} # table → [col, ...] self.fact_tables: List[str] = [] # tables likely to be updated self.dimension_tables: Dict[str, str] = {} # table → name_column # ------------------------------------------------------------------ # Connection & schema discovery # ------------------------------------------------------------------ def connect(self): """Connect to Snowflake and ThoughtSpot, then discover schema.""" self.conn = get_snowflake_connection() cursor = self.conn.cursor() cursor.execute(f"USE DATABASE {self.database}") cursor.execute(f'USE SCHEMA "{self.schema}"') self.ts_client = ThoughtSpotDeployer( base_url=self.ts_url or None, secret_key=self.ts_secret or None, ) self.ts_client.authenticate() self._discover_schema() def _discover_schema(self): """Read actual table/column structure from INFORMATION_SCHEMA.""" cursor = self.conn.cursor() cursor.execute(f""" SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE FROM {self.database}.INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA = '{self.schema}' ORDER BY TABLE_NAME, ORDINAL_POSITION """) raw: Dict[str, List[Dict]] = {} for table, column, dtype in cursor.fetchall(): raw.setdefault(table, []).append({'name': column, 'type': dtype.upper()}) self.schema_tables = {t: [c['name'] for c in cols] for t, cols in raw.items()} # Heuristic: dimension tables have a _NAME column; fact tables have date + numeric cols for table, cols in raw.items(): col_names = [c['name'] for c in cols] col_types = {c['name']: c['type'] for c in cols} name_cols = [c for c in col_names if c.endswith('_NAME')] num_cols = [c for c in col_names if any(t in col_types.get(c, '') for t in ('NUMBER', 'FLOAT', 'INT', 'DECIMAL', 'NUMERIC'))] date_cols = [c for c in col_names if any(t in col_types.get(c, '') for t in ('DATE', 'TIME', 'TIMESTAMP'))] if name_cols: # Use the first _NAME column as the entity name column self.dimension_tables[table] = name_cols[0] if num_cols and date_cols: self.fact_tables.append(table) # If nothing looks like a fact table, fall back to largest table if not self.fact_tables and self.schema_tables: self.fact_tables = list(self.schema_tables.keys()) def _call_llm(self, prompt: str) -> str: """Call the configured LLM via litellm (supports all providers).""" from prompt_logger import logged_completion model = resolve_model_name(self.llm_model) response = logged_completion( stage="data_adjuster", logger=self._prompt_logger, model=model, messages=[{"role": "user", "content": prompt}], temperature=0, max_tokens=1000, ) return response.choices[0].message.content.strip() # ------------------------------------------------------------------ # Liveboard context # ------------------------------------------------------------------ def load_liveboard_context(self) -> bool: """Load liveboard metadata and visualization list from ThoughtSpot.""" response = self.ts_client.session.post( f"{self.ts_client.base_url}/api/rest/2.0/metadata/search", json={ "metadata": [{"type": "LIVEBOARD", "identifier": self.liveboard_guid}], "include_visualization_headers": True } ) if response.status_code != 200: return False data = response.json()[0] self.liveboard_name = data.get('metadata_name', 'Unknown Liveboard') for viz in data.get('visualization_headers', []): name = viz.get('name', '') if 'note-tile' in name.lower(): continue self.visualizations.append({'id': viz.get('id'), 'name': name}) return bool(self.visualizations) # ------------------------------------------------------------------ # Entity matching (schema-aware) # ------------------------------------------------------------------ def _fuzzy_match(self, target: str, candidates: List[str]) -> Optional[str]: """Return the best matching candidate for target, or None.""" def norm(s): return s.lower().replace(' ', '').replace('-', '').replace('_', '') t = norm(target) t_lower = target.lower() for c in candidates: if c.lower() == t_lower: return c for c in candidates: if norm(c) == t: return c for c in candidates: if t_lower in c.lower() or c.lower() in t_lower: return c for c in candidates: if t in norm(c) or norm(c) in t: return c return None def _find_entity(self, entity_value: str, entity_type_hint: str = None) -> Tuple[Optional[str], Optional[str], Optional[str]]: """ Find the closest matching entity name in any dimension table. Returns: (matched_name, table_name, name_column) or (None, None, None) """ cursor = self.conn.cursor() # Sort dimension tables: prefer ones whose name matches entity_type_hint tables_to_try = list(self.dimension_tables.items()) if entity_type_hint: hint = entity_type_hint.lower() tables_to_try.sort(key=lambda x: 0 if hint in x[0].lower() else 1) for table, name_col in tables_to_try: cursor.execute(f'SELECT DISTINCT {name_col} FROM {self.database}."{self.schema}".{table}') candidates = [row[0] for row in cursor.fetchall() if row[0]] match = self._fuzzy_match(entity_value, candidates) if match: return match, table, name_col return None, None, None def _find_fact_join(self, dim_table: str) -> Optional[Tuple[str, str, str]]: """ Find a fact table that has an FK column referencing dim_table. Returns: (fact_table, fact_fk_column, dim_pk_column) or None. """ # Look for an ID column in dim_table dim_cols = self.schema_tables.get(dim_table, []) dim_id = next((c for c in dim_cols if c.endswith('_ID')), None) if not dim_id: return None # Look for that column in fact tables for ft in self.fact_tables: if ft == dim_table: continue ft_cols = self.schema_tables.get(ft, []) if dim_id in ft_cols: return ft, dim_id, dim_id return None # ------------------------------------------------------------------ # Request interpretation # ------------------------------------------------------------------ def _parse_request_simple(self, message: str) -> Optional[Dict]: """ Fast regex-based parser for common patterns: "decrease webcam by 10%", "make laptop 50B", "increase revenue for acme by 20%" Returns parsed dict or None if pattern not matched. """ msg_lower = message.lower() # Percentage match: "by 20%", "-10%" pct_match = re.search(r'by\s+(-?\d+\.?\d*)%|(-?\d+\.?\d*)%', msg_lower) # Absolute value: "50B", "50 billion", "1.5M", "1000000" val_match = re.search(r'(\d+\.?\d*)\s*(b(?:illion)?|m(?:illion)?|k(?:thousand)?)\b', msg_lower) bare_num = re.search(r'\b(\d{4,})\b', message) # bare large integer is_percentage = bool(pct_match) percentage = None target_value = None if pct_match: raw_pct = float(pct_match.group(1) or pct_match.group(2)) if any(w in msg_lower for w in ('decrease', 'reduce', 'lower', 'drop', 'cut')): raw_pct = -abs(raw_pct) percentage = raw_pct elif val_match: num = float(val_match.group(1)) unit = val_match.group(2)[0].lower() multipliers = {'b': 1e9, 'm': 1e6, 'k': 1e3} target_value = num * multipliers.get(unit, 1) elif bare_num: target_value = float(bare_num.group(1)) else: return None # Extract entity name: quoted or after action verb entity = None quoted = re.search(r'"([^"]+)"', message) if quoted: entity = quoted.group(1).strip() else: # "make/set/increase/decrease/adjust [by/to]" action_pat = r'(?:make|set|increase|decrease|reduce|boost|lower|adjust|change)\s+(?:the\s+)?(?:\w+\s+(?:for|of)\s+)?([a-z0-9][\w\s-]*?)(?:\s+(?:by|to|revenue|sales|at)\b|\s+\d|$)' am = re.search(action_pat, msg_lower, re.I) if am: entity = am.group(1).strip() if not entity: return None # Detect entity type from keywords entity_type = None for kw in ('seller', 'vendor', 'customer', 'product', 'item', 'region', 'store'): if kw in msg_lower: entity_type = kw break return { 'entity_value': entity, 'entity_type': entity_type, 'is_percentage': is_percentage, 'percentage': percentage, 'target_value': target_value, 'confidence': 'medium', } def match_request_to_viz(self, user_request: str) -> Optional[Dict]: """ Parse request and enrich with schema context. Returns structured match dict or None. """ result = self._parse_request_simple(user_request) if not result: # Fall back to LLM for complex requests schema_summary = "\n".join( f" {t}: {', '.join(cols[:8])}" + (' ...' if len(cols) > 8 else '') for t, cols in self.schema_tables.items() ) viz_summary = "\n".join(f" {i+1}. {v['name']}" for i, v in enumerate(self.visualizations)) prompt = f"""Parse this data adjustment request. Request: "{user_request}" Snowflake schema tables: {schema_summary} Liveboard visualizations: {viz_summary} Return JSON with these fields (numbers only, not strings): {{ "entity_value": "the entity name to adjust (e.g. '1080p Webcam')", "entity_type": "product|seller|customer|region|null", "is_percentage": true|false, "percentage": , "target_value": , "metric_hint": "keyword like 'revenue', 'sales', 'profit_margin', or null", "confidence": "high|medium|low" }}""" try: raw = self._call_llm(prompt) if raw.startswith('```'): raw = '\n'.join(raw.split('\n')[1:-1]) result = json.loads(raw) except Exception: return None return result if result else None # ------------------------------------------------------------------ # Value retrieval & SQL generation # ------------------------------------------------------------------ def get_current_value(self, entity_value: str, metric_column: str, entity_type: str = None) -> Tuple[float, Optional[str], Optional[str], Optional[str]]: """ Query current aggregate value for entity from Snowflake. Returns: (current_value, matched_entity, dim_table, fact_table) """ matched, dim_table, name_col = self._find_entity(entity_value, entity_type) if not matched: return 0.0, None, None, None join_info = self._find_fact_join(dim_table) if dim_table else None cursor = self.conn.cursor() if join_info: fact_table, fk_col, dim_pk_col = join_info dim_cols = self.schema_tables.get(dim_table, []) dim_pk = next((c for c in dim_cols if c.endswith('_ID')), dim_pk_col) query = f""" SELECT SUM(f.{metric_column}) FROM {self.database}."{self.schema}".{fact_table} f JOIN {self.database}."{self.schema}".{dim_table} d ON f.{fk_col} = d.{dim_pk} WHERE LOWER(d.{name_col}) = LOWER('{matched}') """ else: # entity is directly in the table with the metric query = f""" SELECT SUM({metric_column}) FROM {self.database}."{self.schema}".{dim_table} WHERE LOWER({name_col}) = LOWER('{matched}') """ try: cursor.execute(query) row = cursor.fetchone() value = float(row[0]) if row and row[0] is not None else 0.0 fact_table_used = join_info[0] if join_info else dim_table return value, matched, dim_table, fact_table_used except Exception as e: print(f"[SmartDataAdjuster] get_current_value query failed: {e}") return 0.0, None, None, None def _pick_metric_column(self, metric_hint: str = None) -> Optional[str]: """Choose the best metric column from fact tables based on hint.""" # Build a list of all numeric-looking columns across fact tables candidates = [] for ft in self.fact_tables: for col in self.schema_tables.get(ft, []): if any(kw in col.upper() for kw in ('AMOUNT', 'REVENUE', 'TOTAL', 'SALES', 'VALUE', 'MARGIN', 'PROFIT', 'COST', 'PRICE')): candidates.append((ft, col)) if not candidates: return None if metric_hint: hint_upper = metric_hint.upper() for ft, col in candidates: if hint_upper in col: return col # Default: prefer TOTAL_AMOUNT, REVENUE, then first available for preferred in ('TOTAL_AMOUNT', 'TOTAL_REVENUE', 'REVENUE', 'AMOUNT'): for ft, col in candidates: if col == preferred: return col return candidates[0][1] if candidates else None def generate_strategy(self, entity_value: str, metric_column: str, current_value: float, target_value: float = None, percentage: float = None, entity_type: str = None) -> Dict: """Generate an UPDATE strategy based on the adjustment request.""" matched, dim_table, name_col = self._find_entity(entity_value, entity_type) if not matched: matched = entity_value if percentage is not None: multiplier = 1 + (percentage / 100) pct_change = percentage if target_value is None: target_value = current_value * multiplier elif target_value and current_value > 0: multiplier = target_value / current_value pct_change = (multiplier - 1) * 100 else: multiplier = 1.0 pct_change = 0.0 join_info = self._find_fact_join(dim_table) if dim_table else None if join_info: fact_table, fk_col, _ = join_info dim_cols = self.schema_tables.get(dim_table, []) dim_pk = next((c for c in dim_cols if c.endswith('_ID')), fk_col) sql = f"""UPDATE {self.database}."{self.schema}".{fact_table} SET {metric_column} = {metric_column} * {multiplier:.6f} WHERE {fk_col} IN ( SELECT {dim_pk} FROM {self.database}."{self.schema}".{dim_table} WHERE LOWER({name_col}) = LOWER('{matched}') )""" elif dim_table: sql = f"""UPDATE {self.database}."{self.schema}".{dim_table} SET {metric_column} = {metric_column} * {multiplier:.6f} WHERE LOWER({name_col}) = LOWER('{matched}')""" else: sql = f"-- Could not determine table structure for '{entity_value}'" return { 'id': 'A', 'name': 'Scale All Transactions', 'description': f"Multiply all rows for '{matched}' by {multiplier:.3f}x ({pct_change:+.1f}%)", 'sql': sql, 'matched_entity': matched, 'target_value': target_value, } def present_smart_confirmation(self, match: Dict, current_value: float, strategy: Dict, metric_column: str) -> str: """Format a human-readable confirmation message.""" entity = match.get('entity_value', '?') matched = strategy.get('matched_entity', entity) target = strategy.get('target_value', 0) or 0 if matched.lower() != entity.lower(): entity_display = f"{entity} → **{matched}**" else: entity_display = f"**{matched}**" change = target - current_value pct = (change / current_value * 100) if current_value else 0 lines = [ f"**Liveboard:** {self.liveboard_name}", f"**Entity:** {entity_display}", f"**Metric:** `{metric_column}`", f"**Current:** {current_value:,.0f}", f"**Target:** {target:,.0f} ({change:+,.0f} / {pct:+.1f}%)", f"**Strategy:** {strategy['description']}", "", f"```sql\n{strategy['sql']}\n```", ] if match.get('confidence') == 'low': lines.append("\n⚠️ Low confidence — please verify before confirming.") return "\n".join(lines) # ------------------------------------------------------------------ # SQL execution # ------------------------------------------------------------------ def execute_sql(self, sql: str) -> Dict: """Execute an UPDATE statement. Returns success/error dict.""" cursor = self.conn.cursor() try: cursor.execute(sql) rows_affected = cursor.rowcount self.conn.commit() return {'success': True, 'rows_affected': rows_affected} except Exception as e: try: self.conn.rollback() except Exception: pass return {'success': False, 'error': str(e)} # ------------------------------------------------------------------ # Teardown # ------------------------------------------------------------------ def close(self): """Close Snowflake connection.""" if self.conn: try: self.conn.close() except Exception: pass # --------------------------------------------------------------------------- # Liveboard-first context loader # --------------------------------------------------------------------------- def load_context_from_liveboard(liveboard_guid: str, ts_client) -> dict: """ Resolve Snowflake database/schema from a liveboard GUID. Flow: liveboard TML (export_fqn=True) → model GUID from visualizations[n].answer.tables[0].fqn → model TML → database / schema from model.tables[0].table.{db, schema} Args: liveboard_guid: ThoughtSpot liveboard GUID ts_client: Authenticated ThoughtSpotDeployer instance Returns: dict with keys: liveboard_name, model_guid, model_name, database, schema Raises: ValueError if any step fails to resolve. """ import yaml # Step 1: Export liveboard TML with FQNs response = ts_client.session.post( f"{ts_client.base_url}/api/rest/2.0/metadata/tml/export", json={ "metadata": [{"identifier": liveboard_guid}], "export_associated": False, "export_fqn": True, "format_type": "YAML", } ) if response.status_code != 200: raise ValueError( f"Failed to export liveboard TML ({response.status_code}): {response.text[:300]}" ) tml_data = response.json() if not tml_data: raise ValueError("Empty response from liveboard TML export") lb_tml = yaml.safe_load(tml_data[0]['edoc']) liveboard_name = lb_tml.get('liveboard', {}).get('name', 'Unknown Liveboard') # Step 2: Find model GUID from first visualization with answer.tables[].fqn model_guid = None for viz in lb_tml.get('liveboard', {}).get('visualizations', []): for t in viz.get('answer', {}).get('tables', []): fqn = t.get('fqn') if fqn: model_guid = fqn break if model_guid: break if not model_guid: raise ValueError( "Could not find model GUID in liveboard TML — " "make sure the liveboard has at least one answer-based visualization." ) # Step 3: Export model TML to get database/schema response = ts_client.session.post( f"{ts_client.base_url}/api/rest/2.0/metadata/tml/export", json={ "metadata": [{"identifier": model_guid, "type": "LOGICAL_TABLE"}], "export_associated": False, "export_fqn": True, "format_type": "YAML", } ) if response.status_code != 200: raise ValueError( f"Failed to export model TML ({response.status_code}): {response.text[:300]}" ) tml_data = response.json() model_tml = yaml.safe_load(tml_data[0]['edoc']) model_name = model_tml.get('model', {}).get('name', 'Unknown Model') # Step 4: Extract db/schema from first model table entry tables = model_tml.get('model', {}).get('tables', []) if not tables: raise ValueError("No tables found in model TML") first_table = tables[0].get('table', {}) database = first_table.get('db') schema = first_table.get('schema') if not database or not schema: raise ValueError( f"Could not resolve database/schema from model TML " f"(db={database!r}, schema={schema!r})" ) return { 'liveboard_name': liveboard_name, 'model_guid': model_guid, 'model_name': model_name, 'database': database, 'schema': schema, }