ecomcp / src /core /async_knowledge_base.py
vinhnx90's picture
feat: Implement LlamaIndex integration with new core modules for knowledge base, document loading, vector search, and comprehensive documentation and tests.
108d8af
"""
Async wrapper for knowledge base operations.
Provides non-blocking async/await interface for knowledge base operations,
suitable for async MCP server and concurrent requests.
"""
import asyncio
import logging
from typing import List, Dict, Any, Optional
from functools import partial
from concurrent.futures import ThreadPoolExecutor
from time import time
from .knowledge_base import KnowledgeBase
from .vector_search import SearchResult
from .response_models import SearchResponse, QueryResponse, SearchResultItem
logger = logging.getLogger(__name__)
class AsyncKnowledgeBase:
"""
Async wrapper for KnowledgeBase operations.
Runs blocking operations in thread pool to avoid blocking event loop.
"""
def __init__(self, kb: KnowledgeBase, max_workers: int = 4):
"""
Initialize async knowledge base
Args:
kb: Underlying KnowledgeBase instance
max_workers: Max thread pool workers
"""
self.kb = kb
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._search_cache = {} # Simple cache for frequent queries
self._cache_ttl = 300 # 5 minutes
async def search(
self,
query: str,
top_k: int = 5,
use_cache: bool = True,
) -> SearchResponse:
"""
Async search operation
Args:
query: Search query
top_k: Number of results
use_cache: Use cache if available
Returns:
SearchResponse with results
"""
start_time = time()
try:
# Check cache
cache_key = f"{query}:{top_k}"
if use_cache and cache_key in self._search_cache:
cached_response, cache_time = self._search_cache[cache_key]
if time() - cache_time < self._cache_ttl:
logger.debug(f"Cache hit for query: {query}")
return cached_response
# Run search in thread pool (non-blocking)
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
self.executor,
partial(self.kb.search, query, top_k)
)
# Format results
formatted_results = []
for i, result in enumerate(results, 1):
formatted_results.append(SearchResultItem(
rank=i,
score=round(result.score, 3),
content=result.content,
source=result.source,
metadata=result.metadata
))
response = SearchResponse(
status="success",
query=query,
result_count=len(formatted_results),
results=formatted_results,
elapsed_ms=round((time() - start_time) * 1000, 2)
)
# Cache result
if use_cache:
self._search_cache[cache_key] = (response, time())
return response
except Exception as e:
logger.error(f"Search error: {e}")
return SearchResponse(
status="error",
query=query,
result_count=0,
results=[],
elapsed_ms=round((time() - start_time) * 1000, 2),
error=str(e)
)
async def search_products(
self,
query: str,
top_k: int = 10,
) -> SearchResponse:
"""
Async product search
Args:
query: Search query
top_k: Number of results
Returns:
SearchResponse with product results
"""
start_time = time()
try:
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
self.executor,
partial(self.kb.search_products, query, top_k)
)
formatted_results = []
for i, result in enumerate(results, 1):
formatted_results.append(SearchResultItem(
rank=i,
score=round(result.score, 3),
content=result.content,
source=result.source,
metadata=result.metadata
))
return SearchResponse(
status="success",
query=query,
result_count=len(formatted_results),
results=formatted_results,
elapsed_ms=round((time() - start_time) * 1000, 2)
)
except Exception as e:
logger.error(f"Product search error: {e}")
return SearchResponse(
status="error",
query=query,
result_count=0,
results=[],
elapsed_ms=round((time() - start_time) * 1000, 2),
error=str(e)
)
async def search_documentation(
self,
query: str,
top_k: int = 5,
) -> SearchResponse:
"""
Async documentation search
Args:
query: Search query
top_k: Number of results
Returns:
SearchResponse with documentation results
"""
start_time = time()
try:
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(
self.executor,
partial(self.kb.search_documentation, query, top_k)
)
formatted_results = []
for i, result in enumerate(results, 1):
formatted_results.append(SearchResultItem(
rank=i,
score=round(result.score, 3),
content=result.content,
source=result.source,
metadata=result.metadata
))
return SearchResponse(
status="success",
query=query,
result_count=len(formatted_results),
results=formatted_results,
elapsed_ms=round((time() - start_time) * 1000, 2)
)
except Exception as e:
logger.error(f"Documentation search error: {e}")
return SearchResponse(
status="error",
query=query,
result_count=0,
results=[],
elapsed_ms=round((time() - start_time) * 1000, 2),
error=str(e)
)
async def query(
self,
question: str,
top_k: Optional[int] = None,
) -> QueryResponse:
"""
Async query with natural language
Args:
question: Natural language question
top_k: Number of sources to use
Returns:
QueryResponse with answer
"""
start_time = time()
try:
loop = asyncio.get_event_loop()
answer = await loop.run_in_executor(
self.executor,
partial(self.kb.query, question, top_k)
)
return QueryResponse(
status="success",
question=question,
answer=answer,
source_count=top_k or 5,
confidence=0.85, # Placeholder
elapsed_ms=round((time() - start_time) * 1000, 2)
)
except Exception as e:
logger.error(f"Query error: {e}")
return QueryResponse(
status="error",
question=question,
answer="",
source_count=0,
confidence=0.0,
elapsed_ms=round((time() - start_time) * 1000, 2),
error=str(e)
)
async def batch_search(
self,
queries: List[str],
top_k: int = 5,
) -> List[SearchResponse]:
"""
Async batch search multiple queries
Args:
queries: List of search queries
top_k: Number of results per query
Returns:
List of SearchResponse objects
"""
tasks = [self.search(query, top_k) for query in queries]
return await asyncio.gather(*tasks)
def clear_cache(self):
"""Clear search result cache"""
self._search_cache.clear()
logger.info("Search cache cleared")
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
return {
"cached_queries": len(self._search_cache),
"cache_ttl_seconds": self._cache_ttl,
}
async def shutdown(self):
"""Shutdown executor"""
self.executor.shutdown(wait=True)
logger.info("AsyncKnowledgeBase shut down")