Spaces:
Runtime error
Runtime error
| from langchain_community.vectorstores import FAISS | |
| from smolagents import Tool | |
| from rag.settings import get_vector_store | |
| class ObjectDetectionModelRetrieverTool(Tool): | |
| name = "object_detection_model_retriever" | |
| description = """ | |
| For a given class of objects, retrieve the models that can detect that class. | |
| The query is a string that describes the class of objects the model needs to detect. | |
| The output is a dictionary with the model id as the key and the labels that the model can detect as the value. | |
| """ | |
| inputs = { | |
| "query": { | |
| "type": "object", | |
| "description": "The class of objects the model needs to detect.", | |
| } | |
| } | |
| output_type = "object" | |
| def __init__(self): | |
| super().__init__() | |
| def setup(self): | |
| self.vector_store = get_vector_store() | |
| print("Loaded vector store") | |
| def forward(self, query: str) -> str: | |
| assert isinstance(query, str), "Your search query must be a string" | |
| docs = self.vector_store.similarity_search(query, k=7) | |
| model_ids = [doc.metadata["model_id"] for doc in docs] | |
| model_labels = [doc.metadata["model_labels"] for doc in docs] | |
| models_dict = { | |
| model_id: model_labels | |
| for model_id, model_labels in zip(model_ids, model_labels) | |
| } | |
| return models_dict | |