File size: 1,333 Bytes
98aa770
277590a
 
 
 
98aa770
 
 
 
 
 
 
277590a
98aa770
 
 
277590a
 
 
 
 
 
98aa770
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Literal
from functools import partial

from .utils import load_documents, agentbase_indexing


class BaseRetriever(ABC):
    """
    Abstract base class for AgentBase retrievers.
    """
    
    def __init__(self, db_path: str, index_config: Literal["naive", "v1"]):
        self.db_path = db_path
        self.agent_ids = []
        self.documents = []
        self.index_config = index_config
        self.indexing_func = {
            "naive": partial(load_documents, self.db_path),
            "v1": partial(agentbase_indexing, self.db_path),
        }

    @abstractmethod
    def build_index(self) -> None:
        """Build retrieval index from database."""
        pass
    
    @abstractmethod
    def retrieve(self, query: str, top_k: int = 10) -> List[Tuple[str, float]]:
        """
        Retrieve top-k agents for a single query.
        
        Returns:
            List of (agent_id, score) tuples
        """
        pass
    
    def batch_retrieve(self, queries: Dict[str, str], top_k: int = 10) -> Dict[str, List[Tuple[str, float]]]:
        """Retrieve for multiple queries (for evaluation)."""
        results = {}
        for qid, query in queries.items():
            results[qid] = self.retrieve(query, top_k)
        return results