Spaces:
Build error
Build error
| # Ultralytics YOLO π, AGPL-3.0 license | |
| from pathlib import Path | |
| from ultralytics.engine.model import Model | |
| from ultralytics.models import yolo | |
| from ultralytics.nn.tasks import ClassificationModel, DetectionModel, OBBModel, PoseModel, SegmentationModel, WorldModel | |
| from ultralytics.utils import yaml_load, ROOT | |
| class YOLO(Model): | |
| """YOLO (You Only Look Once) object detection model.""" | |
| def __init__(self, model="yolov8n.pt", task=None, verbose=False): | |
| """Initialize YOLO model, switching to YOLOWorld if model filename contains '-world'.""" | |
| path = Path(model) | |
| if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model | |
| new_instance = YOLOWorld(path) | |
| self.__class__ = type(new_instance) | |
| self.__dict__ = new_instance.__dict__ | |
| elif "yolov10" in path.stem: | |
| from ultralytics import YOLOv10 | |
| new_instance = YOLOv10(path) | |
| self.__class__ = type(new_instance) | |
| self.__dict__ = new_instance.__dict__ | |
| else: | |
| # Continue with default YOLO initialization | |
| super().__init__(model=model, task=task, verbose=verbose) | |
| def task_map(self): | |
| """Map head to model, trainer, validator, and predictor classes.""" | |
| return { | |
| "classify": { | |
| "model": ClassificationModel, | |
| "trainer": yolo.classify.ClassificationTrainer, | |
| "validator": yolo.classify.ClassificationValidator, | |
| "predictor": yolo.classify.ClassificationPredictor, | |
| }, | |
| "detect": { | |
| "model": DetectionModel, | |
| "trainer": yolo.detect.DetectionTrainer, | |
| "validator": yolo.detect.DetectionValidator, | |
| "predictor": yolo.detect.DetectionPredictor, | |
| }, | |
| "segment": { | |
| "model": SegmentationModel, | |
| "trainer": yolo.segment.SegmentationTrainer, | |
| "validator": yolo.segment.SegmentationValidator, | |
| "predictor": yolo.segment.SegmentationPredictor, | |
| }, | |
| "pose": { | |
| "model": PoseModel, | |
| "trainer": yolo.pose.PoseTrainer, | |
| "validator": yolo.pose.PoseValidator, | |
| "predictor": yolo.pose.PosePredictor, | |
| }, | |
| "obb": { | |
| "model": OBBModel, | |
| "trainer": yolo.obb.OBBTrainer, | |
| "validator": yolo.obb.OBBValidator, | |
| "predictor": yolo.obb.OBBPredictor, | |
| }, | |
| } | |
| class YOLOWorld(Model): | |
| """YOLO-World object detection model.""" | |
| def __init__(self, model="yolov8s-world.pt") -> None: | |
| """ | |
| Initializes the YOLOv8-World model with the given pre-trained model file. Supports *.pt and *.yaml formats. | |
| Args: | |
| model (str | Path): Path to the pre-trained model. Defaults to 'yolov8s-world.pt'. | |
| """ | |
| super().__init__(model=model, task="detect") | |
| # Assign default COCO class names when there are no custom names | |
| if not hasattr(self.model, "names"): | |
| self.model.names = yaml_load(ROOT / "cfg/datasets/coco8.yaml").get("names") | |
| def task_map(self): | |
| """Map head to model, validator, and predictor classes.""" | |
| return { | |
| "detect": { | |
| "model": WorldModel, | |
| "validator": yolo.detect.DetectionValidator, | |
| "predictor": yolo.detect.DetectionPredictor, | |
| } | |
| } | |
| def set_classes(self, classes): | |
| """ | |
| Set classes. | |
| Args: | |
| classes (List(str)): A list of categories i.e ["person"]. | |
| """ | |
| self.model.set_classes(classes) | |
| # Remove background if it's given | |
| background = " " | |
| if background in classes: | |
| classes.remove(background) | |
| self.model.names = classes | |
| # Reset method class names | |
| # self.predictor = None # reset predictor otherwise old names remain | |
| if self.predictor: | |
| self.predictor.model.names = classes | |