demoprep / tests /newvision_sample_runner.py
mikeboone's picture
feat: March 2026 sprint — new vision merge, pipeline improvements, settings refactor
5ac32c1
#!/usr/bin/env python3
"""Run the New Vision 4-case sample set without UI login flow.
Modes:
1) Full chat pipeline mode (uses configured default_llm and required settings)
2) Offline DDL mode (deterministic schema template, still validates settings up front)
Usage:
source ./demoprep/bin/activate
python tests/newvision_sample_runner.py
python tests/newvision_sample_runner.py --offline-ddl
python tests/newvision_sample_runner.py --skip-thoughtspot
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import yaml
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from dotenv import load_dotenv
load_dotenv(PROJECT_ROOT / ".env")
os.environ.setdefault("DEMOPREP_NO_AUTH", "true")
# Pull admin settings into environment when available.
try:
from supabase_client import inject_admin_settings_to_env
inject_admin_settings_to_env()
except Exception as exc: # noqa: BLE001
print(f"[newvision_runner] Admin setting injection unavailable: {exc}")
OFFLINE_DEMO_DDL = """
CREATE TABLE DIM_DATE (
DATE_KEY INT PRIMARY KEY,
ORDER_DATE DATE,
MONTH_NAME VARCHAR(30),
QUARTER_NAME VARCHAR(10),
YEAR_NUM INT,
IS_WEEKEND BOOLEAN
);
CREATE TABLE DIM_LOCATION (
LOCATION_KEY INT PRIMARY KEY,
COUNTRY VARCHAR(100),
REGION VARCHAR(100),
STATE VARCHAR(100),
CITY VARCHAR(100),
SALES_CHANNEL VARCHAR(100),
CUSTOMER_SEGMENT VARCHAR(100)
);
CREATE TABLE DIM_PRODUCT (
PRODUCT_KEY INT PRIMARY KEY,
PRODUCT_NAME VARCHAR(200),
BRAND_NAME VARCHAR(100),
CATEGORY VARCHAR(100),
SUB_CATEGORY VARCHAR(100),
PRODUCT_TIER VARCHAR(50),
UNIT_PRICE DECIMAL(12,2)
);
CREATE TABLE FACT_RETAIL_DAILY (
TRANSACTION_KEY INT PRIMARY KEY,
DATE_KEY INT,
LOCATION_KEY INT,
PRODUCT_KEY INT,
ORDER_DATE DATE,
ORDER_COUNT INT,
UNITS_SOLD INT,
UNIT_PRICE DECIMAL(12,2),
GROSS_REVENUE DECIMAL(14,2),
NET_REVENUE DECIMAL(14,2),
SALES_AMOUNT DECIMAL(14,2),
DISCOUNT_PCT DECIMAL(5,2),
INVENTORY_ON_HAND INT,
LOST_SALES_USD DECIMAL(14,2),
IS_OOS BOOLEAN,
FOREIGN KEY (DATE_KEY) REFERENCES DIM_DATE(DATE_KEY),
FOREIGN KEY (LOCATION_KEY) REFERENCES DIM_LOCATION(LOCATION_KEY),
FOREIGN KEY (PRODUCT_KEY) REFERENCES DIM_PRODUCT(PRODUCT_KEY)
);
""".strip()
def _now_utc_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _load_cases(cases_file: Path) -> list[dict[str, Any]]:
data = yaml.safe_load(cases_file.read_text(encoding="utf-8")) or {}
return list(data.get("test_cases", []))
def _parse_quality_report_path(message: str) -> str | None:
for line in (message or "").splitlines():
if "Report:" in line:
return line.split("Report:", 1)[1].strip()
if "See report:" in line:
return line.split("See report:", 1)[1].strip()
return None
def _load_quality_report(report_path: str | None) -> dict[str, Any]:
if not report_path:
return {}
path = Path(report_path)
json_path = path.with_suffix(".json")
if json_path.exists():
try:
return json.loads(json_path.read_text(encoding="utf-8"))
except Exception: # noqa: BLE001
return {}
return {}
def _build_quality_gate_stage(report_path: str | None) -> dict[str, Any]:
report = _load_quality_report(report_path)
summary = report.get("summary", {}) if isinstance(report, dict) else {}
passed = report.get("passed") if isinstance(report, dict) else None
return {
"ok": bool(passed) if passed is not None else False,
"report_path": report_path,
"passed": passed,
"summary": {
"semantic_pass_ratio": summary.get("semantic_pass_ratio"),
"categorical_junk_count": summary.get("categorical_junk_count"),
"fk_orphan_count": summary.get("fk_orphan_count"),
"temporal_violations": summary.get("temporal_violations"),
"numeric_violations": summary.get("numeric_violations"),
"volatility_breaches": summary.get("volatility_breaches"),
"smoothness_score": summary.get("smoothness_score"),
"outlier_explainability": summary.get("outlier_explainability"),
"kpi_consistency": summary.get("kpi_consistency"),
},
}
def _resolve_runtime_settings() -> tuple[str, str]:
user_email = (
os.getenv("USER_EMAIL")
or os.getenv("INITIAL_USER")
or os.getenv("THOUGHTSPOT_ADMIN_USER")
or "default@user.com"
).strip()
default_llm = (os.getenv("DEFAULT_LLM") or os.getenv("OPENAI_MODEL") or "").strip()
if not default_llm:
raise ValueError("Missing required env var: DEFAULT_LLM or OPENAI_MODEL")
return user_email, default_llm
def _run_realism_sanity_checks(schema_name: str, case: dict[str, Any]) -> dict[str, Any]:
"""Fast, opinionated sanity checks for demo realism.
These checks intentionally target user-visible demo breakages that can slip
through structural quality gates (e.g., null dimensions in top-N charts).
"""
checks: list[dict[str, Any]] = []
failures: list[str] = []
if not schema_name:
return {"ok": False, "checks": checks, "failures": ["Missing schema name for sanity checks"]}
use_case = str(case.get("use_case", "") or "")
use_case_lower = (use_case or "").lower()
case_name = str(case.get("name", "") or "").lower()
is_legal = "legal" in use_case_lower
is_private_equity = any(
marker in use_case_lower
for marker in ("private equity", "lp reporting", "state street")
)
is_saas_finance = any(
marker in use_case_lower
for marker in ("saas finance", "unit economics", "financial analytics", "fp&a", "fpa")
)
if not is_legal and not is_private_equity and not is_saas_finance:
# Keep runtime fast by evaluating only scoped vertical checks.
return {"ok": True, "checks": checks, "failures": []}
from supabase_client import inject_admin_settings_to_env
from snowflake_auth import get_snowflake_connection
inject_admin_settings_to_env()
conn = None
cur = None
try:
db_name = (os.getenv("SNOWFLAKE_DATABASE") or "DEMOBUILD").strip()
safe_schema = schema_name.replace('"', "")
conn = get_snowflake_connection()
cur = conn.cursor()
cur.execute(f'USE DATABASE "{db_name}"')
cur.execute(f'USE SCHEMA "{safe_schema}"')
if is_legal:
cur.execute("SHOW TABLES")
legal_tables = {str(row[1]).upper() for row in cur.fetchall()}
has_split_legal = {"LEGAL_MATTERS", "OUTSIDE_COUNSEL_INVOICES", "ATTORNEYS", "MATTER_TYPES"}.issubset(legal_tables)
has_event_legal = "LEGAL_SPEND_EVENTS" in legal_tables
if has_split_legal:
# 1) Invoice -> matter -> attorney join coverage.
cur.execute(
"""
SELECT
COUNT(*) AS total_rows,
COUNT_IF(a.ATTORNEY_NAME IS NULL) AS null_rows
FROM OUTSIDE_COUNSEL_INVOICES oci
LEFT JOIN LEGAL_MATTERS lm ON oci.MATTER_ID = lm.MATTER_ID
LEFT JOIN ATTORNEYS a ON lm.ASSIGNED_ATTORNEY_ID = a.ATTORNEY_ID
"""
)
total_rows, null_rows = cur.fetchone()
null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0
checks.append(
{
"name": "legal_attorney_join_null_pct",
"value": round(null_pct, 2),
"threshold": "<= 5.0",
"ok": null_pct <= 5.0,
}
)
if null_pct > 5.0:
failures.append(
f"Attorney join null rate too high: {null_pct:.2f}% (expected <= 5%)"
)
# 1b) Invoice MATTER_ID linkage must be complete.
cur.execute(
"""
SELECT
COUNT(*) AS total_rows,
COUNT_IF(MATTER_ID IS NULL) AS null_rows
FROM OUTSIDE_COUNSEL_INVOICES
"""
)
total_rows, null_rows = cur.fetchone()
null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0
checks.append(
{
"name": "legal_invoice_matter_id_null_pct",
"value": round(null_pct, 2),
"threshold": "== 0.0",
"ok": null_pct == 0.0,
}
)
if null_pct != 0.0:
failures.append(
f"Invoice MATTER_ID null rate is {null_pct:.2f}% (expected 0%)"
)
# 2) Region cardinality should be compact for legal executive demos.
cur.execute("SELECT COUNT(DISTINCT REGION) FROM LEGAL_MATTERS WHERE REGION IS NOT NULL")
region_cardinality = int(cur.fetchone()[0] or 0)
checks.append(
{
"name": "legal_region_distinct_count",
"value": region_cardinality,
"threshold": "<= 6",
"ok": region_cardinality <= 6,
}
)
if region_cardinality > 6:
failures.append(
f"Region cardinality too high: {region_cardinality} distinct values (expected <= 6)"
)
# 3) Firm names should not contain obvious cross-vertical banking/org jargon.
cur.execute(
"""
SELECT COUNT(*)
FROM OUTSIDE_COUNSEL
WHERE REGEXP_LIKE(
LOWER(FIRM_NAME),
'retail banking|consumer lending|digital channels|enterprise operations|regional service'
)
"""
)
bad_firm_count = int(cur.fetchone()[0] or 0)
checks.append(
{
"name": "legal_firm_name_cross_vertical_count",
"value": bad_firm_count,
"threshold": "== 0",
"ok": bad_firm_count == 0,
}
)
if bad_firm_count != 0:
failures.append(
f"Detected {bad_firm_count} cross-vertical/non-legal firm names"
)
# 4) Matter type taxonomy should remain concise and demo-friendly.
cur.execute(
"""
SELECT COUNT(DISTINCT mt.MATTER_TYPE_NAME)
FROM LEGAL_MATTERS lm
LEFT JOIN MATTER_TYPES mt ON lm.MATTER_TYPE_ID = mt.MATTER_TYPE_ID
WHERE mt.MATTER_TYPE_NAME IS NOT NULL
"""
)
matter_type_cardinality = int(cur.fetchone()[0] or 0)
checks.append(
{
"name": "legal_matter_type_distinct_count",
"value": matter_type_cardinality,
"threshold": "<= 15",
"ok": matter_type_cardinality <= 15,
}
)
if matter_type_cardinality > 15:
failures.append(
f"Matter type cardinality too high: {matter_type_cardinality} distinct values (expected <= 15)"
)
elif has_event_legal:
# 1) Attorney dimension join coverage (critical for "Top Attorney by Cost").
cur.execute(
"""
SELECT
COUNT(*) AS total_rows,
COUNT_IF(a.ATTORNEY_NAME IS NULL) AS null_rows
FROM LEGAL_SPEND_EVENTS lse
LEFT JOIN ATTORNEYS a ON lse.ATTORNEY_ID = a.ATTORNEY_ID
"""
)
total_rows, null_rows = cur.fetchone()
null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0
checks.append(
{
"name": "legal_attorney_join_null_pct",
"value": round(null_pct, 2),
"threshold": "<= 5.0",
"ok": null_pct <= 5.0,
}
)
if null_pct > 5.0:
failures.append(
f"Attorney join null rate too high: {null_pct:.2f}% (expected <= 5%)"
)
# 2) Region cardinality should be compact for legal executive demos.
cur.execute("SELECT COUNT(DISTINCT REGION) FROM LEGAL_SPEND_EVENTS WHERE REGION IS NOT NULL")
region_cardinality = int(cur.fetchone()[0] or 0)
checks.append(
{
"name": "legal_region_distinct_count",
"value": region_cardinality,
"threshold": "<= 6",
"ok": region_cardinality <= 6,
}
)
if region_cardinality > 6:
failures.append(
f"Region cardinality too high: {region_cardinality} distinct values (expected <= 6)"
)
# 3) Firm names should not contain obvious cross-vertical banking/org jargon.
cur.execute(
"""
SELECT COUNT(*)
FROM OUTSIDE_COUNSEL_FIRMS
WHERE REGEXP_LIKE(
LOWER(FIRM_NAME),
'retail banking|consumer lending|digital channels|enterprise operations|regional service'
)
"""
)
bad_firm_count = int(cur.fetchone()[0] or 0)
checks.append(
{
"name": "legal_firm_name_cross_vertical_count",
"value": bad_firm_count,
"threshold": "== 0",
"ok": bad_firm_count == 0,
}
)
if bad_firm_count != 0:
failures.append(
f"Detected {bad_firm_count} cross-vertical/non-legal firm names"
)
# 4) Matter type taxonomy should remain concise and demo-friendly.
cur.execute(
"""
SELECT COUNT(DISTINCT mt.MATTER_TYPE_NAME)
FROM LEGAL_SPEND_EVENTS lse
LEFT JOIN MATTER_TYPES mt ON lse.MATTER_TYPE_ID = mt.MATTER_TYPE_ID
WHERE mt.MATTER_TYPE_NAME IS NOT NULL
"""
)
matter_type_cardinality = int(cur.fetchone()[0] or 0)
checks.append(
{
"name": "legal_matter_type_distinct_count",
"value": matter_type_cardinality,
"threshold": "<= 15",
"ok": matter_type_cardinality <= 15,
}
)
if matter_type_cardinality > 15:
failures.append(
f"Matter type cardinality too high: {matter_type_cardinality} distinct values (expected <= 15)"
)
else:
failures.append("Could not find supported legal schema shape for realism checks")
if is_private_equity:
# Guard against semantic leakage where sector/strategy dimensions are
# accidentally populated with company names.
cur.execute(
"""
WITH dim_companies AS (
SELECT DISTINCT COMPANY_NAME
FROM PORTFOLIO_COMPANIES
WHERE COMPANY_NAME IS NOT NULL
),
dim_sectors AS (
SELECT DISTINCT SECTOR_NAME
FROM SECTORS
WHERE SECTOR_NAME IS NOT NULL
),
dim_strategies AS (
SELECT DISTINCT FUND_STRATEGY
FROM FUNDS
WHERE FUND_STRATEGY IS NOT NULL
)
SELECT
(SELECT COUNT(*) FROM dim_sectors),
(SELECT COUNT(*) FROM dim_strategies),
(SELECT COUNT(*) FROM dim_sectors s JOIN dim_companies c ON s.SECTOR_NAME = c.COMPANY_NAME),
(SELECT COUNT(*) FROM dim_strategies f JOIN dim_companies c ON f.FUND_STRATEGY = c.COMPANY_NAME)
"""
)
sector_distinct, strategy_distinct, sector_overlap, strategy_overlap = cur.fetchone()
sector_distinct = int(sector_distinct or 0)
strategy_distinct = int(strategy_distinct or 0)
sector_overlap = int(sector_overlap or 0)
strategy_overlap = int(strategy_overlap or 0)
checks.append(
{
"name": "pe_sector_name_company_overlap_count",
"value": sector_overlap,
"threshold": "== 0",
"ok": sector_overlap == 0,
}
)
if sector_overlap != 0:
failures.append(
f"Sector names overlap company names ({sector_overlap} overlaps); likely mislabeled dimensions"
)
checks.append(
{
"name": "pe_fund_strategy_company_overlap_count",
"value": strategy_overlap,
"threshold": "== 0",
"ok": strategy_overlap == 0,
}
)
if strategy_overlap != 0:
failures.append(
f"Fund strategy values overlap company names ({strategy_overlap} overlaps); likely mislabeled dimensions"
)
checks.append(
{
"name": "pe_sector_distinct_count",
"value": sector_distinct,
"threshold": ">= 4 and <= 20",
"ok": 4 <= sector_distinct <= 20,
}
)
if not (4 <= sector_distinct <= 20):
failures.append(
f"Sector distinct count out of expected demo range: {sector_distinct} (expected 4-20)"
)
checks.append(
{
"name": "pe_fund_strategy_distinct_count",
"value": strategy_distinct,
"threshold": ">= 4 and <= 20",
"ok": 4 <= strategy_distinct <= 20,
}
)
if not (4 <= strategy_distinct <= 20):
failures.append(
f"Fund strategy distinct count out of expected demo range: {strategy_distinct} (expected 4-20)"
)
if case_name == "statestreet_private_equity_lp_reporting":
cur.execute(
"""
SELECT
COUNT(*) AS total_rows,
COUNT_IF(ABS(TOTAL_VALUE_USD - (REPORTED_VALUE_USD + DISTRIBUTIONS_USD)) > 0.01) AS bad_rows
FROM PORTFOLIO_PERFORMANCE
"""
)
total_rows, bad_rows = cur.fetchone()
total_rows = int(total_rows or 0)
bad_rows = int(bad_rows or 0)
identity_ok = total_rows > 0 and bad_rows == 0
checks.append(
{
"name": "pe_total_value_identity_bad_rows",
"value": bad_rows,
"threshold": "== 0",
"ok": identity_ok,
}
)
if not identity_ok:
failures.append(
f"Total value identity broken in {bad_rows} PE fact rows"
)
cur.execute(
"""
SELECT
COUNT(*) AS total_rows,
COUNT_IF(IRR_SUB_LINE_IMPACT_BPS BETWEEN 80 AND 210) AS in_band_rows,
COUNT_IF(ABS(IRR_SUB_LINE_IMPACT_BPS - ((GROSS_IRR - GROSS_IRR_WITHOUT_SUB_LINE) * 10000)) <= 5) AS identity_rows
FROM PORTFOLIO_PERFORMANCE
"""
)
total_rows, in_band_rows, identity_rows = cur.fetchone()
total_rows = int(total_rows or 0)
in_band_rows = int(in_band_rows or 0)
identity_rows = int(identity_rows or 0)
irr_band_ok = total_rows > 0 and in_band_rows == total_rows and identity_rows == total_rows
checks.append(
{
"name": "pe_subscription_line_impact_rows_valid",
"value": {"total": total_rows, "in_band": in_band_rows, "identity": identity_rows},
"threshold": "all rows in 80-210 bps band and identity holds",
"ok": irr_band_ok,
}
)
if not irr_band_ok:
failures.append("Subscription line impact rows do not consistently satisfy PE IRR delta rules")
cur.execute(
"""
SELECT
COUNT(*) AS apex_rows,
MAX(pp.IRR_SUB_LINE_IMPACT_BPS) AS apex_max_bps,
(
SELECT MAX(IRR_SUB_LINE_IMPACT_BPS)
FROM PORTFOLIO_PERFORMANCE
) AS overall_max_bps
FROM PORTFOLIO_PERFORMANCE pp
JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
WHERE LOWER(pc.COMPANY_NAME) = 'apex industrial solutions'
"""
)
apex_rows, apex_max_bps, overall_max_bps = cur.fetchone()
apex_ok = int(apex_rows or 0) > 0 and apex_max_bps is not None and abs(float(apex_max_bps) - 210.0) <= 1.0 and overall_max_bps is not None and abs(float(overall_max_bps) - 210.0) <= 1.0
checks.append(
{
"name": "pe_apex_subscription_line_outlier",
"value": {"rows": int(apex_rows or 0), "apex_max_bps": apex_max_bps, "overall_max_bps": overall_max_bps},
"threshold": "Apex exists and max impact == 210 bps",
"ok": apex_ok,
}
)
if not apex_ok:
failures.append("Apex Industrial Solutions outlier is missing or not set to the expected 210 bps impact")
cur.execute(
"""
WITH covenant_exceptions AS (
SELECT
LOWER(pc.COMPANY_NAME) AS company_name,
LOWER(pp.COVENANT_STATUS) AS covenant_status,
COUNT(*) AS row_count
FROM PORTFOLIO_PERFORMANCE pp
JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
WHERE LOWER(pp.COVENANT_STATUS) <> 'compliant'
GROUP BY 1, 2
)
SELECT
COUNT_IF(company_name = 'meridian specialty chemicals' AND covenant_status = 'waived') AS meridian_waived_groups,
COUNT_IF(company_name <> 'meridian specialty chemicals' OR covenant_status <> 'waived') AS invalid_groups
FROM covenant_exceptions
"""
)
meridian_groups, invalid_groups = cur.fetchone()
meridian_ok = int(meridian_groups or 0) > 0 and int(invalid_groups or 0) == 0
checks.append(
{
"name": "pe_meridian_covenant_exception",
"value": {"meridian_groups": int(meridian_groups or 0), "invalid_groups": int(invalid_groups or 0)},
"threshold": "Meridian only, status waived",
"ok": meridian_ok,
}
)
if not meridian_ok:
failures.append("Meridian Specialty Chemicals is not the sole waived/non-compliant covenant exception")
cur.execute(
"""
WITH sector_perf AS (
SELECT
s.SECTOR_NAME,
AVG(pp.ENTRY_EV_EBITDA_MULTIPLE) AS avg_entry_multiple,
AVG(pp.TOTAL_RETURN_MULTIPLE) AS avg_tvpi
FROM PORTFOLIO_PERFORMANCE pp
JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
JOIN SECTORS s ON pc.SECTOR_ID = s.SECTOR_ID
GROUP BY 1
)
SELECT
MAX(CASE WHEN LOWER(SECTOR_NAME) = 'technology' THEN avg_entry_multiple END) AS tech_entry,
MAX(CASE WHEN LOWER(SECTOR_NAME) = 'technology' THEN avg_tvpi END) AS tech_tvpi,
MAX(CASE WHEN LOWER(SECTOR_NAME) <> 'technology' THEN avg_entry_multiple END) AS other_entry_max,
MAX(CASE WHEN LOWER(SECTOR_NAME) <> 'technology' THEN avg_tvpi END) AS other_tvpi_max
FROM sector_perf
"""
)
tech_entry, tech_tvpi, other_entry_max, other_tvpi_max = cur.fetchone()
tech_sector_ok = (
tech_entry is not None
and tech_tvpi is not None
and other_entry_max is not None
and other_tvpi_max is not None
and float(tech_entry) >= float(other_entry_max)
and float(tech_tvpi) >= float(other_tvpi_max)
)
checks.append(
{
"name": "pe_technology_sector_leads_multiples",
"value": {
"tech_entry": tech_entry,
"tech_tvpi": tech_tvpi,
"other_entry_max": other_entry_max,
"other_tvpi_max": other_tvpi_max,
},
"threshold": "Technology leads average entry and return multiples",
"ok": tech_sector_ok,
}
)
if not tech_sector_ok:
failures.append("Technology sector does not lead entry and return multiples as required by the State Street narrative")
cur.execute(
"""
WITH vintage_rank AS (
SELECT
VINTAGE_YEAR,
SUM(REPORTED_VALUE_USD) AS total_reported_value,
DENSE_RANK() OVER (ORDER BY SUM(REPORTED_VALUE_USD) DESC) AS value_rank
FROM PORTFOLIO_PERFORMANCE
GROUP BY 1
)
SELECT LISTAGG(TO_VARCHAR(VINTAGE_YEAR), ',') WITHIN GROUP (ORDER BY VINTAGE_YEAR)
FROM vintage_rank
WHERE value_rank <= 2
"""
)
top_vintages = cur.fetchone()[0] or ""
top_vintage_set = {part.strip() for part in str(top_vintages).split(",") if part.strip()}
vintage_ok = top_vintage_set == {"2021", "2022"}
checks.append(
{
"name": "pe_top_vintages_reported_value",
"value": sorted(top_vintage_set),
"threshold": "top 2 vintages are 2021 and 2022",
"ok": vintage_ok,
}
)
if not vintage_ok:
failures.append("2021 and 2022 are not the top reported-value vintages")
cur.execute(
"""
WITH healthcare_quarters AS (
SELECT
DATE_TRUNC('quarter', pp.FULL_DATE) AS quarter_start,
AVG(pp.TOTAL_VALUE_USD) AS avg_total_value
FROM PORTFOLIO_PERFORMANCE pp
JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
JOIN SECTORS s ON pc.SECTOR_ID = s.SECTOR_ID
WHERE LOWER(s.SECTOR_NAME) = 'healthcare'
GROUP BY 1
)
SELECT
MAX(CASE WHEN quarter_start = DATE '2024-07-01' THEN avg_total_value END) AS q3_2024_value,
MAX(CASE WHEN quarter_start = DATE '2024-10-01' THEN avg_total_value END) AS q4_2024_value
FROM healthcare_quarters
"""
)
q3_2024_value, q4_2024_value = cur.fetchone()
healthcare_dip_ok = (
q3_2024_value is not None
and q4_2024_value is not None
and float(q4_2024_value) < float(q3_2024_value)
)
checks.append(
{
"name": "pe_healthcare_q4_2024_dip",
"value": {"q3_2024": q3_2024_value, "q4_2024": q4_2024_value},
"threshold": "Q4 2024 healthcare total value lower than Q3 2024",
"ok": healthcare_dip_ok,
}
)
if not healthcare_dip_ok:
failures.append("Healthcare Q4 2024 performance dip is missing")
cur.execute(
"""
WITH company_trends AS (
SELECT
pc.COMPANY_NAME,
FIRST_VALUE(pp.REVENUE_USD) OVER (PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE) AS first_revenue,
LAST_VALUE(pp.REVENUE_USD) OVER (
PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS last_revenue,
FIRST_VALUE(pp.EBITDA_MARGIN_PCT) OVER (PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE) AS first_margin,
LAST_VALUE(pp.EBITDA_MARGIN_PCT) OVER (
PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS last_margin
FROM PORTFOLIO_PERFORMANCE pp
JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID
)
SELECT COUNT(DISTINCT COMPANY_NAME)
FROM company_trends
WHERE last_revenue > first_revenue AND last_margin < first_margin
"""
)
trend_company_count = int(cur.fetchone()[0] or 0)
trend_ok = trend_company_count >= 1
checks.append(
{
"name": "pe_revenue_up_margin_down_company_count",
"value": trend_company_count,
"threshold": ">= 1",
"ok": trend_ok,
}
)
if not trend_ok:
failures.append("No portfolio company shows the required revenue-up / EBITDA-margin-down trend")
if is_saas_finance:
cur.execute("SELECT COUNT(DISTINCT MONTH_KEY) FROM DATES")
month_count = int(cur.fetchone()[0] or 0)
checks.append(
{
"name": "saas_month_count",
"value": month_count,
"threshold": ">= 24",
"ok": month_count >= 24,
}
)
if month_count < 24:
failures.append(f"SaaS finance month count too low: {month_count} (expected >= 24)")
cur.execute("SELECT COUNT(DISTINCT SEGMENT) FROM CUSTOMERS WHERE SEGMENT IS NOT NULL")
segment_count = int(cur.fetchone()[0] or 0)
checks.append(
{
"name": "saas_segment_distinct_count",
"value": segment_count,
"threshold": ">= 3",
"ok": segment_count >= 3,
}
)
if segment_count < 3:
failures.append(f"SaaS finance segment count too low: {segment_count} (expected >= 3)")
cur.execute("SELECT COUNT(DISTINCT REGION) FROM LOCATIONS WHERE REGION IS NOT NULL")
region_count = int(cur.fetchone()[0] or 0)
checks.append(
{
"name": "saas_region_distinct_count",
"value": region_count,
"threshold": ">= 3",
"ok": region_count >= 3,
}
)
if region_count < 3:
failures.append(f"SaaS finance region count too low: {region_count} (expected >= 3)")
cur.execute(
"""
SELECT
COUNT(*) AS total_rows,
COUNT_IF(
ABS(
ENDING_ARR_USD - (
STARTING_ARR_USD + NEW_LOGO_ARR_USD + EXPANSION_ARR_USD
- CONTRACTION_ARR_USD - CHURNED_ARR_USD
)
) > 1.0
) AS bad_arr_rows,
COUNT_IF(ABS((MRR_USD * 12.0) - ENDING_ARR_USD) > 12.0) AS bad_mrr_rows
FROM SAAS_CUSTOMER_MONTHLY
"""
)
total_rows, bad_arr_rows, bad_mrr_rows = cur.fetchone()
total_rows = int(total_rows or 0)
bad_arr_rows = int(bad_arr_rows or 0)
bad_mrr_rows = int(bad_mrr_rows or 0)
arr_identity_ok = total_rows > 0 and bad_arr_rows == 0 and bad_mrr_rows == 0
checks.append(
{
"name": "saas_arr_rollforward_bad_rows",
"value": {"total": total_rows, "bad_arr": bad_arr_rows, "bad_mrr": bad_mrr_rows},
"threshold": "all rows reconcile",
"ok": arr_identity_ok,
}
)
if not arr_identity_ok:
failures.append(
f"SaaS finance ARR identities broken (bad_arr={bad_arr_rows}, bad_mrr={bad_mrr_rows})"
)
cur.execute(
"""
WITH month_counts AS (
SELECT CUSTOMER_KEY, COUNT(DISTINCT MONTH_KEY) AS active_months
FROM SAAS_CUSTOMER_MONTHLY
GROUP BY 1
)
SELECT AVG(active_months), MIN(active_months), MAX(active_months)
FROM month_counts
"""
)
avg_active_months, min_active_months, max_active_months = cur.fetchone()
avg_active_months = float(avg_active_months or 0.0)
min_active_months = int(min_active_months or 0)
max_active_months = int(max_active_months or 0)
density_ok = avg_active_months >= 12.0 and max_active_months >= 20
checks.append(
{
"name": "saas_customer_month_density",
"value": {
"avg_active_months": round(avg_active_months, 2),
"min_active_months": min_active_months,
"max_active_months": max_active_months,
},
"threshold": "avg >= 12.0 and max >= 20",
"ok": density_ok,
}
)
if not density_ok:
failures.append(
f"SaaS finance customer-month density too sparse (avg={avg_active_months:.2f}, max={max_active_months})"
)
cur.execute(
"""
SELECT
COUNT(*) AS total_rows,
COUNT_IF(ABS(TOTAL_S_AND_M_SPEND_USD - (SALES_SPEND_USD + MARKETING_SPEND_USD)) > 1.0) AS bad_rows
FROM SALES_MARKETING_SPEND_MONTHLY
"""
)
spend_total_rows, bad_spend_rows = cur.fetchone()
spend_total_rows = int(spend_total_rows or 0)
bad_spend_rows = int(bad_spend_rows or 0)
spend_ok = spend_total_rows > 0 and bad_spend_rows == 0
checks.append(
{
"name": "saas_spend_identity_bad_rows",
"value": {"total": spend_total_rows, "bad_rows": bad_spend_rows},
"threshold": "== 0",
"ok": spend_ok,
}
)
if not spend_ok:
failures.append(f"SaaS finance spend identity broken in {bad_spend_rows} rows")
except Exception as exc: # noqa: BLE001
failures.append(f"Realism sanity checks failed to execute: {exc}")
finally:
try:
if cur is not None:
cur.close()
except Exception:
pass
try:
if conn is not None:
conn.close()
except Exception:
pass
return {"ok": len(failures) == 0, "checks": checks, "failures": failures}
def _run_case_chat(
case: dict[str, Any],
default_llm: str,
user_email: str,
skip_thoughtspot: bool = False,
) -> dict[str, Any]:
from chat_interface import ChatDemoInterface
from demo_personas import get_use_case_config, parse_use_case
company = case["company"]
use_case = case["use_case"]
model = default_llm
context = case.get("context", "")
controller = ChatDemoInterface(user_email=user_email)
controller.settings["model"] = model
controller.vertical, controller.function = parse_use_case(use_case or "")
controller.use_case_config = get_use_case_config(
controller.vertical or "Generic",
controller.function or "Generic",
)
result: dict[str, Any] = {
"name": case.get("name") or f"{company}_{use_case}",
"company": company,
"use_case": use_case,
"mode": "chat",
"started_at": _now_utc_iso(),
"success": False,
"stages": {},
}
stage_start = datetime.now(timezone.utc)
last_research = None
for update in controller.run_research_streaming(company, use_case, generic_context=context):
last_research = update
result["stages"]["research"] = {
"ok": bool(controller.demo_builder and controller.demo_builder.company_analysis_results),
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"preview": str(last_research)[:500] if last_research else "",
}
stage_start = datetime.now(timezone.utc)
ddl_text = (controller.demo_builder.schema_generation_results or "") if controller.demo_builder else ""
if not ddl_text:
_, ddl_text = controller.run_ddl_creation()
result["stages"]["ddl"] = {
"ok": bool(ddl_text and "CREATE TABLE" in ddl_text.upper()),
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"ddl_length": len(ddl_text or ""),
}
if not result["stages"]["ddl"]["ok"]:
result["error"] = "DDL generation failed"
result["finished_at"] = _now_utc_iso()
return result
stage_start = datetime.now(timezone.utc)
deploy_error = None
try:
for _ in controller.run_deployment_streaming():
pass
except Exception as exc: # noqa: BLE001
deploy_error = str(exc)
deployed_schema = getattr(controller, "_deployed_schema_name", None)
schema_candidate = deployed_schema or getattr(controller, "_last_schema_name", None)
result["stages"]["deploy_snowflake"] = {
"ok": bool(deployed_schema) and deploy_error is None,
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"schema": schema_candidate,
"error": deploy_error,
}
if schema_candidate:
quality_report_path = getattr(controller, "_last_population_quality_report_path", None)
result["stages"]["quality_gate"] = _build_quality_gate_stage(quality_report_path)
if not result["stages"]["quality_gate"]["ok"]:
result["error"] = f"Quality gate failed: {quality_report_path or 'missing quality report'}"
elif deploy_error and not result.get("error"):
result["error"] = deploy_error
stage_start = datetime.now(timezone.utc)
if result["stages"]["quality_gate"]["ok"]:
sanity = _run_realism_sanity_checks(schema_candidate, case)
result["stages"]["realism_sanity"] = {
"ok": bool(sanity.get("ok")),
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"checks": sanity.get("checks", []),
"failures": sanity.get("failures", []),
}
if not result["stages"]["realism_sanity"]["ok"] and not result.get("error"):
result["error"] = "Realism sanity checks failed before ThoughtSpot deployment"
if not skip_thoughtspot and deployed_schema and not result.get("error"):
stage_start = datetime.now(timezone.utc)
ts_ok = True
ts_last = None
try:
for ts_update in controller._run_thoughtspot_deployment(deployed_schema, company, use_case):
ts_last = ts_update
except Exception as exc: # noqa: BLE001
ts_ok = False
ts_last = str(exc)
# Some deployment paths return a structured failure payload rather than
# raising; treat those as failures so pass/fail reporting is accurate.
ts_preview_text = str(ts_last) if ts_last is not None else ""
if ts_ok and (
"THOUGHTSPOT DEPLOYMENT FAILED" in ts_preview_text.upper()
or "MODEL VALIDATION FAILED" in ts_preview_text.upper()
or "LIVEBOARD CREATION FAILED" in ts_preview_text.upper()
):
ts_ok = False
result["stages"]["deploy_thoughtspot"] = {
"ok": ts_ok,
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"preview": ts_preview_text[:1000],
}
result["schema_name"] = schema_candidate
result["success"] = all(stage.get("ok") for stage in result["stages"].values())
result["finished_at"] = _now_utc_iso()
return result
def _run_case_offline(
case: dict[str, Any],
default_llm: str,
user_email: str,
skip_thoughtspot: bool = False,
) -> dict[str, Any]:
from cdw_connector import SnowflakeDeployer
from demo_prep import generate_demo_base_name
from legitdata_bridge import populate_demo_data
from thoughtspot_deployer import deploy_to_thoughtspot
company = case["company"]
use_case = case["use_case"]
result: dict[str, Any] = {
"name": case.get("name") or f"{company}_{use_case}",
"company": company,
"use_case": use_case,
"mode": "offline_ddl",
"started_at": _now_utc_iso(),
"success": False,
"stages": {},
"ddl_template": "offline_star_schema_v1",
}
deployer = SnowflakeDeployer()
# 1) Snowflake schema + DDL deploy
stage_start = datetime.now(timezone.utc)
ok, msg = deployer.connect()
if not ok:
result["stages"]["snowflake_connect"] = {
"ok": False,
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"message": msg,
}
result["error"] = msg
result["finished_at"] = _now_utc_iso()
return result
base_name = generate_demo_base_name("", company)
ok, schema_name, ddl_msg = deployer.create_demo_schema_and_deploy(base_name, OFFLINE_DEMO_DDL)
result["stages"]["snowflake_ddl"] = {
"ok": ok,
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"schema": schema_name,
"message": ddl_msg,
}
if not ok or not schema_name:
result["error"] = ddl_msg
result["finished_at"] = _now_utc_iso()
return result
# 2) Data population via LegitData
stage_start = datetime.now(timezone.utc)
pop_ok, pop_msg, pop_results = populate_demo_data(
ddl_content=OFFLINE_DEMO_DDL,
company_url=company,
use_case=use_case,
schema_name=schema_name,
llm_model=default_llm,
user_email=user_email,
size="medium",
)
result["stages"]["populate_data"] = {
"ok": pop_ok,
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"rows": pop_results,
"quality_report": _parse_quality_report_path(pop_msg),
}
if not pop_ok:
result["error"] = pop_msg
result["finished_at"] = _now_utc_iso()
return result
quality_report_path = _parse_quality_report_path(pop_msg)
result["stages"]["quality_gate"] = _build_quality_gate_stage(quality_report_path)
if not result["stages"]["quality_gate"]["ok"]:
result["error"] = f"Quality gate failed: {quality_report_path or 'missing quality report'}"
result["schema_name"] = schema_name
result["finished_at"] = _now_utc_iso()
return result
stage_start = datetime.now(timezone.utc)
sanity = _run_realism_sanity_checks(schema_name, case)
result["stages"]["realism_sanity"] = {
"ok": bool(sanity.get("ok")),
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"checks": sanity.get("checks", []),
"failures": sanity.get("failures", []),
}
if not result["stages"]["realism_sanity"]["ok"]:
result["error"] = "Realism sanity checks failed before ThoughtSpot deployment"
result["schema_name"] = schema_name
result["finished_at"] = _now_utc_iso()
return result
# 3) ThoughtSpot model + liveboard
if not skip_thoughtspot:
stage_start = datetime.now(timezone.utc)
ts_result = deploy_to_thoughtspot(
ddl=OFFLINE_DEMO_DDL,
database=os.getenv("SNOWFLAKE_DATABASE", "DEMOBUILD"),
schema=schema_name,
base_name=base_name,
connection_name=f"{base_name}_conn",
company_name=company,
use_case=use_case,
llm_model=default_llm,
)
result["stages"]["deploy_thoughtspot"] = {
"ok": bool(ts_result and not ts_result.get("errors")),
"duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(),
"result": ts_result,
}
result["schema_name"] = schema_name
result["success"] = all(stage.get("ok") for stage in result["stages"].values())
result["finished_at"] = _now_utc_iso()
return result
def main() -> None:
parser = argparse.ArgumentParser(description="Run New Vision sample set")
parser.add_argument(
"--cases-file",
default="tests/newvision_test_cases.yaml",
help="Path to YAML test case file",
)
parser.add_argument(
"--skip-thoughtspot",
action="store_true",
help="Run through data generation only and skip ThoughtSpot object creation",
)
parser.add_argument(
"--offline-ddl",
action="store_true",
help="Force offline DDL mode (no LLM dependency)",
)
args = parser.parse_args()
user_email, default_llm = _resolve_runtime_settings()
from startup_validation import validate_required_pipeline_settings_or_raise
validate_required_pipeline_settings_or_raise(
default_llm=default_llm,
require_thoughtspot=not args.skip_thoughtspot,
require_snowflake=True,
)
cases_file = (PROJECT_ROOT / args.cases_file).resolve()
cases = _load_cases(cases_file)
run_id = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
out_dir = PROJECT_ROOT / "results" / "newvision_samples" / run_id
out_dir.mkdir(parents=True, exist_ok=True)
use_offline = bool(args.offline_ddl)
print(f"Mode: {'offline_ddl' if use_offline else 'chat'}", flush=True)
print(f"default_llm: {default_llm}", flush=True)
results = []
for idx, case in enumerate(cases, start=1):
print(f"\n[{idx}/{len(cases)}] {case.get('name', case['company'])} -> {case['use_case']}", flush=True)
try:
if use_offline:
case_result = _run_case_offline(
case,
default_llm=default_llm,
user_email=user_email,
skip_thoughtspot=args.skip_thoughtspot,
)
else:
case_result = _run_case_chat(
case,
default_llm=default_llm,
user_email=user_email,
skip_thoughtspot=args.skip_thoughtspot,
)
except Exception as exc: # noqa: BLE001
case_result = {
"name": case.get("name") or f"{case['company']}_{case['use_case']}",
"company": case["company"],
"use_case": case["use_case"],
"mode": "offline_ddl" if use_offline else "chat",
"started_at": _now_utc_iso(),
"finished_at": _now_utc_iso(),
"success": False,
"error": f"Runner exception: {exc}",
"stages": {
"runner_exception": {
"ok": False,
"message": str(exc),
}
},
}
results.append(case_result)
(out_dir / f"{case_result['name']}.json").write_text(
json.dumps(case_result, indent=2),
encoding="utf-8",
)
print(f" success={case_result['success']} schema={case_result.get('schema_name')}", flush=True)
summary = {
"run_id": run_id,
"mode": "offline_ddl" if use_offline else "chat",
"cases_file": str(cases_file),
"total": len(results),
"passed": sum(1 for r in results if r.get("success")),
"failed": sum(1 for r in results if not r.get("success")),
"results": results,
}
(out_dir / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
print("\nSaved sample artifacts:", out_dir)
print(f"Passed: {summary['passed']} / {summary['total']}")
if summary["failed"]:
raise SystemExit(1)
if __name__ == "__main__":
main()