| import os |
| import random |
| from pathlib import Path |
| import shutil |
|
|
| |
| |
| |
|
|
| |
| SOURCE_DIRS = { |
| 'location_1': 'mpala', |
| 'location_2': 'opc', |
| 'location_3': 'wilds' |
| } |
|
|
| |
| DEST_DIR = "/data" |
|
|
| |
| CLASS_LABELS = { |
| 0: "Zebra", |
| 1: "Giraffe", |
| 2: "Onager", |
| 3: "Dog", |
| } |
|
|
| |
| SAMPLING_RATE = 10 |
|
|
| |
| splits = { |
| 'train': { |
| 'location_3': { |
| 'session_1': ['DJI_0034', 'DJI_0035_part1'], |
| 'session_2': ['P0140018'], |
| 'session_3': ['P0100010', 'P0110011', 'P0080008', 'P0090009'], |
| |
| }, |
| 'location_1': { |
| 'session_1': ['DJI_0001', 'DJI_0002'], |
| 'session_2': ['DJI_0005', 'DJI_0006'], |
| 'session_3': ['DJI_0068', 'DJI_0069'], |
| 'session_4': ['DJI_0142', 'DJI_0143', 'DJI_0144'], |
| 'session_5': ['DJI_0206', 'DJI_0208'], |
| }, |
| 'location_2': { |
| 'session_1': ['P0800081', 'P0830086', 'P0840087', 'P0870091'], |
| 'session_2': ['P0910095'], |
| } |
| }, |
| 'test': { |
| 'location_3': { |
| 'session_1': ['DJI_0035_part2'], |
| 'session_3': ['P0070007', 'P0160016', 'P0120012'], |
| 'session_2': ['P0150019'], |
| 'session_4': ['P0070010'], |
| }, |
| 'location_1': { |
| 'session_3': ['DJI_0070', 'DJI_0071'], |
| 'session_4': ['DJI_0145', 'DJI_0146', 'DJI_0147'], |
| 'session_5': ['DJI_0210', 'DJI_0211'], |
| }, |
| 'location_2': { |
| 'session_1': ['P0860090'], |
| 'session_2': ['P0940098'], |
| } |
| } |
| } |
|
|
| |
| |
| |
|
|
| |
| for split in ['train', 'test']: |
| os.makedirs(f"{DEST_DIR}/images/{split}", exist_ok=True) |
| os.makedirs(f"{DEST_DIR}/labels/{split}", exist_ok=True) |
|
|
| def find_images_in_directory(dir_path): |
| """Find all image files in a directory""" |
| try: |
| return [f for f in os.listdir(dir_path) |
| if f.endswith(('.jpg', '.png', '.jpeg')) and os.path.isfile(dir_path / f)] |
| except (FileNotFoundError, NotADirectoryError, PermissionError) as e: |
| print(f"Error accessing {dir_path}: {e}") |
| return [] |
|
|
| def find_partitions(session_path): |
| """Find partition directories in a session""" |
| try: |
| return [d for d in os.listdir(session_path) |
| if os.path.isdir(session_path / d) and d.startswith('partition_')] |
| except (FileNotFoundError, NotADirectoryError, PermissionError) as e: |
| print(f"Error accessing {session_path}: {e}") |
| return [] |
|
|
| def find_video_images(session_path, video_name): |
| """ |
| Find all images for a specific video in all partitions or video directory |
| Returns a list of tuples: (image_path, image_name, partition_name) |
| """ |
| all_images = [] |
| |
| |
| video_path = session_path / video_name |
| if os.path.isdir(video_path): |
| |
| partitions = find_partitions(video_path) |
| |
| if partitions: |
| |
| for partition in partitions: |
| partition_path = video_path / partition |
| images = find_images_in_directory(partition_path) |
| all_images.extend([(partition_path, img, partition) for img in images]) |
| else: |
| |
| images = find_images_in_directory(video_path) |
| all_images.extend([(video_path, img, "") for img in images]) |
| |
| |
| partitions = find_partitions(session_path) |
| for partition in partitions: |
| partition_path = session_path / partition |
| |
| |
| for img in find_images_in_directory(partition_path): |
| |
| if video_name in img: |
| all_images.append((partition_path, img, partition)) |
| |
| return all_images |
|
|
| |
| for split_name, locations in splits.items(): |
| for location_name, sessions in locations.items(): |
| |
| if location_name not in SOURCE_DIRS: |
| print(f"Warning: No source directory defined for {location_name}. Skipping.") |
| continue |
| |
| location_source_dir = Path(SOURCE_DIRS[location_name]) |
| |
| for session_name, video_info in sessions.items(): |
| session_path = location_source_dir / session_name |
| |
| if not os.path.exists(session_path): |
| print(f"Warning: Session path {session_path} does not exist. Skipping.") |
| continue |
| |
| |
| if isinstance(video_info, bool) and video_info: |
| |
| try: |
| |
| videos = [v for v in os.listdir(session_path) |
| if os.path.isdir(session_path / v) and not v.startswith('partition_')] |
| |
| |
| if not videos: |
| partitions = find_partitions(session_path) |
| if partitions: |
| |
| first_partition = session_path / partitions[0] |
| all_imgs = find_images_in_directory(first_partition) |
| |
| videos = list(set([img.split('_')[0] for img in all_imgs if '_' in img])) |
| |
| except (FileNotFoundError, NotADirectoryError) as e: |
| print(f"Warning: Could not list directory {session_path}: {e}") |
| continue |
| else: |
| |
| videos = video_info |
| |
| |
| for video in videos: |
| print(f"Processing {location_name}/{session_name}/{video}...") |
| |
| |
| frame_info = find_video_images(session_path, video) |
| |
| if not frame_info: |
| print(f"Warning: No frames found for {video} in {session_name}") |
| continue |
| |
| |
| frame_info.sort(key=lambda x: x[1]) |
| |
| |
| sampled_frame_info = frame_info[::SAMPLING_RATE] |
| |
| |
| for frame_dir, frame_name, partition in sampled_frame_info: |
| |
| partition_str = "" if partition == "" else f"_{partition}" |
| |
| |
| src_img = frame_dir / frame_name |
| dest_img_name = f"{location_name}_{session_name}_{video}{partition_str}_{frame_name}" |
| dest_img = Path(DEST_DIR) / "images" / split_name / dest_img_name |
| |
| try: |
| shutil.copy(src_img, dest_img) |
| except (FileNotFoundError, IOError) as e: |
| print(f"Error copying image {src_img}: {e}") |
| continue |
| |
| |
| label_name = frame_name.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt') |
| |
| |
| possible_label_paths = [ |
| |
| frame_dir / label_name, |
| |
| |
| frame_dir / "labels" / label_name, |
| |
| |
| session_path / "labels" / partition / label_name, |
| |
| |
| session_path / "labels" / label_name, |
| |
| |
| session_path / video / "labels" / label_name, |
| ] |
| |
| src_label = None |
| for label_path in possible_label_paths: |
| if os.path.exists(label_path): |
| src_label = label_path |
| break |
| |
| if src_label: |
| dest_label_name = dest_img_name.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt') |
| dest_label = Path(DEST_DIR) / "labels" / split_name / dest_label_name |
| try: |
| shutil.copy(src_label, dest_label) |
| except (FileNotFoundError, IOError) as e: |
| print(f"Error copying label {src_label}: {e}") |
| else: |
| print(f"Warning: No label found for {src_img}") |
|
|
| print("Dataset split completed successfully!") |
|
|
| |
| def create_dataset_yaml(): |
| with open(f"{DEST_DIR}/dataset.yaml", "w") as f: |
| f.write(f"# YOLOv11 dataset config\n") |
| f.write(f"path: {os.path.abspath(DEST_DIR)} # dataset root dir\n") |
| f.write(f"train: images/train # train images\n") |
| f.write(f"val: images/train # validation uses train images\n") |
| f.write(f"test: images/test # test images\n\n") |
| |
| f.write(f"# Classes\n") |
| f.write(f"names:\n") |
| for class_id, class_name in CLASS_LABELS.items(): |
| f.write(f" {class_id}: {class_name}\n") |
|
|
| create_dataset_yaml() |
|
|
| |
| stats = {"train": {}, "test": {}} |
|
|
| for split in ['train', 'test']: |
| |
| locations = {} |
| species_count = {} |
| |
| |
| img_dir = Path(DEST_DIR) / "images" / split |
| if not os.path.exists(img_dir): |
| print(f"Warning: Directory {img_dir} does not exist.") |
| continue |
| |
| total_count = 0 |
| |
| for img in os.listdir(img_dir): |
| parts = img.split('_') |
| if len(parts) < 2: |
| continue |
| |
| location = parts[0] |
| session = parts[1] |
| |
| |
| if location not in locations: |
| locations[location] = 0 |
| locations[location] += 1 |
| |
| |
| species_key = f"{location}_{session}" |
| if species_key not in species_count: |
| species_count[species_key] = 0 |
| species_count[species_key] += 1 |
| |
| |
| total_count += 1 |
| |
| stats[split]["total"] = total_count |
| stats[split]["locations"] = locations |
| stats[split]["species"] = species_count |
|
|
| |
| for split, data in stats.items(): |
| print(f"\n{split.upper()} set:") |
| print(f"Total images: {data['total']}") |
| |
| print("Distribution by location:") |
| for loc, count in data["locations"].items(): |
| percentage = (count/data['total']*100) if data['total'] > 0 else 0 |
| print(f" - {loc}: {count} ({percentage:.1f}%)") |
| |
| print("\nDistribution by location_session:") |
| for species_key, count in data["species"].items(): |
| percentage = (count/data['total']*100) if data['total'] > 0 else 0 |
| print(f" - {species_key}: {count} ({percentage:.1f}%)") |
|
|
| print("\nOverall train/test ratio:", |
| f"{stats['train']['total'] / (stats['train']['total'] + stats['test']['total']):.1%}", |
| f"/ {stats['test']['total'] / (stats['train']['total'] + stats['test']['total']):.1%}") |