Spaces:
Build error
Build error
| """ | |
| Smart Query Test Suite - Run with: python test_smart_query.py | |
| Tests the smart query tool functions and integration. | |
| Requires GEMINI_API_KEY for full integration tests. | |
| """ | |
| import asyncio | |
| import sys | |
| from datetime import datetime, timedelta | |
| sys.path.insert(0, '.') | |
| from app.database import init_db, SessionLocal | |
| from app.models import User, Task, LogEntry, TaskStatus, ActorType, ActionType, ProjectMembership | |
| from app.tools.projects import create_project | |
| from app.tools.tasks import create_task | |
| from app.vectorstore import init_vectorstore, delete_by_project | |
| # Test counters | |
| passed = 0 | |
| failed = 0 | |
| def test(name, condition, details=""): | |
| global passed, failed | |
| if condition: | |
| print(f" [PASS] {name}") | |
| passed += 1 | |
| else: | |
| print(f" [FAIL] {name} - {details}") | |
| failed += 1 | |
| def create_test_user(db, first_name: str, last_name: str) -> User: | |
| """Create a test user and return the user object.""" | |
| from app.models import generate_user_id | |
| user_id = generate_user_id(first_name) | |
| user = User(id=user_id, first_name=first_name, last_name=last_name) | |
| db.add(user) | |
| db.commit() | |
| db.refresh(user) | |
| return user | |
| def create_test_log_entry(db, project_id: str, task_id: str, user_id: str, raw_input: str, hours_ago: int = 0) -> LogEntry: | |
| """Create a test log entry.""" | |
| entry = LogEntry( | |
| project_id=project_id, | |
| task_id=task_id, | |
| user_id=user_id, | |
| actor_type=ActorType.human, | |
| action_type=ActionType.task_completed, | |
| raw_input=raw_input, | |
| generated_doc=f"Documentation for: {raw_input}", | |
| tags=["test"], | |
| created_at=datetime.now() - timedelta(hours=hours_ago) | |
| ) | |
| db.add(entry) | |
| db.commit() | |
| db.refresh(entry) | |
| return entry | |
| def setup_test_data(): | |
| """Set up test data and return IDs.""" | |
| db = SessionLocal() | |
| try: | |
| # Create users | |
| user1 = create_test_user(db, "Alice", "Developer") | |
| user2 = create_test_user(db, "Bob", "Developer") | |
| # Create project | |
| project_result = create_project( | |
| name=f"Smart Query Test Project {datetime.now().timestamp()}", | |
| description="Testing smart query functionality", | |
| user_id=user1.id | |
| ) | |
| project_id = project_result["id"] | |
| # Add user2 to project | |
| membership = ProjectMembership(project_id=project_id, user_id=user2.id, role="member") | |
| db.add(membership) | |
| db.commit() | |
| # Create tasks | |
| task1_result = create_task( | |
| project_id=project_id, | |
| title="Implement Authentication API", | |
| description="Build JWT auth system", | |
| assigned_to=user1.id | |
| ) | |
| task1_id = task1_result["id"] | |
| task2_result = create_task( | |
| project_id=project_id, | |
| title="Write Unit Tests", | |
| description="Create test coverage", | |
| assigned_to=user2.id | |
| ) | |
| task2_id = task2_result["id"] | |
| # Mark task1 as done | |
| task1 = db.query(Task).filter(Task.id == task1_id).first() | |
| task1.status = TaskStatus.done | |
| task1.completed_at = datetime.now() - timedelta(hours=2) | |
| db.commit() | |
| # Create log entries | |
| log1 = create_test_log_entry( | |
| db, project_id, task1_id, user1.id, | |
| "Implemented JWT authentication with refresh tokens", | |
| hours_ago=2 | |
| ) | |
| log2 = create_test_log_entry( | |
| db, project_id, task1_id, user1.id, | |
| "Added password hashing with bcrypt", | |
| hours_ago=25 # Yesterday | |
| ) | |
| log3 = create_test_log_entry( | |
| db, project_id, task2_id, user2.id, | |
| "Created unit tests for auth module", | |
| hours_ago=1 | |
| ) | |
| return { | |
| "project_id": project_id, | |
| "user1_id": user1.id, | |
| "user1_name": user1.name, | |
| "user2_id": user2.id, | |
| "user2_name": user2.name, | |
| "task1_id": task1_id, | |
| "task2_id": task2_id, | |
| "log1_id": log1.id, | |
| "log2_id": log2.id, | |
| "log3_id": log3.id, | |
| } | |
| finally: | |
| db.close() | |
| def test_tool_functions(test_data): | |
| """Test individual tool functions.""" | |
| print("\n[1/4] Testing Smart Query Tool Functions") | |
| print("-" * 50) | |
| from app.smart_query import ( | |
| _tool_get_user_activity, | |
| _tool_get_task_status, | |
| _tool_check_completion, | |
| _tool_list_users, | |
| QueryContext, | |
| ) | |
| db = SessionLocal() | |
| context = QueryContext( | |
| current_user_id=test_data["user1_id"], | |
| current_datetime=datetime.now(), | |
| project_id=test_data["project_id"] | |
| ) | |
| try: | |
| # Test _tool_get_user_activity with user_id | |
| print("\n Testing _tool_get_user_activity...") | |
| today = datetime.now().date() | |
| yesterday = (datetime.now() - timedelta(days=1)).date() | |
| result = _tool_get_user_activity(db, context, { | |
| "user_id": test_data["user1_id"], | |
| "date_from": yesterday.isoformat(), | |
| "date_to": (today + timedelta(days=1)).isoformat() | |
| }) | |
| test("get_user_activity returns dict", isinstance(result, dict)) | |
| test("get_user_activity has count", "count" in result) | |
| test("get_user_activity has activities", "activities" in result) | |
| test("get_user_activity finds entries", result["count"] >= 1, f"got {result['count']}") | |
| # Test _tool_get_user_activity with user_name | |
| result2 = _tool_get_user_activity(db, context, { | |
| "user_name": "Alice", | |
| "date_from": yesterday.isoformat(), | |
| "date_to": (today + timedelta(days=1)).isoformat() | |
| }) | |
| test("get_user_activity resolves name", result2["user_id"] == test_data["user1_id"]) | |
| # Test _tool_get_task_status by ID | |
| print("\n Testing _tool_get_task_status...") | |
| result = _tool_get_task_status(db, test_data["project_id"], { | |
| "task_id": test_data["task1_id"] | |
| }) | |
| test("get_task_status returns dict", isinstance(result, dict)) | |
| test("get_task_status found=True", result.get("found") == True) | |
| test("get_task_status has task", "task" in result) | |
| test("get_task_status correct title", "Authentication" in result["task"]["title"]) | |
| test("get_task_status status is done", result["task"]["status"] == "done") | |
| # Test _tool_get_task_status by title | |
| result2 = _tool_get_task_status(db, test_data["project_id"], { | |
| "task_title": "Unit Tests" | |
| }) | |
| test("get_task_status by title works", result2.get("found") == True) | |
| test("get_task_status by title correct", "Unit Tests" in result2["task"]["title"]) | |
| # Test _tool_get_task_status not found | |
| result3 = _tool_get_task_status(db, test_data["project_id"], { | |
| "task_title": "Nonexistent Task XYZ" | |
| }) | |
| test("get_task_status not found", result3.get("found") == False) | |
| # Test _tool_check_completion | |
| print("\n Testing _tool_check_completion...") | |
| result = _tool_check_completion(db, test_data["project_id"], { | |
| "task_title": "Authentication" | |
| }) | |
| test("check_completion returns dict", isinstance(result, dict)) | |
| test("check_completion found", result.get("found") == True) | |
| test("check_completion is_completed", result.get("is_completed") == True) | |
| test("check_completion has details", result.get("completion_details") is not None) | |
| # Test _tool_check_completion with user | |
| result2 = _tool_check_completion(db, test_data["project_id"], { | |
| "task_title": "Authentication", | |
| "user_name": "Alice" | |
| }) | |
| test("check_completion by user works", result2.get("completed_by_specified_user") == True) | |
| result3 = _tool_check_completion(db, test_data["project_id"], { | |
| "task_title": "Authentication", | |
| "user_name": "Bob" | |
| }) | |
| test("check_completion wrong user", result3.get("completed_by_specified_user") == False) | |
| # Test _tool_list_users | |
| print("\n Testing _tool_list_users...") | |
| result = _tool_list_users(db, test_data["project_id"]) | |
| test("list_users returns dict", isinstance(result, dict)) | |
| test("list_users has users", "users" in result) | |
| test("list_users has count", "count" in result) | |
| test("list_users finds 2 users", result["count"] == 2, f"got {result['count']}") | |
| user_names = [u["name"] for u in result["users"]] | |
| test("list_users includes Alice", any("Alice" in n for n in user_names)) | |
| test("list_users includes Bob", any("Bob" in n for n in user_names)) | |
| finally: | |
| db.close() | |
| def test_extract_sources(): | |
| """Test the extract_sources helper.""" | |
| print("\n[2/4] Testing extract_sources Helper") | |
| print("-" * 50) | |
| from app.smart_query import extract_sources | |
| # Test with activity results | |
| tool_results = [ | |
| { | |
| "tool": "get_user_activity", | |
| "result": { | |
| "activities": [ | |
| {"id": "log-1", "what_was_done": "Built auth", "timestamp": "2024-01-01T10:00:00"}, | |
| {"id": "log-2", "what_was_done": "Fixed bug", "timestamp": "2024-01-01T11:00:00"}, | |
| ] | |
| } | |
| } | |
| ] | |
| sources = extract_sources(tool_results) | |
| test("extract_sources returns list", isinstance(sources, list)) | |
| test("extract_sources extracts activities", len(sources) == 2) | |
| test("extract_sources has type=activity", sources[0]["type"] == "activity") | |
| # Test with semantic search results | |
| tool_results2 = [ | |
| { | |
| "tool": "semantic_search", | |
| "result": { | |
| "results": [ | |
| {"id": "mem-1", "what_was_done": "Implemented login", "relevance_score": 0.95}, | |
| ] | |
| } | |
| } | |
| ] | |
| sources2 = extract_sources(tool_results2) | |
| test("extract_sources handles search results", len(sources2) == 1) | |
| test("extract_sources has type=memory", sources2[0]["type"] == "memory") | |
| # Test with task results | |
| tool_results3 = [ | |
| { | |
| "tool": "get_task_status", | |
| "result": { | |
| "found": True, | |
| "task": {"id": "task-1", "title": "Auth API", "status": "done"} | |
| } | |
| } | |
| ] | |
| sources3 = extract_sources(tool_results3) | |
| test("extract_sources handles task results", len(sources3) == 1) | |
| test("extract_sources has type=task", sources3[0]["type"] == "task") | |
| # Test deduplication | |
| tool_results4 = [ | |
| {"tool": "get_user_activity", "result": {"activities": [{"id": "dup-1", "what_was_done": "Test"}]}}, | |
| {"tool": "get_user_activity", "result": {"activities": [{"id": "dup-1", "what_was_done": "Test"}]}}, | |
| ] | |
| sources4 = extract_sources(tool_results4) | |
| test("extract_sources deduplicates", len(sources4) == 1) | |
| def test_query_context(): | |
| """Test QueryContext dataclass.""" | |
| print("\n[3/4] Testing QueryContext") | |
| print("-" * 50) | |
| from app.smart_query import QueryContext | |
| ctx = QueryContext( | |
| current_user_id="user-123", | |
| current_datetime=datetime(2024, 1, 15, 10, 30, 0), | |
| project_id="proj-456" | |
| ) | |
| test("QueryContext has current_user_id", ctx.current_user_id == "user-123") | |
| test("QueryContext has current_datetime", ctx.current_datetime.year == 2024) | |
| test("QueryContext has project_id", ctx.project_id == "proj-456") | |
| async def test_semantic_search_tool(test_data): | |
| """Test semantic search tool (requires API key).""" | |
| print("\n[4/4] Testing Semantic Search Tool (requires API)") | |
| print("-" * 50) | |
| try: | |
| from app.smart_query import _tool_semantic_search | |
| from app.llm import get_embedding | |
| from app.vectorstore import add_embedding | |
| # Add test embedding | |
| text = "Implemented JWT authentication with refresh tokens" | |
| emb = await get_embedding(text) | |
| add_embedding( | |
| log_entry_id=test_data["log1_id"], | |
| text=text, | |
| embedding=emb, | |
| metadata={ | |
| "project_id": test_data["project_id"], | |
| "user_id": test_data["user1_id"], | |
| "task_id": test_data["task1_id"], | |
| "created_at": datetime.now().isoformat() | |
| } | |
| ) | |
| # Test semantic search | |
| result = await _tool_semantic_search(test_data["project_id"], { | |
| "search_query": "authentication JWT" | |
| }) | |
| test("semantic_search returns dict", isinstance(result, dict)) | |
| test("semantic_search has results", "results" in result) | |
| test("semantic_search has count", "count" in result) | |
| test("semantic_search finds results", result["count"] > 0, f"got {result['count']}") | |
| if result["results"]: | |
| test("semantic_search result has id", "id" in result["results"][0]) | |
| test("semantic_search result has relevance", "relevance_score" in result["results"][0]) | |
| except Exception as e: | |
| test("semantic_search (API required)", False, str(e)) | |
| async def test_full_smart_query(test_data): | |
| """Test the full smart_query function (requires API key).""" | |
| print("\n[BONUS] Testing Full smart_query Integration") | |
| print("-" * 50) | |
| try: | |
| from app.smart_query import smart_query | |
| # Test simple query | |
| result = await smart_query( | |
| project_id=test_data["project_id"], | |
| query="What tasks are done?", | |
| current_user_id=test_data["user1_id"], | |
| current_datetime=datetime.now().isoformat() | |
| ) | |
| test("smart_query returns dict", isinstance(result, dict)) | |
| test("smart_query has answer", "answer" in result) | |
| test("smart_query has tools_used", "tools_used" in result) | |
| test("smart_query has sources", "sources" in result) | |
| test("smart_query answer not empty", len(result.get("answer", "")) > 0) | |
| print(f"\n Query: 'What tasks are done?'") | |
| print(f" Answer: {result.get('answer', '')[:200]}...") | |
| print(f" Tools used: {result.get('tools_used', [])}") | |
| except Exception as e: | |
| test("smart_query integration (API required)", False, str(e)) | |
| def test_user_resolution(test_data): | |
| """Test duplicate name handling in user resolution.""" | |
| print("\n[NEW] Testing User Resolution & Duplicate Name Handling") | |
| print("-" * 50) | |
| from app.smart_query import ( | |
| _resolve_user_in_project, | |
| _get_recent_work_hint, | |
| _tool_get_user_activity, | |
| _tool_check_completion, | |
| QueryContext, | |
| ) | |
| db = SessionLocal() | |
| try: | |
| # Test _get_recent_work_hint | |
| print("\n Testing _get_recent_work_hint...") | |
| hint = _get_recent_work_hint(db, str(test_data["user1_id"]), test_data["project_id"]) | |
| test("get_recent_work_hint returns string", isinstance(hint, str)) | |
| test("get_recent_work_hint has content", len(hint) > 0) | |
| test("get_recent_work_hint format correct", "worked on:" in hint or hint == "no recent activity") | |
| # Test _resolve_user_in_project - single match | |
| print("\n Testing _resolve_user_in_project (single match)...") | |
| result = _resolve_user_in_project(db, test_data["project_id"], "Alice") | |
| test("resolve single user found=True", result.get("found") == True) | |
| test("resolve single user has user_id", "user_id" in result) | |
| test("resolve single user correct id", result.get("user_id") == str(test_data["user1_id"])) | |
| # Test _resolve_user_in_project - not found | |
| print("\n Testing _resolve_user_in_project (not found)...") | |
| result2 = _resolve_user_in_project(db, test_data["project_id"], "NonExistentUser") | |
| test("resolve not found returns found=False", result2.get("found") == False) | |
| test("resolve not found reason=not_found", result2.get("reason") == "not_found") | |
| test("resolve not found has message", "message" in result2) | |
| # Create duplicate name scenario | |
| print("\n Testing duplicate name handling...") | |
| # Add another user with similar name to the project | |
| duplicate_user = User(name="Dev Alice Smith", email=f"alice.smith_{datetime.now().timestamp()}@test.com") | |
| db.add(duplicate_user) | |
| db.commit() | |
| db.refresh(duplicate_user) | |
| # Add to project | |
| membership = ProjectMembership(project_id=test_data["project_id"], user_id=duplicate_user.id, role="member") | |
| db.add(membership) | |
| db.commit() | |
| # Now search for "Alice" - should find 2 users | |
| result3 = _resolve_user_in_project(db, test_data["project_id"], "Alice") | |
| test("resolve ambiguous found=False", result3.get("found") == False) | |
| test("resolve ambiguous reason=ambiguous", result3.get("reason") == "ambiguous") | |
| test("resolve ambiguous has options", "options" in result3) | |
| test("resolve ambiguous options is list", isinstance(result3.get("options"), list)) | |
| test("resolve ambiguous has 2 options", len(result3.get("options", [])) == 2, f"got {len(result3.get('options', []))}") | |
| if result3.get("options"): | |
| opt = result3["options"][0] | |
| test("option has user_id", "user_id" in opt) | |
| test("option has name", "name" in opt) | |
| test("option has email", "email" in opt) | |
| test("option has role", "role" in opt) | |
| test("option has recent_work", "recent_work" in opt) | |
| # Test that _tool_get_user_activity returns disambiguation | |
| print("\n Testing tool returns disambiguation...") | |
| context = QueryContext( | |
| current_user_id=str(test_data["user1_id"]), | |
| current_datetime=datetime.now(), | |
| project_id=test_data["project_id"] | |
| ) | |
| today = datetime.now().date() | |
| yesterday = (datetime.now() - timedelta(days=1)).date() | |
| activity_result = _tool_get_user_activity(db, context, { | |
| "user_name": "Alice", | |
| "date_from": yesterday.isoformat(), | |
| "date_to": (today + timedelta(days=1)).isoformat() | |
| }) | |
| test("tool returns ambiguous response", activity_result.get("found") == False) | |
| test("tool has options", "options" in activity_result) | |
| # Test that _tool_check_completion returns disambiguation | |
| completion_result = _tool_check_completion(db, test_data["project_id"], { | |
| "task_title": "Authentication", | |
| "user_name": "Alice" | |
| }) | |
| test("check_completion returns ambiguous", completion_result.get("found") == False) | |
| test("check_completion has options", "options" in completion_result) | |
| # Cleanup the duplicate user | |
| db.query(ProjectMembership).filter(ProjectMembership.user_id == duplicate_user.id).delete() | |
| db.query(User).filter(User.id == duplicate_user.id).delete() | |
| db.commit() | |
| # Test email fallback | |
| print("\n Testing email fallback...") | |
| result4 = _resolve_user_in_project(db, test_data["project_id"], "alice") # Part of email | |
| # This might match by name or email depending on data | |
| test("email fallback works", result4.get("found") == True or result4.get("reason") == "ambiguous") | |
| finally: | |
| db.close() | |
| async def test_memory_search_filters(test_data): | |
| """Test memory_search with filters (requires API key).""" | |
| print("\n[FILTERS] Testing memory_search Filters") | |
| print("-" * 50) | |
| try: | |
| from app.tools.memory import memory_search | |
| from app.llm import get_embedding | |
| from app.vectorstore import add_embedding | |
| # Add test embeddings with different users and dates | |
| text1 = "User1 implemented login feature yesterday" | |
| text2 = "User2 added payment processing today" | |
| emb1 = await get_embedding(text1) | |
| emb2 = await get_embedding(text2) | |
| yesterday = (datetime.now() - timedelta(days=1)).isoformat() | |
| today = datetime.now().isoformat() | |
| add_embedding( | |
| log_entry_id=str(test_data["log1_id"]) + "-filter1", | |
| text=text1, | |
| embedding=emb1, | |
| metadata={ | |
| "project_id": test_data["project_id"], | |
| "user_id": str(test_data["user1_id"]), | |
| "task_id": str(test_data["task1_id"]), | |
| "created_at": yesterday | |
| } | |
| ) | |
| add_embedding( | |
| log_entry_id=str(test_data["log2_id"]) + "-filter2", | |
| text=text2, | |
| embedding=emb2, | |
| metadata={ | |
| "project_id": test_data["project_id"], | |
| "user_id": str(test_data["user2_id"]), | |
| "task_id": str(test_data["task2_id"]), | |
| "created_at": today | |
| } | |
| ) | |
| # Test 1: Search without filters | |
| print("\n Testing search without filters...") | |
| result_no_filter = await memory_search( | |
| project_id=test_data["project_id"], | |
| query="feature implementation" | |
| ) | |
| test("search without filters returns answer", "answer" in result_no_filter) | |
| # Test 2: Search with userId filter | |
| print("\n Testing search with userId filter...") | |
| result_user_filter = await memory_search( | |
| project_id=test_data["project_id"], | |
| query="feature implementation", | |
| filters={"userId": str(test_data["user1_id"])} | |
| ) | |
| test("search with userId filter returns answer", "answer" in result_user_filter) | |
| # Test 3: Search with date range filter | |
| print("\n Testing search with date filters...") | |
| tomorrow = (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d") | |
| result_date_filter = await memory_search( | |
| project_id=test_data["project_id"], | |
| query="feature implementation", | |
| filters={ | |
| "dateFrom": (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d"), | |
| "dateTo": tomorrow | |
| } | |
| ) | |
| test("search with date filter returns answer", "answer" in result_date_filter) | |
| # Test 4: Search with combined filters | |
| print("\n Testing search with combined filters...") | |
| result_combined = await memory_search( | |
| project_id=test_data["project_id"], | |
| query="feature", | |
| filters={ | |
| "userId": str(test_data["user1_id"]), | |
| "dateFrom": (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d") | |
| } | |
| ) | |
| test("search with combined filters returns answer", "answer" in result_combined) | |
| print("\n Filter wiring tests passed!") | |
| except Exception as e: | |
| test("memory_search filters (API required)", False, str(e)) | |
| def cleanup_test_data(test_data): | |
| """Clean up test data.""" | |
| try: | |
| delete_by_project(test_data["project_id"]) | |
| except: | |
| pass | |
| async def main(): | |
| print("=" * 60) | |
| print(" SMART QUERY TEST SUITE") | |
| print("=" * 60) | |
| # Initialize | |
| print("\nInitializing database and vectorstore...") | |
| init_db() | |
| init_vectorstore() | |
| # Setup test data | |
| print("Setting up test data...") | |
| test_data = setup_test_data() | |
| print(f" Project ID: {test_data['project_id']}") | |
| print(f" User 1: {test_data['user1_name']} ({test_data['user1_id']})") | |
| print(f" User 2: {test_data['user2_name']} ({test_data['user2_id']})") | |
| try: | |
| # Run tests | |
| test_tool_functions(test_data) | |
| test_extract_sources() | |
| test_query_context() | |
| test_user_resolution(test_data) # New: duplicate name handling tests | |
| # API-dependent tests (will fail gracefully without API key) | |
| await test_semantic_search_tool(test_data) | |
| await test_memory_search_filters(test_data) # New: filter wiring tests | |
| await test_full_smart_query(test_data) | |
| finally: | |
| # Cleanup | |
| print("\nCleaning up test data...") | |
| cleanup_test_data(test_data) | |
| # Results | |
| print("\n" + "=" * 60) | |
| print(f" RESULTS: {passed} passed, {failed} failed") | |
| print("=" * 60) | |
| if failed > 0: | |
| print("\nNote: Some tests require GEMINI_API_KEY to be set.") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |