arm-model / model /gan.py
pragadeeshv23's picture
Upload folder using huggingface_hub
5b86813 verified
#!/usr/bin/env python3
"""
Convert YOLO-format object detection dataset to GAN-ready patch dataset.
This script reads YOLO annotations and extracts object patches from images,
resizing them to 64×64 for use in GAN training.
Structure:
Input: dataset/train(valid)/{images,labels}
Output: gan_dataset/{pothole,cracks,open_manhole}/
"""
import os
import sys
from pathlib import Path
import cv2
# Dataset configuration
DATASET_ROOT = Path("/home/pragadeesh/ARM/model/dataset")
OUTPUT_ROOT = Path("/home/pragadeesh/ARM/model/gan_dataset")
# Class mapping from YOLO format
CLASS_NAMES = {
0: "pothole",
1: "cracks",
2: "open_manhole",
3: "good_road"
}
# Classes to ignore (good_road)
IGNORE_CLASSES = {3}
# Target patch size for GAN training
PATCH_SIZE = 64
# Minimum object size in pixels (skip smaller objects)
MIN_WIDTH = 10
MIN_HEIGHT = 10
def setup_output_directories():
"""Create output directory structure for each class."""
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
for class_id, class_name in CLASS_NAMES.items():
if class_id not in IGNORE_CLASSES:
class_dir = OUTPUT_ROOT / class_name
class_dir.mkdir(parents=True, exist_ok=True)
print(f"✓ Created: {class_dir}")
def parse_yolo_label(label_path, img_width, img_height):
"""
Parse YOLO format label file.
YOLO format: class_id x_center y_center width height
All values normalized to [0, 1]
Args:
label_path: Path to .txt label file
img_width: Image width in pixels
img_height: Image height in pixels
Returns:
List of tuples: (class_id, x1, y1, x2, y2) in pixel coordinates
"""
bboxes = []
# Return empty list if label file doesn't exist
if not label_path.exists():
return bboxes
try:
with open(label_path, 'r') as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
# Skip empty lines
if not line:
continue
try:
parts = line.split()
# Validate format
if len(parts) < 5:
continue
# Parse YOLO annotation
class_id = int(parts[0])
x_center_norm = float(parts[1])
y_center_norm = float(parts[2])
width_norm = float(parts[3])
height_norm = float(parts[4])
# Convert normalized coordinates to pixel coordinates
x_center_px = x_center_norm * img_width
y_center_px = y_center_norm * img_height
width_px = width_norm * img_width
height_px = height_norm * img_height
# Calculate bounding box corners (top-left and bottom-right)
x1 = max(0, int(x_center_px - width_px / 2))
y1 = max(0, int(y_center_px - height_px / 2))
x2 = min(img_width, int(x_center_px + width_px / 2))
y2 = min(img_height, int(y_center_px + height_px / 2))
# Add to list
bboxes.append((class_id, x1, y1, x2, y2))
except (ValueError, IndexError) as e:
# Gracefully skip malformed lines
pass
except Exception as e:
# Log but don't crash on read errors
pass
return bboxes
def extract_patch(image, bbox, class_name):
"""
Extract a patch from image and resize to 64×64.
Args:
image: OpenCV image (BGR)
bbox: Tuple (class_id, x1, y1, x2, y2) in pixel coordinates
class_name: Name of class for validation
Returns:
Resized patch (64×64) or None if invalid
"""
class_id, x1, y1, x2, y2 = bbox
# Validate bounding box coordinates
if x1 >= x2 or y1 >= y2:
return None
patch_width = x2 - x1
patch_height = y2 - y1
# Skip patches that are too small
if patch_width < MIN_WIDTH or patch_height < MIN_HEIGHT:
return None
# Extract region from image
try:
patch = image[y1:y2, x1:x2]
# Validate patch
if patch.size == 0:
return None
# Resize to 64×64
patch_resized = cv2.resize(patch, (PATCH_SIZE, PATCH_SIZE))
return patch_resized
except Exception as e:
# Handle any extraction errors
return None
def process_dataset_split(split_name):
"""
Process all images in a dataset split (train or valid).
Args:
split_name: String 'train' or 'valid'
Returns:
Dictionary with processing statistics
"""
stats = {
'images_processed': 0,
'labels_read': 0,
'patches_extracted': 0,
'patches_saved': 0,
'errors': 0,
'by_class': {}
}
# Construct paths
split_root = DATASET_ROOT / split_name
images_dir = split_root / 'images'
labels_dir = split_root / 'labels'
# Validate directories exist
if not images_dir.exists():
print(f"✗ Images directory not found: {images_dir}")
return stats
print(f"\n{'='*70}")
print(f"Processing {split_name.upper()} split")
print(f"{'='*70}")
print(f"Images dir: {images_dir}")
print(f"Labels dir: {labels_dir}")
# Find all image files
image_extensions = ('.jpg', '.jpeg', '.png', '.JPG', '.JPEG', '.PNG')
image_paths = sorted([
p for p in images_dir.iterdir()
if p.suffix.lower() in [e.lower() for e in image_extensions]
])
print(f"Found {len(image_paths)} images")
if not image_paths:
print(f"⚠ No images found in {images_dir}")
return stats
# Counter for patch filenames per class
patch_counters = {
'pothole': 0,
'cracks': 0,
'open_manhole': 0
}
# Process each image
for img_idx, image_path in enumerate(image_paths, 1):
# Corresponding label file
label_path = labels_dir / (image_path.stem + '.txt')
# Read image
try:
image = cv2.imread(str(image_path))
if image is None:
stats['errors'] += 1
continue
except Exception as e:
stats['errors'] += 1
continue
img_height, img_width = image.shape[:2]
# Parse YOLO labels
bboxes = parse_yolo_label(label_path, img_width, img_height)
if bboxes:
stats['labels_read'] += 1
# Process each bounding box
for bbox in bboxes:
class_id = bbox[0]
# Skip ignored classes (good_road)
if class_id in IGNORE_CLASSES:
continue
# Get class name
if class_id not in CLASS_NAMES:
continue
class_name = CLASS_NAMES[class_id]
# Extract and resize patch
patch = extract_patch(image, bbox, class_name)
if patch is None:
continue
stats['patches_extracted'] += 1
# Save patch
output_dir = OUTPUT_ROOT / class_name
patch_counters[class_name] += 1
output_filename = f"{class_name}_{patch_counters[class_name]:06d}.jpg"
output_path = output_dir / output_filename
try:
success = cv2.imwrite(str(output_path), patch)
if success:
stats['patches_saved'] += 1
stats['by_class'][class_name] = stats['by_class'].get(class_name, 0) + 1
else:
stats['errors'] += 1
except Exception as e:
stats['errors'] += 1
stats['images_processed'] += 1
# Progress indicator every 50 images
if img_idx % 50 == 0:
print(f" Progress: {img_idx}/{len(image_paths)} images")
return stats
def main():
"""Main entry point."""
print("\n" + "="*70)
print("YOLO to GAN Dataset Converter")
print("="*70)
print(f"Dataset root: {DATASET_ROOT}")
print(f"Output root: {OUTPUT_ROOT}")
print(f"Patch size: {PATCH_SIZE}×{PATCH_SIZE}")
print(f"Min dimensions: {MIN_WIDTH}×{MIN_HEIGHT}")
# Check dataset exists
if not DATASET_ROOT.exists():
print(f"\n✗ Dataset directory not found: {DATASET_ROOT}")
sys.exit(1)
# Setup output structure
print("\nSetting up output directories...")
setup_output_directories()
# Process both splits
all_stats = {}
for split in ['train', 'valid']:
stats = process_dataset_split(split)
all_stats[split] = stats
# Print summary report
print("\n" + "="*70)
print("CONVERSION COMPLETE - SUMMARY")
print("="*70)
total_images = 0
total_patches = 0
for split in ['train', 'valid']:
stats = all_stats[split]
print(f"\n{split.upper()} SET:")
print(f" Images processed: {stats['images_processed']}")
print(f" Labels found: {stats['labels_read']}")
print(f" Patches extracted: {stats['patches_extracted']}")
print(f" Patches saved: {stats['patches_saved']}")
if stats['errors'] > 0:
print(f" Errors: {stats['errors']}")
if stats['by_class']:
print(f" By class:")
for class_name in sorted(stats['by_class'].keys()):
count = stats['by_class'][class_name]
print(f" • {class_name}: {count}")
total_images += stats['images_processed']
total_patches += stats['patches_saved']
print("\n" + "-"*70)
print(f"TOTAL IMAGES: {total_images}")
print(f"TOTAL PATCHES: {total_patches}")
print(f"OUTPUT: {OUTPUT_ROOT}")
print("="*70 + "\n")
if total_patches == 0:
print("⚠ No patches extracted. Check dataset structure and labels.")
sys.exit(1)
print("✓ Done!\n")
if __name__ == "__main__":
main()