Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Batch Test Runner for DemoPrep | |
| Runs a series of test cases through the demo pipeline and collects results. | |
| Supports selective stage execution (research-only, full pipeline, etc.) | |
| Usage: | |
| source ./demoprep/bin/activate | |
| python tests/batch_runner.py # Run all cases, full pipeline | |
| python tests/batch_runner.py --stages research,ddl # Research + DDL only | |
| python tests/batch_runner.py --cases 1,3,5 # Specific cases only | |
| python tests/batch_runner.py --cases-file custom.yaml # Custom cases file | |
| """ | |
| import os | |
| import sys | |
| import yaml | |
| import json | |
| import time | |
| import argparse | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, List, Optional | |
| # Add project root to path | |
| PROJECT_ROOT = Path(__file__).parent.parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| from dotenv import load_dotenv | |
| load_dotenv(PROJECT_ROOT / '.env') | |
| from llm_config import DEFAULT_LLM_MODEL | |
| # Available pipeline stages (in order) | |
| STAGES = ['research', 'ddl', 'population', 'deploy_snowflake', 'deploy_thoughtspot', 'liveboard'] | |
| DEFAULT_CASES_FILE = Path(__file__).parent / 'test_cases.yaml' | |
| def load_test_cases(cases_file: str = None) -> List[Dict]: | |
| """Load test cases from YAML file.""" | |
| path = Path(cases_file) if cases_file else DEFAULT_CASES_FILE | |
| if not path.exists(): | |
| print(f"ERROR: Test cases file not found: {path}") | |
| sys.exit(1) | |
| with open(path) as f: | |
| data = yaml.safe_load(f) | |
| return data.get('test_cases', []) | |
| def run_research(company: str, use_case: str, model: str = DEFAULT_LLM_MODEL) -> Dict: | |
| """Run the research stage for a test case.""" | |
| from chat_interface import ChatDemoInterface | |
| controller = ChatDemoInterface() | |
| controller.settings['model'] = model | |
| result = { | |
| 'stage': 'research', | |
| 'success': False, | |
| 'duration_ms': 0, | |
| 'details': {}, | |
| } | |
| start = time.time() | |
| try: | |
| # Run research (non-streaming version) | |
| research_result = controller.run_research(company, use_case) | |
| elapsed = int((time.time() - start) * 1000) | |
| result['duration_ms'] = elapsed | |
| if research_result and controller.demo_builder: | |
| result['success'] = True | |
| result['details'] = { | |
| 'company_analysis_len': len(controller.demo_builder.company_analysis_results or ''), | |
| 'industry_research_len': len(controller.demo_builder.industry_research_results or ''), | |
| } | |
| else: | |
| result['error'] = 'Research returned None or empty' | |
| except Exception as e: | |
| result['duration_ms'] = int((time.time() - start) * 1000) | |
| result['error'] = str(e) | |
| return result, controller | |
| def run_ddl(controller) -> Dict: | |
| """Run DDL generation stage.""" | |
| result = { | |
| 'stage': 'ddl', | |
| 'success': False, | |
| 'duration_ms': 0, | |
| 'details': {}, | |
| } | |
| start = time.time() | |
| try: | |
| response, ddl_code = controller.run_ddl_creation() | |
| elapsed = int((time.time() - start) * 1000) | |
| result['duration_ms'] = elapsed | |
| if ddl_code and 'CREATE TABLE' in ddl_code.upper(): | |
| result['success'] = True | |
| # Count tables | |
| import re | |
| tables = re.findall(r'CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)', ddl_code, re.IGNORECASE) | |
| result['details'] = { | |
| 'tables': tables, | |
| 'table_count': len(tables), | |
| 'ddl_length': len(ddl_code), | |
| } | |
| else: | |
| result['error'] = 'No valid DDL generated' | |
| except Exception as e: | |
| result['duration_ms'] = int((time.time() - start) * 1000) | |
| result['error'] = str(e) | |
| return result | |
| def run_test_case(case: Dict, stages: List[str]) -> Dict: | |
| """Run a single test case through specified stages.""" | |
| company = case['company'] | |
| use_case = case['use_case'] | |
| model = case.get('model', DEFAULT_LLM_MODEL) | |
| context = case.get('context', '') | |
| print(f"\n{'='*60}") | |
| print(f" {company} / {use_case}") | |
| print(f" Model: {model}") | |
| if context: | |
| print(f" Context: {context[:80]}...") | |
| print(f"{'='*60}") | |
| case_result = { | |
| 'company': company, | |
| 'use_case': use_case, | |
| 'model': model, | |
| 'started_at': datetime.now().isoformat(), | |
| 'stages': {}, | |
| 'overall_success': True, | |
| } | |
| controller = None | |
| # Research | |
| if 'research' in stages: | |
| print(f" [1/6] Research...", end=' ', flush=True) | |
| res, controller = run_research(company, use_case, model) | |
| case_result['stages']['research'] = res | |
| status = 'OK' if res['success'] else f"FAIL: {res.get('error', '')[:60]}" | |
| print(f"{status} ({res['duration_ms']}ms)") | |
| if not res['success']: | |
| case_result['overall_success'] = False | |
| return case_result | |
| # DDL | |
| if 'ddl' in stages and controller: | |
| print(f" [2/6] DDL Generation...", end=' ', flush=True) | |
| res = run_ddl(controller) | |
| case_result['stages']['ddl'] = res | |
| tables = res.get('details', {}).get('tables', []) | |
| status = f"OK ({len(tables)} tables: {', '.join(tables)})" if res['success'] else f"FAIL: {res.get('error', '')[:60]}" | |
| print(f"{status} ({res['duration_ms']}ms)") | |
| if not res['success']: | |
| case_result['overall_success'] = False | |
| # TODO: Add population, deploy_snowflake, deploy_thoughtspot, liveboard stages | |
| # These require actual infrastructure connections | |
| case_result['finished_at'] = datetime.now().isoformat() | |
| return case_result | |
| def print_summary(results: List[Dict]): | |
| """Print a summary table of all test case results.""" | |
| print(f"\n\n{'='*80}") | |
| print(f" BATCH TEST RESULTS SUMMARY") | |
| print(f"{'='*80}\n") | |
| passed = sum(1 for r in results if r['overall_success']) | |
| failed = len(results) - passed | |
| print(f" Total: {len(results)} | Passed: {passed} | Failed: {failed}\n") | |
| for i, r in enumerate(results, 1): | |
| status = "PASS" if r['overall_success'] else "FAIL" | |
| company = r['company'] | |
| use_case = r['use_case'] | |
| # Collect stage summaries | |
| stage_notes = [] | |
| for stage_name, stage_data in r.get('stages', {}).items(): | |
| if stage_data['success']: | |
| details = stage_data.get('details', {}) | |
| if stage_name == 'research': | |
| stage_notes.append(f"research({details.get('company_analysis_len', 0)} chars)") | |
| elif stage_name == 'ddl': | |
| tables = details.get('tables', []) | |
| stage_notes.append(f"ddl({len(tables)} tables: {', '.join(tables[:3])})") | |
| else: | |
| stage_notes.append(f"{stage_name}(FAILED)") | |
| notes = ' | '.join(stage_notes) if stage_notes else 'No stages run' | |
| print(f" [{status}] {i}. {company} / {use_case}") | |
| print(f" {notes}") | |
| print() | |
| print(f"{'='*80}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description='DemoPrep Batch Test Runner') | |
| parser.add_argument('--cases-file', '-f', default=None, | |
| help='Path to test cases YAML file (default: tests/test_cases.yaml)') | |
| parser.add_argument('--cases', '-c', default=None, | |
| help='Comma-separated case numbers to run (e.g., 1,3,5)') | |
| parser.add_argument('--stages', '-s', default='research,ddl', | |
| help=f'Comma-separated stages to run: {",".join(STAGES)} (default: research,ddl)') | |
| parser.add_argument('--output', '-o', default=None, | |
| help='Output JSON file for results') | |
| args = parser.parse_args() | |
| # Load test cases | |
| cases = load_test_cases(args.cases_file) | |
| # Filter cases if specified | |
| if args.cases: | |
| indices = [int(x.strip()) - 1 for x in args.cases.split(',')] | |
| cases = [cases[i] for i in indices if i < len(cases)] | |
| # Parse stages | |
| stages = [s.strip() for s in args.stages.split(',')] | |
| print(f"\n{'#'*60}") | |
| print(f" DemoPrep Batch Test Runner") | |
| print(f" Cases: {len(cases)}") | |
| print(f" Stages: {', '.join(stages)}") | |
| print(f" Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| print(f"{'#'*60}") | |
| # Run all cases | |
| results = [] | |
| for i, case in enumerate(cases, 1): | |
| print(f"\n--- Case {i}/{len(cases)} ---") | |
| result = run_test_case(case, stages) | |
| results.append(result) | |
| # Print summary | |
| print_summary(results) | |
| # Save results | |
| output_file = args.output or f"tests/batch_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
| output_path = PROJECT_ROOT / output_file | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(output_path, 'w') as f: | |
| json.dump({ | |
| 'run_at': datetime.now().isoformat(), | |
| 'stages': stages, | |
| 'total_cases': len(results), | |
| 'passed': sum(1 for r in results if r['overall_success']), | |
| 'failed': sum(1 for r in results if not r['overall_success']), | |
| 'results': results, | |
| }, f, indent=2) | |
| print(f"\nResults saved to: {output_path}") | |
| # Exit with error code if any failed | |
| if any(not r['overall_success'] for r in results): | |
| sys.exit(1) | |
| if __name__ == '__main__': | |
| main() | |