Basti-1995's picture
init flag
03d0bbd
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers.util import cos_sim
from smolagents import Tool
import numpy as np
import datasets
class HybridRetriever:
def __init__(self, docs, mode="rerank", k=5):
"""
mode: "ensemble" or "rerank"
k: number of top docs to return
"""
self.docs = docs
self.mode = mode
self.k = k
self.embedding_model = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# Initialize BM25 retriever
self.bm25 = BM25Retriever.from_documents(docs)
self.bm25.k = 20
# Initialize FAISS retriever
self.faiss = FAISS.from_documents(docs, self.embedding_model)
self.faiss_retriever = self.faiss.as_retriever(search_kwargs={"k": 20})
# For reranker mode, cache doc embeddings
self.doc_embeddings = {
doc.page_content: self.embedding_model.embed_query(doc.page_content)
for doc in docs
}
# Ensemble retriever setup
if mode == "ensemble":
self.retriever = EnsembleRetriever(
retrievers=[self.bm25, self.faiss_retriever],
weights=[0.5, 0.5]
)
def get_relevant_documents(self, query: str):
if self.mode == "ensemble":
return self.retriever.get_relevant_documents(query)[:self.k]
elif self.mode == "rerank":
bm25_candidates = self.bm25.get_relevant_documents(query)
query_embedding = self.embedding_model.embed_query(query)
scores = []
for doc in bm25_candidates:
doc_vec = self.doc_embeddings.get(doc.page_content)
# similarity calculation
if doc_vec is not None:
sim = np.dot(query_embedding, doc_vec) / (
np.linalg.norm(query_embedding) * np.linalg.norm(doc_vec)
)
scores.append((sim, doc))
top_docs = sorted(scores, key=lambda x: x[0], reverse=True)[:self.k]
return [doc for _, doc in top_docs]
else:
raise ValueError(f"Unsupported mode: {self.mode}")
class GuestInfoHybridTool(Tool):
name = "guest_info_retriever"
description = (
"Retrieves detailed information about gala guests based on their name or relation "
"using a hybrid of BM25 and embeddings. Supports ensemble or reranking."
)
inputs = {
"query": {
"type": "string",
"description": "The name or relation of the guest you want information about."
}
}
output_type = "string"
def __init__(self, docs, mode="rerank"):
self.is_initialized = False # Flag to check if the tool is initialized
self.retriever = HybridRetriever(docs, mode=mode)
def forward(self, query: str):
results = self.retriever.get_relevant_documents(query)
if results:
return "\n\n".join([doc.page_content for doc in results])
else:
return "No matching guest information found."
def load_guest_dataset():
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
docs = [
Document(
page_content="\n".join([
f"Name: {guest['name']}",
f"Relation: {guest['relation']}",
f"Description: {guest['description']}",
f"Email: {guest['email']}"
]),
metadata={"name": guest["name"]}
)
for guest in guest_dataset
]
return GuestInfoHybridTool(docs, mode="rerank")