""" Smart Query Test Suite - Run with: python test_smart_query.py Tests the smart query tool functions and integration. Requires GEMINI_API_KEY for full integration tests. """ import asyncio import sys from datetime import datetime, timedelta sys.path.insert(0, '.') from app.database import init_db, SessionLocal from app.models import User, Task, LogEntry, TaskStatus, ActorType, ActionType, ProjectMembership from app.tools.projects import create_project from app.tools.tasks import create_task from app.vectorstore import init_vectorstore, delete_by_project # Test counters passed = 0 failed = 0 def test(name, condition, details=""): global passed, failed if condition: print(f" [PASS] {name}") passed += 1 else: print(f" [FAIL] {name} - {details}") failed += 1 def create_test_user(db, first_name: str, last_name: str) -> User: """Create a test user and return the user object.""" from app.models import generate_user_id user_id = generate_user_id(first_name) user = User(id=user_id, first_name=first_name, last_name=last_name) db.add(user) db.commit() db.refresh(user) return user def create_test_log_entry(db, project_id: str, task_id: str, user_id: str, raw_input: str, hours_ago: int = 0) -> LogEntry: """Create a test log entry.""" entry = LogEntry( project_id=project_id, task_id=task_id, user_id=user_id, actor_type=ActorType.human, action_type=ActionType.task_completed, raw_input=raw_input, generated_doc=f"Documentation for: {raw_input}", tags=["test"], created_at=datetime.now() - timedelta(hours=hours_ago) ) db.add(entry) db.commit() db.refresh(entry) return entry def setup_test_data(): """Set up test data and return IDs.""" db = SessionLocal() try: # Create users user1 = create_test_user(db, "Alice", "Developer") user2 = create_test_user(db, "Bob", "Developer") # Create project project_result = create_project( name=f"Smart Query Test Project {datetime.now().timestamp()}", description="Testing smart query functionality", user_id=user1.id ) project_id = project_result["id"] # Add user2 to project membership = ProjectMembership(project_id=project_id, user_id=user2.id, role="member") db.add(membership) db.commit() # Create tasks task1_result = create_task( project_id=project_id, title="Implement Authentication API", description="Build JWT auth system", assigned_to=user1.id ) task1_id = task1_result["id"] task2_result = create_task( project_id=project_id, title="Write Unit Tests", description="Create test coverage", assigned_to=user2.id ) task2_id = task2_result["id"] # Mark task1 as done task1 = db.query(Task).filter(Task.id == task1_id).first() task1.status = TaskStatus.done task1.completed_at = datetime.now() - timedelta(hours=2) db.commit() # Create log entries log1 = create_test_log_entry( db, project_id, task1_id, user1.id, "Implemented JWT authentication with refresh tokens", hours_ago=2 ) log2 = create_test_log_entry( db, project_id, task1_id, user1.id, "Added password hashing with bcrypt", hours_ago=25 # Yesterday ) log3 = create_test_log_entry( db, project_id, task2_id, user2.id, "Created unit tests for auth module", hours_ago=1 ) return { "project_id": project_id, "user1_id": user1.id, "user1_name": user1.name, "user2_id": user2.id, "user2_name": user2.name, "task1_id": task1_id, "task2_id": task2_id, "log1_id": log1.id, "log2_id": log2.id, "log3_id": log3.id, } finally: db.close() def test_tool_functions(test_data): """Test individual tool functions.""" print("\n[1/4] Testing Smart Query Tool Functions") print("-" * 50) from app.smart_query import ( _tool_get_user_activity, _tool_get_task_status, _tool_check_completion, _tool_list_users, QueryContext, ) db = SessionLocal() context = QueryContext( current_user_id=test_data["user1_id"], current_datetime=datetime.now(), project_id=test_data["project_id"] ) try: # Test _tool_get_user_activity with user_id print("\n Testing _tool_get_user_activity...") today = datetime.now().date() yesterday = (datetime.now() - timedelta(days=1)).date() result = _tool_get_user_activity(db, context, { "user_id": test_data["user1_id"], "date_from": yesterday.isoformat(), "date_to": (today + timedelta(days=1)).isoformat() }) test("get_user_activity returns dict", isinstance(result, dict)) test("get_user_activity has count", "count" in result) test("get_user_activity has activities", "activities" in result) test("get_user_activity finds entries", result["count"] >= 1, f"got {result['count']}") # Test _tool_get_user_activity with user_name result2 = _tool_get_user_activity(db, context, { "user_name": "Alice", "date_from": yesterday.isoformat(), "date_to": (today + timedelta(days=1)).isoformat() }) test("get_user_activity resolves name", result2["user_id"] == test_data["user1_id"]) # Test _tool_get_task_status by ID print("\n Testing _tool_get_task_status...") result = _tool_get_task_status(db, test_data["project_id"], { "task_id": test_data["task1_id"] }) test("get_task_status returns dict", isinstance(result, dict)) test("get_task_status found=True", result.get("found") == True) test("get_task_status has task", "task" in result) test("get_task_status correct title", "Authentication" in result["task"]["title"]) test("get_task_status status is done", result["task"]["status"] == "done") # Test _tool_get_task_status by title result2 = _tool_get_task_status(db, test_data["project_id"], { "task_title": "Unit Tests" }) test("get_task_status by title works", result2.get("found") == True) test("get_task_status by title correct", "Unit Tests" in result2["task"]["title"]) # Test _tool_get_task_status not found result3 = _tool_get_task_status(db, test_data["project_id"], { "task_title": "Nonexistent Task XYZ" }) test("get_task_status not found", result3.get("found") == False) # Test _tool_check_completion print("\n Testing _tool_check_completion...") result = _tool_check_completion(db, test_data["project_id"], { "task_title": "Authentication" }) test("check_completion returns dict", isinstance(result, dict)) test("check_completion found", result.get("found") == True) test("check_completion is_completed", result.get("is_completed") == True) test("check_completion has details", result.get("completion_details") is not None) # Test _tool_check_completion with user result2 = _tool_check_completion(db, test_data["project_id"], { "task_title": "Authentication", "user_name": "Alice" }) test("check_completion by user works", result2.get("completed_by_specified_user") == True) result3 = _tool_check_completion(db, test_data["project_id"], { "task_title": "Authentication", "user_name": "Bob" }) test("check_completion wrong user", result3.get("completed_by_specified_user") == False) # Test _tool_list_users print("\n Testing _tool_list_users...") result = _tool_list_users(db, test_data["project_id"]) test("list_users returns dict", isinstance(result, dict)) test("list_users has users", "users" in result) test("list_users has count", "count" in result) test("list_users finds 2 users", result["count"] == 2, f"got {result['count']}") user_names = [u["name"] for u in result["users"]] test("list_users includes Alice", any("Alice" in n for n in user_names)) test("list_users includes Bob", any("Bob" in n for n in user_names)) finally: db.close() def test_extract_sources(): """Test the extract_sources helper.""" print("\n[2/4] Testing extract_sources Helper") print("-" * 50) from app.smart_query import extract_sources # Test with activity results tool_results = [ { "tool": "get_user_activity", "result": { "activities": [ {"id": "log-1", "what_was_done": "Built auth", "timestamp": "2024-01-01T10:00:00"}, {"id": "log-2", "what_was_done": "Fixed bug", "timestamp": "2024-01-01T11:00:00"}, ] } } ] sources = extract_sources(tool_results) test("extract_sources returns list", isinstance(sources, list)) test("extract_sources extracts activities", len(sources) == 2) test("extract_sources has type=activity", sources[0]["type"] == "activity") # Test with semantic search results tool_results2 = [ { "tool": "semantic_search", "result": { "results": [ {"id": "mem-1", "what_was_done": "Implemented login", "relevance_score": 0.95}, ] } } ] sources2 = extract_sources(tool_results2) test("extract_sources handles search results", len(sources2) == 1) test("extract_sources has type=memory", sources2[0]["type"] == "memory") # Test with task results tool_results3 = [ { "tool": "get_task_status", "result": { "found": True, "task": {"id": "task-1", "title": "Auth API", "status": "done"} } } ] sources3 = extract_sources(tool_results3) test("extract_sources handles task results", len(sources3) == 1) test("extract_sources has type=task", sources3[0]["type"] == "task") # Test deduplication tool_results4 = [ {"tool": "get_user_activity", "result": {"activities": [{"id": "dup-1", "what_was_done": "Test"}]}}, {"tool": "get_user_activity", "result": {"activities": [{"id": "dup-1", "what_was_done": "Test"}]}}, ] sources4 = extract_sources(tool_results4) test("extract_sources deduplicates", len(sources4) == 1) def test_query_context(): """Test QueryContext dataclass.""" print("\n[3/4] Testing QueryContext") print("-" * 50) from app.smart_query import QueryContext ctx = QueryContext( current_user_id="user-123", current_datetime=datetime(2024, 1, 15, 10, 30, 0), project_id="proj-456" ) test("QueryContext has current_user_id", ctx.current_user_id == "user-123") test("QueryContext has current_datetime", ctx.current_datetime.year == 2024) test("QueryContext has project_id", ctx.project_id == "proj-456") async def test_semantic_search_tool(test_data): """Test semantic search tool (requires API key).""" print("\n[4/4] Testing Semantic Search Tool (requires API)") print("-" * 50) try: from app.smart_query import _tool_semantic_search from app.llm import get_embedding from app.vectorstore import add_embedding # Add test embedding text = "Implemented JWT authentication with refresh tokens" emb = await get_embedding(text) add_embedding( log_entry_id=test_data["log1_id"], text=text, embedding=emb, metadata={ "project_id": test_data["project_id"], "user_id": test_data["user1_id"], "task_id": test_data["task1_id"], "created_at": datetime.now().isoformat() } ) # Test semantic search result = await _tool_semantic_search(test_data["project_id"], { "search_query": "authentication JWT" }) test("semantic_search returns dict", isinstance(result, dict)) test("semantic_search has results", "results" in result) test("semantic_search has count", "count" in result) test("semantic_search finds results", result["count"] > 0, f"got {result['count']}") if result["results"]: test("semantic_search result has id", "id" in result["results"][0]) test("semantic_search result has relevance", "relevance_score" in result["results"][0]) except Exception as e: test("semantic_search (API required)", False, str(e)) async def test_full_smart_query(test_data): """Test the full smart_query function (requires API key).""" print("\n[BONUS] Testing Full smart_query Integration") print("-" * 50) try: from app.smart_query import smart_query # Test simple query result = await smart_query( project_id=test_data["project_id"], query="What tasks are done?", current_user_id=test_data["user1_id"], current_datetime=datetime.now().isoformat() ) test("smart_query returns dict", isinstance(result, dict)) test("smart_query has answer", "answer" in result) test("smart_query has tools_used", "tools_used" in result) test("smart_query has sources", "sources" in result) test("smart_query answer not empty", len(result.get("answer", "")) > 0) print(f"\n Query: 'What tasks are done?'") print(f" Answer: {result.get('answer', '')[:200]}...") print(f" Tools used: {result.get('tools_used', [])}") except Exception as e: test("smart_query integration (API required)", False, str(e)) def test_user_resolution(test_data): """Test duplicate name handling in user resolution.""" print("\n[NEW] Testing User Resolution & Duplicate Name Handling") print("-" * 50) from app.smart_query import ( _resolve_user_in_project, _get_recent_work_hint, _tool_get_user_activity, _tool_check_completion, QueryContext, ) db = SessionLocal() try: # Test _get_recent_work_hint print("\n Testing _get_recent_work_hint...") hint = _get_recent_work_hint(db, str(test_data["user1_id"]), test_data["project_id"]) test("get_recent_work_hint returns string", isinstance(hint, str)) test("get_recent_work_hint has content", len(hint) > 0) test("get_recent_work_hint format correct", "worked on:" in hint or hint == "no recent activity") # Test _resolve_user_in_project - single match print("\n Testing _resolve_user_in_project (single match)...") result = _resolve_user_in_project(db, test_data["project_id"], "Alice") test("resolve single user found=True", result.get("found") == True) test("resolve single user has user_id", "user_id" in result) test("resolve single user correct id", result.get("user_id") == str(test_data["user1_id"])) # Test _resolve_user_in_project - not found print("\n Testing _resolve_user_in_project (not found)...") result2 = _resolve_user_in_project(db, test_data["project_id"], "NonExistentUser") test("resolve not found returns found=False", result2.get("found") == False) test("resolve not found reason=not_found", result2.get("reason") == "not_found") test("resolve not found has message", "message" in result2) # Create duplicate name scenario print("\n Testing duplicate name handling...") # Add another user with similar name to the project duplicate_user = User(name="Dev Alice Smith", email=f"alice.smith_{datetime.now().timestamp()}@test.com") db.add(duplicate_user) db.commit() db.refresh(duplicate_user) # Add to project membership = ProjectMembership(project_id=test_data["project_id"], user_id=duplicate_user.id, role="member") db.add(membership) db.commit() # Now search for "Alice" - should find 2 users result3 = _resolve_user_in_project(db, test_data["project_id"], "Alice") test("resolve ambiguous found=False", result3.get("found") == False) test("resolve ambiguous reason=ambiguous", result3.get("reason") == "ambiguous") test("resolve ambiguous has options", "options" in result3) test("resolve ambiguous options is list", isinstance(result3.get("options"), list)) test("resolve ambiguous has 2 options", len(result3.get("options", [])) == 2, f"got {len(result3.get('options', []))}") if result3.get("options"): opt = result3["options"][0] test("option has user_id", "user_id" in opt) test("option has name", "name" in opt) test("option has email", "email" in opt) test("option has role", "role" in opt) test("option has recent_work", "recent_work" in opt) # Test that _tool_get_user_activity returns disambiguation print("\n Testing tool returns disambiguation...") context = QueryContext( current_user_id=str(test_data["user1_id"]), current_datetime=datetime.now(), project_id=test_data["project_id"] ) today = datetime.now().date() yesterday = (datetime.now() - timedelta(days=1)).date() activity_result = _tool_get_user_activity(db, context, { "user_name": "Alice", "date_from": yesterday.isoformat(), "date_to": (today + timedelta(days=1)).isoformat() }) test("tool returns ambiguous response", activity_result.get("found") == False) test("tool has options", "options" in activity_result) # Test that _tool_check_completion returns disambiguation completion_result = _tool_check_completion(db, test_data["project_id"], { "task_title": "Authentication", "user_name": "Alice" }) test("check_completion returns ambiguous", completion_result.get("found") == False) test("check_completion has options", "options" in completion_result) # Cleanup the duplicate user db.query(ProjectMembership).filter(ProjectMembership.user_id == duplicate_user.id).delete() db.query(User).filter(User.id == duplicate_user.id).delete() db.commit() # Test email fallback print("\n Testing email fallback...") result4 = _resolve_user_in_project(db, test_data["project_id"], "alice") # Part of email # This might match by name or email depending on data test("email fallback works", result4.get("found") == True or result4.get("reason") == "ambiguous") finally: db.close() async def test_memory_search_filters(test_data): """Test memory_search with filters (requires API key).""" print("\n[FILTERS] Testing memory_search Filters") print("-" * 50) try: from app.tools.memory import memory_search from app.llm import get_embedding from app.vectorstore import add_embedding # Add test embeddings with different users and dates text1 = "User1 implemented login feature yesterday" text2 = "User2 added payment processing today" emb1 = await get_embedding(text1) emb2 = await get_embedding(text2) yesterday = (datetime.now() - timedelta(days=1)).isoformat() today = datetime.now().isoformat() add_embedding( log_entry_id=str(test_data["log1_id"]) + "-filter1", text=text1, embedding=emb1, metadata={ "project_id": test_data["project_id"], "user_id": str(test_data["user1_id"]), "task_id": str(test_data["task1_id"]), "created_at": yesterday } ) add_embedding( log_entry_id=str(test_data["log2_id"]) + "-filter2", text=text2, embedding=emb2, metadata={ "project_id": test_data["project_id"], "user_id": str(test_data["user2_id"]), "task_id": str(test_data["task2_id"]), "created_at": today } ) # Test 1: Search without filters print("\n Testing search without filters...") result_no_filter = await memory_search( project_id=test_data["project_id"], query="feature implementation" ) test("search without filters returns answer", "answer" in result_no_filter) # Test 2: Search with userId filter print("\n Testing search with userId filter...") result_user_filter = await memory_search( project_id=test_data["project_id"], query="feature implementation", filters={"userId": str(test_data["user1_id"])} ) test("search with userId filter returns answer", "answer" in result_user_filter) # Test 3: Search with date range filter print("\n Testing search with date filters...") tomorrow = (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d") result_date_filter = await memory_search( project_id=test_data["project_id"], query="feature implementation", filters={ "dateFrom": (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d"), "dateTo": tomorrow } ) test("search with date filter returns answer", "answer" in result_date_filter) # Test 4: Search with combined filters print("\n Testing search with combined filters...") result_combined = await memory_search( project_id=test_data["project_id"], query="feature", filters={ "userId": str(test_data["user1_id"]), "dateFrom": (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d") } ) test("search with combined filters returns answer", "answer" in result_combined) print("\n Filter wiring tests passed!") except Exception as e: test("memory_search filters (API required)", False, str(e)) def cleanup_test_data(test_data): """Clean up test data.""" try: delete_by_project(test_data["project_id"]) except: pass async def main(): print("=" * 60) print(" SMART QUERY TEST SUITE") print("=" * 60) # Initialize print("\nInitializing database and vectorstore...") init_db() init_vectorstore() # Setup test data print("Setting up test data...") test_data = setup_test_data() print(f" Project ID: {test_data['project_id']}") print(f" User 1: {test_data['user1_name']} ({test_data['user1_id']})") print(f" User 2: {test_data['user2_name']} ({test_data['user2_id']})") try: # Run tests test_tool_functions(test_data) test_extract_sources() test_query_context() test_user_resolution(test_data) # New: duplicate name handling tests # API-dependent tests (will fail gracefully without API key) await test_semantic_search_tool(test_data) await test_memory_search_filters(test_data) # New: filter wiring tests await test_full_smart_query(test_data) finally: # Cleanup print("\nCleaning up test data...") cleanup_test_data(test_data) # Results print("\n" + "=" * 60) print(f" RESULTS: {passed} passed, {failed} failed") print("=" * 60) if failed > 0: print("\nNote: Some tests require GEMINI_API_KEY to be set.") sys.exit(1) if __name__ == "__main__": asyncio.run(main())