""" Async wrapper for knowledge base operations. Provides non-blocking async/await interface for knowledge base operations, suitable for async MCP server and concurrent requests. """ import asyncio import logging from typing import List, Dict, Any, Optional from functools import partial from concurrent.futures import ThreadPoolExecutor from time import time from .knowledge_base import KnowledgeBase from .vector_search import SearchResult from .response_models import SearchResponse, QueryResponse, SearchResultItem logger = logging.getLogger(__name__) class AsyncKnowledgeBase: """ Async wrapper for KnowledgeBase operations. Runs blocking operations in thread pool to avoid blocking event loop. """ def __init__(self, kb: KnowledgeBase, max_workers: int = 4): """ Initialize async knowledge base Args: kb: Underlying KnowledgeBase instance max_workers: Max thread pool workers """ self.kb = kb self.executor = ThreadPoolExecutor(max_workers=max_workers) self._search_cache = {} # Simple cache for frequent queries self._cache_ttl = 300 # 5 minutes async def search( self, query: str, top_k: int = 5, use_cache: bool = True, ) -> SearchResponse: """ Async search operation Args: query: Search query top_k: Number of results use_cache: Use cache if available Returns: SearchResponse with results """ start_time = time() try: # Check cache cache_key = f"{query}:{top_k}" if use_cache and cache_key in self._search_cache: cached_response, cache_time = self._search_cache[cache_key] if time() - cache_time < self._cache_ttl: logger.debug(f"Cache hit for query: {query}") return cached_response # Run search in thread pool (non-blocking) loop = asyncio.get_event_loop() results = await loop.run_in_executor( self.executor, partial(self.kb.search, query, top_k) ) # Format results formatted_results = [] for i, result in enumerate(results, 1): formatted_results.append(SearchResultItem( rank=i, score=round(result.score, 3), content=result.content, source=result.source, metadata=result.metadata )) response = SearchResponse( status="success", query=query, result_count=len(formatted_results), results=formatted_results, elapsed_ms=round((time() - start_time) * 1000, 2) ) # Cache result if use_cache: self._search_cache[cache_key] = (response, time()) return response except Exception as e: logger.error(f"Search error: {e}") return SearchResponse( status="error", query=query, result_count=0, results=[], elapsed_ms=round((time() - start_time) * 1000, 2), error=str(e) ) async def search_products( self, query: str, top_k: int = 10, ) -> SearchResponse: """ Async product search Args: query: Search query top_k: Number of results Returns: SearchResponse with product results """ start_time = time() try: loop = asyncio.get_event_loop() results = await loop.run_in_executor( self.executor, partial(self.kb.search_products, query, top_k) ) formatted_results = [] for i, result in enumerate(results, 1): formatted_results.append(SearchResultItem( rank=i, score=round(result.score, 3), content=result.content, source=result.source, metadata=result.metadata )) return SearchResponse( status="success", query=query, result_count=len(formatted_results), results=formatted_results, elapsed_ms=round((time() - start_time) * 1000, 2) ) except Exception as e: logger.error(f"Product search error: {e}") return SearchResponse( status="error", query=query, result_count=0, results=[], elapsed_ms=round((time() - start_time) * 1000, 2), error=str(e) ) async def search_documentation( self, query: str, top_k: int = 5, ) -> SearchResponse: """ Async documentation search Args: query: Search query top_k: Number of results Returns: SearchResponse with documentation results """ start_time = time() try: loop = asyncio.get_event_loop() results = await loop.run_in_executor( self.executor, partial(self.kb.search_documentation, query, top_k) ) formatted_results = [] for i, result in enumerate(results, 1): formatted_results.append(SearchResultItem( rank=i, score=round(result.score, 3), content=result.content, source=result.source, metadata=result.metadata )) return SearchResponse( status="success", query=query, result_count=len(formatted_results), results=formatted_results, elapsed_ms=round((time() - start_time) * 1000, 2) ) except Exception as e: logger.error(f"Documentation search error: {e}") return SearchResponse( status="error", query=query, result_count=0, results=[], elapsed_ms=round((time() - start_time) * 1000, 2), error=str(e) ) async def query( self, question: str, top_k: Optional[int] = None, ) -> QueryResponse: """ Async query with natural language Args: question: Natural language question top_k: Number of sources to use Returns: QueryResponse with answer """ start_time = time() try: loop = asyncio.get_event_loop() answer = await loop.run_in_executor( self.executor, partial(self.kb.query, question, top_k) ) return QueryResponse( status="success", question=question, answer=answer, source_count=top_k or 5, confidence=0.85, # Placeholder elapsed_ms=round((time() - start_time) * 1000, 2) ) except Exception as e: logger.error(f"Query error: {e}") return QueryResponse( status="error", question=question, answer="", source_count=0, confidence=0.0, elapsed_ms=round((time() - start_time) * 1000, 2), error=str(e) ) async def batch_search( self, queries: List[str], top_k: int = 5, ) -> List[SearchResponse]: """ Async batch search multiple queries Args: queries: List of search queries top_k: Number of results per query Returns: List of SearchResponse objects """ tasks = [self.search(query, top_k) for query in queries] return await asyncio.gather(*tasks) def clear_cache(self): """Clear search result cache""" self._search_cache.clear() logger.info("Search cache cleared") def get_cache_stats(self) -> Dict[str, Any]: """Get cache statistics""" return { "cached_queries": len(self._search_cache), "cache_ttl_seconds": self._cache_ttl, } async def shutdown(self): """Shutdown executor""" self.executor.shutdown(wait=True) logger.info("AsyncKnowledgeBase shut down")