File size: 3,223 Bytes
efb1801 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
from pathlib import Path
class_name = 'strawberry'
def rewrite_single_class_data_yaml(dataset_dir, class_name='strawberry'):
dataset_dir = Path(dataset_dir)
data_yaml_path = dataset_dir / 'data.yaml'
if not data_yaml_path.exists():
print('⚠️ data.yaml not found, skipping rewrite.')
return
train_path = dataset_dir / 'train' / 'images'
val_path = dataset_dir / 'valid' / 'images'
test_path = dataset_dir / 'test' / 'images'
content_lines = [
'# Strawberry-only dataset',
f'train: {train_path}',
f'val: {val_path}',
f"test: {test_path if test_path.exists() else ''}",
'',
'nc: 1',
f"names: ['{class_name}']",
]
data_yaml_path.write_text('\n'.join(content_lines) + '\n')
print(f"✅ data.yaml updated for single-class training ({class_name}).")
def enforce_single_class_dataset(dataset_dir, target_class=0, class_name='strawberry'):
dataset_dir = Path(dataset_dir)
stats = {'split_kept': {}, 'labels_removed': 0, 'images_removed': 0}
allowed_ext = ['.jpg', '.jpeg', '.png']
for split in ['train', 'valid', 'test']:
labels_dir = dataset_dir / split / 'labels'
images_dir = dataset_dir / split / 'images'
if not labels_dir.exists():
continue
kept = 0
for label_path in labels_dir.glob('*.txt'):
kept_lines = []
for raw_line in label_path.read_text().splitlines():
line = raw_line.strip()
if not line:
continue
parts = line.split()
if not parts:
continue
try:
class_id = int(parts[0])
except ValueError:
continue
if class_id == target_class:
kept_lines.append(line)
if kept_lines:
label_path.write_text('\n'.join(kept_lines) + '\n')
kept += len(kept_lines)
else:
label_path.unlink()
stats['labels_removed'] += 1
for ext in allowed_ext:
candidate = images_dir / f"{label_path.stem}{ext}"
if candidate.exists():
candidate.unlink()
stats['images_removed'] += 1
break
stats['split_kept'][split] = kept
rewrite_single_class_data_yaml(dataset_dir, class_name)
print('\n🍓 Strawberry-only filtering summary:')
for split, count in stats['split_kept'].items():
print(f" {split}: {count} annotations kept")
print(f" Label files removed: {stats['labels_removed']}")
print(f" Images removed (non-strawberry or empty labels): {stats['images_removed']}")
return stats
if __name__ == "__main__":
# Example usage - replace with your dataset path
dataset_path = "path/to/your/dataset" # Update this path
if dataset_path and Path(dataset_path).exists():
strawberry_stats = enforce_single_class_dataset(dataset_path, target_class=0, class_name=class_name)
else:
print("Please set a valid dataset_path variable.") |