Spaces:
Runtime error
Runtime error
| import modal | |
| from smolagents import Tool | |
| from modal_apps.app import app | |
| from modal_apps.task_model_retriever import TaskModelRetrieverModalApp | |
| class TaskModelRetrieverTool(Tool): | |
| name = "task_model_retriever" | |
| description = """ | |
| For a given task, retrieve the models that can perform that task. | |
| The supported tasks are: | |
| - object-detection | |
| - image-segmentation | |
| The query is a string that describes the task the model needs to perform. | |
| The output is a dictionary with the model id as the key and the labels that the model can detect as the value. | |
| """ | |
| inputs = { | |
| "task": { | |
| "type": "string", | |
| "description": "The task the model needs to perform.", | |
| }, | |
| "query": { | |
| "type": "string", | |
| "description": "The class of objects the model needs to detect.", | |
| }, | |
| } | |
| output_type = "object" | |
| def __init__(self): | |
| super().__init__() | |
| self.tasks = ["object-detection", "image-segmentation"] | |
| self.tool_class = modal.Cls.from_name(app.name, TaskModelRetrieverModalApp.__name__) | |
| def setup(self): | |
| self.tool: TaskModelRetrieverModalApp = self.tool_class() | |
| def forward(self, task: str, query: str) -> str: | |
| assert task in self.tasks, f"Task {task} is not supported, supported tasks are: {self.tasks}" | |
| assert isinstance(query, str), "Your search query must be a string" | |
| print(f"Retrieving models for task {task} with query {query}") | |
| if task == "object-detection": | |
| result = self.tool.object_detection_search.remote(query) | |
| elif task == "image-segmentation": | |
| result = self.tool.image_segmentation_search.remote(query) | |
| return result | |