Spaces:
Configuration error
Configuration error
| import logging | |
| from langchain_core.tools import StructuredTool | |
| from pydantic import BaseModel, Field | |
| from datasets import load_dataset | |
| from rank_bm25 import BM25Okapi | |
| logger = logging.getLogger(__name__) | |
| class GuestInfoInput(BaseModel): | |
| query: str = Field(description="Query about guest information") | |
| async def guest_info_func(query: str) -> str: | |
| """ | |
| Retrieve guest information based on a query. | |
| Args: | |
| query (str): Query about guest information. | |
| Returns: | |
| str: Guest information or error message. | |
| """ | |
| try: | |
| logger.info(f"Retrieving guest info for query: {query}") | |
| dataset = load_dataset("agents-course/unit3-invitees", split="train") | |
| logger.info(f"Loaded {len(dataset)} guests from Hugging Face dataset") | |
| documents = [f"{row['name']} {row['relation']}" for row in dataset] | |
| tokenized_docs = [doc.lower().split() for doc in documents] | |
| bm25 = BM25Okapi(tokenized_docs) | |
| tokenized_query = query.lower().split() | |
| scores = bm25.get_scores(tokenized_query) | |
| best_idx = scores.argmax() | |
| if scores[best_idx] > 0: | |
| return f"Guest: {dataset[best_idx]['name']}, Relation: {dataset[best_idx]['relation']}" | |
| return "No matching guest found" | |
| except Exception as e: | |
| logger.error(f"Error retrieving guest info for query '{query}': {e}") | |
| return f"Error: {str(e)}" | |
| guest_info_retriever_tool = StructuredTool.from_function( | |
| func=guest_info_func, | |
| name="guest_info_retriever_tool", | |
| args_schema=GuestInfoInput, | |
| coroutine=guest_info_func | |
| ) |