""" Test script to verify the NL→SQL Leaderboard system works correctly. """ import os import sys import time # Add src to path for imports sys.path.append('src') from evaluator import evaluator, DatasetManager from models_registry import models_registry from scoring import scoring_engine def test_dataset_discovery(): """Test that datasets are discovered correctly.""" print("Testing dataset discovery...") dataset_manager = DatasetManager() datasets = dataset_manager.get_datasets() print(f"Found datasets: {list(datasets.keys())}") if "nyc_taxi_small" in datasets: print("✓ NYC Taxi dataset found") return True else: print("✗ NYC Taxi dataset not found") return False def test_models_loading(): """Test that models are loaded correctly.""" print("\nTesting models loading...") models = models_registry.get_models() print(f"Found models: {[model.name for model in models]}") if len(models) > 0: print("✓ Models loaded successfully") return True else: print("✗ No models found") return False def test_database_creation(): """Test database creation for NYC Taxi dataset.""" print("\nTesting database creation...") try: dataset_manager = DatasetManager() db_path = dataset_manager.create_database("nyc_taxi_small") if os.path.exists(db_path): print("✓ Database created successfully") # Clean up os.remove(db_path) return True else: print("✗ Database file not created") return False except Exception as e: print(f"✗ Database creation failed: {e}") return False def test_cases_loading(): """Test loading test cases.""" print("\nTesting cases loading...") try: dataset_manager = DatasetManager() cases = dataset_manager.load_cases("nyc_taxi_small") print(f"Found {len(cases)} test cases") if len(cases) > 0: print("✓ Test cases loaded successfully") return True else: print("✗ No test cases found") return False except Exception as e: print(f"✗ Cases loading failed: {e}") return False def test_prompt_templates(): """Test that prompt templates exist.""" print("\nTesting prompt templates...") dialects = ["presto", "bigquery", "snowflake"] all_exist = True for dialect in dialects: template_path = f"prompts/template_{dialect}.txt" if os.path.exists(template_path): print(f"✓ {dialect} template found") else: print(f"✗ {dialect} template not found") all_exist = False return all_exist def test_scoring_engine(): """Test the scoring engine.""" print("\nTesting scoring engine...") try: from scoring import Metrics # Test with sample metrics metrics = Metrics( correctness_exact=1.0, result_match_f1=0.8, exec_success=1.0, latency_ms=100.0, readability=0.9, dialect_ok=1.0 ) score = scoring_engine.compute_composite_score(metrics) print(f"✓ Composite score computed: {score}") if 0.0 <= score <= 1.0: print("✓ Score is in valid range") return True else: print("✗ Score is out of valid range") return False except Exception as e: print(f"✗ Scoring engine test failed: {e}") return False def test_sql_execution(): """Test SQL execution with DuckDB.""" print("\nTesting SQL execution...") try: import duckdb # Create a simple test database conn = duckdb.connect(":memory:") conn.execute("CREATE TABLE test (id INTEGER, name VARCHAR(10))") conn.execute("INSERT INTO test VALUES (1, 'Alice'), (2, 'Bob')") # Test query result = conn.execute("SELECT COUNT(*) FROM test").fetchdf() print(f"✓ SQL execution successful: {result.iloc[0, 0]} rows") conn.close() return True except Exception as e: print(f"✗ SQL execution failed: {e}") return False def test_sqlglot_transpilation(): """Test SQL transpilation with sqlglot.""" print("\nTesting SQL transpilation...") try: import sqlglot # Test simple query sql = "SELECT COUNT(*) FROM trips" parsed = sqlglot.parse_one(sql) # Transpile to different dialects dialects = ["presto", "bigquery", "snowflake"] for dialect in dialects: transpiled = parsed.sql(dialect=dialect) print(f"✓ {dialect} transpilation: {transpiled}") return True except Exception as e: print(f"✗ SQL transpilation failed: {e}") return False def main(): """Run all tests.""" print("NL→SQL Leaderboard System Test") print("=" * 40) tests = [ test_dataset_discovery, test_models_loading, test_database_creation, test_cases_loading, test_prompt_templates, test_scoring_engine, test_sql_execution, test_sqlglot_transpilation ] passed = 0 total = len(tests) for test in tests: try: if test(): passed += 1 except Exception as e: print(f"✗ Test {test.__name__} failed with exception: {e}") print("\n" + "=" * 40) print(f"Test Results: {passed}/{total} tests passed") if passed == total: print("🎉 All tests passed! The system is ready to use.") return True else: print("❌ Some tests failed. Please check the issues above.") return False if __name__ == "__main__": success = main() sys.exit(0 if success else 1)