File size: 3,223 Bytes
efb1801
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from pathlib import Path

class_name = 'strawberry'

def rewrite_single_class_data_yaml(dataset_dir, class_name='strawberry'):
    dataset_dir = Path(dataset_dir)
    data_yaml_path = dataset_dir / 'data.yaml'
    if not data_yaml_path.exists():
        print('⚠️  data.yaml not found, skipping rewrite.')
        return

    train_path = dataset_dir / 'train' / 'images'
    val_path = dataset_dir / 'valid' / 'images'
    test_path = dataset_dir / 'test' / 'images'

    content_lines = [
        '# Strawberry-only dataset',
        f'train: {train_path}',
        f'val: {val_path}',
        f"test: {test_path if test_path.exists() else ''}",
        '',
        'nc: 1',
        f"names: ['{class_name}']",
    ]

    data_yaml_path.write_text('\n'.join(content_lines) + '\n')
    print(f"✅ data.yaml updated for single-class training ({class_name}).")


def enforce_single_class_dataset(dataset_dir, target_class=0, class_name='strawberry'):
    dataset_dir = Path(dataset_dir)
    stats = {'split_kept': {}, 'labels_removed': 0, 'images_removed': 0}
    allowed_ext = ['.jpg', '.jpeg', '.png']

    for split in ['train', 'valid', 'test']:
        labels_dir = dataset_dir / split / 'labels'
        images_dir = dataset_dir / split / 'images'
        if not labels_dir.exists():
            continue

        kept = 0
        for label_path in labels_dir.glob('*.txt'):
            kept_lines = []
            for raw_line in label_path.read_text().splitlines():
                line = raw_line.strip()
                if not line:
                    continue
                parts = line.split()
                if not parts:
                    continue
                try:
                    class_id = int(parts[0])
                except ValueError:
                    continue
                if class_id == target_class:
                    kept_lines.append(line)

            if kept_lines:
                label_path.write_text('\n'.join(kept_lines) + '\n')
                kept += len(kept_lines)
            else:
                label_path.unlink()
                stats['labels_removed'] += 1
                for ext in allowed_ext:
                    candidate = images_dir / f"{label_path.stem}{ext}"
                    if candidate.exists():
                        candidate.unlink()
                        stats['images_removed'] += 1
                        break

        stats['split_kept'][split] = kept

    rewrite_single_class_data_yaml(dataset_dir, class_name)

    print('\n🍓 Strawberry-only filtering summary:')
    for split, count in stats['split_kept'].items():
        print(f"   {split}: {count} annotations kept")
    print(f"   Label files removed: {stats['labels_removed']}")
    print(f"   Images removed (non-strawberry or empty labels): {stats['images_removed']}")

    return stats


if __name__ == "__main__":
    # Example usage - replace with your dataset path
    dataset_path = "path/to/your/dataset"  # Update this path

    if dataset_path and Path(dataset_path).exists():
        strawberry_stats = enforce_single_class_dataset(dataset_path, target_class=0, class_name=class_name)
    else:
        print("Please set a valid dataset_path variable.")