File size: 3,903 Bytes
b7316d5
 
d565e36
38812af
d565e36
 
 
 
38812af
 
 
d565e36
 
 
 
 
 
 
 
 
03d0bbd
 
 
d565e36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d0bbd
 
d565e36
 
 
 
 
 
 
 
 
 
 
 
 
 
38812af
d565e36
 
 
 
38812af
 
 
 
 
 
 
 
d565e36
03d0bbd
d565e36
38812af
 
 
 
d565e36
38812af
 
 
 
 
8507438
38812af
 
 
 
 
 
 
 
 
 
 
 
d565e36
38812af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers.util import cos_sim
from smolagents import Tool
import numpy as np
import datasets


class HybridRetriever:
    def __init__(self, docs, mode="rerank", k=5):
        """
        mode: "ensemble" or "rerank"
        k: number of top docs to return
        """
        self.docs = docs
        self.mode = mode
        self.k = k
        self.embedding_model = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2"
        )

        # Initialize BM25 retriever
        self.bm25 = BM25Retriever.from_documents(docs)
        self.bm25.k = 20

        # Initialize FAISS retriever
        self.faiss = FAISS.from_documents(docs, self.embedding_model)
        self.faiss_retriever = self.faiss.as_retriever(search_kwargs={"k": 20})

        # For reranker mode, cache doc embeddings
        self.doc_embeddings = {
            doc.page_content: self.embedding_model.embed_query(doc.page_content)
            for doc in docs
        }

        # Ensemble retriever setup
        if mode == "ensemble":
            self.retriever = EnsembleRetriever(
                retrievers=[self.bm25, self.faiss_retriever],
                weights=[0.5, 0.5]
            )

    def get_relevant_documents(self, query: str):
        if self.mode == "ensemble":
            return self.retriever.get_relevant_documents(query)[:self.k]

        elif self.mode == "rerank":
            bm25_candidates = self.bm25.get_relevant_documents(query)
            query_embedding = self.embedding_model.embed_query(query)

            scores = []
            for doc in bm25_candidates:
                doc_vec = self.doc_embeddings.get(doc.page_content)

                # similarity calculation
                if doc_vec is not None:
                    sim = np.dot(query_embedding, doc_vec) / (
                        np.linalg.norm(query_embedding) * np.linalg.norm(doc_vec)
                    )
                    scores.append((sim, doc))

            top_docs = sorted(scores, key=lambda x: x[0], reverse=True)[:self.k]
            return [doc for _, doc in top_docs]

        else:
            raise ValueError(f"Unsupported mode: {self.mode}")


class GuestInfoHybridTool(Tool):
    name = "guest_info_retriever"
    description = (
        "Retrieves detailed information about gala guests based on their name or relation "
        "using a hybrid of BM25 and embeddings. Supports ensemble or reranking."
    )
    inputs = {
        "query": {
            "type": "string",
            "description": "The name or relation of the guest you want information about."
        }
    }
    output_type = "string"

    def __init__(self, docs, mode="rerank"):
        self.is_initialized = False # Flag to check if the tool is initialized
        self.retriever = HybridRetriever(docs, mode=mode)

    def forward(self, query: str):
        results = self.retriever.get_relevant_documents(query)
        if results:
            return "\n\n".join([doc.page_content for doc in results])
        else:
            return "No matching guest information found."


def load_guest_dataset():
    guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
    docs = [
        Document(
            page_content="\n".join([
                f"Name: {guest['name']}",
                f"Relation: {guest['relation']}",
                f"Description: {guest['description']}",
                f"Email: {guest['email']}"
            ]),
            metadata={"name": guest["name"]}
        )
        for guest in guest_dataset
    ]
    return GuestInfoHybridTool(docs, mode="rerank")