|
|
from langchain_community.retrievers import BM25Retriever |
|
|
from langchain.retrievers import EnsembleRetriever |
|
|
from langchain.vectorstores import FAISS |
|
|
from langchain.docstore.document import Document |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from sentence_transformers.util import cos_sim |
|
|
from smolagents import Tool |
|
|
import numpy as np |
|
|
import datasets |
|
|
|
|
|
|
|
|
class HybridRetriever: |
|
|
def __init__(self, docs, mode="rerank", k=5): |
|
|
""" |
|
|
mode: "ensemble" or "rerank" |
|
|
k: number of top docs to return |
|
|
""" |
|
|
self.docs = docs |
|
|
self.mode = mode |
|
|
self.k = k |
|
|
self.embedding_model = HuggingFaceEmbeddings( |
|
|
model_name="sentence-transformers/all-MiniLM-L6-v2" |
|
|
) |
|
|
|
|
|
|
|
|
self.bm25 = BM25Retriever.from_documents(docs) |
|
|
self.bm25.k = 20 |
|
|
|
|
|
|
|
|
self.faiss = FAISS.from_documents(docs, self.embedding_model) |
|
|
self.faiss_retriever = self.faiss.as_retriever(search_kwargs={"k": 20}) |
|
|
|
|
|
|
|
|
self.doc_embeddings = { |
|
|
doc.page_content: self.embedding_model.embed_query(doc.page_content) |
|
|
for doc in docs |
|
|
} |
|
|
|
|
|
|
|
|
if mode == "ensemble": |
|
|
self.retriever = EnsembleRetriever( |
|
|
retrievers=[self.bm25, self.faiss_retriever], |
|
|
weights=[0.5, 0.5] |
|
|
) |
|
|
|
|
|
def get_relevant_documents(self, query: str): |
|
|
if self.mode == "ensemble": |
|
|
return self.retriever.get_relevant_documents(query)[:self.k] |
|
|
|
|
|
elif self.mode == "rerank": |
|
|
bm25_candidates = self.bm25.get_relevant_documents(query) |
|
|
query_embedding = self.embedding_model.embed_query(query) |
|
|
|
|
|
scores = [] |
|
|
for doc in bm25_candidates: |
|
|
doc_vec = self.doc_embeddings.get(doc.page_content) |
|
|
|
|
|
|
|
|
if doc_vec is not None: |
|
|
sim = np.dot(query_embedding, doc_vec) / ( |
|
|
np.linalg.norm(query_embedding) * np.linalg.norm(doc_vec) |
|
|
) |
|
|
scores.append((sim, doc)) |
|
|
|
|
|
top_docs = sorted(scores, key=lambda x: x[0], reverse=True)[:self.k] |
|
|
return [doc for _, doc in top_docs] |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported mode: {self.mode}") |
|
|
|
|
|
|
|
|
class GuestInfoHybridTool(Tool): |
|
|
name = "guest_info_retriever" |
|
|
description = ( |
|
|
"Retrieves detailed information about gala guests based on their name or relation " |
|
|
"using a hybrid of BM25 and embeddings. Supports ensemble or reranking." |
|
|
) |
|
|
inputs = { |
|
|
"query": { |
|
|
"type": "string", |
|
|
"description": "The name or relation of the guest you want information about." |
|
|
} |
|
|
} |
|
|
output_type = "string" |
|
|
|
|
|
def __init__(self, docs, mode="rerank"): |
|
|
self.is_initialized = False |
|
|
self.retriever = HybridRetriever(docs, mode=mode) |
|
|
|
|
|
def forward(self, query: str): |
|
|
results = self.retriever.get_relevant_documents(query) |
|
|
if results: |
|
|
return "\n\n".join([doc.page_content for doc in results]) |
|
|
else: |
|
|
return "No matching guest information found." |
|
|
|
|
|
|
|
|
def load_guest_dataset(): |
|
|
guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train") |
|
|
docs = [ |
|
|
Document( |
|
|
page_content="\n".join([ |
|
|
f"Name: {guest['name']}", |
|
|
f"Relation: {guest['relation']}", |
|
|
f"Description: {guest['description']}", |
|
|
f"Email: {guest['email']}" |
|
|
]), |
|
|
metadata={"name": guest["name"]} |
|
|
) |
|
|
for guest in guest_dataset |
|
|
] |
|
|
return GuestInfoHybridTool(docs, mode="rerank") |
|
|
|
|
|
|