Spaces:
Sleeping
Sleeping
| """Test script for Agentic RAG functionality.""" | |
| import asyncio | |
| import os | |
| import sys | |
| # Add project root to path | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| # Load environment variables | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| except ImportError: | |
| pass | |
| async def test_rag_tools(): | |
| """Test individual RAG tools.""" | |
| print("\n" + "=" * 60) | |
| print("Testing Agentic RAG Tools") | |
| print("=" * 60) | |
| from tools.rag_tools import ( | |
| RAG_TOOLS, | |
| set_rag_dependencies, | |
| list_available_categories, | |
| get_statistics, | |
| semantic_search, | |
| ) | |
| # Check available tools | |
| print("\nπ¦ Available RAG Tools:") | |
| for tool in RAG_TOOLS: | |
| print(f" - {tool.name}: {tool.description[:60]}...") | |
| # Initialize dependencies | |
| print("\nπ§ Initializing dependencies...") | |
| try: | |
| from encoders.sealion import SeaLionEncoder | |
| from recommender.vector_store import DonorVectorStore | |
| from psycopg_pool import AsyncConnectionPool | |
| # Check for required env vars | |
| sealion_endpoint = os.getenv("SEALION_ENDPOINT") | |
| db_host = os.getenv("SUPABASE_DB_HOST") | |
| if not sealion_endpoint: | |
| print(" β οΈ SEALION_ENDPOINT not set, skipping live tests") | |
| return | |
| if not db_host: | |
| print(" β οΈ Database credentials not set, skipping live tests") | |
| return | |
| # Initialize encoder | |
| encoder = SeaLionEncoder(endpoint_url=sealion_endpoint) | |
| print(f" β SeaLion encoder initialized (dim: {encoder.embedding_dimension})") | |
| # Initialize database pool | |
| db_port = os.getenv("SUPABASE_DB_PORT", "6543") | |
| db_name = os.getenv("SUPABASE_DB_NAME", "postgres") | |
| db_user = os.getenv("SUPABASE_DB_USER") | |
| db_password = os.getenv("SUPABASE_DB_PASSWORD") | |
| db_sslmode = os.getenv("SUPABASE_DB_SSLMODE", "require") | |
| conn_string = ( | |
| f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" | |
| f"?sslmode={db_sslmode}" | |
| ) | |
| pool = AsyncConnectionPool( | |
| conninfo=conn_string, | |
| max_size=5, | |
| kwargs={"autocommit": True, "prepare_threshold": None}, | |
| ) | |
| await pool.open() | |
| print(" β Database pool connected") | |
| vector_store = DonorVectorStore(pool) | |
| print(" β Vector store initialized") | |
| # Set dependencies for tools | |
| set_rag_dependencies(encoder, vector_store) | |
| print(" β RAG tools configured") | |
| # Test list_available_categories | |
| print("\nπ Testing list_available_categories...") | |
| categories_result = await list_available_categories.ainvoke({}) | |
| print(f" Result: {categories_result[:200]}...") | |
| # Test get_statistics | |
| print("\nπ Testing get_statistics...") | |
| stats_result = await get_statistics.ainvoke({}) | |
| print(f" Result: {stats_result}") | |
| # Test semantic_search (if there's data) | |
| print("\nπ Testing semantic_search...") | |
| search_result = await semantic_search.ainvoke({ | |
| "query": "education donors in Singapore", | |
| "limit": 3 | |
| }) | |
| print(f" Result: {search_result[:300]}...") | |
| # Cleanup | |
| await pool.close() | |
| print("\nβ All tool tests completed!") | |
| except Exception as e: | |
| import traceback | |
| print(f" β Error: {e}") | |
| traceback.print_exc() | |
| async def test_agentic_rag_agent(): | |
| """Test the full Agentic RAG agent.""" | |
| print("\n" + "=" * 60) | |
| print("Testing Agentic RAG Agent") | |
| print("=" * 60) | |
| try: | |
| from agents.agentic_rag import AgenticRAGAgent | |
| from encoders.sealion import SeaLionEncoder | |
| from recommender.vector_store import DonorVectorStore | |
| from psycopg_pool import AsyncConnectionPool | |
| from langchain_ollama import ChatOllama | |
| # Check for required env vars | |
| sealion_endpoint = os.getenv("SEALION_ENDPOINT") | |
| db_host = os.getenv("SUPABASE_DB_HOST") | |
| ollama_api_key = os.getenv("OLLAMA_API_KEY") | |
| if not all([sealion_endpoint, db_host]): | |
| print(" β οΈ Required environment variables not set, skipping agent test") | |
| return | |
| print("\nπ§ Initializing agent components...") | |
| # Initialize encoder | |
| encoder = SeaLionEncoder(endpoint_url=sealion_endpoint) | |
| # Initialize database | |
| db_port = os.getenv("SUPABASE_DB_PORT", "6543") | |
| db_name = os.getenv("SUPABASE_DB_NAME", "postgres") | |
| db_user = os.getenv("SUPABASE_DB_USER") | |
| db_password = os.getenv("SUPABASE_DB_PASSWORD") | |
| db_sslmode = os.getenv("SUPABASE_DB_SSLMODE", "require") | |
| conn_string = ( | |
| f"postgresql://{db_user}:{db_password}@{db_host}:{db_port}/{db_name}" | |
| f"?sslmode={db_sslmode}" | |
| ) | |
| pool = AsyncConnectionPool( | |
| conninfo=conn_string, | |
| max_size=5, | |
| kwargs={"autocommit": True, "prepare_threshold": None}, | |
| ) | |
| await pool.open() | |
| vector_store = DonorVectorStore(pool) | |
| # Initialize LLM | |
| if ollama_api_key: | |
| llm = ChatOllama( | |
| model="gpt-oss:120b", | |
| base_url="https://ollama.com", | |
| client_kwargs={ | |
| "headers": {"Authorization": f"Bearer {ollama_api_key}"} | |
| } | |
| ) | |
| else: | |
| llm = ChatOllama(model="gpt-oss:120b-cloud") | |
| print(" β All components initialized") | |
| # Create agent | |
| agent = AgenticRAGAgent(llm, encoder, vector_store) | |
| print(" β Agentic RAG agent created") | |
| # Test a query | |
| print("\nπ€ Running agent query: 'Find donors interested in education'") | |
| print("-" * 40) | |
| result = await agent.search("Find donors interested in education in Singapore") | |
| print(f"\nπ Response:\n{result['response'][:500]}...") | |
| print(f"\nπ§ Tool calls made: {len(result['tool_calls'])}") | |
| for tc in result['tool_calls']: | |
| print(f" - {tc['tool']}: {tc['args']}") | |
| print(f"\nπ Total messages: {result['message_count']}") | |
| # Cleanup | |
| await pool.close() | |
| print("\nβ Agent test completed!") | |
| except Exception as e: | |
| import traceback | |
| print(f" β Error: {e}") | |
| traceback.print_exc() | |
| async def main(): | |
| """Run all tests.""" | |
| print("\nπ Agentic RAG Test Suite") | |
| print("=" * 60) | |
| await test_rag_tools() | |
| await test_agentic_rag_agent() | |
| print("\n" + "=" * 60) | |
| print("All tests completed!") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| # Windows async fix | |
| if sys.platform == "win32": | |
| asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) | |
| asyncio.run(main()) | |