Spaces:
Running
Running
File size: 5,818 Bytes
15995ba 7ffa2fb 15995ba 7ffa2fb 15995ba 7ffa2fb 15995ba 7ffa2fb 15995ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# train.py
from .yolo_manager import YOLOManager
from .utils import get_abs_path, backup_file
import os
from .config import Config
import yaml
import os
from pathlib import Path
import shutil
def create_filtered_dataset(original_dataset_path, output_base_path):
"""
Create a filtered dataset with only images that have non-empty labels
"""
shutil.rmtree(f'{original_dataset_path}/filtered_dataset', ignore_errors=True)
original_path = Path(f'{original_dataset_path}/filtered_dataset')
output_path = Path(output_base_path)
# Create output directory structure
output_images = output_path / "images"
output_labels = output_path / "labels"
for split in ['train', 'val', 'test']:
(output_images / split).mkdir(parents=True, exist_ok=True)
(output_labels / split).mkdir(parents=True, exist_ok=True)
filtered_counts = {}
for split in ['train', 'val', 'test']:
original_images_dir = original_path / 'images' / split
original_labels_dir = original_path / 'labels' / split
output_images_dir = output_images / split
output_labels_dir = output_labels / split
if not original_images_dir.exists() or not original_labels_dir.exists():
print(f"Skipping {split} - source directory not found")
filtered_counts[split] = 0
continue
total_count = 0
copied_count = 0
# Process each image
for img_file in original_images_dir.glob('*'):
if img_file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']:
total_count += 1
label_file = original_labels_dir / f"{img_file.stem}.txt"
# Check if label file exists and has content
if label_file.exists():
with open(label_file, 'r') as f:
content = f.read().strip()
if content: # Label file has content
# Copy image
shutil.copy2(img_file, output_images_dir / img_file.name)
# Copy label
shutil.copy2(label_file, output_labels_dir / label_file.name)
copied_count += 1
else:
print(f"Skipping {img_file.name} - empty label file")
else:
print(f"Skipping {img_file.name} - no label file")
filtered_counts[split] = copied_count
print(f"{split.upper()} split: {copied_count}/{total_count} images copied")
return filtered_counts
def create_filtered_yaml(output_base_path, filtered_counts):
"""
Create the YAML file for the filtered dataset
"""
output_path = Path(output_base_path)
yaml_path = f'{Config.current_path}/filtered_comic.yaml'
# Create YAML structure
yaml_data = {
'names': ['panel'],
'nc': 1,
'path': str(output_path),
'train': str(output_path / 'images' / 'train'),
'val': str(output_path / 'images' / 'val')
}
# Only add test if it has images
if filtered_counts.get('test', 0) > 0:
yaml_data['test'] = str(output_path / 'images' / 'test')
# Write YAML file
with open(yaml_path, 'w') as f:
yaml.dump(yaml_data, f, default_flow_style=False, sort_keys=False)
print(f"\nβ
Created filtered dataset YAML: {yaml_path}")
return yaml_path
def main():
"""Main training function."""
try:
# Initialize YOLO manager
yolo_manager = YOLOManager()
# Configuration
data_yaml_path = f'{Config.current_path}/filtered_comic.yaml'
if not os.path.isfile(data_yaml_path):
raise FileNotFoundError(f"β Dataset YAML not found: {data_yaml_path}")
print(f"π― Training model: {Config.YOLO_MODEL_NAME}")
# Train model
model = yolo_manager.train(
data_yaml_path=data_yaml_path,
run_name=Config.YOLO_MODEL_NAME
)
# Validate model
metrics = yolo_manager.validate()
# Backup best weights
weights_path = yolo_manager.get_best_weights_path()
backup_path = f'{Config.YOLO_MODEL_NAME}.pt'
backup_file(weights_path, backup_path)
print("π Training completed successfully!")
except Exception as e:
print(f"β Training failed: {str(e)}")
raise
if __name__ == "__main__":# Configuration
# Configuration
original_dataset_path = "/home/jebineinstein/git/comic-panel-extractor/comic_panel_extractor/dataset"
output_base_path = "/home/jebineinstein/git/comic-panel-extractor/comic_panel_extractor"
print("π Starting dataset filtering...")
print(f"π Source: {original_dataset_path}")
print(f"π Output: {output_base_path}")
# Create filtered dataset
filtered_counts = create_filtered_dataset(original_dataset_path, output_base_path)
# Create YAML file
yaml_path = create_filtered_yaml(output_base_path, filtered_counts)
# Summary
total_filtered = sum(filtered_counts.values())
print(f"\nπ Filtering Summary:")
for split, count in filtered_counts.items():
if count > 0:
print(f" {split.upper()}: {count} images")
print(f" TOTAL: {total_filtered} images with labels")
print(f"\nπ― Use this YAML for training: {yaml_path}")
# Display the created YAML content
with open(yaml_path, 'r') as f:
yaml_content = f.read()
print(f"\nπ Generated YAML content:")
print("β" * 50)
print(yaml_content)
print("β" * 50)
main() |