Spaces:
Sleeping
Sleeping
| """ | |
| Async wrapper for knowledge base operations. | |
| Provides non-blocking async/await interface for knowledge base operations, | |
| suitable for async MCP server and concurrent requests. | |
| """ | |
| import asyncio | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| from functools import partial | |
| from concurrent.futures import ThreadPoolExecutor | |
| from time import time | |
| from .knowledge_base import KnowledgeBase | |
| from .vector_search import SearchResult | |
| from .response_models import SearchResponse, QueryResponse, SearchResultItem | |
| logger = logging.getLogger(__name__) | |
| class AsyncKnowledgeBase: | |
| """ | |
| Async wrapper for KnowledgeBase operations. | |
| Runs blocking operations in thread pool to avoid blocking event loop. | |
| """ | |
| def __init__(self, kb: KnowledgeBase, max_workers: int = 4): | |
| """ | |
| Initialize async knowledge base | |
| Args: | |
| kb: Underlying KnowledgeBase instance | |
| max_workers: Max thread pool workers | |
| """ | |
| self.kb = kb | |
| self.executor = ThreadPoolExecutor(max_workers=max_workers) | |
| self._search_cache = {} # Simple cache for frequent queries | |
| self._cache_ttl = 300 # 5 minutes | |
| async def search( | |
| self, | |
| query: str, | |
| top_k: int = 5, | |
| use_cache: bool = True, | |
| ) -> SearchResponse: | |
| """ | |
| Async search operation | |
| Args: | |
| query: Search query | |
| top_k: Number of results | |
| use_cache: Use cache if available | |
| Returns: | |
| SearchResponse with results | |
| """ | |
| start_time = time() | |
| try: | |
| # Check cache | |
| cache_key = f"{query}:{top_k}" | |
| if use_cache and cache_key in self._search_cache: | |
| cached_response, cache_time = self._search_cache[cache_key] | |
| if time() - cache_time < self._cache_ttl: | |
| logger.debug(f"Cache hit for query: {query}") | |
| return cached_response | |
| # Run search in thread pool (non-blocking) | |
| loop = asyncio.get_event_loop() | |
| results = await loop.run_in_executor( | |
| self.executor, | |
| partial(self.kb.search, query, top_k) | |
| ) | |
| # Format results | |
| formatted_results = [] | |
| for i, result in enumerate(results, 1): | |
| formatted_results.append(SearchResultItem( | |
| rank=i, | |
| score=round(result.score, 3), | |
| content=result.content, | |
| source=result.source, | |
| metadata=result.metadata | |
| )) | |
| response = SearchResponse( | |
| status="success", | |
| query=query, | |
| result_count=len(formatted_results), | |
| results=formatted_results, | |
| elapsed_ms=round((time() - start_time) * 1000, 2) | |
| ) | |
| # Cache result | |
| if use_cache: | |
| self._search_cache[cache_key] = (response, time()) | |
| return response | |
| except Exception as e: | |
| logger.error(f"Search error: {e}") | |
| return SearchResponse( | |
| status="error", | |
| query=query, | |
| result_count=0, | |
| results=[], | |
| elapsed_ms=round((time() - start_time) * 1000, 2), | |
| error=str(e) | |
| ) | |
| async def search_products( | |
| self, | |
| query: str, | |
| top_k: int = 10, | |
| ) -> SearchResponse: | |
| """ | |
| Async product search | |
| Args: | |
| query: Search query | |
| top_k: Number of results | |
| Returns: | |
| SearchResponse with product results | |
| """ | |
| start_time = time() | |
| try: | |
| loop = asyncio.get_event_loop() | |
| results = await loop.run_in_executor( | |
| self.executor, | |
| partial(self.kb.search_products, query, top_k) | |
| ) | |
| formatted_results = [] | |
| for i, result in enumerate(results, 1): | |
| formatted_results.append(SearchResultItem( | |
| rank=i, | |
| score=round(result.score, 3), | |
| content=result.content, | |
| source=result.source, | |
| metadata=result.metadata | |
| )) | |
| return SearchResponse( | |
| status="success", | |
| query=query, | |
| result_count=len(formatted_results), | |
| results=formatted_results, | |
| elapsed_ms=round((time() - start_time) * 1000, 2) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Product search error: {e}") | |
| return SearchResponse( | |
| status="error", | |
| query=query, | |
| result_count=0, | |
| results=[], | |
| elapsed_ms=round((time() - start_time) * 1000, 2), | |
| error=str(e) | |
| ) | |
| async def search_documentation( | |
| self, | |
| query: str, | |
| top_k: int = 5, | |
| ) -> SearchResponse: | |
| """ | |
| Async documentation search | |
| Args: | |
| query: Search query | |
| top_k: Number of results | |
| Returns: | |
| SearchResponse with documentation results | |
| """ | |
| start_time = time() | |
| try: | |
| loop = asyncio.get_event_loop() | |
| results = await loop.run_in_executor( | |
| self.executor, | |
| partial(self.kb.search_documentation, query, top_k) | |
| ) | |
| formatted_results = [] | |
| for i, result in enumerate(results, 1): | |
| formatted_results.append(SearchResultItem( | |
| rank=i, | |
| score=round(result.score, 3), | |
| content=result.content, | |
| source=result.source, | |
| metadata=result.metadata | |
| )) | |
| return SearchResponse( | |
| status="success", | |
| query=query, | |
| result_count=len(formatted_results), | |
| results=formatted_results, | |
| elapsed_ms=round((time() - start_time) * 1000, 2) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Documentation search error: {e}") | |
| return SearchResponse( | |
| status="error", | |
| query=query, | |
| result_count=0, | |
| results=[], | |
| elapsed_ms=round((time() - start_time) * 1000, 2), | |
| error=str(e) | |
| ) | |
| async def query( | |
| self, | |
| question: str, | |
| top_k: Optional[int] = None, | |
| ) -> QueryResponse: | |
| """ | |
| Async query with natural language | |
| Args: | |
| question: Natural language question | |
| top_k: Number of sources to use | |
| Returns: | |
| QueryResponse with answer | |
| """ | |
| start_time = time() | |
| try: | |
| loop = asyncio.get_event_loop() | |
| answer = await loop.run_in_executor( | |
| self.executor, | |
| partial(self.kb.query, question, top_k) | |
| ) | |
| return QueryResponse( | |
| status="success", | |
| question=question, | |
| answer=answer, | |
| source_count=top_k or 5, | |
| confidence=0.85, # Placeholder | |
| elapsed_ms=round((time() - start_time) * 1000, 2) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Query error: {e}") | |
| return QueryResponse( | |
| status="error", | |
| question=question, | |
| answer="", | |
| source_count=0, | |
| confidence=0.0, | |
| elapsed_ms=round((time() - start_time) * 1000, 2), | |
| error=str(e) | |
| ) | |
| async def batch_search( | |
| self, | |
| queries: List[str], | |
| top_k: int = 5, | |
| ) -> List[SearchResponse]: | |
| """ | |
| Async batch search multiple queries | |
| Args: | |
| queries: List of search queries | |
| top_k: Number of results per query | |
| Returns: | |
| List of SearchResponse objects | |
| """ | |
| tasks = [self.search(query, top_k) for query in queries] | |
| return await asyncio.gather(*tasks) | |
| def clear_cache(self): | |
| """Clear search result cache""" | |
| self._search_cache.clear() | |
| logger.info("Search cache cleared") | |
| def get_cache_stats(self) -> Dict[str, Any]: | |
| """Get cache statistics""" | |
| return { | |
| "cached_queries": len(self._search_cache), | |
| "cache_ttl_seconds": self._cache_ttl, | |
| } | |
| async def shutdown(self): | |
| """Shutdown executor""" | |
| self.executor.shutdown(wait=True) | |
| logger.info("AsyncKnowledgeBase shut down") | |