| |
| |
|
|
| print("π Starting Court Keypoint Setup...") |
|
|
| |
| !pip install -q ultralytics roboflow |
|
|
| |
| import os |
| import shutil |
| import yaml |
| import torch |
| from roboflow import Roboflow |
| from google.colab import drive |
|
|
| |
| if not torch.cuda.is_available(): |
| print("β ERROR: GPU not detected! Go to 'Runtime' -> 'Change runtime type' and select 'T4 GPU' before running.") |
| else: |
| print("β
GPU detected: " + torch.cuda.get_device_name(0)) |
|
|
| |
| drive.mount('/content/drive') |
|
|
| |
| |
| |
| ROBOFLOW_API_KEY = "ZzD21wz5oTPdE0fhb04C" |
| rf = Roboflow(api_key=ROBOFLOW_API_KEY) |
| project = rf.workspace("fyp-3bwmg").project("reloc2-den7l") |
| version = project.version(1) |
| dataset = version.download("yolov8") |
|
|
| |
| def fix_data_yaml(dataset_location): |
| yaml_path = os.path.join(dataset_location, 'data.yaml') |
| with open(yaml_path, 'r') as f: |
| data = yaml.safe_load(f) |
| |
| |
| data['train'] = os.path.join(dataset_location, 'train', 'images') |
| data['val'] = os.path.join(dataset_location, 'valid', 'images') |
| if 'test' in data: |
| data['test'] = os.path.join(dataset_location, 'test', 'images') |
| |
| with open(yaml_path, 'w') as f: |
| yaml.dump(data, f) |
| print(f"β
Paths corrected in {yaml_path}") |
|
|
| fix_data_yaml(dataset.location) |
|
|
| |
| |
| MODEL = "yolov8x-pose.pt" |
| IMG_SIZE = 640 |
| EPOCHS = 500 |
| BATCH_SIZE = 16 |
|
|
| |
| print(f"π₯ Starting training for {EPOCHS} epochs... this may take several hours.") |
| !yolo task=pose mode=train model={MODEL} data={dataset.location}/data.yaml epochs={EPOCHS} imgsz={IMG_SIZE} batch={BATCH_SIZE} plots=True |
|
|
| |
| print("πΎ Saving weights to Google Drive...") |
| |
| source_path = "runs/pose/train/weights/best.pt" |
| dest_path = "/content/drive/MyDrive/basketball_court_keypoints_best.pt" |
|
|
| if os.path.exists(source_path): |
| shutil.copy(source_path, dest_path) |
| print(f"β
SUCCESS: Weights saved to {dest_path}") |
| else: |
| print("β ERROR: Could not find trained weights at " + source_path) |
|
|