Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Run the New Vision 4-case sample set without UI login flow. | |
| Modes: | |
| 1) Full chat pipeline mode (uses configured default_llm and required settings) | |
| 2) Offline DDL mode (deterministic schema template, still validates settings up front) | |
| Usage: | |
| source ./demoprep/bin/activate | |
| python tests/newvision_sample_runner.py | |
| python tests/newvision_sample_runner.py --offline-ddl | |
| python tests/newvision_sample_runner.py --skip-thoughtspot | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| PROJECT_ROOT = Path(__file__).parent.parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from dotenv import load_dotenv | |
| load_dotenv(PROJECT_ROOT / ".env") | |
| os.environ.setdefault("DEMOPREP_NO_AUTH", "true") | |
| # Pull admin settings into environment when available. | |
| try: | |
| from supabase_client import inject_admin_settings_to_env | |
| inject_admin_settings_to_env() | |
| except Exception as exc: # noqa: BLE001 | |
| print(f"[newvision_runner] Admin setting injection unavailable: {exc}") | |
| OFFLINE_DEMO_DDL = """ | |
| CREATE TABLE DIM_DATE ( | |
| DATE_KEY INT PRIMARY KEY, | |
| ORDER_DATE DATE, | |
| MONTH_NAME VARCHAR(30), | |
| QUARTER_NAME VARCHAR(10), | |
| YEAR_NUM INT, | |
| IS_WEEKEND BOOLEAN | |
| ); | |
| CREATE TABLE DIM_LOCATION ( | |
| LOCATION_KEY INT PRIMARY KEY, | |
| COUNTRY VARCHAR(100), | |
| REGION VARCHAR(100), | |
| STATE VARCHAR(100), | |
| CITY VARCHAR(100), | |
| SALES_CHANNEL VARCHAR(100), | |
| CUSTOMER_SEGMENT VARCHAR(100) | |
| ); | |
| CREATE TABLE DIM_PRODUCT ( | |
| PRODUCT_KEY INT PRIMARY KEY, | |
| PRODUCT_NAME VARCHAR(200), | |
| BRAND_NAME VARCHAR(100), | |
| CATEGORY VARCHAR(100), | |
| SUB_CATEGORY VARCHAR(100), | |
| PRODUCT_TIER VARCHAR(50), | |
| UNIT_PRICE DECIMAL(12,2) | |
| ); | |
| CREATE TABLE FACT_RETAIL_DAILY ( | |
| TRANSACTION_KEY INT PRIMARY KEY, | |
| DATE_KEY INT, | |
| LOCATION_KEY INT, | |
| PRODUCT_KEY INT, | |
| ORDER_DATE DATE, | |
| ORDER_COUNT INT, | |
| UNITS_SOLD INT, | |
| UNIT_PRICE DECIMAL(12,2), | |
| GROSS_REVENUE DECIMAL(14,2), | |
| NET_REVENUE DECIMAL(14,2), | |
| SALES_AMOUNT DECIMAL(14,2), | |
| DISCOUNT_PCT DECIMAL(5,2), | |
| INVENTORY_ON_HAND INT, | |
| LOST_SALES_USD DECIMAL(14,2), | |
| IS_OOS BOOLEAN, | |
| FOREIGN KEY (DATE_KEY) REFERENCES DIM_DATE(DATE_KEY), | |
| FOREIGN KEY (LOCATION_KEY) REFERENCES DIM_LOCATION(LOCATION_KEY), | |
| FOREIGN KEY (PRODUCT_KEY) REFERENCES DIM_PRODUCT(PRODUCT_KEY) | |
| ); | |
| """.strip() | |
| def _now_utc_iso() -> str: | |
| return datetime.now(timezone.utc).isoformat() | |
| def _load_cases(cases_file: Path) -> list[dict[str, Any]]: | |
| data = yaml.safe_load(cases_file.read_text(encoding="utf-8")) or {} | |
| return list(data.get("test_cases", [])) | |
| def _parse_quality_report_path(message: str) -> str | None: | |
| for line in (message or "").splitlines(): | |
| if "Report:" in line: | |
| return line.split("Report:", 1)[1].strip() | |
| if "See report:" in line: | |
| return line.split("See report:", 1)[1].strip() | |
| return None | |
| def _load_quality_report(report_path: str | None) -> dict[str, Any]: | |
| if not report_path: | |
| return {} | |
| path = Path(report_path) | |
| json_path = path.with_suffix(".json") | |
| if json_path.exists(): | |
| try: | |
| return json.loads(json_path.read_text(encoding="utf-8")) | |
| except Exception: # noqa: BLE001 | |
| return {} | |
| return {} | |
| def _build_quality_gate_stage(report_path: str | None) -> dict[str, Any]: | |
| report = _load_quality_report(report_path) | |
| summary = report.get("summary", {}) if isinstance(report, dict) else {} | |
| passed = report.get("passed") if isinstance(report, dict) else None | |
| return { | |
| "ok": bool(passed) if passed is not None else False, | |
| "report_path": report_path, | |
| "passed": passed, | |
| "summary": { | |
| "semantic_pass_ratio": summary.get("semantic_pass_ratio"), | |
| "categorical_junk_count": summary.get("categorical_junk_count"), | |
| "fk_orphan_count": summary.get("fk_orphan_count"), | |
| "temporal_violations": summary.get("temporal_violations"), | |
| "numeric_violations": summary.get("numeric_violations"), | |
| "volatility_breaches": summary.get("volatility_breaches"), | |
| "smoothness_score": summary.get("smoothness_score"), | |
| "outlier_explainability": summary.get("outlier_explainability"), | |
| "kpi_consistency": summary.get("kpi_consistency"), | |
| }, | |
| } | |
| def _resolve_runtime_settings() -> tuple[str, str]: | |
| user_email = ( | |
| os.getenv("USER_EMAIL") | |
| or os.getenv("INITIAL_USER") | |
| or os.getenv("THOUGHTSPOT_ADMIN_USER") | |
| or "default@user.com" | |
| ).strip() | |
| default_llm = (os.getenv("DEFAULT_LLM") or os.getenv("OPENAI_MODEL") or "").strip() | |
| if not default_llm: | |
| raise ValueError("Missing required env var: DEFAULT_LLM or OPENAI_MODEL") | |
| return user_email, default_llm | |
| def _run_realism_sanity_checks(schema_name: str, case: dict[str, Any]) -> dict[str, Any]: | |
| """Fast, opinionated sanity checks for demo realism. | |
| These checks intentionally target user-visible demo breakages that can slip | |
| through structural quality gates (e.g., null dimensions in top-N charts). | |
| """ | |
| checks: list[dict[str, Any]] = [] | |
| failures: list[str] = [] | |
| if not schema_name: | |
| return {"ok": False, "checks": checks, "failures": ["Missing schema name for sanity checks"]} | |
| use_case = str(case.get("use_case", "") or "") | |
| use_case_lower = (use_case or "").lower() | |
| case_name = str(case.get("name", "") or "").lower() | |
| is_legal = "legal" in use_case_lower | |
| is_private_equity = any( | |
| marker in use_case_lower | |
| for marker in ("private equity", "lp reporting", "state street") | |
| ) | |
| is_saas_finance = any( | |
| marker in use_case_lower | |
| for marker in ("saas finance", "unit economics", "financial analytics", "fp&a", "fpa") | |
| ) | |
| if not is_legal and not is_private_equity and not is_saas_finance: | |
| # Keep runtime fast by evaluating only scoped vertical checks. | |
| return {"ok": True, "checks": checks, "failures": []} | |
| from supabase_client import inject_admin_settings_to_env | |
| from snowflake_auth import get_snowflake_connection | |
| inject_admin_settings_to_env() | |
| conn = None | |
| cur = None | |
| try: | |
| db_name = (os.getenv("SNOWFLAKE_DATABASE") or "DEMOBUILD").strip() | |
| safe_schema = schema_name.replace('"', "") | |
| conn = get_snowflake_connection() | |
| cur = conn.cursor() | |
| cur.execute(f'USE DATABASE "{db_name}"') | |
| cur.execute(f'USE SCHEMA "{safe_schema}"') | |
| if is_legal: | |
| cur.execute("SHOW TABLES") | |
| legal_tables = {str(row[1]).upper() for row in cur.fetchall()} | |
| has_split_legal = {"LEGAL_MATTERS", "OUTSIDE_COUNSEL_INVOICES", "ATTORNEYS", "MATTER_TYPES"}.issubset(legal_tables) | |
| has_event_legal = "LEGAL_SPEND_EVENTS" in legal_tables | |
| if has_split_legal: | |
| # 1) Invoice -> matter -> attorney join coverage. | |
| cur.execute( | |
| """ | |
| SELECT | |
| COUNT(*) AS total_rows, | |
| COUNT_IF(a.ATTORNEY_NAME IS NULL) AS null_rows | |
| FROM OUTSIDE_COUNSEL_INVOICES oci | |
| LEFT JOIN LEGAL_MATTERS lm ON oci.MATTER_ID = lm.MATTER_ID | |
| LEFT JOIN ATTORNEYS a ON lm.ASSIGNED_ATTORNEY_ID = a.ATTORNEY_ID | |
| """ | |
| ) | |
| total_rows, null_rows = cur.fetchone() | |
| null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0 | |
| checks.append( | |
| { | |
| "name": "legal_attorney_join_null_pct", | |
| "value": round(null_pct, 2), | |
| "threshold": "<= 5.0", | |
| "ok": null_pct <= 5.0, | |
| } | |
| ) | |
| if null_pct > 5.0: | |
| failures.append( | |
| f"Attorney join null rate too high: {null_pct:.2f}% (expected <= 5%)" | |
| ) | |
| # 1b) Invoice MATTER_ID linkage must be complete. | |
| cur.execute( | |
| """ | |
| SELECT | |
| COUNT(*) AS total_rows, | |
| COUNT_IF(MATTER_ID IS NULL) AS null_rows | |
| FROM OUTSIDE_COUNSEL_INVOICES | |
| """ | |
| ) | |
| total_rows, null_rows = cur.fetchone() | |
| null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0 | |
| checks.append( | |
| { | |
| "name": "legal_invoice_matter_id_null_pct", | |
| "value": round(null_pct, 2), | |
| "threshold": "== 0.0", | |
| "ok": null_pct == 0.0, | |
| } | |
| ) | |
| if null_pct != 0.0: | |
| failures.append( | |
| f"Invoice MATTER_ID null rate is {null_pct:.2f}% (expected 0%)" | |
| ) | |
| # 2) Region cardinality should be compact for legal executive demos. | |
| cur.execute("SELECT COUNT(DISTINCT REGION) FROM LEGAL_MATTERS WHERE REGION IS NOT NULL") | |
| region_cardinality = int(cur.fetchone()[0] or 0) | |
| checks.append( | |
| { | |
| "name": "legal_region_distinct_count", | |
| "value": region_cardinality, | |
| "threshold": "<= 6", | |
| "ok": region_cardinality <= 6, | |
| } | |
| ) | |
| if region_cardinality > 6: | |
| failures.append( | |
| f"Region cardinality too high: {region_cardinality} distinct values (expected <= 6)" | |
| ) | |
| # 3) Firm names should not contain obvious cross-vertical banking/org jargon. | |
| cur.execute( | |
| """ | |
| SELECT COUNT(*) | |
| FROM OUTSIDE_COUNSEL | |
| WHERE REGEXP_LIKE( | |
| LOWER(FIRM_NAME), | |
| 'retail banking|consumer lending|digital channels|enterprise operations|regional service' | |
| ) | |
| """ | |
| ) | |
| bad_firm_count = int(cur.fetchone()[0] or 0) | |
| checks.append( | |
| { | |
| "name": "legal_firm_name_cross_vertical_count", | |
| "value": bad_firm_count, | |
| "threshold": "== 0", | |
| "ok": bad_firm_count == 0, | |
| } | |
| ) | |
| if bad_firm_count != 0: | |
| failures.append( | |
| f"Detected {bad_firm_count} cross-vertical/non-legal firm names" | |
| ) | |
| # 4) Matter type taxonomy should remain concise and demo-friendly. | |
| cur.execute( | |
| """ | |
| SELECT COUNT(DISTINCT mt.MATTER_TYPE_NAME) | |
| FROM LEGAL_MATTERS lm | |
| LEFT JOIN MATTER_TYPES mt ON lm.MATTER_TYPE_ID = mt.MATTER_TYPE_ID | |
| WHERE mt.MATTER_TYPE_NAME IS NOT NULL | |
| """ | |
| ) | |
| matter_type_cardinality = int(cur.fetchone()[0] or 0) | |
| checks.append( | |
| { | |
| "name": "legal_matter_type_distinct_count", | |
| "value": matter_type_cardinality, | |
| "threshold": "<= 15", | |
| "ok": matter_type_cardinality <= 15, | |
| } | |
| ) | |
| if matter_type_cardinality > 15: | |
| failures.append( | |
| f"Matter type cardinality too high: {matter_type_cardinality} distinct values (expected <= 15)" | |
| ) | |
| elif has_event_legal: | |
| # 1) Attorney dimension join coverage (critical for "Top Attorney by Cost"). | |
| cur.execute( | |
| """ | |
| SELECT | |
| COUNT(*) AS total_rows, | |
| COUNT_IF(a.ATTORNEY_NAME IS NULL) AS null_rows | |
| FROM LEGAL_SPEND_EVENTS lse | |
| LEFT JOIN ATTORNEYS a ON lse.ATTORNEY_ID = a.ATTORNEY_ID | |
| """ | |
| ) | |
| total_rows, null_rows = cur.fetchone() | |
| null_pct = (float(null_rows) * 100.0 / float(total_rows)) if total_rows else 100.0 | |
| checks.append( | |
| { | |
| "name": "legal_attorney_join_null_pct", | |
| "value": round(null_pct, 2), | |
| "threshold": "<= 5.0", | |
| "ok": null_pct <= 5.0, | |
| } | |
| ) | |
| if null_pct > 5.0: | |
| failures.append( | |
| f"Attorney join null rate too high: {null_pct:.2f}% (expected <= 5%)" | |
| ) | |
| # 2) Region cardinality should be compact for legal executive demos. | |
| cur.execute("SELECT COUNT(DISTINCT REGION) FROM LEGAL_SPEND_EVENTS WHERE REGION IS NOT NULL") | |
| region_cardinality = int(cur.fetchone()[0] or 0) | |
| checks.append( | |
| { | |
| "name": "legal_region_distinct_count", | |
| "value": region_cardinality, | |
| "threshold": "<= 6", | |
| "ok": region_cardinality <= 6, | |
| } | |
| ) | |
| if region_cardinality > 6: | |
| failures.append( | |
| f"Region cardinality too high: {region_cardinality} distinct values (expected <= 6)" | |
| ) | |
| # 3) Firm names should not contain obvious cross-vertical banking/org jargon. | |
| cur.execute( | |
| """ | |
| SELECT COUNT(*) | |
| FROM OUTSIDE_COUNSEL_FIRMS | |
| WHERE REGEXP_LIKE( | |
| LOWER(FIRM_NAME), | |
| 'retail banking|consumer lending|digital channels|enterprise operations|regional service' | |
| ) | |
| """ | |
| ) | |
| bad_firm_count = int(cur.fetchone()[0] or 0) | |
| checks.append( | |
| { | |
| "name": "legal_firm_name_cross_vertical_count", | |
| "value": bad_firm_count, | |
| "threshold": "== 0", | |
| "ok": bad_firm_count == 0, | |
| } | |
| ) | |
| if bad_firm_count != 0: | |
| failures.append( | |
| f"Detected {bad_firm_count} cross-vertical/non-legal firm names" | |
| ) | |
| # 4) Matter type taxonomy should remain concise and demo-friendly. | |
| cur.execute( | |
| """ | |
| SELECT COUNT(DISTINCT mt.MATTER_TYPE_NAME) | |
| FROM LEGAL_SPEND_EVENTS lse | |
| LEFT JOIN MATTER_TYPES mt ON lse.MATTER_TYPE_ID = mt.MATTER_TYPE_ID | |
| WHERE mt.MATTER_TYPE_NAME IS NOT NULL | |
| """ | |
| ) | |
| matter_type_cardinality = int(cur.fetchone()[0] or 0) | |
| checks.append( | |
| { | |
| "name": "legal_matter_type_distinct_count", | |
| "value": matter_type_cardinality, | |
| "threshold": "<= 15", | |
| "ok": matter_type_cardinality <= 15, | |
| } | |
| ) | |
| if matter_type_cardinality > 15: | |
| failures.append( | |
| f"Matter type cardinality too high: {matter_type_cardinality} distinct values (expected <= 15)" | |
| ) | |
| else: | |
| failures.append("Could not find supported legal schema shape for realism checks") | |
| if is_private_equity: | |
| # Guard against semantic leakage where sector/strategy dimensions are | |
| # accidentally populated with company names. | |
| cur.execute( | |
| """ | |
| WITH dim_companies AS ( | |
| SELECT DISTINCT COMPANY_NAME | |
| FROM PORTFOLIO_COMPANIES | |
| WHERE COMPANY_NAME IS NOT NULL | |
| ), | |
| dim_sectors AS ( | |
| SELECT DISTINCT SECTOR_NAME | |
| FROM SECTORS | |
| WHERE SECTOR_NAME IS NOT NULL | |
| ), | |
| dim_strategies AS ( | |
| SELECT DISTINCT FUND_STRATEGY | |
| FROM FUNDS | |
| WHERE FUND_STRATEGY IS NOT NULL | |
| ) | |
| SELECT | |
| (SELECT COUNT(*) FROM dim_sectors), | |
| (SELECT COUNT(*) FROM dim_strategies), | |
| (SELECT COUNT(*) FROM dim_sectors s JOIN dim_companies c ON s.SECTOR_NAME = c.COMPANY_NAME), | |
| (SELECT COUNT(*) FROM dim_strategies f JOIN dim_companies c ON f.FUND_STRATEGY = c.COMPANY_NAME) | |
| """ | |
| ) | |
| sector_distinct, strategy_distinct, sector_overlap, strategy_overlap = cur.fetchone() | |
| sector_distinct = int(sector_distinct or 0) | |
| strategy_distinct = int(strategy_distinct or 0) | |
| sector_overlap = int(sector_overlap or 0) | |
| strategy_overlap = int(strategy_overlap or 0) | |
| checks.append( | |
| { | |
| "name": "pe_sector_name_company_overlap_count", | |
| "value": sector_overlap, | |
| "threshold": "== 0", | |
| "ok": sector_overlap == 0, | |
| } | |
| ) | |
| if sector_overlap != 0: | |
| failures.append( | |
| f"Sector names overlap company names ({sector_overlap} overlaps); likely mislabeled dimensions" | |
| ) | |
| checks.append( | |
| { | |
| "name": "pe_fund_strategy_company_overlap_count", | |
| "value": strategy_overlap, | |
| "threshold": "== 0", | |
| "ok": strategy_overlap == 0, | |
| } | |
| ) | |
| if strategy_overlap != 0: | |
| failures.append( | |
| f"Fund strategy values overlap company names ({strategy_overlap} overlaps); likely mislabeled dimensions" | |
| ) | |
| checks.append( | |
| { | |
| "name": "pe_sector_distinct_count", | |
| "value": sector_distinct, | |
| "threshold": ">= 4 and <= 20", | |
| "ok": 4 <= sector_distinct <= 20, | |
| } | |
| ) | |
| if not (4 <= sector_distinct <= 20): | |
| failures.append( | |
| f"Sector distinct count out of expected demo range: {sector_distinct} (expected 4-20)" | |
| ) | |
| checks.append( | |
| { | |
| "name": "pe_fund_strategy_distinct_count", | |
| "value": strategy_distinct, | |
| "threshold": ">= 4 and <= 20", | |
| "ok": 4 <= strategy_distinct <= 20, | |
| } | |
| ) | |
| if not (4 <= strategy_distinct <= 20): | |
| failures.append( | |
| f"Fund strategy distinct count out of expected demo range: {strategy_distinct} (expected 4-20)" | |
| ) | |
| if case_name == "statestreet_private_equity_lp_reporting": | |
| cur.execute( | |
| """ | |
| SELECT | |
| COUNT(*) AS total_rows, | |
| COUNT_IF(ABS(TOTAL_VALUE_USD - (REPORTED_VALUE_USD + DISTRIBUTIONS_USD)) > 0.01) AS bad_rows | |
| FROM PORTFOLIO_PERFORMANCE | |
| """ | |
| ) | |
| total_rows, bad_rows = cur.fetchone() | |
| total_rows = int(total_rows or 0) | |
| bad_rows = int(bad_rows or 0) | |
| identity_ok = total_rows > 0 and bad_rows == 0 | |
| checks.append( | |
| { | |
| "name": "pe_total_value_identity_bad_rows", | |
| "value": bad_rows, | |
| "threshold": "== 0", | |
| "ok": identity_ok, | |
| } | |
| ) | |
| if not identity_ok: | |
| failures.append( | |
| f"Total value identity broken in {bad_rows} PE fact rows" | |
| ) | |
| cur.execute( | |
| """ | |
| SELECT | |
| COUNT(*) AS total_rows, | |
| COUNT_IF(IRR_SUB_LINE_IMPACT_BPS BETWEEN 80 AND 210) AS in_band_rows, | |
| COUNT_IF(ABS(IRR_SUB_LINE_IMPACT_BPS - ((GROSS_IRR - GROSS_IRR_WITHOUT_SUB_LINE) * 10000)) <= 5) AS identity_rows | |
| FROM PORTFOLIO_PERFORMANCE | |
| """ | |
| ) | |
| total_rows, in_band_rows, identity_rows = cur.fetchone() | |
| total_rows = int(total_rows or 0) | |
| in_band_rows = int(in_band_rows or 0) | |
| identity_rows = int(identity_rows or 0) | |
| irr_band_ok = total_rows > 0 and in_band_rows == total_rows and identity_rows == total_rows | |
| checks.append( | |
| { | |
| "name": "pe_subscription_line_impact_rows_valid", | |
| "value": {"total": total_rows, "in_band": in_band_rows, "identity": identity_rows}, | |
| "threshold": "all rows in 80-210 bps band and identity holds", | |
| "ok": irr_band_ok, | |
| } | |
| ) | |
| if not irr_band_ok: | |
| failures.append("Subscription line impact rows do not consistently satisfy PE IRR delta rules") | |
| cur.execute( | |
| """ | |
| SELECT | |
| COUNT(*) AS apex_rows, | |
| MAX(pp.IRR_SUB_LINE_IMPACT_BPS) AS apex_max_bps, | |
| ( | |
| SELECT MAX(IRR_SUB_LINE_IMPACT_BPS) | |
| FROM PORTFOLIO_PERFORMANCE | |
| ) AS overall_max_bps | |
| FROM PORTFOLIO_PERFORMANCE pp | |
| JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID | |
| WHERE LOWER(pc.COMPANY_NAME) = 'apex industrial solutions' | |
| """ | |
| ) | |
| apex_rows, apex_max_bps, overall_max_bps = cur.fetchone() | |
| apex_ok = int(apex_rows or 0) > 0 and apex_max_bps is not None and abs(float(apex_max_bps) - 210.0) <= 1.0 and overall_max_bps is not None and abs(float(overall_max_bps) - 210.0) <= 1.0 | |
| checks.append( | |
| { | |
| "name": "pe_apex_subscription_line_outlier", | |
| "value": {"rows": int(apex_rows or 0), "apex_max_bps": apex_max_bps, "overall_max_bps": overall_max_bps}, | |
| "threshold": "Apex exists and max impact == 210 bps", | |
| "ok": apex_ok, | |
| } | |
| ) | |
| if not apex_ok: | |
| failures.append("Apex Industrial Solutions outlier is missing or not set to the expected 210 bps impact") | |
| cur.execute( | |
| """ | |
| WITH covenant_exceptions AS ( | |
| SELECT | |
| LOWER(pc.COMPANY_NAME) AS company_name, | |
| LOWER(pp.COVENANT_STATUS) AS covenant_status, | |
| COUNT(*) AS row_count | |
| FROM PORTFOLIO_PERFORMANCE pp | |
| JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID | |
| WHERE LOWER(pp.COVENANT_STATUS) <> 'compliant' | |
| GROUP BY 1, 2 | |
| ) | |
| SELECT | |
| COUNT_IF(company_name = 'meridian specialty chemicals' AND covenant_status = 'waived') AS meridian_waived_groups, | |
| COUNT_IF(company_name <> 'meridian specialty chemicals' OR covenant_status <> 'waived') AS invalid_groups | |
| FROM covenant_exceptions | |
| """ | |
| ) | |
| meridian_groups, invalid_groups = cur.fetchone() | |
| meridian_ok = int(meridian_groups or 0) > 0 and int(invalid_groups or 0) == 0 | |
| checks.append( | |
| { | |
| "name": "pe_meridian_covenant_exception", | |
| "value": {"meridian_groups": int(meridian_groups or 0), "invalid_groups": int(invalid_groups or 0)}, | |
| "threshold": "Meridian only, status waived", | |
| "ok": meridian_ok, | |
| } | |
| ) | |
| if not meridian_ok: | |
| failures.append("Meridian Specialty Chemicals is not the sole waived/non-compliant covenant exception") | |
| cur.execute( | |
| """ | |
| WITH sector_perf AS ( | |
| SELECT | |
| s.SECTOR_NAME, | |
| AVG(pp.ENTRY_EV_EBITDA_MULTIPLE) AS avg_entry_multiple, | |
| AVG(pp.TOTAL_RETURN_MULTIPLE) AS avg_tvpi | |
| FROM PORTFOLIO_PERFORMANCE pp | |
| JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID | |
| JOIN SECTORS s ON pc.SECTOR_ID = s.SECTOR_ID | |
| GROUP BY 1 | |
| ) | |
| SELECT | |
| MAX(CASE WHEN LOWER(SECTOR_NAME) = 'technology' THEN avg_entry_multiple END) AS tech_entry, | |
| MAX(CASE WHEN LOWER(SECTOR_NAME) = 'technology' THEN avg_tvpi END) AS tech_tvpi, | |
| MAX(CASE WHEN LOWER(SECTOR_NAME) <> 'technology' THEN avg_entry_multiple END) AS other_entry_max, | |
| MAX(CASE WHEN LOWER(SECTOR_NAME) <> 'technology' THEN avg_tvpi END) AS other_tvpi_max | |
| FROM sector_perf | |
| """ | |
| ) | |
| tech_entry, tech_tvpi, other_entry_max, other_tvpi_max = cur.fetchone() | |
| tech_sector_ok = ( | |
| tech_entry is not None | |
| and tech_tvpi is not None | |
| and other_entry_max is not None | |
| and other_tvpi_max is not None | |
| and float(tech_entry) >= float(other_entry_max) | |
| and float(tech_tvpi) >= float(other_tvpi_max) | |
| ) | |
| checks.append( | |
| { | |
| "name": "pe_technology_sector_leads_multiples", | |
| "value": { | |
| "tech_entry": tech_entry, | |
| "tech_tvpi": tech_tvpi, | |
| "other_entry_max": other_entry_max, | |
| "other_tvpi_max": other_tvpi_max, | |
| }, | |
| "threshold": "Technology leads average entry and return multiples", | |
| "ok": tech_sector_ok, | |
| } | |
| ) | |
| if not tech_sector_ok: | |
| failures.append("Technology sector does not lead entry and return multiples as required by the State Street narrative") | |
| cur.execute( | |
| """ | |
| WITH vintage_rank AS ( | |
| SELECT | |
| VINTAGE_YEAR, | |
| SUM(REPORTED_VALUE_USD) AS total_reported_value, | |
| DENSE_RANK() OVER (ORDER BY SUM(REPORTED_VALUE_USD) DESC) AS value_rank | |
| FROM PORTFOLIO_PERFORMANCE | |
| GROUP BY 1 | |
| ) | |
| SELECT LISTAGG(TO_VARCHAR(VINTAGE_YEAR), ',') WITHIN GROUP (ORDER BY VINTAGE_YEAR) | |
| FROM vintage_rank | |
| WHERE value_rank <= 2 | |
| """ | |
| ) | |
| top_vintages = cur.fetchone()[0] or "" | |
| top_vintage_set = {part.strip() for part in str(top_vintages).split(",") if part.strip()} | |
| vintage_ok = top_vintage_set == {"2021", "2022"} | |
| checks.append( | |
| { | |
| "name": "pe_top_vintages_reported_value", | |
| "value": sorted(top_vintage_set), | |
| "threshold": "top 2 vintages are 2021 and 2022", | |
| "ok": vintage_ok, | |
| } | |
| ) | |
| if not vintage_ok: | |
| failures.append("2021 and 2022 are not the top reported-value vintages") | |
| cur.execute( | |
| """ | |
| WITH healthcare_quarters AS ( | |
| SELECT | |
| DATE_TRUNC('quarter', pp.FULL_DATE) AS quarter_start, | |
| AVG(pp.TOTAL_VALUE_USD) AS avg_total_value | |
| FROM PORTFOLIO_PERFORMANCE pp | |
| JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID | |
| JOIN SECTORS s ON pc.SECTOR_ID = s.SECTOR_ID | |
| WHERE LOWER(s.SECTOR_NAME) = 'healthcare' | |
| GROUP BY 1 | |
| ) | |
| SELECT | |
| MAX(CASE WHEN quarter_start = DATE '2024-07-01' THEN avg_total_value END) AS q3_2024_value, | |
| MAX(CASE WHEN quarter_start = DATE '2024-10-01' THEN avg_total_value END) AS q4_2024_value | |
| FROM healthcare_quarters | |
| """ | |
| ) | |
| q3_2024_value, q4_2024_value = cur.fetchone() | |
| healthcare_dip_ok = ( | |
| q3_2024_value is not None | |
| and q4_2024_value is not None | |
| and float(q4_2024_value) < float(q3_2024_value) | |
| ) | |
| checks.append( | |
| { | |
| "name": "pe_healthcare_q4_2024_dip", | |
| "value": {"q3_2024": q3_2024_value, "q4_2024": q4_2024_value}, | |
| "threshold": "Q4 2024 healthcare total value lower than Q3 2024", | |
| "ok": healthcare_dip_ok, | |
| } | |
| ) | |
| if not healthcare_dip_ok: | |
| failures.append("Healthcare Q4 2024 performance dip is missing") | |
| cur.execute( | |
| """ | |
| WITH company_trends AS ( | |
| SELECT | |
| pc.COMPANY_NAME, | |
| FIRST_VALUE(pp.REVENUE_USD) OVER (PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE) AS first_revenue, | |
| LAST_VALUE(pp.REVENUE_USD) OVER ( | |
| PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE | |
| ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING | |
| ) AS last_revenue, | |
| FIRST_VALUE(pp.EBITDA_MARGIN_PCT) OVER (PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE) AS first_margin, | |
| LAST_VALUE(pp.EBITDA_MARGIN_PCT) OVER ( | |
| PARTITION BY pc.COMPANY_NAME ORDER BY pp.FULL_DATE | |
| ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING | |
| ) AS last_margin | |
| FROM PORTFOLIO_PERFORMANCE pp | |
| JOIN PORTFOLIO_COMPANIES pc ON pp.COMPANY_ID = pc.COMPANY_ID | |
| ) | |
| SELECT COUNT(DISTINCT COMPANY_NAME) | |
| FROM company_trends | |
| WHERE last_revenue > first_revenue AND last_margin < first_margin | |
| """ | |
| ) | |
| trend_company_count = int(cur.fetchone()[0] or 0) | |
| trend_ok = trend_company_count >= 1 | |
| checks.append( | |
| { | |
| "name": "pe_revenue_up_margin_down_company_count", | |
| "value": trend_company_count, | |
| "threshold": ">= 1", | |
| "ok": trend_ok, | |
| } | |
| ) | |
| if not trend_ok: | |
| failures.append("No portfolio company shows the required revenue-up / EBITDA-margin-down trend") | |
| if is_saas_finance: | |
| cur.execute("SELECT COUNT(DISTINCT MONTH_KEY) FROM DATES") | |
| month_count = int(cur.fetchone()[0] or 0) | |
| checks.append( | |
| { | |
| "name": "saas_month_count", | |
| "value": month_count, | |
| "threshold": ">= 24", | |
| "ok": month_count >= 24, | |
| } | |
| ) | |
| if month_count < 24: | |
| failures.append(f"SaaS finance month count too low: {month_count} (expected >= 24)") | |
| cur.execute("SELECT COUNT(DISTINCT SEGMENT) FROM CUSTOMERS WHERE SEGMENT IS NOT NULL") | |
| segment_count = int(cur.fetchone()[0] or 0) | |
| checks.append( | |
| { | |
| "name": "saas_segment_distinct_count", | |
| "value": segment_count, | |
| "threshold": ">= 3", | |
| "ok": segment_count >= 3, | |
| } | |
| ) | |
| if segment_count < 3: | |
| failures.append(f"SaaS finance segment count too low: {segment_count} (expected >= 3)") | |
| cur.execute("SELECT COUNT(DISTINCT REGION) FROM LOCATIONS WHERE REGION IS NOT NULL") | |
| region_count = int(cur.fetchone()[0] or 0) | |
| checks.append( | |
| { | |
| "name": "saas_region_distinct_count", | |
| "value": region_count, | |
| "threshold": ">= 3", | |
| "ok": region_count >= 3, | |
| } | |
| ) | |
| if region_count < 3: | |
| failures.append(f"SaaS finance region count too low: {region_count} (expected >= 3)") | |
| cur.execute( | |
| """ | |
| SELECT | |
| COUNT(*) AS total_rows, | |
| COUNT_IF( | |
| ABS( | |
| ENDING_ARR_USD - ( | |
| STARTING_ARR_USD + NEW_LOGO_ARR_USD + EXPANSION_ARR_USD | |
| - CONTRACTION_ARR_USD - CHURNED_ARR_USD | |
| ) | |
| ) > 1.0 | |
| ) AS bad_arr_rows, | |
| COUNT_IF(ABS((MRR_USD * 12.0) - ENDING_ARR_USD) > 12.0) AS bad_mrr_rows | |
| FROM SAAS_CUSTOMER_MONTHLY | |
| """ | |
| ) | |
| total_rows, bad_arr_rows, bad_mrr_rows = cur.fetchone() | |
| total_rows = int(total_rows or 0) | |
| bad_arr_rows = int(bad_arr_rows or 0) | |
| bad_mrr_rows = int(bad_mrr_rows or 0) | |
| arr_identity_ok = total_rows > 0 and bad_arr_rows == 0 and bad_mrr_rows == 0 | |
| checks.append( | |
| { | |
| "name": "saas_arr_rollforward_bad_rows", | |
| "value": {"total": total_rows, "bad_arr": bad_arr_rows, "bad_mrr": bad_mrr_rows}, | |
| "threshold": "all rows reconcile", | |
| "ok": arr_identity_ok, | |
| } | |
| ) | |
| if not arr_identity_ok: | |
| failures.append( | |
| f"SaaS finance ARR identities broken (bad_arr={bad_arr_rows}, bad_mrr={bad_mrr_rows})" | |
| ) | |
| cur.execute( | |
| """ | |
| WITH month_counts AS ( | |
| SELECT CUSTOMER_KEY, COUNT(DISTINCT MONTH_KEY) AS active_months | |
| FROM SAAS_CUSTOMER_MONTHLY | |
| GROUP BY 1 | |
| ) | |
| SELECT AVG(active_months), MIN(active_months), MAX(active_months) | |
| FROM month_counts | |
| """ | |
| ) | |
| avg_active_months, min_active_months, max_active_months = cur.fetchone() | |
| avg_active_months = float(avg_active_months or 0.0) | |
| min_active_months = int(min_active_months or 0) | |
| max_active_months = int(max_active_months or 0) | |
| density_ok = avg_active_months >= 12.0 and max_active_months >= 20 | |
| checks.append( | |
| { | |
| "name": "saas_customer_month_density", | |
| "value": { | |
| "avg_active_months": round(avg_active_months, 2), | |
| "min_active_months": min_active_months, | |
| "max_active_months": max_active_months, | |
| }, | |
| "threshold": "avg >= 12.0 and max >= 20", | |
| "ok": density_ok, | |
| } | |
| ) | |
| if not density_ok: | |
| failures.append( | |
| f"SaaS finance customer-month density too sparse (avg={avg_active_months:.2f}, max={max_active_months})" | |
| ) | |
| cur.execute( | |
| """ | |
| SELECT | |
| COUNT(*) AS total_rows, | |
| COUNT_IF(ABS(TOTAL_S_AND_M_SPEND_USD - (SALES_SPEND_USD + MARKETING_SPEND_USD)) > 1.0) AS bad_rows | |
| FROM SALES_MARKETING_SPEND_MONTHLY | |
| """ | |
| ) | |
| spend_total_rows, bad_spend_rows = cur.fetchone() | |
| spend_total_rows = int(spend_total_rows or 0) | |
| bad_spend_rows = int(bad_spend_rows or 0) | |
| spend_ok = spend_total_rows > 0 and bad_spend_rows == 0 | |
| checks.append( | |
| { | |
| "name": "saas_spend_identity_bad_rows", | |
| "value": {"total": spend_total_rows, "bad_rows": bad_spend_rows}, | |
| "threshold": "== 0", | |
| "ok": spend_ok, | |
| } | |
| ) | |
| if not spend_ok: | |
| failures.append(f"SaaS finance spend identity broken in {bad_spend_rows} rows") | |
| except Exception as exc: # noqa: BLE001 | |
| failures.append(f"Realism sanity checks failed to execute: {exc}") | |
| finally: | |
| try: | |
| if cur is not None: | |
| cur.close() | |
| except Exception: | |
| pass | |
| try: | |
| if conn is not None: | |
| conn.close() | |
| except Exception: | |
| pass | |
| return {"ok": len(failures) == 0, "checks": checks, "failures": failures} | |
| def _run_case_chat( | |
| case: dict[str, Any], | |
| default_llm: str, | |
| user_email: str, | |
| skip_thoughtspot: bool = False, | |
| ) -> dict[str, Any]: | |
| from chat_interface import ChatDemoInterface | |
| from demo_personas import get_use_case_config, parse_use_case | |
| company = case["company"] | |
| use_case = case["use_case"] | |
| model = default_llm | |
| context = case.get("context", "") | |
| controller = ChatDemoInterface(user_email=user_email) | |
| controller.settings["model"] = model | |
| controller.vertical, controller.function = parse_use_case(use_case or "") | |
| controller.use_case_config = get_use_case_config( | |
| controller.vertical or "Generic", | |
| controller.function or "Generic", | |
| ) | |
| result: dict[str, Any] = { | |
| "name": case.get("name") or f"{company}_{use_case}", | |
| "company": company, | |
| "use_case": use_case, | |
| "mode": "chat", | |
| "started_at": _now_utc_iso(), | |
| "success": False, | |
| "stages": {}, | |
| } | |
| stage_start = datetime.now(timezone.utc) | |
| last_research = None | |
| for update in controller.run_research_streaming(company, use_case, generic_context=context): | |
| last_research = update | |
| result["stages"]["research"] = { | |
| "ok": bool(controller.demo_builder and controller.demo_builder.company_analysis_results), | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "preview": str(last_research)[:500] if last_research else "", | |
| } | |
| stage_start = datetime.now(timezone.utc) | |
| ddl_text = (controller.demo_builder.schema_generation_results or "") if controller.demo_builder else "" | |
| if not ddl_text: | |
| _, ddl_text = controller.run_ddl_creation() | |
| result["stages"]["ddl"] = { | |
| "ok": bool(ddl_text and "CREATE TABLE" in ddl_text.upper()), | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "ddl_length": len(ddl_text or ""), | |
| } | |
| if not result["stages"]["ddl"]["ok"]: | |
| result["error"] = "DDL generation failed" | |
| result["finished_at"] = _now_utc_iso() | |
| return result | |
| stage_start = datetime.now(timezone.utc) | |
| deploy_error = None | |
| try: | |
| for _ in controller.run_deployment_streaming(): | |
| pass | |
| except Exception as exc: # noqa: BLE001 | |
| deploy_error = str(exc) | |
| deployed_schema = getattr(controller, "_deployed_schema_name", None) | |
| schema_candidate = deployed_schema or getattr(controller, "_last_schema_name", None) | |
| result["stages"]["deploy_snowflake"] = { | |
| "ok": bool(deployed_schema) and deploy_error is None, | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "schema": schema_candidate, | |
| "error": deploy_error, | |
| } | |
| if schema_candidate: | |
| quality_report_path = getattr(controller, "_last_population_quality_report_path", None) | |
| result["stages"]["quality_gate"] = _build_quality_gate_stage(quality_report_path) | |
| if not result["stages"]["quality_gate"]["ok"]: | |
| result["error"] = f"Quality gate failed: {quality_report_path or 'missing quality report'}" | |
| elif deploy_error and not result.get("error"): | |
| result["error"] = deploy_error | |
| stage_start = datetime.now(timezone.utc) | |
| if result["stages"]["quality_gate"]["ok"]: | |
| sanity = _run_realism_sanity_checks(schema_candidate, case) | |
| result["stages"]["realism_sanity"] = { | |
| "ok": bool(sanity.get("ok")), | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "checks": sanity.get("checks", []), | |
| "failures": sanity.get("failures", []), | |
| } | |
| if not result["stages"]["realism_sanity"]["ok"] and not result.get("error"): | |
| result["error"] = "Realism sanity checks failed before ThoughtSpot deployment" | |
| if not skip_thoughtspot and deployed_schema and not result.get("error"): | |
| stage_start = datetime.now(timezone.utc) | |
| ts_ok = True | |
| ts_last = None | |
| try: | |
| for ts_update in controller._run_thoughtspot_deployment(deployed_schema, company, use_case): | |
| ts_last = ts_update | |
| except Exception as exc: # noqa: BLE001 | |
| ts_ok = False | |
| ts_last = str(exc) | |
| # Some deployment paths return a structured failure payload rather than | |
| # raising; treat those as failures so pass/fail reporting is accurate. | |
| ts_preview_text = str(ts_last) if ts_last is not None else "" | |
| if ts_ok and ( | |
| "THOUGHTSPOT DEPLOYMENT FAILED" in ts_preview_text.upper() | |
| or "MODEL VALIDATION FAILED" in ts_preview_text.upper() | |
| or "LIVEBOARD CREATION FAILED" in ts_preview_text.upper() | |
| ): | |
| ts_ok = False | |
| result["stages"]["deploy_thoughtspot"] = { | |
| "ok": ts_ok, | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "preview": ts_preview_text[:1000], | |
| } | |
| result["schema_name"] = schema_candidate | |
| result["success"] = all(stage.get("ok") for stage in result["stages"].values()) | |
| result["finished_at"] = _now_utc_iso() | |
| return result | |
| def _run_case_offline( | |
| case: dict[str, Any], | |
| default_llm: str, | |
| user_email: str, | |
| skip_thoughtspot: bool = False, | |
| ) -> dict[str, Any]: | |
| from cdw_connector import SnowflakeDeployer | |
| from demo_prep import generate_demo_base_name | |
| from legitdata_bridge import populate_demo_data | |
| from thoughtspot_deployer import deploy_to_thoughtspot | |
| company = case["company"] | |
| use_case = case["use_case"] | |
| result: dict[str, Any] = { | |
| "name": case.get("name") or f"{company}_{use_case}", | |
| "company": company, | |
| "use_case": use_case, | |
| "mode": "offline_ddl", | |
| "started_at": _now_utc_iso(), | |
| "success": False, | |
| "stages": {}, | |
| "ddl_template": "offline_star_schema_v1", | |
| } | |
| deployer = SnowflakeDeployer() | |
| # 1) Snowflake schema + DDL deploy | |
| stage_start = datetime.now(timezone.utc) | |
| ok, msg = deployer.connect() | |
| if not ok: | |
| result["stages"]["snowflake_connect"] = { | |
| "ok": False, | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "message": msg, | |
| } | |
| result["error"] = msg | |
| result["finished_at"] = _now_utc_iso() | |
| return result | |
| base_name = generate_demo_base_name("", company) | |
| ok, schema_name, ddl_msg = deployer.create_demo_schema_and_deploy(base_name, OFFLINE_DEMO_DDL) | |
| result["stages"]["snowflake_ddl"] = { | |
| "ok": ok, | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "schema": schema_name, | |
| "message": ddl_msg, | |
| } | |
| if not ok or not schema_name: | |
| result["error"] = ddl_msg | |
| result["finished_at"] = _now_utc_iso() | |
| return result | |
| # 2) Data population via LegitData | |
| stage_start = datetime.now(timezone.utc) | |
| pop_ok, pop_msg, pop_results = populate_demo_data( | |
| ddl_content=OFFLINE_DEMO_DDL, | |
| company_url=company, | |
| use_case=use_case, | |
| schema_name=schema_name, | |
| llm_model=default_llm, | |
| user_email=user_email, | |
| size="medium", | |
| ) | |
| result["stages"]["populate_data"] = { | |
| "ok": pop_ok, | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "rows": pop_results, | |
| "quality_report": _parse_quality_report_path(pop_msg), | |
| } | |
| if not pop_ok: | |
| result["error"] = pop_msg | |
| result["finished_at"] = _now_utc_iso() | |
| return result | |
| quality_report_path = _parse_quality_report_path(pop_msg) | |
| result["stages"]["quality_gate"] = _build_quality_gate_stage(quality_report_path) | |
| if not result["stages"]["quality_gate"]["ok"]: | |
| result["error"] = f"Quality gate failed: {quality_report_path or 'missing quality report'}" | |
| result["schema_name"] = schema_name | |
| result["finished_at"] = _now_utc_iso() | |
| return result | |
| stage_start = datetime.now(timezone.utc) | |
| sanity = _run_realism_sanity_checks(schema_name, case) | |
| result["stages"]["realism_sanity"] = { | |
| "ok": bool(sanity.get("ok")), | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "checks": sanity.get("checks", []), | |
| "failures": sanity.get("failures", []), | |
| } | |
| if not result["stages"]["realism_sanity"]["ok"]: | |
| result["error"] = "Realism sanity checks failed before ThoughtSpot deployment" | |
| result["schema_name"] = schema_name | |
| result["finished_at"] = _now_utc_iso() | |
| return result | |
| # 3) ThoughtSpot model + liveboard | |
| if not skip_thoughtspot: | |
| stage_start = datetime.now(timezone.utc) | |
| ts_result = deploy_to_thoughtspot( | |
| ddl=OFFLINE_DEMO_DDL, | |
| database=os.getenv("SNOWFLAKE_DATABASE", "DEMOBUILD"), | |
| schema=schema_name, | |
| base_name=base_name, | |
| connection_name=f"{base_name}_conn", | |
| company_name=company, | |
| use_case=use_case, | |
| llm_model=default_llm, | |
| ) | |
| result["stages"]["deploy_thoughtspot"] = { | |
| "ok": bool(ts_result and not ts_result.get("errors")), | |
| "duration_s": (datetime.now(timezone.utc) - stage_start).total_seconds(), | |
| "result": ts_result, | |
| } | |
| result["schema_name"] = schema_name | |
| result["success"] = all(stage.get("ok") for stage in result["stages"].values()) | |
| result["finished_at"] = _now_utc_iso() | |
| return result | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run New Vision sample set") | |
| parser.add_argument( | |
| "--cases-file", | |
| default="tests/newvision_test_cases.yaml", | |
| help="Path to YAML test case file", | |
| ) | |
| parser.add_argument( | |
| "--skip-thoughtspot", | |
| action="store_true", | |
| help="Run through data generation only and skip ThoughtSpot object creation", | |
| ) | |
| parser.add_argument( | |
| "--offline-ddl", | |
| action="store_true", | |
| help="Force offline DDL mode (no LLM dependency)", | |
| ) | |
| args = parser.parse_args() | |
| user_email, default_llm = _resolve_runtime_settings() | |
| from startup_validation import validate_required_pipeline_settings_or_raise | |
| validate_required_pipeline_settings_or_raise( | |
| default_llm=default_llm, | |
| require_thoughtspot=not args.skip_thoughtspot, | |
| require_snowflake=True, | |
| ) | |
| cases_file = (PROJECT_ROOT / args.cases_file).resolve() | |
| cases = _load_cases(cases_file) | |
| run_id = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") | |
| out_dir = PROJECT_ROOT / "results" / "newvision_samples" / run_id | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| use_offline = bool(args.offline_ddl) | |
| print(f"Mode: {'offline_ddl' if use_offline else 'chat'}", flush=True) | |
| print(f"default_llm: {default_llm}", flush=True) | |
| results = [] | |
| for idx, case in enumerate(cases, start=1): | |
| print(f"\n[{idx}/{len(cases)}] {case.get('name', case['company'])} -> {case['use_case']}", flush=True) | |
| try: | |
| if use_offline: | |
| case_result = _run_case_offline( | |
| case, | |
| default_llm=default_llm, | |
| user_email=user_email, | |
| skip_thoughtspot=args.skip_thoughtspot, | |
| ) | |
| else: | |
| case_result = _run_case_chat( | |
| case, | |
| default_llm=default_llm, | |
| user_email=user_email, | |
| skip_thoughtspot=args.skip_thoughtspot, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| case_result = { | |
| "name": case.get("name") or f"{case['company']}_{case['use_case']}", | |
| "company": case["company"], | |
| "use_case": case["use_case"], | |
| "mode": "offline_ddl" if use_offline else "chat", | |
| "started_at": _now_utc_iso(), | |
| "finished_at": _now_utc_iso(), | |
| "success": False, | |
| "error": f"Runner exception: {exc}", | |
| "stages": { | |
| "runner_exception": { | |
| "ok": False, | |
| "message": str(exc), | |
| } | |
| }, | |
| } | |
| results.append(case_result) | |
| (out_dir / f"{case_result['name']}.json").write_text( | |
| json.dumps(case_result, indent=2), | |
| encoding="utf-8", | |
| ) | |
| print(f" success={case_result['success']} schema={case_result.get('schema_name')}", flush=True) | |
| summary = { | |
| "run_id": run_id, | |
| "mode": "offline_ddl" if use_offline else "chat", | |
| "cases_file": str(cases_file), | |
| "total": len(results), | |
| "passed": sum(1 for r in results if r.get("success")), | |
| "failed": sum(1 for r in results if not r.get("success")), | |
| "results": results, | |
| } | |
| (out_dir / "summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| print("\nSaved sample artifacts:", out_dir) | |
| print(f"Passed: {summary['passed']} / {summary['total']}") | |
| if summary["failed"]: | |
| raise SystemExit(1) | |
| if __name__ == "__main__": | |
| main() | |