Spaces:
Runtime error
Runtime error
| import modal | |
| from transformers import AutoModelForObjectDetection, AutoImageProcessor | |
| import torch | |
| from smolagents import Tool | |
| from .app import app | |
| from .image import image | |
| class RemoteObjectDetectionModalApp: | |
| model_name: str = modal.parameter() | |
| def forward(self, image): | |
| self.model = AutoModelForObjectDetection.from_pretrained(self.model_name) | |
| self.processor = AutoImageProcessor.from_pretrained(self.model_name) | |
| self.model.eval() | |
| # Preprocess image | |
| inputs = self.processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| target_sizes = torch.tensor([image.size[::-1]]) # (height, width) | |
| results = self.processor.post_process_object_detection( | |
| outputs, target_sizes=target_sizes, threshold=0.5 | |
| )[0] | |
| boxes = [] | |
| for score, label, box in zip( | |
| results["scores"], results["labels"], results["boxes"] | |
| ): | |
| boxes.append( | |
| { | |
| "box": box.tolist(), # [xmin, ymin, xmax, ymax] | |
| "score": score.item(), | |
| "label": self.model.config.id2label[label.item()], | |
| } | |
| ) | |
| return boxes | |
| class RemoteObjectDetectionTool(Tool): | |
| name = "object_detection" | |
| description = """ | |
| Given an image, detect objects and return bounding boxes. | |
| The image is a PIL image. | |
| The output is a list of dictionaries containing the bounding boxes with the following keys: | |
| - box: a list of 4 numbers [xmin, ymin, xmax, ymax] | |
| - score: a number between 0 and 1 | |
| - label: a string | |
| The bounding boxes are in the format of [xmin, ymin, xmax, ymax]. | |
| You need to provide the model name to use for object detection. | |
| The tool returns a list of bounding boxes for all the objects in the image. | |
| """ | |
| inputs = { | |
| "image": { | |
| "type": "image", | |
| "description": "The image to detect objects in", | |
| }, | |
| "model_name": { | |
| "type": "string", | |
| "description": "The name of the model to use for object detection", | |
| }, | |
| } | |
| output_type = "object" | |
| def __init__(self): | |
| super().__init__() | |
| self.tool_class = modal.Cls.from_name( | |
| app.name, RemoteObjectDetectionModalApp.__name__ | |
| ) | |
| def forward( | |
| self, | |
| image, | |
| model_name: str, | |
| ): | |
| self.tool = self.tool_class(model_name=model_name) | |
| bboxes = self.tool.forward.remote(image) | |
| for bbox in bboxes: | |
| print( | |
| f"Found {bbox['label']} with score: {bbox['score']} at box: {bbox['box']}" | |
| ) | |
| return bboxes | |