Spaces:
Runtime error
Runtime error
| import os | |
| #set the HOME variable to current working directory. Trust me, you don't want to flood your meager home dir.\ | |
| os.environ['HOME'] = os.getcwd() | |
| import sys | |
| import argparse | |
| from ultralytics import settings #Quite frankly, this doesn't seem to do anything, but another failsafe wouldn't hurt. | |
| settings.update({"datasets_dir":"data/", "weights_dir": "weights/", "runs_dir": "runs/"}) | |
| #Parse command line arguments | |
| parser = argparse.ArgumentParser(description="Train YOLOWorld model with specified parameters.") | |
| parser.add_argument("--data", type=str, required=True, help="Path to the dataset YAML file.") | |
| parser.add_argument("--project", type=str, default="runs/train", help="Project directory for training results.") | |
| parser.add_argument("--name", type=str, default="yoloworld_train_", help="Name of the training run.") | |
| parser.add_argument("--epochs", type=int, default=400, help="Number of training epochs.") | |
| parser.add_argument("--time", type=int, default=0, help="Maximum training time in hours (0 for no limit).") | |
| #If time is set, it will override epochs. | |
| parser.add_argument("--patience", type=int, default=50, help="Number of epochs with no improvement before early stopping. Default is 50.") | |
| parser.add_argument("--batch", type=int, default=-1, help="Batch size for training. -1 means adaptive, default.") | |
| parser.add_argument("--workers", type=int, default=8, help="Number of workers for data loading.") | |
| parser.add_argument("--nocache", action="store_true", help="Don't cache images for faster training.") | |
| parser.add_argument("--nosave" , action="store_true", help="Don't save model checkpoints during training.") | |
| parser.add_argument("--save_period", type=int, default=10, help="Save model every N epochs. Default is 10.") | |
| parser.add_argument("--noplots" , action="store_true", help="Don't save training plots.") | |
| parser.add_argument("--imgsz", type=int, default=0, help="Image size for training.") #By default, 0 means variable image size, aka the imgsz option is not set in the training command. | |
| #rect should be OFF BY DEFAULT | |
| parser.add_argument("--rect", action="store_true", help="Use rectangular training images.") | |
| args = parser.parse_args() | |
| from ultralytics import YOLOWorld | |
| import torch | |
| #Train the model with the specified parameters | |
| #There's probably a beter way to do this, but eh, it's readable, and changable for future use. | |
| kwargs = { | |
| "data": args.data, | |
| "project": args.project, | |
| "name": args.name, | |
| "epochs": args.epochs, | |
| "time": args.time if args.time > 0 else None, # Use None if time is 0 | |
| "patience": args.patience, | |
| "batch": args.batch, | |
| "workers": args.workers, | |
| "cache": not args.nocache, # Invert the flag for cache,save and plots | |
| "save": not args.nosave, | |
| "plots": not args.noplots, | |
| "save_period": args.save_period if not args.nosave else None, # Disable saving if nosave is True | |
| "imgsz": args.imgsz if args.imgsz > 0 else None, # Use None if imgsz is 0 | |
| "rect": args.rect, | |
| "classes": [0,1,2,3,5,7,11] | |
| } | |
| #pop out all of the None values to avoid passing them to the train function | |
| kwargs = {k: v for k, v in kwargs.items() if v is not None} | |
| print(f"Training with parameters: {kwargs}") | |
| model = YOLOWorld("yolov8x-worldv2.pt") | |
| #Parse class names from assets/coco_class_list.txt <see format in the file> | |
| with open("assets/coco_class_list.txt", "r") as f: | |
| class_names = [line.strip().split(": ", 1)[1] for line in f if ": " in line] | |
| print(class_names) | |
| model.set_classes(class_names) | |
| if torch.cuda.is_available(): | |
| num_gpus = torch.cuda.device_count() | |
| print(f"Available GPUs ({num_gpus}):") | |
| for i in range(num_gpus): | |
| print(f" [{i}] {torch.cuda.get_device_name(i)}") | |
| else: | |
| print("No GPUs available.") | |
| results = model.train(**kwargs) | |
| #export as ONNX and TensorRT | |
| model.export(format="onnx", dynamic=True) | |
| model.export(format="engine", dynamic=True) | |
| #validate the model | |
| metrics = model.val(data=args.data, project=args.project, name=args.name, rect=args.rect) | |
| print(metrics) |