#!/usr/bin/env python3 """Run the New Vision 4-case sample set without UI login flow. Modes: 1) Full chat pipeline mode (uses configured default_llm and required settings) 2) Offline DDL mode (deterministic schema template, still validates settings up front) Usage: source ./demoprep/bin/activate python tests/newvision_sample_runner.py python tests/newvision_sample_runner.py --offline-ddl python tests/newvision_sample_runner.py --skip-thoughtspot """ from __future__ import annotations import argparse import json import os import sys from datetime import datetime, timezone from pathlib import Path from typing import Any import yaml PROJECT_ROOT = Path(__file__).parent.parent sys.path.insert(0, str(PROJECT_ROOT)) from dotenv import load_dotenv load_dotenv(PROJECT_ROOT / ".env") os.environ.setdefault("DEMOPREP_NO_AUTH", "true") # Pull admin settings into environment when available. try: from supabase_client import inject_admin_settings_to_env inject_admin_settings_to_env() except Exception as exc: # noqa: BLE001 print(f"[newvision_runner] Admin setting injection unavailable: {exc}") OFFLINE_DEMO_DDL = """ CREATE TABLE DIM_DATE ( DATE_KEY INT PRIMARY KEY, ORDER_DATE DATE, MONTH_NAME VARCHAR(30), QUARTER_NAME VARCHAR(10), YEAR_NUM INT, IS_WEEKEND BOOLEAN ); CREATE TABLE DIM_LOCATION ( LOCATION_KEY INT PRIMARY KEY, COUNTRY VARCHAR(100), REGION VARCHAR(100), STATE VARCHAR(100), CITY VARCHAR(100), SALES_CHANNEL VARCHAR(100), CUSTOMER_SEGMENT VARCHAR(100) ); CREATE TABLE DIM_PRODUCT ( PRODUCT_KEY INT PRIMARY KEY, PRODUCT_NAME VARCHAR(200), BRAND_NAME VARCHAR(100), CATEGORY VARCHAR(100), SUB_CATEGORY VARCHAR(100), PRODUCT_TIER VARCHAR(50), UNIT_PRICE DECIMAL(12,2) ); CREATE TABLE FACT_RETAIL_DAILY ( TRANSACTION_KEY INT PRIMARY KEY, DATE_KEY INT, LOCATION_KEY INT, PRODUCT_KEY INT, ORDER_DATE DATE, ORDER_COUNT INT, UNITS_SOLD INT, UNIT_PRICE DECIMAL(12,2), GROSS_REVENUE DECIMAL(14,2), NET_REVENUE DECIMAL(14,2), SALES_AMOUNT DECIMAL(14,2), DISCOUNT_PCT DECIMAL(5,2), INVENTORY_ON_HAND INT, LOST_SALES_USD DECIMAL(14,2), IS_OOS BOOLEAN, FOREIGN KEY (DATE_KEY) REFERENCES DIM_DATE(DATE_KEY), FOREIGN KEY (LOCATION_KEY) REFERENCES DIM_LOCATION(LOCATION_KEY), FOREIGN KEY (PRODUCT_KEY) REFERENCES DIM_PRODUCT(PRODUCT_KEY) ); """.strip() def _now_utc_iso() -> str: return datetime.now(timezone.utc).isoformat() def _load_cases(cases_file: Path) -> list[dict[str, Any]]: data = yaml.safe_load(cases_file.read_text(encoding="utf-8")) or {} return list(data.get("test_cases", [])) def _parse_quality_report_path(message: str) -> str | None: for line in (message or "").splitlines(): if "Report:" in line: return line.split("Report:", 1)[1].strip() if "See report:" in line: return line.split("See report:", 1)[1].strip() return None def _load_quality_report(report_path: str | None) -> dict[str, Any]: if not report_path: return {} path = Path(report_path) json_path = path.with_suffix(".json") if json_path.exists(): try: return json.loads(json_path.read_text(encoding="utf-8")) except Exception: # noqa: BLE001 return {} return {} def _build_quality_gate_stage(report_path: str | None) -> dict[str, Any]: report = _load_quality_report(report_path) summary = report.get("summary", {}) if isinstance(report, dict) else {} passed = report.get("passed") if isinstance(report, dict) else None return { "ok": bool(passed) if passed is not None else False, "report_path": report_path, "passed": passed, "summary": { "semantic_pass_ratio": summary.get("semantic_pass_ratio"), "categorical_junk_count": summary.get("categorical_junk_count"), "fk_orphan_count": summary.get("fk_orphan_count"), "temporal_violations": summary.get("temporal_violations"), "numeric_violations": summary.get("numeric_violations"), "volatility_breaches": summary.get("volatility_breaches"), "smoothness_score": summary.get("smoothness_score"), "outlier_explainability": summary.get("outlier_explainability"), "kpi_consistency": summary.get("kpi_consistency"), }, } def _resolve_runtime_settings() -> tuple[str, str]: user_email = ( os.getenv("USER_EMAIL") or os.getenv("INITIAL_USER") or os.getenv("THOUGHTSPOT_ADMIN_USER") or "default@user.com" ).strip() default_llm = (os.getenv("DEFAULT_LLM") or os.getenv("OPENAI_MODEL") or "").strip() if not default_llm: raise ValueError("Missing required env var: DEFAULT_LLM or OPENAI_MODEL") return user_email, default_llm def _run_realism_sanity_checks(schema_name: str, case: dict[str, Any]) -> dict[str, Any]: """Fast, opinionated sanity checks for demo realism. These checks intentionally target user-visible demo breakages that can slip through structural quality gates (e.g., null dimensions in top-N charts). """ checks: list[dict[str, Any]] = [] failures: list[str] = [] if not schema_name: return {"ok": False, "checks": checks, "failures": ["Missing schema name for sanity checks"]} use_case = str(case.get("use_case", "") or "") use_case_lower = (use_case or "").lower() case_name = str(case.get("name", "") or "").lower() is_legal = "legal" in use_case_lower is_private_equity = any( marker in use_case_lower for marker in ("private equity", "lp reporting", "state street") ) is_saas_finance = any( marker in use_case_lower for marker in ("saas finance", "unit economics", "financial analytics", "fp&a", "fpa") ) if not is_legal and not is_private_equity and not is_saas_finance: # Keep runtime fast by evaluating only scoped vertical checks. return {"ok": True, "checks": checks, "failures": []} from supabase_client import inject_admin_settings_to_env from snowflake_auth import get_snowflake_connection inject_admin_settings_to_env() conn = None cur = None try: db_name = (os.getenv("SNOWFLAKE_DATABASE") or "DEMOBUILD").strip() safe_schema = schema_name.replace('"', "") conn = get_snowflake_connection() cur = conn.cursor() cur.execute(f'USE DATABASE "{db_name}"') cur.execute(f'USE SCHEMA "{safe_schema}"') if is_legal: cur.execute("SHOW TABLES") legal_tables = {str(row[1]).upper() for row in cur.fetchall()} has_split_legal = {"LEGAL_MATTERS", "OUTSIDE_COUNSEL_INVOICES", "ATTORNEYS", "MATTER_TYPES"}.issubset(legal_tables) has_event_legal = "LEGAL_SPEND_EVENTS" in legal_tables if has_split_legal: # 1) Invoice -> matter -> attorney join coverage. cur.execute( """ SELECT COUNT(*) AS total_rows, COUNT_IF(a.ATTORNEY_NAME IS NULL) AS null_rows FROM OUTSIDE_COUNSEL_INVOICES oci LEFT JOIN LEGAL_MATTERS lm ON oci.MATTER_ID = lm.MATTER_ID LEFT JOIN ATTORNEYS a ON lm.ASSIGNED_ATTORNEY_ID = a.ATTORNEY_ID """ ) total_rows, null_rows = cur.fetchone() null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0 checks.append( { "name": "legal_attorney_join_null_pct", "value": round(null_pct, 2), "threshold": "<= 5.0", "ok": null_pct <= 5.0, } ) if null_pct > 5.0: failures.append( f"Attorney join null rate too high: {null_pct:.2f}% (expected <= 5%)" ) # 1b) Invoice MATTER_ID linkage must be complete. cur.execute( """ SELECT COUNT(*) AS total_rows, COUNT_IF(MATTER_ID IS NULL) AS null_rows FROM OUTSIDE_COUNSEL_INVOICES """ ) total_rows, null_rows = cur.fetchone() null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0 checks.append( { "name": "legal_invoice_matter_id_null_pct", "value": round(null_pct, 2), "threshold": "== 0.0", "ok": null_pct == 0.0, } ) if null_pct != 0.0: failures.append( f"Invoice MATTER_ID null rate is {null_pct:.2f}% (expected 0%)" ) # 2) Region cardinality should be compact for legal executive demos. cur.execute("SELECT COUNT(DISTINCT REGION) FROM LEGAL_MATTERS WHERE REGION IS NOT NULL") region_cardinality = int(cur.fetchone()[0] or 0) checks.append( { "name": "legal_region_distinct_count", "value": region_cardinality, "threshold": "<= 6", "ok": region_cardinality <= 6, } ) if region_cardinality > 6: failures.append( f"Region cardinality too high: {region_cardinality} distinct values (expected <= 6)" ) # 3) Firm names should not contain obvious cross-vertical banking/org jargon. cur.execute( """ SELECT COUNT(*) FROM OUTSIDE_COUNSEL WHERE REGEXP_LIKE( LOWER(FIRM_NAME), 'retail banking|consumer lending|digital channels|enterprise operations|regional service' ) """ ) bad_firm_count = int(cur.fetchone()[0] or 0) checks.append( { "name": "legal_firm_name_cross_vertical_count", "value": bad_firm_count, "threshold": "== 0", "ok": bad_firm_count == 0, } ) if bad_firm_count != 0: failures.append( f"Detected {bad_firm_count} cross-vertical/non-legal firm names" ) # 4) Matter type taxonomy should remain concise and demo-friendly. cur.execute( """ SELECT COUNT(DISTINCT mt.MATTER_TYPE_NAME) FROM LEGAL_MATTERS lm LEFT JOIN MATTER_TYPES mt ON lm.MATTER_TYPE_ID = mt.MATTER_TYPE_ID WHERE mt.MATTER_TYPE_NAME IS NOT NULL """ ) matter_type_cardinality = int(cur.fetchone()[0] or 0) checks.append( { "name": "legal_matter_type_distinct_count", "value": matter_type_cardinality, "threshold": "<= 15", "ok": matter_type_cardinality <= 15, } ) if matter_type_cardinality > 15: failures.append( f"Matter type cardinality too high: {matter_type_cardinality} distinct values (expected <= 15)" ) elif has_event_legal: # 1) Attorney dimension join coverage (critical for "Top Attorney by Cost"). cur.execute( """ SELECT COUNT(*) AS total_rows, COUNT_IF(a.ATTORNEY_NAME IS NULL) AS null_rows FROM LEGAL_SPEND_EVENTS lse LEFT JOIN ATTORNEYS a ON lse.ATTORNEY_ID = a.ATTORNEY_ID """ ) total_rows, null_rows = cur.fetchone() null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0 checks.append( { "name": "legal_attorney_join_null_pct", "value": round(null_pct, 2), "threshold": "<= 5.0", "ok": null_pct <= 5.0, } ) if null_pct > 5.0: failures.append( f"Attorney join null rate too high: {null_pct:.2f}% (expected <= 5%)" ) # 2) Region cardinality should be compact for legal executive demos. cur.execute("SELECT COUNT(DISTINCT REGION) FROM LEGAL_SPEND_EVENTS WHERE REGION IS NOT NULL") region_cardinality = int(cur.fetchone()[0] or 0) checks.append( { "name": "legal_region_distinct_count", "value": region_cardinality, "threshold": "<= 6", "ok": region_cardinality <= 6, } ) if region_cardinality > 6: failures.append( f"Region cardinality too high: {region_cardinality} distinct values (expected <= 6)" ) # 3) Firm names should not contain obvious cross-vertical banking/org jargon. cur.execute( """ SELECT COUNT(*) FROM OUTSIDE_COUNSEL_FIRMS WHERE REGEXP_LIKE( LOWER(FIRM_NAME), 'retail banking|consumer lending|digital channels|enterprise operations|regional service' ) """ ) bad_firm_count = int(cur.fetchone()[0] or 0) checks.append( { "name": "legal_firm_name_cross_vertical_count", "value": bad_firm_count, "threshold": "== 0", "ok": bad_firm_count == 0, } ) if bad_firm_count != 0: failures.append( f"Detected {bad_firm_count} cross-vertical/non-legal firm names" ) # 4) Matter type taxonomy should remain concise and demo-friendly. cur.execute( """ SELECT COUNT(DISTINCT mt.MATTER_TYPE_NAME) FROM LEGAL_SPEND_EVENTS lse LEFT JOIN MATTER_TYPES mt ON lse.MATTER_TYPE_ID = mt.MATTER_TYPE_ID WHERE mt.MATTER_TYPE_NAME IS NOT NULL """ ) matter_type_cardinality = int(cur.fetchone()[0] or 0) checks.append( { "name": "legal_matter_type_distinct_count", "value": matter_type_cardinality, "threshold": "<= 15", "ok": matter_type_cardinality <= 15, } ) if matter_type_cardinality > 15: failures.append( f"Matter type cardinality too high: {matter_type_cardinality} distinct values (expected <= 15)" ) else: failures.append("Could not find supported legal schema shape for realism checks") if is_private_equity: # Guard against semantic leakage where sector/strategy dimensions are # accidentally populated with company names. cur.execute( """ WITH dim_companies AS ( SELECT DISTINCT COMPANY_NAME FROM PORTFOLIO_COMPANIES WHERE COMPANY_NAME IS NOT NULL ), dim_sectors AS ( SELECT DISTINCT SECTOR_NAME FROM SECTORS WHERE SECTOR_NAME IS NOT NULL ), dim_strategies AS ( SELECT DISTINCT FUND_STRATEGY FROM FUNDS WHERE FUND_STRATEGY IS NOT NULL ) SELECT (SELECT COUNT(*) FROM dim_sectors), (SELECT COUNT(*) FROM dim_strategies), (SELECT COUNT(*) FROM dim_sectors s JOIN dim_companies c ON s.SECTOR_NAME = c.COMPANY_NAME), (SELECT COUNT(*) FROM dim_strategies f JOIN dim_companies c ON f.FUND_STRATEGY = c.COMPANY_NAME) """ ) sector_distinct, strategy_distinct, sector_overlap, strategy_overlap = cur.fetchone() sector_distinct = int(sector_distinct or 0) strategy_distinct = int(strategy_distinct or 0) sector_overlap = int(sector_overlap or 0) strategy_overlap = int(strategy_overlap or 0) checks.append( { "name": "pe_sector_name_company_overlap_count", "value": sector_overlap, "threshold": "== 0", "ok": sector_overlap == 0, } ) if sector_overlap != 0: failures.append( f"Sector names overlap company names ({sector_overlap} overlaps); likely mislabeled dimensions" ) checks.append( { "name": "pe_fund_strategy_company_overlap_count", "value": strategy_overlap, "threshold": "== 0", "ok": strategy_overlap == 0, } ) if strategy_overlap != 0: failures.append( f"Fund strategy values overlap company names ({strategy_overlap} overlaps); likely mislabeled dimensions" ) checks.append( { "name": "pe_sector_distinct_count", "value": sector_distinct, "threshold": ">= 4 and <= 20", "ok": 4 <= sector_distinct <= 20, } ) if not (4 <= sector_distinct <= 20): failures.append( f"Sector distinct count out of expected demo range: {sector_distinct} (expected 4-20)" ) checks.append( { "name": "pe_fund_strategy_distinct_count", "value": strategy_distinct, "threshold": ">= 4 and <= 20", "ok": 4 <= strategy_distinct <= 20, } ) if not (4 <= strategy_distinct <= 20): failures.append( f"Fund strategy distinct count out of expected demo range: {strategy_distinct} (expected 4-20)" ) if case_name == "statestreet_private_equity_lp_reporting": cur.execute( """ SELECT COUNT(*) AS total_rows, COUNT_IF(ABS(TOTAL_VALUE_USD - (REPORTED_VALUE_USD + DISTRIBUTIONS_USD)) > 0.01) AS bad_rows FROM PORTFOLIO_PERFORMANCE """ ) total_rows, bad_rows = cur.fetchone() total_rows = int(total_rows or 0) bad_rows = int(bad_rows or 0) identity_ok = total_rows > 0 and bad_rows == 0 checks.append( { "name": "pe_total_value_identity_bad_rows", "value": bad_rows, "threshold": "== 0", "ok": identity_ok, } ) if not identity_ok: failures.append( f"Total value identity broken in {bad_rows} PE fact rows" ) cur.execute( """ SELECT COUNT(*) AS total_rows, COUNT_IF(IRR_SUB_LINE_IMPACT_BPS BETWEEN 80 AND 210) AS in_band_rows, COUNT_IF(ABS(IRR_SUB_LINE_IMPACT_BPS - ((GROSS_IRR - GROSS_IRR_WITHOUT_SUB_LINE) * 10000)) <= 5) AS identity_rows FROM PORTFOLIO_PERFORMANCE """ ) total_rows, in_band_rows, identity_rows = cur.fetchone() total_rows = int(total_rows or 0) in_band_rows = int(in_band_rows or 0) identity_rows = int(identity_rows or 0) irr_band_ok = total_rows > 0 and in_band_rows == total_rows and identity_rows == total_rows checks.append( { "name": "pe_subscription_line_impact_rows_valid", "value": {"total": total_rows, "in_band": in_band_rows, "identity": identity_rows}, "threshold": "all rows in 80-210 bps band and identity holds", "ok": irr_band_ok, } ) if not irr_band_ok: failures.append("Subscription line impact rows do not consistently satisfy PE IRR delta rules") cur.execute( """ SELECT COUNT(*) AS apex_rows, MAX(pp.IRR_SUB_LINE_IMPACT_BPS) AS apex_max_bps, ( SELECT MAX(IRR_SUB_LINE_IMPACT_BPS) FROM PORTFOLIO_PERFORMANCE ) AS overall_max_bps FROM PORTFOLIO_PERFORMANCE pp JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID WHERE LOWER(pc.COMPANY_NAME) = 'apex industrial solutions' """ ) apex_rows, apex_max_bps, overall_max_bps = cur.fetchone() apex_ok = int(apex_rows or 0) > 0 and apex_max_bps is not None and abs(float(apex_max_bps) - 210.0) <= 1.0 and overall_max_bps is not None and abs(float(overall_max_bps) - 210.0) <= 1.0 checks.append( { "name": "pe_apex_subscription_line_outlier", "value": {"rows": int(apex_rows or 0), "apex_max_bps": apex_max_bps, "overall_max_bps": overall_max_bps}, "threshold": "Apex exists and max impact == 210 bps", "ok": apex_ok, } ) if not apex_ok: failures.append("Apex Industrial Solutions outlier is missing or not set to the expected 210 bps impact") cur.execute( """ WITH covenant_exceptions AS ( SELECT LOWER(pc.COMPANY_NAME) AS company_name, LOWER(pp.COVENANT_STATUS) AS covenant_status, COUNT(*) AS row_count FROM PORTFOLIO_PERFORMANCE pp JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID WHERE LOWER(pp.COVENANT_STATUS) <> 'compliant' GROUP BY 1, 2 ) SELECT COUNT_IF(company_name = 'meridian specialty chemicals' AND covenant_status = 'waived') AS meridian_waived_groups, COUNT_IF(company_name <> 'meridian specialty chemicals' OR covenant_status <> 'waived') AS invalid_groups FROM covenant_exceptions """ ) meridian_groups, invalid_groups = cur.fetchone() meridian_ok = int(meridian_groups or 0) > 0 and int(invalid_groups or 0) == 0 checks.append( { "name": "pe_meridian_covenant_exception", "value": {"meridian_groups": int(meridian_groups or 0), "invalid_groups": int(invalid_groups or 0)}, "threshold": "Meridian only, status waived", "ok": meridian_ok, } ) if not meridian_ok: failures.append("Meridian Specialty Chemicals is not the sole waived/non-compliant covenant exception") cur.execute( """ WITH sector_perf AS ( SELECT s.SECTOR_NAME, AVG(pp.ENTRY_EV_EBITDA_MULTIPLE) AS avg_entry_multiple, AVG(pp.TOTAL_RETURN_MULTIPLE) AS avg_tvpi FROM PORTFOLIO_PERFORMANCE pp JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID JOIN SECTORS s ON pc.SECTOR_ID = s.SECTOR_ID GROUP BY 1 ) SELECT MAX(CASE WHEN LOWER(SECTOR_NAME) = 'technology' THEN avg_entry_multiple END) AS tech_entry, MAX(CASE WHEN LOWER(SECTOR_NAME) = 'technology' THEN avg_tvpi END) AS tech_tvpi, MAX(CASE WHEN LOWER(SECTOR_NAME) <> 'technology' THEN avg_entry_multiple END) AS other_entry_max, MAX(CASE WHEN LOWER(SECTOR_NAME) <> 'technology' THEN avg_tvpi END) AS other_tvpi_max FROM sector_perf """ ) tech_entry, tech_tvpi, other_entry_max, other_tvpi_max = cur.fetchone() tech_sector_ok = ( tech_entry is not None and tech_tvpi is not None and other_entry_max is not None and other_tvpi_max is not None and float(tech_entry) >= float(other_entry_max) and float(tech_tvpi) >= float(other_tvpi_max) ) checks.append( { "name": "pe_technology_sector_leads_multiples", "value": { "tech_entry": tech_entry, "tech_tvpi": tech_tvpi, "other_entry_max": other_entry_max, "other_tvpi_max": other_tvpi_max, }, "threshold": "Technology leads average entry and return multiples", "ok": tech_sector_ok, } ) if not tech_sector_ok: failures.append("Technology sector does not lead entry and return multiples as required by the State Street narrative") cur.execute( """ WITH vintage_rank AS ( SELECT VINTAGE_YEAR, SUM(REPORTED_VALUE_USD) AS total_reported_value, DENSE_RANK() OVER (ORDER BY SUM(REPORTED_VALUE_USD) DESC) AS value_rank FROM PORTFOLIO_PERFORMANCE GROUP BY 1 ) SELECT LISTAGG(TO_VARCHAR(VINTAGE_YEAR), ',') WITHIN GROUP (ORDER BY VINTAGE_YEAR) FROM vintage_rank WHERE value_rank <= 2 """ ) top_vintages = cur.fetchone()[0] or "" top_vintage_set = {part.strip() for part in str(top_vintages).split(",") if part.strip()} vintage_ok = top_vintage_set == {"2021", "2022"} checks.append( { "name": "pe_top_vintages_reported_value", "value": sorted(top_vintage_set), "threshold": "top 2 vintages are 2021 and 2022", "ok": vintage_ok, } ) if not vintage_ok: failures.append("2021 and 2022 are not the top reported-value vintages") cur.execute( """ WITH healthcare_quarters AS ( SELECT DATE_TRUNC('quarter', pp.FULL_DATE) AS quarter_start, AVG(pp.TOTAL_VALUE_USD) AS avg_total_value FROM PORTFOLIO_PERFORMANCE pp JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID JOIN SECTORS s ON pc.SECTOR_ID = s.SECTOR_ID WHERE LOWER(s.SECTOR_NAME) = 'healthcare' GROUP BY 1 ) SELECT MAX(CASE WHEN quarter_start = DATE '2024-07-01' THEN avg_total_value END) AS q3_2024_value, MAX(CASE WHEN quarter_start = DATE '2024-10-01' THEN avg_total_value END) AS q4_2024_value FROM healthcare_quarters """ ) q3_2024_value, q4_2024_value = cur.fetchone() healthcare_dip_ok = ( q3_2024_value is not None and q4_2024_value is not None and float(q4_2024_value) < float(q3_2024_value) ) checks.append( { "name": "pe_healthcare_q4_2024_dip", "value": {"q3_2024": q3_2024_value, "q4_2024": q4_2024_value}, "threshold": "Q4 2024 healthcare total value lower than Q3 2024", "ok": healthcare_dip_ok, } ) if not healthcare_dip_ok: failures.append("Healthcare Q4 2024 performance dip is missing") cur.execute( """ WITH company_trends AS ( SELECT pc.COMPANY_NAME, FIRST_VALUE(pp.REVENUE_USD) OVER (PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE) AS first_revenue, LAST_VALUE(pp.REVENUE_USD) OVER ( PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS last_revenue, FIRST_VALUE(pp.EBITDA_MARGIN_PCT) OVER (PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE) AS first_margin, LAST_VALUE(pp.EBITDA_MARGIN_PCT) OVER ( PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING ) AS last_margin FROM PORTFOLIO_PERFORMANCE pp JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID ) SELECT COUNT(DISTINCT COMPANY_NAME) FROM company_trends WHERE last_revenue > first_revenue AND last_margin < first_margin """ ) trend_company_count = int(cur.fetchone()[0] or 0) trend_ok = trend_company_count >= 1 checks.append( { "name": "pe_revenue_up_margin_down_company_count", "value": trend_company_count, "threshold": ">= 1", "ok": trend_ok, } ) if not trend_ok: failures.append("No portfolio company shows the required revenue-up / EBITDA-margin-down trend") if is_saas_finance: cur.execute("SELECT COUNT(DISTINCT MONTH_KEY) FROM DATES") month_count = int(cur.fetchone()[0] or 0) checks.append( { "name": "saas_month_count", "value": month_count, "threshold": ">= 24", "ok": month_count >= 24, } ) if month_count < 24: failures.append(f"SaaS finance month count too low: {month_count} (expected >= 24)") cur.execute("SELECT COUNT(DISTINCT SEGMENT) FROM CUSTOMERS WHERE SEGMENT IS NOT NULL") segment_count = int(cur.fetchone()[0] or 0) checks.append( { "name": "saas_segment_distinct_count", "value": segment_count, "threshold": ">= 3", "ok": segment_count >= 3, } ) if segment_count < 3: failures.append(f"SaaS finance segment count too low: {segment_count} (expected >= 3)") cur.execute("SELECT COUNT(DISTINCT REGION) FROM LOCATIONS WHERE REGION IS NOT NULL") region_count = int(cur.fetchone()[0] or 0) checks.append( { "name": "saas_region_distinct_count", "value": region_count, "threshold": ">= 3", "ok": region_count >= 3, } ) if region_count < 3: failures.append(f"SaaS finance region count too low: {region_count} (expected >= 3)") cur.execute( """ SELECT COUNT(*) AS total_rows, COUNT_IF( ABS( ENDING_ARR_USD - ( STARTING_ARR_USD + NEW_LOGO_ARR_USD + EXPANSION_ARR_USD - CONTRACTION_ARR_USD - CHURNED_ARR_USD ) ) > 1.0 ) AS bad_arr_rows, COUNT_IF(ABS((MRR_USD * 12.0) - ENDING_ARR_USD) > 12.0) AS bad_mrr_rows FROM SAAS_CUSTOMER_MONTHLY """ ) total_rows, bad_arr_rows, bad_mrr_rows = cur.fetchone() total_rows = int(total_rows or 0) bad_arr_rows = int(bad_arr_rows or 0) bad_mrr_rows = int(bad_mrr_rows or 0) arr_identity_ok = total_rows > 0 and bad_arr_rows == 0 and bad_mrr_rows == 0 checks.append( { "name": "saas_arr_rollforward_bad_rows", "value": {"total": total_rows, "bad_arr": bad_arr_rows, "bad_mrr": bad_mrr_rows}, "threshold": "all rows reconcile", "ok": arr_identity_ok, } ) if not arr_identity_ok: failures.append( f"SaaS finance ARR identities broken (bad_arr={bad_arr_rows}, bad_mrr={bad_mrr_rows})" ) cur.execute( """ WITH month_counts AS ( SELECT CUSTOMER_KEY, COUNT(DISTINCT MONTH_KEY) AS active_months FROM SAAS_CUSTOMER_MONTHLY GROUP BY 1 ) SELECT AVG(active_months), MIN(active_months), MAX(active_months) FROM month_counts """ ) avg_active_months, min_active_months, max_active_months = cur.fetchone() avg_active_months = float(avg_active_months or 0.0) min_active_months = int(min_active_months or 0) max_active_months = int(max_active_months or 0) density_ok = avg_active_months >= 12.0 and max_active_months >= 20 checks.append( { "name": "saas_customer_month_density", "value": { "avg_active_months": round(avg_active_months, 2), "min_active_months": min_active_months, "max_active_months": max_active_months, }, "threshold": "avg >= 12.0 and max >= 20", "ok": density_ok, } ) if not density_ok: failures.append( f"SaaS finance customer-month density too sparse (avg={avg_active_months:.2f}, max={max_active_months})" ) cur.execute( """ SELECT COUNT(*) AS total_rows, COUNT_IF(ABS(TOTAL_S_AND_M_SPEND_USD - (SALES_SPEND_USD + MARKETING_SPEND_USD)) > 1.0) AS bad_rows FROM SALES_MARKETING_SPEND_MONTHLY """ ) spend_total_rows, bad_spend_rows = cur.fetchone() spend_total_rows = int(spend_total_rows or 0) bad_spend_rows = int(bad_spend_rows or 0) spend_ok = spend_total_rows > 0 and bad_spend_rows == 0 checks.append( { "name": "saas_spend_identity_bad_rows", "value": {"total": spend_total_rows, "bad_rows": bad_spend_rows}, "threshold": "== 0", "ok": spend_ok, } ) if not spend_ok: failures.append(f"SaaS finance spend identity broken in {bad_spend_rows} rows") except Exception as exc: # noqa: BLE001 failures.append(f"Realism sanity checks failed to execute: {exc}") finally: try: if cur is not None: cur.close() except Exception: pass try: if conn is not None: conn.close() except Exception: pass return {"ok": len(failures) == 0, "checks": checks, "failures": failures} def _run_case_chat( case: dict[str, Any], default_llm: str, user_email: str, skip_thoughtspot: bool = False, ) -> dict[str, Any]: from chat_interface import ChatDemoInterface from demo_personas import get_use_case_config, parse_use_case company = case["company"] use_case = case["use_case"] model = default_llm context = case.get("context", "") controller = ChatDemoInterface(user_email=user_email) controller.settings["model"] = model controller.vertical, controller.function = parse_use_case(use_case or "") controller.use_case_config = get_use_case_config( controller.vertical or "Generic", controller.function or "Generic", ) result: dict[str, Any] = { "name": case.get("name") or f"{company}_{use_case}", "company": company, "use_case": use_case, "mode": "chat", "started_at": _now_utc_iso(), "success": False, "stages": {}, } stage_start = datetime.now(timezone.utc) last_research = None for update in controller.run_research_streaming(company, use_case, generic_context=context): last_research = update result["stages"]["research"] = { "ok": bool(controller.demo_builder and controller.demo_builder.company_analysis_results), "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "preview": str(last_research)[:500] if last_research else "", } stage_start = datetime.now(timezone.utc) ddl_text = (controller.demo_builder.schema_generation_results or "") if controller.demo_builder else "" if not ddl_text: _, ddl_text = controller.run_ddl_creation() result["stages"]["ddl"] = { "ok": bool(ddl_text and "CREATE TABLE" in ddl_text.upper()), "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "ddl_length": len(ddl_text or ""), } if not result["stages"]["ddl"]["ok"]: result["error"] = "DDL generation failed" result["finished_at"] = _now_utc_iso() return result stage_start = datetime.now(timezone.utc) deploy_error = None try: for _ in controller.run_deployment_streaming(): pass except Exception as exc: # noqa: BLE001 deploy_error = str(exc) deployed_schema = getattr(controller, "_deployed_schema_name", None) schema_candidate = deployed_schema or getattr(controller, "_last_schema_name", None) result["stages"]["deploy_snowflake"] = { "ok": bool(deployed_schema) and deploy_error is None, "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "schema": schema_candidate, "error": deploy_error, } if schema_candidate: quality_report_path = getattr(controller, "_last_population_quality_report_path", None) result["stages"]["quality_gate"] = _build_quality_gate_stage(quality_report_path) if not result["stages"]["quality_gate"]["ok"]: result["error"] = f"Quality gate failed: {quality_report_path or 'missing quality report'}" elif deploy_error and not result.get("error"): result["error"] = deploy_error stage_start = datetime.now(timezone.utc) if result["stages"]["quality_gate"]["ok"]: sanity = _run_realism_sanity_checks(schema_candidate, case) result["stages"]["realism_sanity"] = { "ok": bool(sanity.get("ok")), "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "checks": sanity.get("checks", []), "failures": sanity.get("failures", []), } if not result["stages"]["realism_sanity"]["ok"] and not result.get("error"): result["error"] = "Realism sanity checks failed before ThoughtSpot deployment" if not skip_thoughtspot and deployed_schema and not result.get("error"): stage_start = datetime.now(timezone.utc) ts_ok = True ts_last = None try: for ts_update in controller._run_thoughtspot_deployment(deployed_schema, company, use_case): ts_last = ts_update except Exception as exc: # noqa: BLE001 ts_ok = False ts_last = str(exc) # Some deployment paths return a structured failure payload rather than # raising; treat those as failures so pass/fail reporting is accurate. ts_preview_text = str(ts_last) if ts_last is not None else "" if ts_ok and ( "THOUGHTSPOT DEPLOYMENT FAILED" in ts_preview_text.upper() or "MODEL VALIDATION FAILED" in ts_preview_text.upper() or "LIVEBOARD CREATION FAILED" in ts_preview_text.upper() ): ts_ok = False result["stages"]["deploy_thoughtspot"] = { "ok": ts_ok, "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "preview": ts_preview_text[:1000], } result["schema_name"] = schema_candidate result["success"] = all(stage.get("ok") for stage in result["stages"].values()) result["finished_at"] = _now_utc_iso() return result def _run_case_offline( case: dict[str, Any], default_llm: str, user_email: str, skip_thoughtspot: bool = False, ) -> dict[str, Any]: from cdw_connector import SnowflakeDeployer from demo_prep import generate_demo_base_name from legitdata_bridge import populate_demo_data from thoughtspot_deployer import deploy_to_thoughtspot company = case["company"] use_case = case["use_case"] result: dict[str, Any] = { "name": case.get("name") or f"{company}_{use_case}", "company": company, "use_case": use_case, "mode": "offline_ddl", "started_at": _now_utc_iso(), "success": False, "stages": {}, "ddl_template": "offline_star_schema_v1", } deployer = SnowflakeDeployer() # 1) Snowflake schema + DDL deploy stage_start = datetime.now(timezone.utc) ok, msg = deployer.connect() if not ok: result["stages"]["snowflake_connect"] = { "ok": False, "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "message": msg, } result["error"] = msg result["finished_at"] = _now_utc_iso() return result base_name = generate_demo_base_name("", company) ok, schema_name, ddl_msg = deployer.create_demo_schema_and_deploy(base_name, OFFLINE_DEMO_DDL) result["stages"]["snowflake_ddl"] = { "ok": ok, "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "schema": schema_name, "message": ddl_msg, } if not ok or not schema_name: result["error"] = ddl_msg result["finished_at"] = _now_utc_iso() return result # 2) Data population via LegitData stage_start = datetime.now(timezone.utc) pop_ok, pop_msg, pop_results = populate_demo_data( ddl_content=OFFLINE_DEMO_DDL, company_url=company, use_case=use_case, schema_name=schema_name, llm_model=default_llm, user_email=user_email, size="medium", ) result["stages"]["populate_data"] = { "ok": pop_ok, "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "rows": pop_results, "quality_report": _parse_quality_report_path(pop_msg), } if not pop_ok: result["error"] = pop_msg result["finished_at"] = _now_utc_iso() return result quality_report_path = _parse_quality_report_path(pop_msg) result["stages"]["quality_gate"] = _build_quality_gate_stage(quality_report_path) if not result["stages"]["quality_gate"]["ok"]: result["error"] = f"Quality gate failed: {quality_report_path or 'missing quality report'}" result["schema_name"] = schema_name result["finished_at"] = _now_utc_iso() return result stage_start = datetime.now(timezone.utc) sanity = _run_realism_sanity_checks(schema_name, case) result["stages"]["realism_sanity"] = { "ok": bool(sanity.get("ok")), "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "checks": sanity.get("checks", []), "failures": sanity.get("failures", []), } if not result["stages"]["realism_sanity"]["ok"]: result["error"] = "Realism sanity checks failed before ThoughtSpot deployment" result["schema_name"] = schema_name result["finished_at"] = _now_utc_iso() return result # 3) ThoughtSpot model + liveboard if not skip_thoughtspot: stage_start = datetime.now(timezone.utc) ts_result = deploy_to_thoughtspot( ddl=OFFLINE_DEMO_DDL, database=os.getenv("SNOWFLAKE_DATABASE", "DEMOBUILD"), schema=schema_name, base_name=base_name, connection_name=f"{base_name}_conn", company_name=company, use_case=use_case, llm_model=default_llm, ) result["stages"]["deploy_thoughtspot"] = { "ok": bool(ts_result and not ts_result.get("errors")), "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), "result": ts_result, } result["schema_name"] = schema_name result["success"] = all(stage.get("ok") for stage in result["stages"].values()) result["finished_at"] = _now_utc_iso() return result def main() -> None: parser = argparse.ArgumentParser(description="Run New Vision sample set") parser.add_argument( "--cases-file", default="tests/newvision_test_cases.yaml", help="Path to YAML test case file", ) parser.add_argument( "--skip-thoughtspot", action="store_true", help="Run through data generation only and skip ThoughtSpot object creation", ) parser.add_argument( "--offline-ddl", action="store_true", help="Force offline DDL mode (no LLM dependency)", ) args = parser.parse_args() user_email, default_llm = _resolve_runtime_settings() from startup_validation import validate_required_pipeline_settings_or_raise validate_required_pipeline_settings_or_raise( default_llm=default_llm, require_thoughtspot=not args.skip_thoughtspot, require_snowflake=True, ) cases_file = (PROJECT_ROOT / args.cases_file).resolve() cases = _load_cases(cases_file) run_id = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") out_dir = PROJECT_ROOT / "results" / "newvision_samples" / run_id out_dir.mkdir(parents=True, exist_ok=True) use_offline = bool(args.offline_ddl) print(f"Mode: {'offline_ddl' if use_offline else 'chat'}", flush=True) print(f"default_llm: {default_llm}", flush=True) results = [] for idx, case in enumerate(cases, start=1): print(f"\n[{idx}/{len(cases)}] {case.get('name', case['company'])} -> {case['use_case']}", flush=True) try: if use_offline: case_result = _run_case_offline( case, default_llm=default_llm, user_email=user_email, skip_thoughtspot=args.skip_thoughtspot, ) else: case_result = _run_case_chat( case, default_llm=default_llm, user_email=user_email, skip_thoughtspot=args.skip_thoughtspot, ) except Exception as exc: # noqa: BLE001 case_result = { "name": case.get("name") or f"{case['company']}_{case['use_case']}", "company": case["company"], "use_case": case["use_case"], "mode": "offline_ddl" if use_offline else "chat", "started_at": _now_utc_iso(), "finished_at": _now_utc_iso(), "success": False, "error": f"Runner exception: {exc}", "stages": { "runner_exception": { "ok": False, "message": str(exc), } }, } results.append(case_result) (out_dir / f"{case_result['name']}.json").write_text( json.dumps(case_result, indent=2), encoding="utf-8", ) print(f" success={case_result['success']} schema={case_result.get('schema_name')}", flush=True) summary = { "run_id": run_id, "mode": "offline_ddl" if use_offline else "chat", "cases_file": str(cases_file), "total": len(results), "passed": sum(1 for r in results if r.get("success")), "failed": sum(1 for r in results if not r.get("success")), "results": results, } (out_dir / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8") print("\nSaved sample artifacts:", out_dir) print(f"Passed: {summary['passed']} / {summary['total']}") if summary["failed"]: raise SystemExit(1) if __name__ == "__main__": main()