| |
| """ |
| Test script for Query Expansion implementation in CogniChat |
| |
| Tests all components of the query expansion system: |
| 1. QueryAnalyzer - Intent and entity extraction |
| 2. QueryRephraser - Variation generation |
| 3. MultiQueryExpander - Complete expansion |
| 4. MultiHopReasoner - Sub-query generation |
| 5. FallbackStrategies - Edge case handling |
| |
| Run: python3 test_query_expansion.py |
| """ |
|
|
| import sys |
| import os |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| def test_imports(): |
| """Test that all Query Expansion modules can be imported.""" |
| print("=" * 70) |
| print("TEST 1: Importing Query Expansion Modules") |
| print("=" * 70) |
| |
| try: |
| from utils.query_expansion import ( |
| QueryAnalyzer, |
| QueryRephraser, |
| MultiQueryExpander, |
| MultiHopReasoner, |
| FallbackStrategies, |
| QueryStrategy, |
| expand_query_simple |
| ) |
| print("β
All modules imported successfully") |
| return True |
| except ImportError as e: |
| print(f"β Import failed: {e}") |
| return False |
|
|
|
|
| def test_query_analyzer(): |
| """Test QueryAnalyzer functionality.""" |
| print("\n" + "=" * 70) |
| print("TEST 2: QueryAnalyzer - Intent and Entity Detection") |
| print("=" * 70) |
| |
| from utils.query_expansion import QueryAnalyzer |
| |
| analyzer = QueryAnalyzer() |
| |
| test_cases = [ |
| ("What is machine learning?", "definition"), |
| ("How do I debug Python code?", "how_to"), |
| ("Compare Python and Java", "comparison"), |
| ("Why does this error occur?", "explanation"), |
| ("List all API endpoints", "listing"), |
| ("Show me an example of recursion", "example"), |
| ] |
| |
| passed = 0 |
| for query, expected_intent in test_cases: |
| analysis = analyzer.analyze(query) |
| intent_match = analysis.intent == expected_intent |
| |
| if intent_match: |
| print(f"β
Query: '{query}'") |
| print(f" Intent: {analysis.intent} (expected: {expected_intent})") |
| print(f" Keywords: {analysis.keywords}") |
| print(f" Complexity: {analysis.complexity}") |
| passed += 1 |
| else: |
| print(f"β Query: '{query}'") |
| print(f" Got intent: {analysis.intent}, expected: {expected_intent}") |
| |
| print(f"\n{passed}/{len(test_cases)} tests passed") |
| return passed == len(test_cases) |
|
|
|
|
| def test_query_rephraser(): |
| """Test QueryRephraser variation generation.""" |
| print("\n" + "=" * 70) |
| print("TEST 3: QueryRephraser - Generating Variations") |
| print("=" * 70) |
| |
| from utils.query_expansion import QueryAnalyzer, QueryRephraser, QueryStrategy |
| |
| analyzer = QueryAnalyzer() |
| rephraser = QueryRephraser() |
| |
| query = "How do I fix memory leaks in Python?" |
| analysis = analyzer.analyze(query) |
| |
| strategies = [ |
| (QueryStrategy.QUICK, 2), |
| (QueryStrategy.BALANCED, 4), |
| (QueryStrategy.COMPREHENSIVE, 6) |
| ] |
| |
| passed = 0 |
| for strategy, expected_min in strategies: |
| variations = rephraser.generate_variations(query, analysis, strategy) |
| |
| if len(variations) >= expected_min: |
| print(f"β
Strategy: {strategy.value}") |
| print(f" Generated {len(variations)} variations (expected β₯{expected_min})") |
| for i, var in enumerate(variations[:3], 1): |
| print(f" {i}. {var}") |
| passed += 1 |
| else: |
| print(f"β Strategy: {strategy.value}") |
| print(f" Generated {len(variations)}, expected β₯{expected_min}") |
| |
| print(f"\n{passed}/{len(strategies)} strategy tests passed") |
| return passed == len(strategies) |
|
|
|
|
| def test_multi_query_expander(): |
| """Test MultiQueryExpander complete workflow.""" |
| print("\n" + "=" * 70) |
| print("TEST 4: MultiQueryExpander - Complete Expansion") |
| print("=" * 70) |
| |
| from utils.query_expansion import MultiQueryExpander, QueryStrategy |
| |
| expander = MultiQueryExpander() |
| |
| test_queries = [ |
| "What is neural network architecture?", |
| "How do I optimize database queries?", |
| "Compare REST and GraphQL APIs", |
| ] |
| |
| passed = 0 |
| for query in test_queries: |
| result = expander.expand(query, strategy=QueryStrategy.BALANCED) |
| |
| has_original = result.original == query |
| has_variations = len(result.variations) >= 3 |
| has_analysis = result.analysis is not None |
| |
| if has_original and has_variations and has_analysis: |
| print(f"β
Query: '{query}'") |
| print(f" Original: {result.original}") |
| print(f" Variations: {len(result.variations)}") |
| print(f" Intent: {result.analysis.intent}") |
| print(f" Keywords: {result.analysis.keywords[:3]}") |
| passed += 1 |
| else: |
| print(f"β Query: '{query}'") |
| print(f" Original match: {has_original}") |
| print(f" Has variations: {has_variations}") |
| print(f" Has analysis: {has_analysis}") |
| |
| print(f"\n{passed}/{len(test_queries)} queries expanded successfully") |
| return passed == len(test_queries) |
|
|
|
|
| def test_multi_hop_reasoner(): |
| """Test MultiHopReasoner sub-query generation.""" |
| print("\n" + "=" * 70) |
| print("TEST 5: MultiHopReasoner - Sub-Query Generation") |
| print("=" * 70) |
| |
| from utils.query_expansion import MultiHopReasoner, QueryAnalyzer |
| |
| reasoner = MultiHopReasoner() |
| analyzer = QueryAnalyzer() |
| |
| test_cases = [ |
| ("Compare Python and Java for web development", 3), |
| ("Simple query", 1), |
| ("How do I implement authentication with security and performance?", 3), |
| ] |
| |
| passed = 0 |
| for query, expected_min in test_cases: |
| analysis = analyzer.analyze(query) |
| sub_queries = reasoner.generate_sub_queries(query, analysis) |
| |
| if len(sub_queries) >= expected_min: |
| print(f"β
Query: '{query}'") |
| print(f" Generated {len(sub_queries)} sub-queries (expected β₯{expected_min})") |
| for i, sq in enumerate(sub_queries, 1): |
| print(f" {i}. {sq}") |
| passed += 1 |
| else: |
| print(f"β Query: '{query}'") |
| print(f" Generated {len(sub_queries)}, expected β₯{expected_min}") |
| |
| print(f"\n{passed}/{len(test_cases)} multi-hop tests passed") |
| return passed == len(test_cases) |
|
|
|
|
| def test_fallback_strategies(): |
| """Test FallbackStrategies edge case handling.""" |
| print("\n" + "=" * 70) |
| print("TEST 6: FallbackStrategies - Edge Case Handling") |
| print("=" * 70) |
| |
| from utils.query_expansion import FallbackStrategies, QueryAnalyzer |
| |
| analyzer = QueryAnalyzer() |
| |
| query = "What is the specific exact difference between supervised and unsupervised learning?" |
| analysis = analyzer.analyze(query) |
| |
| tests = [ |
| ("simplify_query", FallbackStrategies.simplify_query(query)), |
| ("broaden_query", FallbackStrategies.broaden_query(query, analysis)), |
| ("focus_entities", FallbackStrategies.focus_entities(analysis)), |
| ] |
| |
| print(f"Original: {query}\n") |
| |
| passed = 0 |
| for name, result in tests: |
| if result and result != query: |
| print(f"β
{name}: {result}") |
| passed += 1 |
| else: |
| print(f"β {name}: {result}") |
| |
| print(f"\n{passed}/{len(tests)} fallback strategies working") |
| return passed == len(tests) |
|
|
|
|
| def test_convenience_function(): |
| """Test the convenience function expand_query_simple.""" |
| print("\n" + "=" * 70) |
| print("TEST 7: Convenience Function - expand_query_simple") |
| print("=" * 70) |
| |
| from utils.query_expansion import expand_query_simple |
| |
| query = "How do I deploy a machine learning model?" |
| |
| strategies = ["quick", "balanced", "comprehensive"] |
| expected_counts = [2, 4, 6] |
| |
| passed = 0 |
| for strategy, expected_min in zip(strategies, expected_counts): |
| queries = expand_query_simple(query, strategy=strategy) |
| |
| if len(queries) >= expected_min: |
| print(f"β
Strategy: {strategy}") |
| print(f" Generated {len(queries)} queries (expected β₯{expected_min})") |
| passed += 1 |
| else: |
| print(f"β Strategy: {strategy}") |
| print(f" Generated {len(queries)}, expected β₯{expected_min}") |
| |
| print(f"\n{passed}/{len(strategies)} convenience function tests passed") |
| return passed == len(strategies) |
|
|
|
|
| def test_rag_integration(): |
| """Test integration with RAG processor.""" |
| print("\n" + "=" * 70) |
| print("TEST 8: RAG Integration - Checking rag_processor.py") |
| print("=" * 70) |
| |
| try: |
| from rag_processor import create_multi_query_retriever, create_rag_chain |
| print("β
create_multi_query_retriever imported") |
| print("β
create_rag_chain imported") |
| |
| |
| import inspect |
| |
| sig = inspect.signature(create_rag_chain) |
| params = list(sig.parameters.keys()) |
| |
| has_expansion_param = 'enable_query_expansion' in params |
| has_strategy_param = 'expansion_strategy' in params |
| |
| if has_expansion_param and has_strategy_param: |
| print("β
create_rag_chain has query expansion parameters") |
| print(f" Parameters: {params}") |
| return True |
| else: |
| print("β create_rag_chain missing query expansion parameters") |
| print(f" Found parameters: {params}") |
| return False |
| |
| except ImportError as e: |
| print(f"β Import failed: {e}") |
| return False |
|
|
|
|
| def test_app_configuration(): |
| """Test app.py configuration.""" |
| print("\n" + "=" * 70) |
| print("TEST 9: App Configuration - Checking app.py") |
| print("=" * 70) |
| |
| try: |
| |
| with open('app.py', 'r') as f: |
| content = f.read() |
| |
| checks = { |
| 'ENABLE_QUERY_EXPANSION': 'ENABLE_QUERY_EXPANSION' in content, |
| 'QUERY_EXPANSION_STRATEGY': 'QUERY_EXPANSION_STRATEGY' in content, |
| 'enable_query_expansion param': 'enable_query_expansion=' in content, |
| 'expansion_strategy param': 'expansion_strategy=' in content, |
| } |
| |
| passed = sum(checks.values()) |
| total = len(checks) |
| |
| for check_name, check_result in checks.items(): |
| status = "β
" if check_result else "β" |
| print(f"{status} {check_name}: {'Found' if check_result else 'Missing'}") |
| |
| print(f"\n{passed}/{total} configuration checks passed") |
| return passed == total |
| |
| except FileNotFoundError: |
| print("β app.py not found") |
| return False |
| except Exception as e: |
| print(f"β Error checking app.py: {e}") |
| return False |
|
|
|
|
| def run_all_tests(): |
| """Run all test suites.""" |
| print("\n" + "=" * 70) |
| print("COGNICHAT QUERY EXPANSION - COMPREHENSIVE TEST SUITE") |
| print("=" * 70) |
| |
| tests = [ |
| ("Import Test", test_imports), |
| ("QueryAnalyzer", test_query_analyzer), |
| ("QueryRephraser", test_query_rephraser), |
| ("MultiQueryExpander", test_multi_query_expander), |
| ("MultiHopReasoner", test_multi_hop_reasoner), |
| ("FallbackStrategies", test_fallback_strategies), |
| ("Convenience Function", test_convenience_function), |
| ("RAG Integration", test_rag_integration), |
| ("App Configuration", test_app_configuration), |
| ] |
| |
| results = [] |
| for test_name, test_func in tests: |
| try: |
| result = test_func() |
| results.append((test_name, result)) |
| except Exception as e: |
| print(f"\nβ {test_name} crashed: {e}") |
| import traceback |
| traceback.print_exc() |
| results.append((test_name, False)) |
| |
| |
| print("\n" + "=" * 70) |
| print("TEST SUMMARY") |
| print("=" * 70) |
| |
| passed = sum(1 for _, result in results if result) |
| total = len(results) |
| |
| for test_name, result in results: |
| status = "β
PASS" if result else "β FAIL" |
| print(f"{status}: {test_name}") |
| |
| print("\n" + "=" * 70) |
| print(f"OVERALL: {passed}/{total} test suites passed ({passed/total*100:.1f}%)") |
| print("=" * 70) |
| |
| if passed == total: |
| print("\nπ ALL TESTS PASSED! Query Expansion is ready to use!") |
| return 0 |
| else: |
| print(f"\nβ οΈ {total - passed} test suite(s) failed. Review output above.") |
| return 1 |
|
|
|
|
| if __name__ == "__main__": |
| exit_code = run_all_tests() |
| sys.exit(exit_code) |
|
|