yikesongcai commited on
Commit
3c56989
·
verified ·
1 Parent(s): 9cefd5f

Upload model.py

Browse files
Files changed (1) hide show
  1. ultralytics/engine/model.py +465 -465
ultralytics/engine/model.py CHANGED
@@ -1,465 +1,465 @@
1
- # Ultralytics YOLO 🚀, AGPL-3.0 license
2
-
3
- import inspect
4
- import sys
5
- from pathlib import Path
6
- from typing import Union
7
-
8
- from ultralytics.cfg import get_cfg
9
- from ultralytics.engine.exporter import Exporter
10
- from ultralytics.hub.utils import HUB_WEB_ROOT
11
- from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
12
- from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
13
- is_git_dir, yaml_load)
14
- from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
15
- from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
16
- from ultralytics.utils.torch_utils import smart_inference_mode
17
-
18
-
19
- class Model:
20
- """
21
- A base model class to unify apis for all the models.
22
-
23
- Args:
24
- model (str, Path): Path to the model file to load or create.
25
- task (Any, optional): Task type for the YOLO model. Defaults to None.
26
-
27
- Attributes:
28
- predictor (Any): The predictor object.
29
- model (Any): The model object.
30
- trainer (Any): The trainer object.
31
- task (str): The type of model task.
32
- ckpt (Any): The checkpoint object if the model loaded from *.pt file.
33
- cfg (str): The model configuration if loaded from *.yaml file.
34
- ckpt_path (str): The checkpoint file path.
35
- overrides (dict): Overrides for the trainer object.
36
- metrics (Any): The data for metrics.
37
-
38
- Methods:
39
- __call__(source=None, stream=False, **kwargs):
40
- Alias for the predict method.
41
- _new(cfg:str, verbose:bool=True) -> None:
42
- Initializes a new model and infers the task type from the model definitions.
43
- _load(weights:str, task:str='') -> None:
44
- Initializes a new model and infers the task type from the model head.
45
- _check_is_pytorch_model() -> None:
46
- Raises TypeError if the model is not a PyTorch model.
47
- reset() -> None:
48
- Resets the model modules.
49
- info(verbose:bool=False) -> None:
50
- Logs the model info.
51
- fuse() -> None:
52
- Fuses the model for faster inference.
53
- predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]:
54
- Performs prediction using the YOLO model.
55
-
56
- Returns:
57
- list(ultralytics.engine.results.Results): The prediction results.
58
- """
59
-
60
- def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
61
- """
62
- Initializes the YOLO model.
63
-
64
- Args:
65
- model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
66
- task (Any, optional): Task type for the YOLO model. Defaults to None.
67
- """
68
- self.callbacks = callbacks.get_default_callbacks()
69
- self.predictor = None # reuse predictor
70
- self.model = None # model object
71
- self.trainer = None # trainer object
72
- self.ckpt = None # if loaded from *.pt
73
- self.cfg = None # if loaded from *.yaml
74
- self.ckpt_path = None
75
- self.overrides = {} # overrides for trainer object
76
- self.metrics = None # validation/training metrics
77
- self.session = None # HUB session
78
- self.task = task # task type
79
- model = str(model).strip() # strip spaces
80
-
81
- # Check if Ultralytics HUB model from https://hub.ultralytics.com
82
- if self.is_hub_model(model):
83
- from ultralytics.hub.session import HUBTrainingSession
84
- self.session = HUBTrainingSession(model)
85
- model = self.session.model_file
86
-
87
- # Load or create new YOLO model
88
- suffix = Path(model).suffix
89
- if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
90
- model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
91
- if suffix in ('.yaml', '.yml'):
92
- self._new(model, task)
93
- else:
94
- self._load(model, task)
95
-
96
- def __call__(self, source=None, stream=False, **kwargs):
97
- """Calls the 'predict' function with given arguments to perform object detection."""
98
- return self.predict(source, stream, **kwargs)
99
-
100
- @staticmethod
101
- def is_hub_model(model):
102
- """Check if the provided model is a HUB model."""
103
- return any((
104
- model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
105
- [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
106
- len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
107
-
108
- def _new(self, cfg: str, task=None, model=None, verbose=True):
109
- """
110
- Initializes a new model and infers the task type from the model definitions.
111
-
112
- Args:
113
- cfg (str): model configuration file
114
- task (str | None): model task
115
- model (BaseModel): Customized model.
116
- verbose (bool): display model info on load
117
- """
118
- cfg_dict = yaml_model_load(cfg)
119
- self.cfg = cfg
120
- self.task = task or guess_model_task(cfg_dict)
121
- model = model or self.smart_load('model')
122
- self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
123
- self.overrides['model'] = self.cfg
124
-
125
- # Below added to allow export from yamls
126
- args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
127
- self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
128
- self.model.task = self.task
129
-
130
- def _load(self, weights: str, task=None):
131
- """
132
- Initializes a new model and infers the task type from the model head.
133
-
134
- Args:
135
- weights (str): model checkpoint to be loaded
136
- task (str | None): model task
137
- """
138
- suffix = Path(weights).suffix
139
- if suffix == '.pt':
140
- self.model, self.ckpt = attempt_load_one_weight(weights)
141
- self.task = self.model.args['task']
142
- self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
143
- self.ckpt_path = self.model.pt_path
144
- else:
145
- weights = check_file(weights)
146
- self.model, self.ckpt = weights, None
147
- self.task = task or guess_model_task(weights)
148
- self.ckpt_path = weights
149
- self.overrides['model'] = weights
150
- self.overrides['task'] = self.task
151
-
152
- def _check_is_pytorch_model(self):
153
- """
154
- Raises TypeError is model is not a PyTorch model
155
- """
156
- pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
157
- pt_module = isinstance(self.model, nn.Module)
158
- if not (pt_module or pt_str):
159
- raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
160
- f'PyTorch models can be used to train, val, predict and export, i.e. '
161
- f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
162
- f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
163
-
164
- @smart_inference_mode()
165
- def reset_weights(self):
166
- """
167
- Resets the model modules parameters to randomly initialized values, losing all training information.
168
- """
169
- self._check_is_pytorch_model()
170
- for m in self.model.modules():
171
- if hasattr(m, 'reset_parameters'):
172
- m.reset_parameters()
173
- for p in self.model.parameters():
174
- p.requires_grad = True
175
- return self
176
-
177
- @smart_inference_mode()
178
- def load(self, weights='yolov8n.pt'):
179
- """
180
- Transfers parameters with matching names and shapes from 'weights' to model.
181
- """
182
- self._check_is_pytorch_model()
183
- if isinstance(weights, (str, Path)):
184
- weights, self.ckpt = attempt_load_one_weight(weights)
185
- self.model.load(weights)
186
- return self
187
-
188
- def info(self, detailed=False, verbose=True):
189
- """
190
- Logs model info.
191
-
192
- Args:
193
- detailed (bool): Show detailed information about model.
194
- verbose (bool): Controls verbosity.
195
- """
196
- self._check_is_pytorch_model()
197
- return self.model.info(detailed=detailed, verbose=verbose)
198
-
199
- def fuse(self):
200
- """Fuse PyTorch Conv2d and BatchNorm2d layers."""
201
- self._check_is_pytorch_model()
202
- self.model.fuse()
203
-
204
- @smart_inference_mode()
205
- def predict(self, source=None, stream=False, predictor=None, **kwargs):
206
- """
207
- Perform prediction using the YOLO model.
208
-
209
- Args:
210
- source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
211
- Accepts all source types accepted by the YOLO model.
212
- stream (bool): Whether to stream the predictions or not. Defaults to False.
213
- predictor (BasePredictor): Customized predictor.
214
- **kwargs : Additional keyword arguments passed to the predictor.
215
- Check the 'configuration' section in the documentation for all available options.
216
-
217
- Returns:
218
- (List[ultralytics.engine.results.Results]): The prediction results.
219
- """
220
- if source is None:
221
- source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
222
- LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
223
- is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
224
- x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
225
- # Check prompts for SAM/FastSAM
226
- prompts = kwargs.pop('prompts', None)
227
- overrides = self.overrides.copy()
228
- overrides['conf'] = 0.25
229
- overrides.update(kwargs) # prefer kwargs
230
- overrides['mode'] = kwargs.get('mode', 'predict')
231
- assert overrides['mode'] in ['track', 'predict']
232
- if not is_cli:
233
- overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
234
- if not self.predictor:
235
- self.task = overrides.get('task') or self.task
236
- predictor = predictor or self.smart_load('predictor')
237
- self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
238
- self.predictor.setup_model(model=self.model, verbose=is_cli)
239
- else: # only update args if predictor is already setup
240
- self.predictor.args = get_cfg(self.predictor.args, overrides)
241
- if 'project' in overrides or 'name' in overrides:
242
- self.predictor.save_dir = self.predictor.get_save_dir()
243
- # Set prompts for SAM/FastSAM
244
- if len and hasattr(self.predictor, 'set_prompts'):
245
- self.predictor.set_prompts(prompts)
246
- return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
247
-
248
- def track(self, source=None, stream=False, persist=False, **kwargs):
249
- """
250
- Perform object tracking on the input source using the registered trackers.
251
-
252
- Args:
253
- source (str, optional): The input source for object tracking. Can be a file path or a video stream.
254
- stream (bool, optional): Whether the input source is a video stream. Defaults to False.
255
- persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
256
- **kwargs (optional): Additional keyword arguments for the tracking process.
257
-
258
- Returns:
259
- (List[ultralytics.engine.results.Results]): The tracking results.
260
-
261
- """
262
- if not hasattr(self.predictor, 'trackers'):
263
- from ultralytics.trackers import register_tracker
264
- register_tracker(self, persist)
265
- # ByteTrack-based method needs low confidence predictions as input
266
- conf = kwargs.get('conf') or 0.1
267
- kwargs['conf'] = conf
268
- kwargs['mode'] = 'track'
269
- return self.predict(source=source, stream=stream, **kwargs)
270
-
271
- @smart_inference_mode()
272
- def val(self, data=None, validator=None, **kwargs):
273
- """
274
- Validate a model on a given dataset.
275
-
276
- Args:
277
- data (str): The dataset to validate on. Accepts all formats accepted by yolo
278
- validator (BaseValidator): Customized validator.
279
- **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
280
- """
281
- overrides = self.overrides.copy()
282
- overrides['rect'] = True # rect batches as default
283
- overrides.update(kwargs)
284
- overrides['mode'] = 'val'
285
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
286
- args.data = data or args.data
287
- if 'task' in overrides:
288
- self.task = args.task
289
- else:
290
- args.task = self.task
291
- validator = validator or self.smart_load('validator')
292
- if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
293
- args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
294
- args.imgsz = check_imgsz(args.imgsz, max_dim=1)
295
-
296
- validator = validator(args=args, _callbacks=self.callbacks)
297
- validator(model=self.model)
298
- self.metrics = validator.metrics
299
-
300
- return validator.metrics
301
-
302
- @smart_inference_mode()
303
- def benchmark(self, **kwargs):
304
- """
305
- Benchmark a model on all export formats.
306
-
307
- Args:
308
- **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
309
- """
310
- self._check_is_pytorch_model()
311
- from ultralytics.utils.benchmarks import benchmark
312
- overrides = self.model.args.copy()
313
- overrides.update(kwargs)
314
- overrides['mode'] = 'benchmark'
315
- overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
316
- return benchmark(
317
- model=self,
318
- data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
319
- imgsz=overrides['imgsz'],
320
- half=overrides['half'],
321
- int8=overrides['int8'],
322
- device=overrides['device'],
323
- verbose=overrides['verbose'])
324
-
325
- def export(self, **kwargs):
326
- """
327
- Export model.
328
-
329
- Args:
330
- **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
331
- """
332
- self._check_is_pytorch_model()
333
- overrides = self.overrides.copy()
334
- overrides.update(kwargs)
335
- overrides['mode'] = 'export'
336
- if overrides.get('imgsz') is None:
337
- overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
338
- if 'batch' not in kwargs:
339
- overrides['batch'] = 1 # default to 1 if not modified
340
- if 'data' not in kwargs:
341
- overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
342
- args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
343
- args.task = self.task
344
- return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
345
-
346
- def train(self, trainer=None, **kwargs):
347
- """
348
- Trains the model on a given dataset.
349
-
350
- Args:
351
- trainer (BaseTrainer, optional): Customized trainer.
352
- **kwargs (Any): Any number of arguments representing the training configuration.
353
- """
354
- self._check_is_pytorch_model()
355
- if self.session: # Ultralytics HUB session
356
- if any(kwargs):
357
- LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
358
- kwargs = self.session.train_args
359
- check_pip_update_available()
360
- overrides = self.overrides.copy()
361
- if kwargs.get('cfg'):
362
- LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
363
- overrides = yaml_load(check_yaml(kwargs['cfg']))
364
- overrides.update(kwargs)
365
- overrides['mode'] = 'train'
366
- if not overrides.get('data'):
367
- raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
368
- if overrides.get('resume'):
369
- overrides['resume'] = self.ckpt_path
370
- self.task = overrides.get('task') or self.task
371
- trainer = trainer or self.smart_load('trainer')
372
- self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
373
- if not overrides.get('resume'): # manually set model only if not resuming
374
- self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
375
- self.model = self.trainer.model
376
- self.trainer.hub_session = self.session # attach optional HUB session
377
- self.trainer.train()
378
- # Update model and cfg after training
379
- if RANK in (-1, 0):
380
- self.model, _ = attempt_load_one_weight(str(self.trainer.best))
381
- self.overrides = self.model.args
382
- self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
383
-
384
- def to(self, device):
385
- """
386
- Sends the model to the given device.
387
-
388
- Args:
389
- device (str): device
390
- """
391
- self._check_is_pytorch_model()
392
- self.model.to(device)
393
-
394
- def tune(self, *args, **kwargs):
395
- """
396
- Runs hyperparameter tuning using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
397
-
398
- Returns:
399
- (dict): A dictionary containing the results of the hyperparameter search.
400
-
401
- Raises:
402
- ModuleNotFoundError: If Ray Tune is not installed.
403
- """
404
- self._check_is_pytorch_model()
405
- from ultralytics.utils.tuner import run_ray_tune
406
- return run_ray_tune(self, *args, **kwargs)
407
-
408
- @property
409
- def names(self):
410
- """Returns class names of the loaded model."""
411
- return self.model.names if hasattr(self.model, 'names') else None
412
-
413
- @property
414
- def device(self):
415
- """Returns device if PyTorch model."""
416
- return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
417
-
418
- @property
419
- def transforms(self):
420
- """Returns transform of the loaded model."""
421
- return self.model.transforms if hasattr(self.model, 'transforms') else None
422
-
423
- def add_callback(self, event: str, func):
424
- """Add a callback."""
425
- self.callbacks[event].append(func)
426
-
427
- def clear_callback(self, event: str):
428
- """Clear all event callbacks."""
429
- self.callbacks[event] = []
430
-
431
- @staticmethod
432
- def _reset_ckpt_args(args):
433
- """Reset arguments when loading a PyTorch model."""
434
- include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
435
- return {k: v for k, v in args.items() if k in include}
436
-
437
- def _reset_callbacks(self):
438
- """Reset all registered callbacks."""
439
- for event in callbacks.default_callbacks.keys():
440
- self.callbacks[event] = [callbacks.default_callbacks[event][0]]
441
-
442
- def __getattr__(self, attr):
443
- """Raises error if object has no requested attribute."""
444
- name = self.__class__.__name__
445
- raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
446
-
447
- def smart_load(self, key):
448
- """Load model/trainer/validator/predictor."""
449
- try:
450
- return self.task_map[self.task][key]
451
- except Exception:
452
- name = self.__class__.__name__
453
- mode = inspect.stack()[1][3] # get the function name.
454
- raise NotImplementedError(
455
- f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
456
-
457
- @property
458
- def task_map(self):
459
- """
460
- Map head to model, trainer, validator, and predictor classes.
461
-
462
- Returns:
463
- task_map (dict): The map of model task to mode classes.
464
- """
465
- raise NotImplementedError('Please provide task map for your model!')
 
1
+ # Ultralytics YOLO 🚀, AGPL-3.0 license
2
+
3
+ import inspect
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Union
7
+
8
+ from ultralytics.cfg import get_cfg
9
+ from ultralytics.engine.exporter import Exporter
10
+ from ultralytics.hub.utils import HUB_WEB_ROOT
11
+ from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
12
+ from ultralytics.utils import (DEFAULT_CFG, DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, ROOT, callbacks,
13
+ is_git_dir, yaml_load)
14
+ from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
15
+ from ultralytics.utils.downloads import GITHUB_ASSET_STEMS
16
+ from ultralytics.utils.torch_utils import smart_inference_mode
17
+
18
+
19
+ class Model:
20
+ """
21
+ A base model class to unify apis for all the models.
22
+
23
+ Args:
24
+ model (str, Path): Path to the model file to load or create.
25
+ task (Any, optional): Task type for the YOLO model. Defaults to None.
26
+
27
+ Attributes:
28
+ predictor (Any): The predictor object.
29
+ model (Any): The model object.
30
+ trainer (Any): The trainer object.
31
+ task (str): The type of model task.
32
+ ckpt (Any): The checkpoint object if the model loaded from *.pt file.
33
+ cfg (str): The model configuration if loaded from *.yaml file.
34
+ ckpt_path (str): The checkpoint file path.
35
+ overrides (dict): Overrides for the trainer object.
36
+ metrics (Any): The data for metrics.
37
+
38
+ Methods:
39
+ __call__(source=None, stream=False, **kwargs):
40
+ Alias for the predict method.
41
+ _new(cfg:str, verbose:bool=True) -> None:
42
+ Initializes a new model and infers the task type from the model definitions.
43
+ _load(weights:str, task:str='') -> None:
44
+ Initializes a new model and infers the task type from the model head.
45
+ _check_is_pytorch_model() -> None:
46
+ Raises TypeError if the model is not a PyTorch model.
47
+ reset() -> None:
48
+ Resets the model modules.
49
+ info(verbose:bool=False) -> None:
50
+ Logs the model info.
51
+ fuse() -> None:
52
+ Fuses the model for faster inference.
53
+ predict(source=None, stream=False, **kwargs) -> List[ultralytics.engine.results.Results]:
54
+ Performs prediction using the YOLO model.
55
+
56
+ Returns:
57
+ list(ultralytics.engine.results.Results): The prediction results.
58
+ """
59
+
60
+ def __init__(self, model: Union[str, Path] = 'yolov8n.pt', task=None) -> None:
61
+ """
62
+ Initializes the YOLO model.
63
+
64
+ Args:
65
+ model (Union[str, Path], optional): Path or name of the model to load or create. Defaults to 'yolov8n.pt'.
66
+ task (Any, optional): Task type for the YOLO model. Defaults to None.
67
+ """
68
+ self.callbacks = callbacks.get_default_callbacks()
69
+ self.predictor = None # reuse predictor
70
+ self.model = None # model object
71
+ self.trainer = None # trainer object
72
+ self.ckpt = None # if loaded from *.pt
73
+ self.cfg = None # if loaded from *.yaml
74
+ self.ckpt_path = None
75
+ self.overrides = {} # overrides for trainer object
76
+ self.metrics = None # validation/training metrics
77
+ self.session = None # HUB session
78
+ self.task = task # task type
79
+ model = str(model).strip() # strip spaces
80
+
81
+ # Check if Ultralytics HUB model from https://hub.ultralytics.com
82
+ if self.is_hub_model(model):
83
+ from ultralytics.hub.session import HUBTrainingSession
84
+ self.session = HUBTrainingSession(model)
85
+ model = self.session.model_file
86
+
87
+ # Load or create new YOLO model
88
+ suffix = Path(model).suffix
89
+ if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
90
+ model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
91
+ if suffix in ('.yaml', '.yml'):
92
+ self._new(model, task)
93
+ else:
94
+ self._load(model, task)
95
+
96
+ def __call__(self, source=None, stream=False, **kwargs):
97
+ """Calls the 'predict' function with given arguments to perform object detection."""
98
+ return self.predict(source, stream, **kwargs)
99
+
100
+ @staticmethod
101
+ def is_hub_model(model):
102
+ """Check if the provided model is a HUB model."""
103
+ return any((
104
+ model.startswith(f'{HUB_WEB_ROOT}/models/'), # i.e. https://hub.ultralytics.com/models/MODEL_ID
105
+ [len(x) for x in model.split('_')] == [42, 20], # APIKEY_MODELID
106
+ len(model) == 20 and not Path(model).exists() and all(x not in model for x in './\\'))) # MODELID
107
+
108
+ def _new(self, cfg: str, task=None, model=None, verbose=True):
109
+ """
110
+ Initializes a new model and infers the task type from the model definitions.
111
+
112
+ Args:
113
+ cfg (str): model configuration file
114
+ task (str | None): model task
115
+ model (BaseModel): Customized model.
116
+ verbose (bool): display model info on load
117
+ """
118
+ cfg_dict = yaml_model_load(cfg)
119
+ self.cfg = cfg
120
+ self.task = task or guess_model_task(cfg_dict)
121
+ model = model or self.smart_load('model')
122
+ self.model = model(cfg_dict, verbose=verbose and RANK == -1) # build model
123
+ self.overrides['model'] = self.cfg
124
+
125
+ # Below added to allow export from yamls
126
+ args = {**DEFAULT_CFG_DICT, **self.overrides} # combine model and default args, preferring model args
127
+ self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
128
+ self.model.task = self.task
129
+
130
+ def _load(self, weights: str, task=None):
131
+ """
132
+ Initializes a new model and infers the task type from the model head.
133
+
134
+ Args:
135
+ weights (str): model checkpoint to be loaded
136
+ task (str | None): model task
137
+ """
138
+ suffix = Path(weights).suffix
139
+ if suffix == '.pt':
140
+ self.model, self.ckpt = attempt_load_one_weight(weights)
141
+ self.task = self.model.args['task']
142
+ self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
143
+ self.ckpt_path = self.model.pt_path
144
+ else:
145
+ weights = check_file(weights)
146
+ self.model, self.ckpt = weights, None
147
+ self.task = task or guess_model_task(weights)
148
+ self.ckpt_path = weights
149
+ self.overrides['model'] = weights
150
+ self.overrides['task'] = self.task
151
+
152
+ def _check_is_pytorch_model(self):
153
+ """
154
+ Raises TypeError is model is not a PyTorch model
155
+ """
156
+ pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == '.pt'
157
+ pt_module = isinstance(self.model, nn.Module)
158
+ if not (pt_module or pt_str):
159
+ raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
160
+ f'PyTorch models can be used to train, val, predict and export, i.e. '
161
+ f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
162
+ f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
163
+
164
+ @smart_inference_mode()
165
+ def reset_weights(self):
166
+ """
167
+ Resets the model modules parameters to randomly initialized values, losing all training information.
168
+ """
169
+ self._check_is_pytorch_model()
170
+ for m in self.model.modules():
171
+ if hasattr(m, 'reset_parameters'):
172
+ m.reset_parameters()
173
+ for p in self.model.parameters():
174
+ p.requires_grad = True
175
+ return self
176
+
177
+ @smart_inference_mode()
178
+ def load(self, weights='yolov8n.pt'):
179
+ """
180
+ Transfers parameters with matching names and shapes from 'weights' to model.
181
+ """
182
+ self._check_is_pytorch_model()
183
+ if isinstance(weights, (str, Path)):
184
+ weights, self.ckpt = attempt_load_one_weight(weights)
185
+ self.model.load(weights)
186
+ return self
187
+
188
+ def info(self, detailed=False, verbose=True):
189
+ """
190
+ Logs model info.
191
+
192
+ Args:
193
+ detailed (bool): Show detailed information about model.
194
+ verbose (bool): Controls verbosity.
195
+ """
196
+ self._check_is_pytorch_model()
197
+ return self.model.info(detailed=detailed, verbose=verbose)
198
+
199
+ def fuse(self):
200
+ """Fuse PyTorch Conv2d and BatchNorm2d layers."""
201
+ self._check_is_pytorch_model()
202
+ self.model.fuse()
203
+
204
+ @smart_inference_mode()
205
+ def predict(self, source=None, stream=False, predictor=None, **kwargs):
206
+ """
207
+ Perform prediction using the YOLO model.
208
+
209
+ Args:
210
+ source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
211
+ Accepts all source types accepted by the YOLO model.
212
+ stream (bool): Whether to stream the predictions or not. Defaults to False.
213
+ predictor (BasePredictor): Customized predictor.
214
+ **kwargs : Additional keyword arguments passed to the predictor.
215
+ Check the 'configuration' section in the documentation for all available options.
216
+
217
+ Returns:
218
+ (List[ultralytics.engine.results.Results]): The prediction results.
219
+ """
220
+ if source is None:
221
+ source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg'
222
+ LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.")
223
+ is_cli = (sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')) and any(
224
+ x in sys.argv for x in ('predict', 'track', 'mode=predict', 'mode=track'))
225
+ # Check prompts for SAM/FastSAM
226
+ prompts = kwargs.pop('prompts', None)
227
+ overrides = self.overrides.copy()
228
+ overrides['conf'] = 0.25
229
+ overrides.update(kwargs) # prefer kwargs
230
+ overrides['mode'] = kwargs.get('mode', 'predict')
231
+ assert overrides['mode'] in ['track', 'predict']
232
+ if not is_cli:
233
+ overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python
234
+ if not self.predictor:
235
+ self.task = overrides.get('task') or self.task
236
+ predictor = predictor or self.smart_load('predictor')
237
+ self.predictor = predictor(overrides=overrides, _callbacks=self.callbacks)
238
+ self.predictor.setup_model(model=self.model, verbose=is_cli)
239
+ else: # only update args if predictor is already setup
240
+ self.predictor.args = get_cfg(self.predictor.args, overrides)
241
+ if 'project' in overrides or 'name' in overrides:
242
+ self.predictor.save_dir = self.predictor.get_save_dir()
243
+ # Set prompts for SAM/FastSAM
244
+ if len and hasattr(self.predictor, 'set_prompts'):
245
+ self.predictor.set_prompts(prompts)
246
+ return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
247
+
248
+ def track(self, source=None, stream=False, persist=False, **kwargs):
249
+ """
250
+ Perform object tracking on the input source using the registered trackers.
251
+
252
+ Args:
253
+ source (str, optional): The input source for object tracking. Can be a file path or a video stream.
254
+ stream (bool, optional): Whether the input source is a video stream. Defaults to False.
255
+ persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
256
+ **kwargs (optional): Additional keyword arguments for the tracking process.
257
+
258
+ Returns:
259
+ (List[ultralytics.engine.results.Results]): The tracking results.
260
+
261
+ """
262
+ if not hasattr(self.predictor, 'trackers'):
263
+ from ultralytics.trackers import register_tracker
264
+ register_tracker(self, persist)
265
+ # ByteTrack-based method needs low confidence predictions as input
266
+ conf = kwargs.get('conf') or 0.1
267
+ kwargs['conf'] = conf
268
+ kwargs['mode'] = 'track'
269
+ return self.predict(source=source, stream=stream, **kwargs)
270
+
271
+ @smart_inference_mode()
272
+ def val(self, data=None, validator=None, **kwargs):
273
+ """
274
+ Validate a model on a given dataset.
275
+
276
+ Args:
277
+ data (str): The dataset to validate on. Accepts all formats accepted by yolo
278
+ validator (BaseValidator): Customized validator.
279
+ **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
280
+ """
281
+ overrides = self.overrides.copy()
282
+ overrides['rect'] = True # rect batches as default
283
+ overrides.update(kwargs)
284
+ overrides['mode'] = 'val'
285
+ args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
286
+ args.data = data or args.data
287
+ if 'task' in overrides:
288
+ self.task = args.task
289
+ else:
290
+ args.task = self.task
291
+ validator = validator or self.smart_load('validator')
292
+ if args.imgsz == DEFAULT_CFG.imgsz and not isinstance(self.model, (str, Path)):
293
+ args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
294
+ args.imgsz = check_imgsz(args.imgsz, max_dim=1)
295
+
296
+ validator = validator(args=args, _callbacks=self.callbacks)
297
+ validator(model=self.model)
298
+ self.metrics = validator.metrics
299
+
300
+ return validator.metrics
301
+
302
+ @smart_inference_mode()
303
+ def benchmark(self, **kwargs):
304
+ """
305
+ Benchmark a model on all export formats.
306
+
307
+ Args:
308
+ **kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
309
+ """
310
+ self._check_is_pytorch_model()
311
+ from ultralytics.utils.benchmarks import benchmark
312
+ overrides = self.model.args.copy()
313
+ overrides.update(kwargs)
314
+ overrides['mode'] = 'benchmark'
315
+ overrides = {**DEFAULT_CFG_DICT, **overrides} # fill in missing overrides keys with defaults
316
+ return benchmark(
317
+ model=self,
318
+ data=kwargs.get('data'), # if no 'data' argument passed set data=None for default datasets
319
+ imgsz=overrides['imgsz'],
320
+ half=overrides['half'],
321
+ int8=overrides['int8'],
322
+ device=overrides['device'],
323
+ verbose=overrides['verbose'])
324
+
325
+ def export(self, **kwargs):
326
+ """
327
+ Export model.
328
+
329
+ Args:
330
+ **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs
331
+ """
332
+ self._check_is_pytorch_model()
333
+ overrides = self.overrides.copy()
334
+ overrides.update(kwargs)
335
+ overrides['mode'] = 'export'
336
+ if overrides.get('imgsz') is None:
337
+ overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
338
+ if 'batch' not in kwargs:
339
+ overrides['batch'] = 1 # default to 1 if not modified
340
+ if 'data' not in kwargs:
341
+ overrides['data'] = None # default to None if not modified (avoid int8 calibration with coco.yaml)
342
+ args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
343
+ args.task = self.task
344
+ return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
345
+
346
+ def train(self, trainer=None, **kwargs):
347
+ """
348
+ Trains the model on a given dataset.
349
+
350
+ Args:
351
+ trainer (BaseTrainer, optional): Customized trainer.
352
+ **kwargs (Any): Any number of arguments representing the training configuration.
353
+ """
354
+ self._check_is_pytorch_model()
355
+ if self.session: # Ultralytics HUB session
356
+ if any(kwargs):
357
+ LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
358
+ kwargs = self.session.train_args
359
+ check_pip_update_available()
360
+ overrides = self.overrides.copy()
361
+ if kwargs.get('cfg'):
362
+ LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
363
+ overrides = yaml_load(check_yaml(kwargs['cfg']))
364
+ overrides.update(kwargs)
365
+ overrides['mode'] = 'train'
366
+ if not overrides.get('data'):
367
+ raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
368
+ if overrides.get('resume'):
369
+ overrides['resume'] = self.ckpt_path
370
+ self.task = overrides.get('task') or self.task
371
+ trainer = trainer or self.smart_load('trainer')
372
+ self.trainer = trainer(overrides=overrides, _callbacks=self.callbacks)
373
+ if not overrides.get('resume'): # manually set model only if not resuming
374
+ self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
375
+ self.model = self.trainer.model
376
+ self.trainer.hub_session = self.session # attach optional HUB session
377
+ self.trainer.train()
378
+ # Update model and cfg after training
379
+ if RANK in (-1, 0):
380
+ self.model, _ = attempt_load_one_weight(str(self.trainer.best))
381
+ self.overrides = self.model.args
382
+ self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
383
+
384
+ def to(self, device):
385
+ """
386
+ Sends the model to the given device.
387
+
388
+ Args:
389
+ device (str): device
390
+ """
391
+ self._check_is_pytorch_model()
392
+ self.model.to(device)
393
+
394
+ def tune(self, *args, **kwargs):
395
+ """
396
+ Runs hyperparameter tuning using Ray Tune. See ultralytics.utils.tuner.run_ray_tune for Args.
397
+
398
+ Returns:
399
+ (dict): A dictionary containing the results of the hyperparameter search.
400
+
401
+ Raises:
402
+ ModuleNotFoundError: If Ray Tune is not installed.
403
+ """
404
+ self._check_is_pytorch_model()
405
+ from ultralytics.utils.tuner import run_ray_tune
406
+ return run_ray_tune(self, *args, **kwargs)
407
+
408
+ @property
409
+ def names(self):
410
+ """Returns class names of the loaded model."""
411
+ return self.model.names if hasattr(self.model, 'names') else None
412
+
413
+ @property
414
+ def device(self):
415
+ """Returns device if PyTorch model."""
416
+ return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
417
+
418
+ @property
419
+ def transforms(self):
420
+ """Returns transform of the loaded model."""
421
+ return self.model.transforms if hasattr(self.model, 'transforms') else None
422
+
423
+ def add_callback(self, event: str, func):
424
+ """Add a callback."""
425
+ self.callbacks[event].append(func)
426
+
427
+ def clear_callback(self, event: str):
428
+ """Clear all event callbacks."""
429
+ self.callbacks[event] = []
430
+
431
+ @staticmethod
432
+ def _reset_ckpt_args(args):
433
+ """Reset arguments when loading a PyTorch model."""
434
+ include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
435
+ return {k: v for k, v in args.items() if k in include}
436
+
437
+ def _reset_callbacks(self):
438
+ """Reset all registered callbacks."""
439
+ for event in callbacks.default_callbacks.keys():
440
+ self.callbacks[event] = [callbacks.default_callbacks[event][0]]
441
+
442
+ def __getattr__(self, attr):
443
+ """Raises error if object has no requested attribute."""
444
+ name = self.__class__.__name__
445
+ raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
446
+
447
+ def smart_load(self, key):
448
+ """Load model/trainer/validator/predictor."""
449
+ try:
450
+ return self.task_map[self.task][key]
451
+ except Exception:
452
+ name = self.__class__.__name__
453
+ mode = inspect.stack()[1][3] # get the function name.
454
+ raise NotImplementedError(
455
+ f'WARNING ⚠️ `{name}` model does not support `{mode}` mode for `{self.task}` task yet.')
456
+
457
+ @property
458
+ def task_map(self):
459
+ """
460
+ Map head to model, trainer, validator, and predictor classes.
461
+
462
+ Returns:
463
+ task_map (dict): The map of model task to mode classes.
464
+ """
465
+ raise NotImplementedError('Please provide task map for your model!')