Spaces:
Running
Running
| """ | |
| Smart Data Adjuster | |
| Understands liveboard and schema context; handles conversational, multi-turn | |
| data adjustment requests in natural language. | |
| Connects to: | |
| - ThoughtSpot (to load liveboard viz context) | |
| - Snowflake (to query and update data) | |
| Works with any configured LLM (Claude, GPT-4, etc.) via litellm. | |
| Schema is discovered dynamically β no hardcoded table names. | |
| """ | |
| from typing import Dict, List, Optional, Tuple | |
| from snowflake_auth import get_snowflake_connection | |
| from thoughtspot_deployer import ThoughtSpotDeployer | |
| import json | |
| import re | |
| from llm_config import resolve_model_name | |
| class SmartDataAdjuster: | |
| """ | |
| Conversational data adjuster with liveboard context and schema discovery. | |
| Usage: | |
| adjuster = SmartDataAdjuster(database, schema, liveboard_guid, llm_model) | |
| adjuster.connect() | |
| adjuster.load_liveboard_context() | |
| # Per user message: | |
| result = adjuster.handle_message("make webcam revenue 40B") | |
| # result: {'type': 'confirmation', 'text': '...', 'pending': {...}} | |
| # or {'type': 'result', 'text': '...'} | |
| # or {'type': 'error', 'text': '...'} | |
| """ | |
| def __init__(self, database: str, schema: str, liveboard_guid: str, | |
| llm_model: str = None, ts_url: str = None, ts_secret: str = None, | |
| prompt_logger=None): | |
| self.database = database | |
| self.schema = schema | |
| self.liveboard_guid = liveboard_guid | |
| self.ts_url = (ts_url or "").strip() or None | |
| self.ts_secret = (ts_secret or "").strip() or None | |
| self.llm_model = (llm_model or "").strip() | |
| if not self.llm_model: | |
| raise ValueError("SmartDataAdjuster requires llm_model from settings.") | |
| self._prompt_logger = prompt_logger | |
| self.conn = None | |
| self.ts_client = None | |
| # Populated by load_liveboard_context() | |
| self.liveboard_name: Optional[str] = None | |
| self.visualizations: List[Dict] = [] | |
| # Populated by _discover_schema() in connect() | |
| self.schema_tables: Dict[str, List[str]] = {} # table β [col, ...] | |
| self.fact_tables: List[str] = [] # tables likely to be updated | |
| self.dimension_tables: Dict[str, str] = {} # table β name_column | |
| # ------------------------------------------------------------------ | |
| # Connection & schema discovery | |
| # ------------------------------------------------------------------ | |
| def connect(self): | |
| """Connect to Snowflake and ThoughtSpot, then discover schema.""" | |
| self.conn = get_snowflake_connection() | |
| cursor = self.conn.cursor() | |
| cursor.execute(f"USE DATABASE {self.database}") | |
| cursor.execute(f'USE SCHEMA "{self.schema}"') | |
| self.ts_client = ThoughtSpotDeployer( | |
| base_url=self.ts_url or None, | |
| secret_key=self.ts_secret or None, | |
| ) | |
| self.ts_client.authenticate() | |
| self._discover_schema() | |
| def _discover_schema(self): | |
| """Read actual table/column structure from INFORMATION_SCHEMA.""" | |
| cursor = self.conn.cursor() | |
| cursor.execute(f""" | |
| SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE | |
| FROM {self.database}.INFORMATION_SCHEMA.COLUMNS | |
| WHERE TABLE_SCHEMA = '{self.schema}' | |
| ORDER BY TABLE_NAME, ORDINAL_POSITION | |
| """) | |
| raw: Dict[str, List[Dict]] = {} | |
| for table, column, dtype in cursor.fetchall(): | |
| raw.setdefault(table, []).append({'name': column, 'type': dtype.upper()}) | |
| self.schema_tables = {t: [c['name'] for c in cols] for t, cols in raw.items()} | |
| # Heuristic: dimension tables have a _NAME column; fact tables have date + numeric cols | |
| for table, cols in raw.items(): | |
| col_names = [c['name'] for c in cols] | |
| col_types = {c['name']: c['type'] for c in cols} | |
| name_cols = [c for c in col_names if c.endswith('_NAME')] | |
| num_cols = [c for c in col_names | |
| if any(t in col_types.get(c, '') for t in ('NUMBER', 'FLOAT', 'INT', 'DECIMAL', 'NUMERIC'))] | |
| date_cols = [c for c in col_names | |
| if any(t in col_types.get(c, '') for t in ('DATE', 'TIME', 'TIMESTAMP'))] | |
| if name_cols: | |
| # Use the first _NAME column as the entity name column | |
| self.dimension_tables[table] = name_cols[0] | |
| if num_cols and date_cols: | |
| self.fact_tables.append(table) | |
| # If nothing looks like a fact table, fall back to largest table | |
| if not self.fact_tables and self.schema_tables: | |
| self.fact_tables = list(self.schema_tables.keys()) | |
| def _call_llm(self, prompt: str) -> str: | |
| """Call the configured LLM via litellm (supports all providers).""" | |
| from prompt_logger import logged_completion | |
| model = resolve_model_name(self.llm_model) | |
| response = logged_completion( | |
| stage="data_adjuster", | |
| logger=self._prompt_logger, | |
| model=model, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0, | |
| max_tokens=1000, | |
| ) | |
| return response.choices[0].message.content.strip() | |
| # ------------------------------------------------------------------ | |
| # Liveboard context | |
| # ------------------------------------------------------------------ | |
| def load_liveboard_context(self) -> bool: | |
| """Load liveboard metadata and visualization list from ThoughtSpot.""" | |
| response = self.ts_client.session.post( | |
| f"{self.ts_client.base_url}/api/rest/2.0/metadata/search", | |
| json={ | |
| "metadata": [{"type": "LIVEBOARD", "identifier": self.liveboard_guid}], | |
| "include_visualization_headers": True | |
| } | |
| ) | |
| if response.status_code != 200: | |
| return False | |
| data = response.json()[0] | |
| self.liveboard_name = data.get('metadata_name', 'Unknown Liveboard') | |
| for viz in data.get('visualization_headers', []): | |
| name = viz.get('name', '') | |
| if 'note-tile' in name.lower(): | |
| continue | |
| self.visualizations.append({'id': viz.get('id'), 'name': name}) | |
| return bool(self.visualizations) | |
| # ------------------------------------------------------------------ | |
| # Entity matching (schema-aware) | |
| # ------------------------------------------------------------------ | |
| def _fuzzy_match(self, target: str, candidates: List[str]) -> Optional[str]: | |
| """Return the best matching candidate for target, or None.""" | |
| def norm(s): | |
| return s.lower().replace(' ', '').replace('-', '').replace('_', '') | |
| t = norm(target) | |
| t_lower = target.lower() | |
| for c in candidates: | |
| if c.lower() == t_lower: | |
| return c | |
| for c in candidates: | |
| if norm(c) == t: | |
| return c | |
| for c in candidates: | |
| if t_lower in c.lower() or c.lower() in t_lower: | |
| return c | |
| for c in candidates: | |
| if t in norm(c) or norm(c) in t: | |
| return c | |
| return None | |
| def _find_entity(self, entity_value: str, entity_type_hint: str = None) -> Tuple[Optional[str], Optional[str], Optional[str]]: | |
| """ | |
| Find the closest matching entity name in any dimension table. | |
| Returns: (matched_name, table_name, name_column) or (None, None, None) | |
| """ | |
| cursor = self.conn.cursor() | |
| # Sort dimension tables: prefer ones whose name matches entity_type_hint | |
| tables_to_try = list(self.dimension_tables.items()) | |
| if entity_type_hint: | |
| hint = entity_type_hint.lower() | |
| tables_to_try.sort(key=lambda x: 0 if hint in x[0].lower() else 1) | |
| for table, name_col in tables_to_try: | |
| cursor.execute(f'SELECT DISTINCT {name_col} FROM {self.database}."{self.schema}".{table}') | |
| candidates = [row[0] for row in cursor.fetchall() if row[0]] | |
| match = self._fuzzy_match(entity_value, candidates) | |
| if match: | |
| return match, table, name_col | |
| return None, None, None | |
| def _find_fact_join(self, dim_table: str) -> Optional[Tuple[str, str, str]]: | |
| """ | |
| Find a fact table that has an FK column referencing dim_table. | |
| Returns: (fact_table, fact_fk_column, dim_pk_column) or None. | |
| """ | |
| # Look for an ID column in dim_table | |
| dim_cols = self.schema_tables.get(dim_table, []) | |
| dim_id = next((c for c in dim_cols if c.endswith('_ID')), None) | |
| if not dim_id: | |
| return None | |
| # Look for that column in fact tables | |
| for ft in self.fact_tables: | |
| if ft == dim_table: | |
| continue | |
| ft_cols = self.schema_tables.get(ft, []) | |
| if dim_id in ft_cols: | |
| return ft, dim_id, dim_id | |
| return None | |
| # ------------------------------------------------------------------ | |
| # Request interpretation | |
| # ------------------------------------------------------------------ | |
| def _parse_request_simple(self, message: str) -> Optional[Dict]: | |
| """ | |
| Fast regex-based parser for common patterns: | |
| "decrease webcam by 10%", "make laptop 50B", "increase revenue for acme by 20%" | |
| Returns parsed dict or None if pattern not matched. | |
| """ | |
| msg_lower = message.lower() | |
| # Percentage match: "by 20%", "-10%" | |
| pct_match = re.search(r'by\s+(-?\d+\.?\d*)%|(-?\d+\.?\d*)%', msg_lower) | |
| # Absolute value: "50B", "50 billion", "1.5M", "1000000" | |
| val_match = re.search(r'(\d+\.?\d*)\s*(b(?:illion)?|m(?:illion)?|k(?:thousand)?)\b', msg_lower) | |
| bare_num = re.search(r'\b(\d{4,})\b', message) # bare large integer | |
| is_percentage = bool(pct_match) | |
| percentage = None | |
| target_value = None | |
| if pct_match: | |
| raw_pct = float(pct_match.group(1) or pct_match.group(2)) | |
| if any(w in msg_lower for w in ('decrease', 'reduce', 'lower', 'drop', 'cut')): | |
| raw_pct = -abs(raw_pct) | |
| percentage = raw_pct | |
| elif val_match: | |
| num = float(val_match.group(1)) | |
| unit = val_match.group(2)[0].lower() | |
| multipliers = {'b': 1e9, 'm': 1e6, 'k': 1e3} | |
| target_value = num * multipliers.get(unit, 1) | |
| elif bare_num: | |
| target_value = float(bare_num.group(1)) | |
| else: | |
| return None | |
| # Extract entity name: quoted or after action verb | |
| entity = None | |
| quoted = re.search(r'"([^"]+)"', message) | |
| if quoted: | |
| entity = quoted.group(1).strip() | |
| else: | |
| # "make/set/increase/decrease/adjust <entity> [by/to]" | |
| action_pat = r'(?:make|set|increase|decrease|reduce|boost|lower|adjust|change)\s+(?:the\s+)?(?:\w+\s+(?:for|of)\s+)?([a-z0-9][\w\s-]*?)(?:\s+(?:by|to|revenue|sales|at)\b|\s+\d|$)' | |
| am = re.search(action_pat, msg_lower, re.I) | |
| if am: | |
| entity = am.group(1).strip() | |
| if not entity: | |
| return None | |
| # Detect entity type from keywords | |
| entity_type = None | |
| for kw in ('seller', 'vendor', 'customer', 'product', 'item', 'region', 'store'): | |
| if kw in msg_lower: | |
| entity_type = kw | |
| break | |
| return { | |
| 'entity_value': entity, | |
| 'entity_type': entity_type, | |
| 'is_percentage': is_percentage, | |
| 'percentage': percentage, | |
| 'target_value': target_value, | |
| 'confidence': 'medium', | |
| } | |
| def match_request_to_viz(self, user_request: str) -> Optional[Dict]: | |
| """ | |
| Parse request and enrich with schema context. | |
| Returns structured match dict or None. | |
| """ | |
| result = self._parse_request_simple(user_request) | |
| if not result: | |
| # Fall back to LLM for complex requests | |
| schema_summary = "\n".join( | |
| f" {t}: {', '.join(cols[:8])}" + (' ...' if len(cols) > 8 else '') | |
| for t, cols in self.schema_tables.items() | |
| ) | |
| viz_summary = "\n".join(f" {i+1}. {v['name']}" for i, v in enumerate(self.visualizations)) | |
| prompt = f"""Parse this data adjustment request. | |
| Request: "{user_request}" | |
| Snowflake schema tables: | |
| {schema_summary} | |
| Liveboard visualizations: | |
| {viz_summary} | |
| Return JSON with these fields (numbers only, not strings): | |
| {{ | |
| "entity_value": "the entity name to adjust (e.g. '1080p Webcam')", | |
| "entity_type": "product|seller|customer|region|null", | |
| "is_percentage": true|false, | |
| "percentage": <number or null>, | |
| "target_value": <number or null>, | |
| "metric_hint": "keyword like 'revenue', 'sales', 'profit_margin', or null", | |
| "confidence": "high|medium|low" | |
| }}""" | |
| try: | |
| raw = self._call_llm(prompt) | |
| if raw.startswith('```'): | |
| raw = '\n'.join(raw.split('\n')[1:-1]) | |
| result = json.loads(raw) | |
| except Exception: | |
| return None | |
| return result if result else None | |
| # ------------------------------------------------------------------ | |
| # Value retrieval & SQL generation | |
| # ------------------------------------------------------------------ | |
| def get_current_value(self, entity_value: str, metric_column: str, | |
| entity_type: str = None) -> Tuple[float, Optional[str], Optional[str], Optional[str]]: | |
| """ | |
| Query current aggregate value for entity from Snowflake. | |
| Returns: (current_value, matched_entity, dim_table, fact_table) | |
| """ | |
| matched, dim_table, name_col = self._find_entity(entity_value, entity_type) | |
| if not matched: | |
| return 0.0, None, None, None | |
| join_info = self._find_fact_join(dim_table) if dim_table else None | |
| cursor = self.conn.cursor() | |
| if join_info: | |
| fact_table, fk_col, dim_pk_col = join_info | |
| dim_cols = self.schema_tables.get(dim_table, []) | |
| dim_pk = next((c for c in dim_cols if c.endswith('_ID')), dim_pk_col) | |
| query = f""" | |
| SELECT SUM(f.{metric_column}) | |
| FROM {self.database}."{self.schema}".{fact_table} f | |
| JOIN {self.database}."{self.schema}".{dim_table} d | |
| ON f.{fk_col} = d.{dim_pk} | |
| WHERE LOWER(d.{name_col}) = LOWER('{matched}') | |
| """ | |
| else: | |
| # entity is directly in the table with the metric | |
| query = f""" | |
| SELECT SUM({metric_column}) | |
| FROM {self.database}."{self.schema}".{dim_table} | |
| WHERE LOWER({name_col}) = LOWER('{matched}') | |
| """ | |
| try: | |
| cursor.execute(query) | |
| row = cursor.fetchone() | |
| value = float(row[0]) if row and row[0] is not None else 0.0 | |
| fact_table_used = join_info[0] if join_info else dim_table | |
| return value, matched, dim_table, fact_table_used | |
| except Exception as e: | |
| print(f"[SmartDataAdjuster] get_current_value query failed: {e}") | |
| return 0.0, None, None, None | |
| def _pick_metric_column(self, metric_hint: str = None) -> Optional[str]: | |
| """Choose the best metric column from fact tables based on hint.""" | |
| # Build a list of all numeric-looking columns across fact tables | |
| candidates = [] | |
| for ft in self.fact_tables: | |
| for col in self.schema_tables.get(ft, []): | |
| if any(kw in col.upper() for kw in ('AMOUNT', 'REVENUE', 'TOTAL', 'SALES', 'VALUE', 'MARGIN', 'PROFIT', 'COST', 'PRICE')): | |
| candidates.append((ft, col)) | |
| if not candidates: | |
| return None | |
| if metric_hint: | |
| hint_upper = metric_hint.upper() | |
| for ft, col in candidates: | |
| if hint_upper in col: | |
| return col | |
| # Default: prefer TOTAL_AMOUNT, REVENUE, then first available | |
| for preferred in ('TOTAL_AMOUNT', 'TOTAL_REVENUE', 'REVENUE', 'AMOUNT'): | |
| for ft, col in candidates: | |
| if col == preferred: | |
| return col | |
| return candidates[0][1] if candidates else None | |
| def generate_strategy(self, entity_value: str, metric_column: str, | |
| current_value: float, target_value: float = None, | |
| percentage: float = None, entity_type: str = None) -> Dict: | |
| """Generate an UPDATE strategy based on the adjustment request.""" | |
| matched, dim_table, name_col = self._find_entity(entity_value, entity_type) | |
| if not matched: | |
| matched = entity_value | |
| if percentage is not None: | |
| multiplier = 1 + (percentage / 100) | |
| pct_change = percentage | |
| if target_value is None: | |
| target_value = current_value * multiplier | |
| elif target_value and current_value > 0: | |
| multiplier = target_value / current_value | |
| pct_change = (multiplier - 1) * 100 | |
| else: | |
| multiplier = 1.0 | |
| pct_change = 0.0 | |
| join_info = self._find_fact_join(dim_table) if dim_table else None | |
| if join_info: | |
| fact_table, fk_col, _ = join_info | |
| dim_cols = self.schema_tables.get(dim_table, []) | |
| dim_pk = next((c for c in dim_cols if c.endswith('_ID')), fk_col) | |
| sql = f"""UPDATE {self.database}."{self.schema}".{fact_table} | |
| SET {metric_column} = {metric_column} * {multiplier:.6f} | |
| WHERE {fk_col} IN ( | |
| SELECT {dim_pk} | |
| FROM {self.database}."{self.schema}".{dim_table} | |
| WHERE LOWER({name_col}) = LOWER('{matched}') | |
| )""" | |
| elif dim_table: | |
| sql = f"""UPDATE {self.database}."{self.schema}".{dim_table} | |
| SET {metric_column} = {metric_column} * {multiplier:.6f} | |
| WHERE LOWER({name_col}) = LOWER('{matched}')""" | |
| else: | |
| sql = f"-- Could not determine table structure for '{entity_value}'" | |
| return { | |
| 'id': 'A', | |
| 'name': 'Scale All Transactions', | |
| 'description': f"Multiply all rows for '{matched}' by {multiplier:.3f}x ({pct_change:+.1f}%)", | |
| 'sql': sql, | |
| 'matched_entity': matched, | |
| 'target_value': target_value, | |
| } | |
| def present_smart_confirmation(self, match: Dict, current_value: float, | |
| strategy: Dict, metric_column: str) -> str: | |
| """Format a human-readable confirmation message.""" | |
| entity = match.get('entity_value', '?') | |
| matched = strategy.get('matched_entity', entity) | |
| target = strategy.get('target_value', 0) or 0 | |
| if matched.lower() != entity.lower(): | |
| entity_display = f"{entity} β **{matched}**" | |
| else: | |
| entity_display = f"**{matched}**" | |
| change = target - current_value | |
| pct = (change / current_value * 100) if current_value else 0 | |
| lines = [ | |
| f"**Liveboard:** {self.liveboard_name}", | |
| f"**Entity:** {entity_display}", | |
| f"**Metric:** `{metric_column}`", | |
| f"**Current:** {current_value:,.0f}", | |
| f"**Target:** {target:,.0f} ({change:+,.0f} / {pct:+.1f}%)", | |
| f"**Strategy:** {strategy['description']}", | |
| "", | |
| f"```sql\n{strategy['sql']}\n```", | |
| ] | |
| if match.get('confidence') == 'low': | |
| lines.append("\nβ οΈ Low confidence β please verify before confirming.") | |
| return "\n".join(lines) | |
| # ------------------------------------------------------------------ | |
| # SQL execution | |
| # ------------------------------------------------------------------ | |
| def execute_sql(self, sql: str) -> Dict: | |
| """Execute an UPDATE statement. Returns success/error dict.""" | |
| cursor = self.conn.cursor() | |
| try: | |
| cursor.execute(sql) | |
| rows_affected = cursor.rowcount | |
| self.conn.commit() | |
| return {'success': True, 'rows_affected': rows_affected} | |
| except Exception as e: | |
| try: | |
| self.conn.rollback() | |
| except Exception: | |
| pass | |
| return {'success': False, 'error': str(e)} | |
| # ------------------------------------------------------------------ | |
| # Teardown | |
| # ------------------------------------------------------------------ | |
| def close(self): | |
| """Close Snowflake connection.""" | |
| if self.conn: | |
| try: | |
| self.conn.close() | |
| except Exception: | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # Liveboard-first context loader | |
| # --------------------------------------------------------------------------- | |
| def load_context_from_liveboard(liveboard_guid: str, ts_client) -> dict: | |
| """ | |
| Resolve Snowflake database/schema from a liveboard GUID. | |
| Flow: | |
| liveboard TML (export_fqn=True) | |
| β model GUID from visualizations[n].answer.tables[0].fqn | |
| β model TML | |
| β database / schema from model.tables[0].table.{db, schema} | |
| Args: | |
| liveboard_guid: ThoughtSpot liveboard GUID | |
| ts_client: Authenticated ThoughtSpotDeployer instance | |
| Returns: | |
| dict with keys: liveboard_name, model_guid, model_name, database, schema | |
| Raises: | |
| ValueError if any step fails to resolve. | |
| """ | |
| import yaml | |
| # Step 1: Export liveboard TML with FQNs | |
| response = ts_client.session.post( | |
| f"{ts_client.base_url}/api/rest/2.0/metadata/tml/export", | |
| json={ | |
| "metadata": [{"identifier": liveboard_guid}], | |
| "export_associated": False, | |
| "export_fqn": True, | |
| "format_type": "YAML", | |
| } | |
| ) | |
| if response.status_code != 200: | |
| raise ValueError( | |
| f"Failed to export liveboard TML ({response.status_code}): {response.text[:300]}" | |
| ) | |
| tml_data = response.json() | |
| if not tml_data: | |
| raise ValueError("Empty response from liveboard TML export") | |
| lb_tml = yaml.safe_load(tml_data[0]['edoc']) | |
| liveboard_name = lb_tml.get('liveboard', {}).get('name', 'Unknown Liveboard') | |
| # Step 2: Find model GUID from first visualization with answer.tables[].fqn | |
| model_guid = None | |
| for viz in lb_tml.get('liveboard', {}).get('visualizations', []): | |
| for t in viz.get('answer', {}).get('tables', []): | |
| fqn = t.get('fqn') | |
| if fqn: | |
| model_guid = fqn | |
| break | |
| if model_guid: | |
| break | |
| if not model_guid: | |
| raise ValueError( | |
| "Could not find model GUID in liveboard TML β " | |
| "make sure the liveboard has at least one answer-based visualization." | |
| ) | |
| # Step 3: Export model TML to get database/schema | |
| response = ts_client.session.post( | |
| f"{ts_client.base_url}/api/rest/2.0/metadata/tml/export", | |
| json={ | |
| "metadata": [{"identifier": model_guid, "type": "LOGICAL_TABLE"}], | |
| "export_associated": False, | |
| "export_fqn": True, | |
| "format_type": "YAML", | |
| } | |
| ) | |
| if response.status_code != 200: | |
| raise ValueError( | |
| f"Failed to export model TML ({response.status_code}): {response.text[:300]}" | |
| ) | |
| tml_data = response.json() | |
| model_tml = yaml.safe_load(tml_data[0]['edoc']) | |
| model_name = model_tml.get('model', {}).get('name', 'Unknown Model') | |
| # Step 4: Extract db/schema from first model table entry | |
| tables = model_tml.get('model', {}).get('tables', []) | |
| if not tables: | |
| raise ValueError("No tables found in model TML") | |
| first_table = tables[0].get('table', {}) | |
| database = first_table.get('db') | |
| schema = first_table.get('schema') | |
| if not database or not schema: | |
| raise ValueError( | |
| f"Could not resolve database/schema from model TML " | |
| f"(db={database!r}, schema={schema!r})" | |
| ) | |
| return { | |
| 'liveboard_name': liveboard_name, | |
| 'model_guid': model_guid, | |
| 'model_name': model_name, | |
| 'database': database, | |
| 'schema': schema, | |
| } | |