demoprep / smart_data_adjuster.py
mikeboone's picture
fix: remove prompt_logger global singleton β€” per-session logger throughout
01e294d
"""
Smart Data Adjuster
Understands liveboard and schema context; handles conversational, multi-turn
data adjustment requests in natural language.
Connects to:
- ThoughtSpot (to load liveboard viz context)
- Snowflake (to query and update data)
Works with any configured LLM (Claude, GPT-4, etc.) via litellm.
Schema is discovered dynamically β€” no hardcoded table names.
"""
from typing import Dict, List, Optional, Tuple
from snowflake_auth import get_snowflake_connection
from thoughtspot_deployer import ThoughtSpotDeployer
import json
import re
from llm_config import resolve_model_name
class SmartDataAdjuster:
"""
Conversational data adjuster with liveboard context and schema discovery.
Usage:
adjuster = SmartDataAdjuster(database, schema, liveboard_guid, llm_model)
adjuster.connect()
adjuster.load_liveboard_context()
# Per user message:
result = adjuster.handle_message("make webcam revenue 40B")
# result: {'type': 'confirmation', 'text': '...', 'pending': {...}}
# or {'type': 'result', 'text': '...'}
# or {'type': 'error', 'text': '...'}
"""
def __init__(self, database: str, schema: str, liveboard_guid: str,
llm_model: str = None, ts_url: str = None, ts_secret: str = None,
prompt_logger=None):
self.database = database
self.schema = schema
self.liveboard_guid = liveboard_guid
self.ts_url = (ts_url or "").strip() or None
self.ts_secret = (ts_secret or "").strip() or None
self.llm_model = (llm_model or "").strip()
if not self.llm_model:
raise ValueError("SmartDataAdjuster requires llm_model from settings.")
self._prompt_logger = prompt_logger
self.conn = None
self.ts_client = None
# Populated by load_liveboard_context()
self.liveboard_name: Optional[str] = None
self.visualizations: List[Dict] = []
# Populated by _discover_schema() in connect()
self.schema_tables: Dict[str, List[str]] = {} # table β†’ [col, ...]
self.fact_tables: List[str] = [] # tables likely to be updated
self.dimension_tables: Dict[str, str] = {} # table β†’ name_column
# ------------------------------------------------------------------
# Connection & schema discovery
# ------------------------------------------------------------------
def connect(self):
"""Connect to Snowflake and ThoughtSpot, then discover schema."""
self.conn = get_snowflake_connection()
cursor = self.conn.cursor()
cursor.execute(f"USE DATABASE {self.database}")
cursor.execute(f'USE SCHEMA "{self.schema}"')
self.ts_client = ThoughtSpotDeployer(
base_url=self.ts_url or None,
secret_key=self.ts_secret or None,
)
self.ts_client.authenticate()
self._discover_schema()
def _discover_schema(self):
"""Read actual table/column structure from INFORMATION_SCHEMA."""
cursor = self.conn.cursor()
cursor.execute(f"""
SELECT TABLE_NAME, COLUMN_NAME, DATA_TYPE
FROM {self.database}.INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = '{self.schema}'
ORDER BY TABLE_NAME, ORDINAL_POSITION
""")
raw: Dict[str, List[Dict]] = {}
for table, column, dtype in cursor.fetchall():
raw.setdefault(table, []).append({'name': column, 'type': dtype.upper()})
self.schema_tables = {t: [c['name'] for c in cols] for t, cols in raw.items()}
# Heuristic: dimension tables have a _NAME column; fact tables have date + numeric cols
for table, cols in raw.items():
col_names = [c['name'] for c in cols]
col_types = {c['name']: c['type'] for c in cols}
name_cols = [c for c in col_names if c.endswith('_NAME')]
num_cols = [c for c in col_names
if any(t in col_types.get(c, '') for t in ('NUMBER', 'FLOAT', 'INT', 'DECIMAL', 'NUMERIC'))]
date_cols = [c for c in col_names
if any(t in col_types.get(c, '') for t in ('DATE', 'TIME', 'TIMESTAMP'))]
if name_cols:
# Use the first _NAME column as the entity name column
self.dimension_tables[table] = name_cols[0]
if num_cols and date_cols:
self.fact_tables.append(table)
# If nothing looks like a fact table, fall back to largest table
if not self.fact_tables and self.schema_tables:
self.fact_tables = list(self.schema_tables.keys())
def _call_llm(self, prompt: str) -> str:
"""Call the configured LLM via litellm (supports all providers)."""
from prompt_logger import logged_completion
model = resolve_model_name(self.llm_model)
response = logged_completion(
stage="data_adjuster",
logger=self._prompt_logger,
model=model,
messages=[{"role": "user", "content": prompt}],
temperature=0,
max_tokens=1000,
)
return response.choices[0].message.content.strip()
# ------------------------------------------------------------------
# Liveboard context
# ------------------------------------------------------------------
def load_liveboard_context(self) -> bool:
"""Load liveboard metadata and visualization list from ThoughtSpot."""
response = self.ts_client.session.post(
f"{self.ts_client.base_url}/api/rest/2.0/metadata/search",
json={
"metadata": [{"type": "LIVEBOARD", "identifier": self.liveboard_guid}],
"include_visualization_headers": True
}
)
if response.status_code != 200:
return False
data = response.json()[0]
self.liveboard_name = data.get('metadata_name', 'Unknown Liveboard')
for viz in data.get('visualization_headers', []):
name = viz.get('name', '')
if 'note-tile' in name.lower():
continue
self.visualizations.append({'id': viz.get('id'), 'name': name})
return bool(self.visualizations)
# ------------------------------------------------------------------
# Entity matching (schema-aware)
# ------------------------------------------------------------------
def _fuzzy_match(self, target: str, candidates: List[str]) -> Optional[str]:
"""Return the best matching candidate for target, or None."""
def norm(s):
return s.lower().replace(' ', '').replace('-', '').replace('_', '')
t = norm(target)
t_lower = target.lower()
for c in candidates:
if c.lower() == t_lower:
return c
for c in candidates:
if norm(c) == t:
return c
for c in candidates:
if t_lower in c.lower() or c.lower() in t_lower:
return c
for c in candidates:
if t in norm(c) or norm(c) in t:
return c
return None
def _find_entity(self, entity_value: str, entity_type_hint: str = None) -> Tuple[Optional[str], Optional[str], Optional[str]]:
"""
Find the closest matching entity name in any dimension table.
Returns: (matched_name, table_name, name_column) or (None, None, None)
"""
cursor = self.conn.cursor()
# Sort dimension tables: prefer ones whose name matches entity_type_hint
tables_to_try = list(self.dimension_tables.items())
if entity_type_hint:
hint = entity_type_hint.lower()
tables_to_try.sort(key=lambda x: 0 if hint in x[0].lower() else 1)
for table, name_col in tables_to_try:
cursor.execute(f'SELECT DISTINCT {name_col} FROM {self.database}."{self.schema}".{table}')
candidates = [row[0] for row in cursor.fetchall() if row[0]]
match = self._fuzzy_match(entity_value, candidates)
if match:
return match, table, name_col
return None, None, None
def _find_fact_join(self, dim_table: str) -> Optional[Tuple[str, str, str]]:
"""
Find a fact table that has an FK column referencing dim_table.
Returns: (fact_table, fact_fk_column, dim_pk_column) or None.
"""
# Look for an ID column in dim_table
dim_cols = self.schema_tables.get(dim_table, [])
dim_id = next((c for c in dim_cols if c.endswith('_ID')), None)
if not dim_id:
return None
# Look for that column in fact tables
for ft in self.fact_tables:
if ft == dim_table:
continue
ft_cols = self.schema_tables.get(ft, [])
if dim_id in ft_cols:
return ft, dim_id, dim_id
return None
# ------------------------------------------------------------------
# Request interpretation
# ------------------------------------------------------------------
def _parse_request_simple(self, message: str) -> Optional[Dict]:
"""
Fast regex-based parser for common patterns:
"decrease webcam by 10%", "make laptop 50B", "increase revenue for acme by 20%"
Returns parsed dict or None if pattern not matched.
"""
msg_lower = message.lower()
# Percentage match: "by 20%", "-10%"
pct_match = re.search(r'by\s+(-?\d+\.?\d*)%|(-?\d+\.?\d*)%', msg_lower)
# Absolute value: "50B", "50 billion", "1.5M", "1000000"
val_match = re.search(r'(\d+\.?\d*)\s*(b(?:illion)?|m(?:illion)?|k(?:thousand)?)\b', msg_lower)
bare_num = re.search(r'\b(\d{4,})\b', message) # bare large integer
is_percentage = bool(pct_match)
percentage = None
target_value = None
if pct_match:
raw_pct = float(pct_match.group(1) or pct_match.group(2))
if any(w in msg_lower for w in ('decrease', 'reduce', 'lower', 'drop', 'cut')):
raw_pct = -abs(raw_pct)
percentage = raw_pct
elif val_match:
num = float(val_match.group(1))
unit = val_match.group(2)[0].lower()
multipliers = {'b': 1e9, 'm': 1e6, 'k': 1e3}
target_value = num * multipliers.get(unit, 1)
elif bare_num:
target_value = float(bare_num.group(1))
else:
return None
# Extract entity name: quoted or after action verb
entity = None
quoted = re.search(r'"([^"]+)"', message)
if quoted:
entity = quoted.group(1).strip()
else:
# "make/set/increase/decrease/adjust <entity> [by/to]"
action_pat = r'(?:make|set|increase|decrease|reduce|boost|lower|adjust|change)\s+(?:the\s+)?(?:\w+\s+(?:for|of)\s+)?([a-z0-9][\w\s-]*?)(?:\s+(?:by|to|revenue|sales|at)\b|\s+\d|$)'
am = re.search(action_pat, msg_lower, re.I)
if am:
entity = am.group(1).strip()
if not entity:
return None
# Detect entity type from keywords
entity_type = None
for kw in ('seller', 'vendor', 'customer', 'product', 'item', 'region', 'store'):
if kw in msg_lower:
entity_type = kw
break
return {
'entity_value': entity,
'entity_type': entity_type,
'is_percentage': is_percentage,
'percentage': percentage,
'target_value': target_value,
'confidence': 'medium',
}
def match_request_to_viz(self, user_request: str) -> Optional[Dict]:
"""
Parse request and enrich with schema context.
Returns structured match dict or None.
"""
result = self._parse_request_simple(user_request)
if not result:
# Fall back to LLM for complex requests
schema_summary = "\n".join(
f" {t}: {', '.join(cols[:8])}" + (' ...' if len(cols) > 8 else '')
for t, cols in self.schema_tables.items()
)
viz_summary = "\n".join(f" {i+1}. {v['name']}" for i, v in enumerate(self.visualizations))
prompt = f"""Parse this data adjustment request.
Request: "{user_request}"
Snowflake schema tables:
{schema_summary}
Liveboard visualizations:
{viz_summary}
Return JSON with these fields (numbers only, not strings):
{{
"entity_value": "the entity name to adjust (e.g. '1080p Webcam')",
"entity_type": "product|seller|customer|region|null",
"is_percentage": true|false,
"percentage": <number or null>,
"target_value": <number or null>,
"metric_hint": "keyword like 'revenue', 'sales', 'profit_margin', or null",
"confidence": "high|medium|low"
}}"""
try:
raw = self._call_llm(prompt)
if raw.startswith('```'):
raw = '\n'.join(raw.split('\n')[1:-1])
result = json.loads(raw)
except Exception:
return None
return result if result else None
# ------------------------------------------------------------------
# Value retrieval & SQL generation
# ------------------------------------------------------------------
def get_current_value(self, entity_value: str, metric_column: str,
entity_type: str = None) -> Tuple[float, Optional[str], Optional[str], Optional[str]]:
"""
Query current aggregate value for entity from Snowflake.
Returns: (current_value, matched_entity, dim_table, fact_table)
"""
matched, dim_table, name_col = self._find_entity(entity_value, entity_type)
if not matched:
return 0.0, None, None, None
join_info = self._find_fact_join(dim_table) if dim_table else None
cursor = self.conn.cursor()
if join_info:
fact_table, fk_col, dim_pk_col = join_info
dim_cols = self.schema_tables.get(dim_table, [])
dim_pk = next((c for c in dim_cols if c.endswith('_ID')), dim_pk_col)
query = f"""
SELECT SUM(f.{metric_column})
FROM {self.database}."{self.schema}".{fact_table} f
JOIN {self.database}."{self.schema}".{dim_table} d
ON f.{fk_col} = d.{dim_pk}
WHERE LOWER(d.{name_col}) = LOWER('{matched}')
"""
else:
# entity is directly in the table with the metric
query = f"""
SELECT SUM({metric_column})
FROM {self.database}."{self.schema}".{dim_table}
WHERE LOWER({name_col}) = LOWER('{matched}')
"""
try:
cursor.execute(query)
row = cursor.fetchone()
value = float(row[0]) if row and row[0] is not None else 0.0
fact_table_used = join_info[0] if join_info else dim_table
return value, matched, dim_table, fact_table_used
except Exception as e:
print(f"[SmartDataAdjuster] get_current_value query failed: {e}")
return 0.0, None, None, None
def _pick_metric_column(self, metric_hint: str = None) -> Optional[str]:
"""Choose the best metric column from fact tables based on hint."""
# Build a list of all numeric-looking columns across fact tables
candidates = []
for ft in self.fact_tables:
for col in self.schema_tables.get(ft, []):
if any(kw in col.upper() for kw in ('AMOUNT', 'REVENUE', 'TOTAL', 'SALES', 'VALUE', 'MARGIN', 'PROFIT', 'COST', 'PRICE')):
candidates.append((ft, col))
if not candidates:
return None
if metric_hint:
hint_upper = metric_hint.upper()
for ft, col in candidates:
if hint_upper in col:
return col
# Default: prefer TOTAL_AMOUNT, REVENUE, then first available
for preferred in ('TOTAL_AMOUNT', 'TOTAL_REVENUE', 'REVENUE', 'AMOUNT'):
for ft, col in candidates:
if col == preferred:
return col
return candidates[0][1] if candidates else None
def generate_strategy(self, entity_value: str, metric_column: str,
current_value: float, target_value: float = None,
percentage: float = None, entity_type: str = None) -> Dict:
"""Generate an UPDATE strategy based on the adjustment request."""
matched, dim_table, name_col = self._find_entity(entity_value, entity_type)
if not matched:
matched = entity_value
if percentage is not None:
multiplier = 1 + (percentage / 100)
pct_change = percentage
if target_value is None:
target_value = current_value * multiplier
elif target_value and current_value > 0:
multiplier = target_value / current_value
pct_change = (multiplier - 1) * 100
else:
multiplier = 1.0
pct_change = 0.0
join_info = self._find_fact_join(dim_table) if dim_table else None
if join_info:
fact_table, fk_col, _ = join_info
dim_cols = self.schema_tables.get(dim_table, [])
dim_pk = next((c for c in dim_cols if c.endswith('_ID')), fk_col)
sql = f"""UPDATE {self.database}."{self.schema}".{fact_table}
SET {metric_column} = {metric_column} * {multiplier:.6f}
WHERE {fk_col} IN (
SELECT {dim_pk}
FROM {self.database}."{self.schema}".{dim_table}
WHERE LOWER({name_col}) = LOWER('{matched}')
)"""
elif dim_table:
sql = f"""UPDATE {self.database}."{self.schema}".{dim_table}
SET {metric_column} = {metric_column} * {multiplier:.6f}
WHERE LOWER({name_col}) = LOWER('{matched}')"""
else:
sql = f"-- Could not determine table structure for '{entity_value}'"
return {
'id': 'A',
'name': 'Scale All Transactions',
'description': f"Multiply all rows for '{matched}' by {multiplier:.3f}x ({pct_change:+.1f}%)",
'sql': sql,
'matched_entity': matched,
'target_value': target_value,
}
def present_smart_confirmation(self, match: Dict, current_value: float,
strategy: Dict, metric_column: str) -> str:
"""Format a human-readable confirmation message."""
entity = match.get('entity_value', '?')
matched = strategy.get('matched_entity', entity)
target = strategy.get('target_value', 0) or 0
if matched.lower() != entity.lower():
entity_display = f"{entity} β†’ **{matched}**"
else:
entity_display = f"**{matched}**"
change = target - current_value
pct = (change / current_value * 100) if current_value else 0
lines = [
f"**Liveboard:** {self.liveboard_name}",
f"**Entity:** {entity_display}",
f"**Metric:** `{metric_column}`",
f"**Current:** {current_value:,.0f}",
f"**Target:** {target:,.0f} ({change:+,.0f} / {pct:+.1f}%)",
f"**Strategy:** {strategy['description']}",
"",
f"```sql\n{strategy['sql']}\n```",
]
if match.get('confidence') == 'low':
lines.append("\n⚠️ Low confidence β€” please verify before confirming.")
return "\n".join(lines)
# ------------------------------------------------------------------
# SQL execution
# ------------------------------------------------------------------
def execute_sql(self, sql: str) -> Dict:
"""Execute an UPDATE statement. Returns success/error dict."""
cursor = self.conn.cursor()
try:
cursor.execute(sql)
rows_affected = cursor.rowcount
self.conn.commit()
return {'success': True, 'rows_affected': rows_affected}
except Exception as e:
try:
self.conn.rollback()
except Exception:
pass
return {'success': False, 'error': str(e)}
# ------------------------------------------------------------------
# Teardown
# ------------------------------------------------------------------
def close(self):
"""Close Snowflake connection."""
if self.conn:
try:
self.conn.close()
except Exception:
pass
# ---------------------------------------------------------------------------
# Liveboard-first context loader
# ---------------------------------------------------------------------------
def load_context_from_liveboard(liveboard_guid: str, ts_client) -> dict:
"""
Resolve Snowflake database/schema from a liveboard GUID.
Flow:
liveboard TML (export_fqn=True)
β†’ model GUID from visualizations[n].answer.tables[0].fqn
β†’ model TML
β†’ database / schema from model.tables[0].table.{db, schema}
Args:
liveboard_guid: ThoughtSpot liveboard GUID
ts_client: Authenticated ThoughtSpotDeployer instance
Returns:
dict with keys: liveboard_name, model_guid, model_name, database, schema
Raises:
ValueError if any step fails to resolve.
"""
import yaml
# Step 1: Export liveboard TML with FQNs
response = ts_client.session.post(
f"{ts_client.base_url}/api/rest/2.0/metadata/tml/export",
json={
"metadata": [{"identifier": liveboard_guid}],
"export_associated": False,
"export_fqn": True,
"format_type": "YAML",
}
)
if response.status_code != 200:
raise ValueError(
f"Failed to export liveboard TML ({response.status_code}): {response.text[:300]}"
)
tml_data = response.json()
if not tml_data:
raise ValueError("Empty response from liveboard TML export")
lb_tml = yaml.safe_load(tml_data[0]['edoc'])
liveboard_name = lb_tml.get('liveboard', {}).get('name', 'Unknown Liveboard')
# Step 2: Find model GUID from first visualization with answer.tables[].fqn
model_guid = None
for viz in lb_tml.get('liveboard', {}).get('visualizations', []):
for t in viz.get('answer', {}).get('tables', []):
fqn = t.get('fqn')
if fqn:
model_guid = fqn
break
if model_guid:
break
if not model_guid:
raise ValueError(
"Could not find model GUID in liveboard TML β€” "
"make sure the liveboard has at least one answer-based visualization."
)
# Step 3: Export model TML to get database/schema
response = ts_client.session.post(
f"{ts_client.base_url}/api/rest/2.0/metadata/tml/export",
json={
"metadata": [{"identifier": model_guid, "type": "LOGICAL_TABLE"}],
"export_associated": False,
"export_fqn": True,
"format_type": "YAML",
}
)
if response.status_code != 200:
raise ValueError(
f"Failed to export model TML ({response.status_code}): {response.text[:300]}"
)
tml_data = response.json()
model_tml = yaml.safe_load(tml_data[0]['edoc'])
model_name = model_tml.get('model', {}).get('name', 'Unknown Model')
# Step 4: Extract db/schema from first model table entry
tables = model_tml.get('model', {}).get('tables', [])
if not tables:
raise ValueError("No tables found in model TML")
first_table = tables[0].get('table', {})
database = first_table.get('db')
schema = first_table.get('schema')
if not database or not schema:
raise ValueError(
f"Could not resolve database/schema from model TML "
f"(db={database!r}, schema={schema!r})"
)
return {
'liveboard_name': liveboard_name,
'model_guid': model_guid,
'model_name': model_name,
'database': database,
'schema': schema,
}