| import datasets |
| from langchain_core.documents import Document |
|
|
| guest_dataset = datasets.load_dataset('agents-course/unit3-invitees', split='train') |
|
|
| docs = [ |
| Document( |
| page_content="\n".join([ |
| f"Name: {guest['name']}", |
| f"Relation: {guest['relation']}", |
| f"Description: {guest['description']}", |
| f"Email: {guest['email']}" |
| ]), |
| metadata={'name': guest['name']} |
| ) |
| for guest in guest_dataset |
| ] |
|
|
| from smolagents import Tool |
| from langchain_community.retrievers import BM25Retriever |
|
|
| class GuestRetrieverTool(Tool): |
| name = "guest_info_retriever" |
| description = "Retrieves information about gala guests by name or relation." |
| inputs = { |
| "query": { |
| "type": "string", |
| "description": "The name or relation of the guest." |
| } |
| } |
| output_type = "string" |
|
|
| def __init__(self, docs): |
| self.initialized = True |
| self.retriever = BM25Retriever.from_documents(docs) |
| |
| self.retriever.k = 3 |
|
|
| def forward(self, query: str) -> str: |
| |
| results = self.retriever.get_relevant_documents(query) |
| if results: |
| return "\n\n".join(doc.page_content for doc in results[:3]) |
| return "No matching guest info found..." |
|
|
| guest_info_tool = GuestRetrieverTool(docs) |
|
|
| |
| if __name__ == "__main__": |
| from smolagents import CodeAgent, InferenceClientModel |
| model = InferenceClientModel(model_id="HuggingFaceH4/zephyr-7b-beta") |
| alfred = CodeAgent(tools=[guest_info_tool], model=model) |
| print(alfred.run("Tell me about guest named 'Lady Ada Lovelace'")) |
|
|