demoprep / tests /batch_runner.py
mikeboone's picture
refactor: centralize LLM config and env-backed client setup
cd91248
#!/usr/bin/env python3
"""
Batch Test Runner for DemoPrep
Runs a series of test cases through the demo pipeline and collects results.
Supports selective stage execution (research-only, full pipeline, etc.)
Usage:
source ./demoprep/bin/activate
python tests/batch_runner.py # Run all cases, full pipeline
python tests/batch_runner.py --stages research,ddl # Research + DDL only
python tests/batch_runner.py --cases 1,3,5 # Specific cases only
python tests/batch_runner.py --cases-file custom.yaml # Custom cases file
"""
import os
import sys
import yaml
import json
import time
import argparse
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
# Add project root to path
PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
from dotenv import load_dotenv
load_dotenv(PROJECT_ROOT / '.env')
from llm_config import DEFAULT_LLM_MODEL
# Available pipeline stages (in order)
STAGES = ['research', 'ddl', 'population', 'deploy_snowflake', 'deploy_thoughtspot', 'liveboard']
DEFAULT_CASES_FILE = Path(__file__).parent / 'test_cases.yaml'
def load_test_cases(cases_file: str = None) -> List[Dict]:
"""Load test cases from YAML file."""
path = Path(cases_file) if cases_file else DEFAULT_CASES_FILE
if not path.exists():
print(f"ERROR: Test cases file not found: {path}")
sys.exit(1)
with open(path) as f:
data = yaml.safe_load(f)
return data.get('test_cases', [])
def run_research(company: str, use_case: str, model: str = DEFAULT_LLM_MODEL) -> Dict:
"""Run the research stage for a test case."""
from chat_interface import ChatDemoInterface
controller = ChatDemoInterface()
controller.settings['model'] = model
result = {
'stage': 'research',
'success': False,
'duration_ms': 0,
'details': {},
}
start = time.time()
try:
# Run research (non-streaming version)
research_result = controller.run_research(company, use_case)
elapsed = int((time.time() - start) * 1000)
result['duration_ms'] = elapsed
if research_result and controller.demo_builder:
result['success'] = True
result['details'] = {
'company_analysis_len': len(controller.demo_builder.company_analysis_results or ''),
'industry_research_len': len(controller.demo_builder.industry_research_results or ''),
}
else:
result['error'] = 'Research returned None or empty'
except Exception as e:
result['duration_ms'] = int((time.time() - start) * 1000)
result['error'] = str(e)
return result, controller
def run_ddl(controller) -> Dict:
"""Run DDL generation stage."""
result = {
'stage': 'ddl',
'success': False,
'duration_ms': 0,
'details': {},
}
start = time.time()
try:
response, ddl_code = controller.run_ddl_creation()
elapsed = int((time.time() - start) * 1000)
result['duration_ms'] = elapsed
if ddl_code and 'CREATE TABLE' in ddl_code.upper():
result['success'] = True
# Count tables
import re
tables = re.findall(r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)', ddl_code, re.IGNORECASE)
result['details'] = {
'tables': tables,
'table_count': len(tables),
'ddl_length': len(ddl_code),
}
else:
result['error'] = 'No valid DDL generated'
except Exception as e:
result['duration_ms'] = int((time.time() - start) * 1000)
result['error'] = str(e)
return result
def run_test_case(case: Dict, stages: List[str]) -> Dict:
"""Run a single test case through specified stages."""
company = case['company']
use_case = case['use_case']
model = case.get('model', DEFAULT_LLM_MODEL)
context = case.get('context', '')
print(f"\n{'='*60}")
print(f" {company} / {use_case}")
print(f" Model: {model}")
if context:
print(f" Context: {context[:80]}...")
print(f"{'='*60}")
case_result = {
'company': company,
'use_case': use_case,
'model': model,
'started_at': datetime.now().isoformat(),
'stages': {},
'overall_success': True,
}
controller = None
# Research
if 'research' in stages:
print(f" [1/6] Research...", end=' ', flush=True)
res, controller = run_research(company, use_case, model)
case_result['stages']['research'] = res
status = 'OK' if res['success'] else f"FAIL: {res.get('error', '')[:60]}"
print(f"{status} ({res['duration_ms']}ms)")
if not res['success']:
case_result['overall_success'] = False
return case_result
# DDL
if 'ddl' in stages and controller:
print(f" [2/6] DDL Generation...", end=' ', flush=True)
res = run_ddl(controller)
case_result['stages']['ddl'] = res
tables = res.get('details', {}).get('tables', [])
status = f"OK ({len(tables)} tables: {', '.join(tables)})" if res['success'] else f"FAIL: {res.get('error', '')[:60]}"
print(f"{status} ({res['duration_ms']}ms)")
if not res['success']:
case_result['overall_success'] = False
# TODO: Add population, deploy_snowflake, deploy_thoughtspot, liveboard stages
# These require actual infrastructure connections
case_result['finished_at'] = datetime.now().isoformat()
return case_result
def print_summary(results: List[Dict]):
"""Print a summary table of all test case results."""
print(f"\n\n{'='*80}")
print(f" BATCH TEST RESULTS SUMMARY")
print(f"{'='*80}\n")
passed = sum(1 for r in results if r['overall_success'])
failed = len(results) - passed
print(f" Total: {len(results)} | Passed: {passed} | Failed: {failed}\n")
for i, r in enumerate(results, 1):
status = "PASS" if r['overall_success'] else "FAIL"
company = r['company']
use_case = r['use_case']
# Collect stage summaries
stage_notes = []
for stage_name, stage_data in r.get('stages', {}).items():
if stage_data['success']:
details = stage_data.get('details', {})
if stage_name == 'research':
stage_notes.append(f"research({details.get('company_analysis_len', 0)} chars)")
elif stage_name == 'ddl':
tables = details.get('tables', [])
stage_notes.append(f"ddl({len(tables)} tables: {', '.join(tables[:3])})")
else:
stage_notes.append(f"{stage_name}(FAILED)")
notes = ' | '.join(stage_notes) if stage_notes else 'No stages run'
print(f" [{status}] {i}. {company} / {use_case}")
print(f" {notes}")
print()
print(f"{'='*80}")
def main():
parser = argparse.ArgumentParser(description='DemoPrep Batch Test Runner')
parser.add_argument('--cases-file', '-f', default=None,
help='Path to test cases YAML file (default: tests/test_cases.yaml)')
parser.add_argument('--cases', '-c', default=None,
help='Comma-separated case numbers to run (e.g., 1,3,5)')
parser.add_argument('--stages', '-s', default='research,ddl',
help=f'Comma-separated stages to run: {",".join(STAGES)} (default: research,ddl)')
parser.add_argument('--output', '-o', default=None,
help='Output JSON file for results')
args = parser.parse_args()
# Load test cases
cases = load_test_cases(args.cases_file)
# Filter cases if specified
if args.cases:
indices = [int(x.strip()) - 1 for x in args.cases.split(',')]
cases = [cases[i] for i in indices if i < len(cases)]
# Parse stages
stages = [s.strip() for s in args.stages.split(',')]
print(f"\n{'#'*60}")
print(f" DemoPrep Batch Test Runner")
print(f" Cases: {len(cases)}")
print(f" Stages: {', '.join(stages)}")
print(f" Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"{'#'*60}")
# Run all cases
results = []
for i, case in enumerate(cases, 1):
print(f"\n--- Case {i}/{len(cases)} ---")
result = run_test_case(case, stages)
results.append(result)
# Print summary
print_summary(results)
# Save results
output_file = args.output or f"tests/batch_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
output_path = PROJECT_ROOT / output_file
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
json.dump({
'run_at': datetime.now().isoformat(),
'stages': stages,
'total_cases': len(results),
'passed': sum(1 for r in results if r['overall_success']),
'failed': sum(1 for r in results if not r['overall_success']),
'results': results,
}, f, indent=2)
print(f"\nResults saved to: {output_path}")
# Exit with error code if any failed
if any(not r['overall_success'] for r in results):
sys.exit(1)
if __name__ == '__main__':
main()