ProjectMemory / backend /tests /test_smart_query.py
Amal Nimmy Lal
feat : Project Memory
35765b5
"""
Smart Query Test Suite - Run with: python test_smart_query.py
Tests the smart query tool functions and integration.
Requires GEMINI_API_KEY for full integration tests.
"""
import asyncio
import sys
from datetime import datetime, timedelta
sys.path.insert(0, '.')
from app.database import init_db, SessionLocal
from app.models import User, Task, LogEntry, TaskStatus, ActorType, ActionType, ProjectMembership
from app.tools.projects import create_project
from app.tools.tasks import create_task
from app.vectorstore import init_vectorstore, delete_by_project
# Test counters
passed = 0
failed = 0
def test(name, condition, details=""):
global passed, failed
if condition:
print(f" [PASS] {name}")
passed += 1
else:
print(f" [FAIL] {name} - {details}")
failed += 1
def create_test_user(db, first_name: str, last_name: str) -> User:
"""Create a test user and return the user object."""
from app.models import generate_user_id
user_id = generate_user_id(first_name)
user = User(id=user_id, first_name=first_name, last_name=last_name)
db.add(user)
db.commit()
db.refresh(user)
return user
def create_test_log_entry(db, project_id: str, task_id: str, user_id: str, raw_input: str, hours_ago: int = 0) -> LogEntry:
"""Create a test log entry."""
entry = LogEntry(
project_id=project_id,
task_id=task_id,
user_id=user_id,
actor_type=ActorType.human,
action_type=ActionType.task_completed,
raw_input=raw_input,
generated_doc=f"Documentation for: {raw_input}",
tags=["test"],
created_at=datetime.now() - timedelta(hours=hours_ago)
)
db.add(entry)
db.commit()
db.refresh(entry)
return entry
def setup_test_data():
"""Set up test data and return IDs."""
db = SessionLocal()
try:
# Create users
user1 = create_test_user(db, "Alice", "Developer")
user2 = create_test_user(db, "Bob", "Developer")
# Create project
project_result = create_project(
name=f"Smart Query Test Project {datetime.now().timestamp()}",
description="Testing smart query functionality",
user_id=user1.id
)
project_id = project_result["id"]
# Add user2 to project
membership = ProjectMembership(project_id=project_id, user_id=user2.id, role="member")
db.add(membership)
db.commit()
# Create tasks
task1_result = create_task(
project_id=project_id,
title="Implement Authentication API",
description="Build JWT auth system",
assigned_to=user1.id
)
task1_id = task1_result["id"]
task2_result = create_task(
project_id=project_id,
title="Write Unit Tests",
description="Create test coverage",
assigned_to=user2.id
)
task2_id = task2_result["id"]
# Mark task1 as done
task1 = db.query(Task).filter(Task.id == task1_id).first()
task1.status = TaskStatus.done
task1.completed_at = datetime.now() - timedelta(hours=2)
db.commit()
# Create log entries
log1 = create_test_log_entry(
db, project_id, task1_id, user1.id,
"Implemented JWT authentication with refresh tokens",
hours_ago=2
)
log2 = create_test_log_entry(
db, project_id, task1_id, user1.id,
"Added password hashing with bcrypt",
hours_ago=25 # Yesterday
)
log3 = create_test_log_entry(
db, project_id, task2_id, user2.id,
"Created unit tests for auth module",
hours_ago=1
)
return {
"project_id": project_id,
"user1_id": user1.id,
"user1_name": user1.name,
"user2_id": user2.id,
"user2_name": user2.name,
"task1_id": task1_id,
"task2_id": task2_id,
"log1_id": log1.id,
"log2_id": log2.id,
"log3_id": log3.id,
}
finally:
db.close()
def test_tool_functions(test_data):
"""Test individual tool functions."""
print("\n[1/4] Testing Smart Query Tool Functions")
print("-" * 50)
from app.smart_query import (
_tool_get_user_activity,
_tool_get_task_status,
_tool_check_completion,
_tool_list_users,
QueryContext,
)
db = SessionLocal()
context = QueryContext(
current_user_id=test_data["user1_id"],
current_datetime=datetime.now(),
project_id=test_data["project_id"]
)
try:
# Test _tool_get_user_activity with user_id
print("\n Testing _tool_get_user_activity...")
today = datetime.now().date()
yesterday = (datetime.now() - timedelta(days=1)).date()
result = _tool_get_user_activity(db, context, {
"user_id": test_data["user1_id"],
"date_from": yesterday.isoformat(),
"date_to": (today + timedelta(days=1)).isoformat()
})
test("get_user_activity returns dict", isinstance(result, dict))
test("get_user_activity has count", "count" in result)
test("get_user_activity has activities", "activities" in result)
test("get_user_activity finds entries", result["count"] >= 1, f"got {result['count']}")
# Test _tool_get_user_activity with user_name
result2 = _tool_get_user_activity(db, context, {
"user_name": "Alice",
"date_from": yesterday.isoformat(),
"date_to": (today + timedelta(days=1)).isoformat()
})
test("get_user_activity resolves name", result2["user_id"] == test_data["user1_id"])
# Test _tool_get_task_status by ID
print("\n Testing _tool_get_task_status...")
result = _tool_get_task_status(db, test_data["project_id"], {
"task_id": test_data["task1_id"]
})
test("get_task_status returns dict", isinstance(result, dict))
test("get_task_status found=True", result.get("found") == True)
test("get_task_status has task", "task" in result)
test("get_task_status correct title", "Authentication" in result["task"]["title"])
test("get_task_status status is done", result["task"]["status"] == "done")
# Test _tool_get_task_status by title
result2 = _tool_get_task_status(db, test_data["project_id"], {
"task_title": "Unit Tests"
})
test("get_task_status by title works", result2.get("found") == True)
test("get_task_status by title correct", "Unit Tests" in result2["task"]["title"])
# Test _tool_get_task_status not found
result3 = _tool_get_task_status(db, test_data["project_id"], {
"task_title": "Nonexistent Task XYZ"
})
test("get_task_status not found", result3.get("found") == False)
# Test _tool_check_completion
print("\n Testing _tool_check_completion...")
result = _tool_check_completion(db, test_data["project_id"], {
"task_title": "Authentication"
})
test("check_completion returns dict", isinstance(result, dict))
test("check_completion found", result.get("found") == True)
test("check_completion is_completed", result.get("is_completed") == True)
test("check_completion has details", result.get("completion_details") is not None)
# Test _tool_check_completion with user
result2 = _tool_check_completion(db, test_data["project_id"], {
"task_title": "Authentication",
"user_name": "Alice"
})
test("check_completion by user works", result2.get("completed_by_specified_user") == True)
result3 = _tool_check_completion(db, test_data["project_id"], {
"task_title": "Authentication",
"user_name": "Bob"
})
test("check_completion wrong user", result3.get("completed_by_specified_user") == False)
# Test _tool_list_users
print("\n Testing _tool_list_users...")
result = _tool_list_users(db, test_data["project_id"])
test("list_users returns dict", isinstance(result, dict))
test("list_users has users", "users" in result)
test("list_users has count", "count" in result)
test("list_users finds 2 users", result["count"] == 2, f"got {result['count']}")
user_names = [u["name"] for u in result["users"]]
test("list_users includes Alice", any("Alice" in n for n in user_names))
test("list_users includes Bob", any("Bob" in n for n in user_names))
finally:
db.close()
def test_extract_sources():
"""Test the extract_sources helper."""
print("\n[2/4] Testing extract_sources Helper")
print("-" * 50)
from app.smart_query import extract_sources
# Test with activity results
tool_results = [
{
"tool": "get_user_activity",
"result": {
"activities": [
{"id": "log-1", "what_was_done": "Built auth", "timestamp": "2024-01-01T10:00:00"},
{"id": "log-2", "what_was_done": "Fixed bug", "timestamp": "2024-01-01T11:00:00"},
]
}
}
]
sources = extract_sources(tool_results)
test("extract_sources returns list", isinstance(sources, list))
test("extract_sources extracts activities", len(sources) == 2)
test("extract_sources has type=activity", sources[0]["type"] == "activity")
# Test with semantic search results
tool_results2 = [
{
"tool": "semantic_search",
"result": {
"results": [
{"id": "mem-1", "what_was_done": "Implemented login", "relevance_score": 0.95},
]
}
}
]
sources2 = extract_sources(tool_results2)
test("extract_sources handles search results", len(sources2) == 1)
test("extract_sources has type=memory", sources2[0]["type"] == "memory")
# Test with task results
tool_results3 = [
{
"tool": "get_task_status",
"result": {
"found": True,
"task": {"id": "task-1", "title": "Auth API", "status": "done"}
}
}
]
sources3 = extract_sources(tool_results3)
test("extract_sources handles task results", len(sources3) == 1)
test("extract_sources has type=task", sources3[0]["type"] == "task")
# Test deduplication
tool_results4 = [
{"tool": "get_user_activity", "result": {"activities": [{"id": "dup-1", "what_was_done": "Test"}]}},
{"tool": "get_user_activity", "result": {"activities": [{"id": "dup-1", "what_was_done": "Test"}]}},
]
sources4 = extract_sources(tool_results4)
test("extract_sources deduplicates", len(sources4) == 1)
def test_query_context():
"""Test QueryContext dataclass."""
print("\n[3/4] Testing QueryContext")
print("-" * 50)
from app.smart_query import QueryContext
ctx = QueryContext(
current_user_id="user-123",
current_datetime=datetime(2024, 1, 15, 10, 30, 0),
project_id="proj-456"
)
test("QueryContext has current_user_id", ctx.current_user_id == "user-123")
test("QueryContext has current_datetime", ctx.current_datetime.year == 2024)
test("QueryContext has project_id", ctx.project_id == "proj-456")
async def test_semantic_search_tool(test_data):
"""Test semantic search tool (requires API key)."""
print("\n[4/4] Testing Semantic Search Tool (requires API)")
print("-" * 50)
try:
from app.smart_query import _tool_semantic_search
from app.llm import get_embedding
from app.vectorstore import add_embedding
# Add test embedding
text = "Implemented JWT authentication with refresh tokens"
emb = await get_embedding(text)
add_embedding(
log_entry_id=test_data["log1_id"],
text=text,
embedding=emb,
metadata={
"project_id": test_data["project_id"],
"user_id": test_data["user1_id"],
"task_id": test_data["task1_id"],
"created_at": datetime.now().isoformat()
}
)
# Test semantic search
result = await _tool_semantic_search(test_data["project_id"], {
"search_query": "authentication JWT"
})
test("semantic_search returns dict", isinstance(result, dict))
test("semantic_search has results", "results" in result)
test("semantic_search has count", "count" in result)
test("semantic_search finds results", result["count"] > 0, f"got {result['count']}")
if result["results"]:
test("semantic_search result has id", "id" in result["results"][0])
test("semantic_search result has relevance", "relevance_score" in result["results"][0])
except Exception as e:
test("semantic_search (API required)", False, str(e))
async def test_full_smart_query(test_data):
"""Test the full smart_query function (requires API key)."""
print("\n[BONUS] Testing Full smart_query Integration")
print("-" * 50)
try:
from app.smart_query import smart_query
# Test simple query
result = await smart_query(
project_id=test_data["project_id"],
query="What tasks are done?",
current_user_id=test_data["user1_id"],
current_datetime=datetime.now().isoformat()
)
test("smart_query returns dict", isinstance(result, dict))
test("smart_query has answer", "answer" in result)
test("smart_query has tools_used", "tools_used" in result)
test("smart_query has sources", "sources" in result)
test("smart_query answer not empty", len(result.get("answer", "")) > 0)
print(f"\n Query: 'What tasks are done?'")
print(f" Answer: {result.get('answer', '')[:200]}...")
print(f" Tools used: {result.get('tools_used', [])}")
except Exception as e:
test("smart_query integration (API required)", False, str(e))
def test_user_resolution(test_data):
"""Test duplicate name handling in user resolution."""
print("\n[NEW] Testing User Resolution & Duplicate Name Handling")
print("-" * 50)
from app.smart_query import (
_resolve_user_in_project,
_get_recent_work_hint,
_tool_get_user_activity,
_tool_check_completion,
QueryContext,
)
db = SessionLocal()
try:
# Test _get_recent_work_hint
print("\n Testing _get_recent_work_hint...")
hint = _get_recent_work_hint(db, str(test_data["user1_id"]), test_data["project_id"])
test("get_recent_work_hint returns string", isinstance(hint, str))
test("get_recent_work_hint has content", len(hint) > 0)
test("get_recent_work_hint format correct", "worked on:" in hint or hint == "no recent activity")
# Test _resolve_user_in_project - single match
print("\n Testing _resolve_user_in_project (single match)...")
result = _resolve_user_in_project(db, test_data["project_id"], "Alice")
test("resolve single user found=True", result.get("found") == True)
test("resolve single user has user_id", "user_id" in result)
test("resolve single user correct id", result.get("user_id") == str(test_data["user1_id"]))
# Test _resolve_user_in_project - not found
print("\n Testing _resolve_user_in_project (not found)...")
result2 = _resolve_user_in_project(db, test_data["project_id"], "NonExistentUser")
test("resolve not found returns found=False", result2.get("found") == False)
test("resolve not found reason=not_found", result2.get("reason") == "not_found")
test("resolve not found has message", "message" in result2)
# Create duplicate name scenario
print("\n Testing duplicate name handling...")
# Add another user with similar name to the project
duplicate_user = User(name="Dev Alice Smith", email=f"alice.smith_{datetime.now().timestamp()}@test.com")
db.add(duplicate_user)
db.commit()
db.refresh(duplicate_user)
# Add to project
membership = ProjectMembership(project_id=test_data["project_id"], user_id=duplicate_user.id, role="member")
db.add(membership)
db.commit()
# Now search for "Alice" - should find 2 users
result3 = _resolve_user_in_project(db, test_data["project_id"], "Alice")
test("resolve ambiguous found=False", result3.get("found") == False)
test("resolve ambiguous reason=ambiguous", result3.get("reason") == "ambiguous")
test("resolve ambiguous has options", "options" in result3)
test("resolve ambiguous options is list", isinstance(result3.get("options"), list))
test("resolve ambiguous has 2 options", len(result3.get("options", [])) == 2, f"got {len(result3.get('options', []))}")
if result3.get("options"):
opt = result3["options"][0]
test("option has user_id", "user_id" in opt)
test("option has name", "name" in opt)
test("option has email", "email" in opt)
test("option has role", "role" in opt)
test("option has recent_work", "recent_work" in opt)
# Test that _tool_get_user_activity returns disambiguation
print("\n Testing tool returns disambiguation...")
context = QueryContext(
current_user_id=str(test_data["user1_id"]),
current_datetime=datetime.now(),
project_id=test_data["project_id"]
)
today = datetime.now().date()
yesterday = (datetime.now() - timedelta(days=1)).date()
activity_result = _tool_get_user_activity(db, context, {
"user_name": "Alice",
"date_from": yesterday.isoformat(),
"date_to": (today + timedelta(days=1)).isoformat()
})
test("tool returns ambiguous response", activity_result.get("found") == False)
test("tool has options", "options" in activity_result)
# Test that _tool_check_completion returns disambiguation
completion_result = _tool_check_completion(db, test_data["project_id"], {
"task_title": "Authentication",
"user_name": "Alice"
})
test("check_completion returns ambiguous", completion_result.get("found") == False)
test("check_completion has options", "options" in completion_result)
# Cleanup the duplicate user
db.query(ProjectMembership).filter(ProjectMembership.user_id == duplicate_user.id).delete()
db.query(User).filter(User.id == duplicate_user.id).delete()
db.commit()
# Test email fallback
print("\n Testing email fallback...")
result4 = _resolve_user_in_project(db, test_data["project_id"], "alice") # Part of email
# This might match by name or email depending on data
test("email fallback works", result4.get("found") == True or result4.get("reason") == "ambiguous")
finally:
db.close()
async def test_memory_search_filters(test_data):
"""Test memory_search with filters (requires API key)."""
print("\n[FILTERS] Testing memory_search Filters")
print("-" * 50)
try:
from app.tools.memory import memory_search
from app.llm import get_embedding
from app.vectorstore import add_embedding
# Add test embeddings with different users and dates
text1 = "User1 implemented login feature yesterday"
text2 = "User2 added payment processing today"
emb1 = await get_embedding(text1)
emb2 = await get_embedding(text2)
yesterday = (datetime.now() - timedelta(days=1)).isoformat()
today = datetime.now().isoformat()
add_embedding(
log_entry_id=str(test_data["log1_id"]) + "-filter1",
text=text1,
embedding=emb1,
metadata={
"project_id": test_data["project_id"],
"user_id": str(test_data["user1_id"]),
"task_id": str(test_data["task1_id"]),
"created_at": yesterday
}
)
add_embedding(
log_entry_id=str(test_data["log2_id"]) + "-filter2",
text=text2,
embedding=emb2,
metadata={
"project_id": test_data["project_id"],
"user_id": str(test_data["user2_id"]),
"task_id": str(test_data["task2_id"]),
"created_at": today
}
)
# Test 1: Search without filters
print("\n Testing search without filters...")
result_no_filter = await memory_search(
project_id=test_data["project_id"],
query="feature implementation"
)
test("search without filters returns answer", "answer" in result_no_filter)
# Test 2: Search with userId filter
print("\n Testing search with userId filter...")
result_user_filter = await memory_search(
project_id=test_data["project_id"],
query="feature implementation",
filters={"userId": str(test_data["user1_id"])}
)
test("search with userId filter returns answer", "answer" in result_user_filter)
# Test 3: Search with date range filter
print("\n Testing search with date filters...")
tomorrow = (datetime.now() + timedelta(days=1)).strftime("%Y-%m-%d")
result_date_filter = await memory_search(
project_id=test_data["project_id"],
query="feature implementation",
filters={
"dateFrom": (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d"),
"dateTo": tomorrow
}
)
test("search with date filter returns answer", "answer" in result_date_filter)
# Test 4: Search with combined filters
print("\n Testing search with combined filters...")
result_combined = await memory_search(
project_id=test_data["project_id"],
query="feature",
filters={
"userId": str(test_data["user1_id"]),
"dateFrom": (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d")
}
)
test("search with combined filters returns answer", "answer" in result_combined)
print("\n Filter wiring tests passed!")
except Exception as e:
test("memory_search filters (API required)", False, str(e))
def cleanup_test_data(test_data):
"""Clean up test data."""
try:
delete_by_project(test_data["project_id"])
except:
pass
async def main():
print("=" * 60)
print(" SMART QUERY TEST SUITE")
print("=" * 60)
# Initialize
print("\nInitializing database and vectorstore...")
init_db()
init_vectorstore()
# Setup test data
print("Setting up test data...")
test_data = setup_test_data()
print(f" Project ID: {test_data['project_id']}")
print(f" User 1: {test_data['user1_name']} ({test_data['user1_id']})")
print(f" User 2: {test_data['user2_name']} ({test_data['user2_id']})")
try:
# Run tests
test_tool_functions(test_data)
test_extract_sources()
test_query_context()
test_user_resolution(test_data) # New: duplicate name handling tests
# API-dependent tests (will fail gracefully without API key)
await test_semantic_search_tool(test_data)
await test_memory_search_filters(test_data) # New: filter wiring tests
await test_full_smart_query(test_data)
finally:
# Cleanup
print("\nCleaning up test data...")
cleanup_test_data(test_data)
# Results
print("\n" + "=" * 60)
print(f" RESULTS: {passed} passed, {failed} failed")
print("=" * 60)
if failed > 0:
print("\nNote: Some tests require GEMINI_API_KEY to be set.")
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())