from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever from langchain.vectorstores import FAISS from langchain.docstore.document import Document from langchain_community.embeddings import HuggingFaceEmbeddings from sentence_transformers.util import cos_sim from smolagents import Tool import numpy as np import datasets class HybridRetriever: def __init__(self, docs, mode="rerank", k=5): """ mode: "ensemble" or "rerank" k: number of top docs to return """ self.docs = docs self.mode = mode self.k = k self.embedding_model = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2" ) # Initialize BM25 retriever self.bm25 = BM25Retriever.from_documents(docs) self.bm25.k = 20 # Initialize FAISS retriever self.faiss = FAISS.from_documents(docs, self.embedding_model) self.faiss_retriever = self.faiss.as_retriever(search_kwargs={"k": 20}) # For reranker mode, cache doc embeddings self.doc_embeddings = { doc.page_content: self.embedding_model.embed_query(doc.page_content) for doc in docs } # Ensemble retriever setup if mode == "ensemble": self.retriever = EnsembleRetriever( retrievers=[self.bm25, self.faiss_retriever], weights=[0.5, 0.5] ) def get_relevant_documents(self, query: str): if self.mode == "ensemble": return self.retriever.get_relevant_documents(query)[:self.k] elif self.mode == "rerank": bm25_candidates = self.bm25.get_relevant_documents(query) query_embedding = self.embedding_model.embed_query(query) scores = [] for doc in bm25_candidates: doc_vec = self.doc_embeddings.get(doc.page_content) # similarity calculation if doc_vec is not None: sim = np.dot(query_embedding, doc_vec) / ( np.linalg.norm(query_embedding) * np.linalg.norm(doc_vec) ) scores.append((sim, doc)) top_docs = sorted(scores, key=lambda x: x[0], reverse=True)[:self.k] return [doc for _, doc in top_docs] else: raise ValueError(f"Unsupported mode: {self.mode}") class GuestInfoHybridTool(Tool): name = "guest_info_retriever" description = ( "Retrieves detailed information about gala guests based on their name or relation " "using a hybrid of BM25 and embeddings. Supports ensemble or reranking." ) inputs = { "query": { "type": "string", "description": "The name or relation of the guest you want information about." } } output_type = "string" def __init__(self, docs, mode="rerank"): self.is_initialized = False # Flag to check if the tool is initialized self.retriever = HybridRetriever(docs, mode=mode) def forward(self, query: str): results = self.retriever.get_relevant_documents(query) if results: return "\n\n".join([doc.page_content for doc in results]) else: return "No matching guest information found." def load_guest_dataset(): guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") docs = [ Document( page_content="\n".join([ f"Name: {guest['name']}", f"Relation: {guest['relation']}", f"Description: {guest['description']}", f"Email: {guest['email']}" ]), metadata={"name": guest["name"]} ) for guest in guest_dataset ] return GuestInfoHybridTool(docs, mode="rerank")