| |
| """ |
| Integrate GAN synthetic data into YOLO training dataset. |
| |
| This script converts synthetic patches from GAN training into complete |
| YOLO-format images with annotations for dataset augmentation. |
| |
| Usage: |
| python integrate_gan.py [--num-synthetic 1000] |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import random |
| from pathlib import Path |
| import cv2 |
| import numpy as np |
|
|
| |
| SYNTHETIC_ROOT = Path("/home/pragadeesh/ARM/model/dataset/synthetic") |
| DATASET_ROOT = Path("/home/pragadeesh/ARM/model/dataset") |
| TRAIN_IMAGES = DATASET_ROOT / "train" / "images" |
| TRAIN_LABELS = DATASET_ROOT / "train" / "labels" |
|
|
| |
| CLASS_NAMES = { |
| "pothole": 0, |
| "cracks": 1, |
| "open_manhole": 2 |
| } |
|
|
| PATCH_SIZE = 64 |
| SYNTHETIC_IMAGE_SIZE = 416 |
|
|
|
|
| def create_synthetic_yolo_image(patches, class_id, image_size=SYNTHETIC_IMAGE_SIZE): |
| """ |
| Create a YOLO-format image by composing GAN patches. |
| |
| Args: |
| patches: List of patch arrays |
| class_id: Class ID for the patches |
| image_size: Output image size |
| |
| Returns: |
| Tuple of (image, label_data) |
| """ |
| image = np.ones((image_size, image_size, 3), dtype=np.uint8) * 128 |
| label_bboxes = [] |
| pixel_bboxes = [] |
| |
| |
| num_patches = random.randint(2, 5) |
| selected_patches = random.sample(patches, min(num_patches, len(patches))) |
| |
| for patch_idx, patch in enumerate(selected_patches): |
| |
| max_x = image_size - PATCH_SIZE |
| max_y = image_size - PATCH_SIZE |
| |
| x = random.randint(0, max(1, max_x)) |
| y = random.randint(0, max(1, max_y)) |
| |
| |
| overlaps = False |
| for bx1, by1, bx2, by2 in pixel_bboxes: |
| if not (x + PATCH_SIZE < bx1 or x > bx2 or |
| y + PATCH_SIZE < by1 or y > by2): |
| overlaps = True |
| break |
| |
| if overlaps and len(pixel_bboxes) > 0: |
| continue |
| |
| |
| image[y:y+PATCH_SIZE, x:x+PATCH_SIZE] = patch |
| |
| |
| x_center = (x + PATCH_SIZE / 2) / image_size |
| y_center = (y + PATCH_SIZE / 2) / image_size |
| width = PATCH_SIZE / image_size |
| height = PATCH_SIZE / image_size |
| |
| label_bboxes.append((class_id, x_center, y_center, width, height)) |
| pixel_bboxes.append((x, y, x + PATCH_SIZE, y + PATCH_SIZE)) |
| |
| return image, label_bboxes |
|
|
|
|
| def integrate_synthetic_data(num_synthetic_per_class): |
| """Integrate synthetic data into training dataset.""" |
| |
| print("\n" + "="*70) |
| print("INTEGRATING GAN SYNTHETIC DATA INTO YOLO DATASET") |
| print("="*70) |
| |
| |
| if not SYNTHETIC_ROOT.exists(): |
| print(f"✗ Synthetic data not found: {SYNTHETIC_ROOT}") |
| print(" Run 'python gan_train.py' first") |
| sys.exit(1) |
| |
| if not TRAIN_IMAGES.exists(): |
| print(f"✗ Training images directory not found: {TRAIN_IMAGES}") |
| sys.exit(1) |
| |
| |
| synthetic_patches = {} |
| |
| for class_name, class_id in CLASS_NAMES.items(): |
| class_dir = SYNTHETIC_ROOT / class_name |
| |
| if not class_dir.exists(): |
| print(f"⚠ No synthetic data for {class_name}") |
| continue |
| |
| patches = [] |
| patch_files = list(class_dir.glob("*.jpg")) |
| |
| print(f"\nLoading synthetic patches for {class_name}...") |
| print(f" Found: {len(patch_files)} patches") |
| |
| for patch_file in patch_files: |
| try: |
| patch = cv2.imread(str(patch_file)) |
| if patch is not None: |
| patches.append(patch) |
| except: |
| pass |
| |
| if patches: |
| synthetic_patches[class_name] = patches |
| print(f" Loaded: {len(patches)} patches") |
| |
| if not synthetic_patches: |
| print("✗ No synthetic patches loaded") |
| sys.exit(1) |
| |
| |
| print(f"\n{'='*70}") |
| print("GENERATING SYNTHETIC YOLO IMAGES") |
| print(f"{'='*70}") |
| |
| |
| existing_images = list(TRAIN_IMAGES.glob("*.jpg")) + list(TRAIN_IMAGES.glob("*.png")) |
| next_id = max([int(f.stem) for f in existing_images if f.stem.isdigit()] + [0]) + 1 |
| |
| total_generated = 0 |
| |
| for class_name, class_id in CLASS_NAMES.items(): |
| if class_name not in synthetic_patches: |
| continue |
| |
| patches = synthetic_patches[class_name] |
| print(f"\nGenerating for {class_name}...") |
| |
| for i in range(num_synthetic_per_class): |
| |
| image, bboxes = create_synthetic_yolo_image(patches, class_id) |
| |
| |
| image_id = f"{next_id:06d}" |
| image_path = TRAIN_IMAGES / f"{image_id}.jpg" |
| cv2.imwrite(str(image_path), image) |
| |
| |
| label_path = TRAIN_LABELS / f"{image_id}.txt" |
| with open(label_path, 'w') as f: |
| for class_id, x_center, y_center, width, height in bboxes: |
| f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n") |
| |
| next_id += 1 |
| total_generated += 1 |
| |
| if (i + 1) % 100 == 0: |
| print(f" Progress: {i+1}/{num_synthetic_per_class}") |
| |
| print(f" ✓ Generated {num_synthetic_per_class} images for {class_name}") |
| |
| print(f"\n{'='*70}") |
| print("INTEGRATION COMPLETE") |
| print(f"{'='*70}") |
| print(f"Total synthetic images generated: {total_generated}") |
| print(f"New training set size: {len(list(TRAIN_IMAGES.glob('*')))}") |
| print(f"\nTo train with synthetic data:") |
| print(f" python train_road_anomaly_model.py") |
| print(f"{'='*70}\n") |
|
|
|
|
| def main(): |
| """Main entry point.""" |
| parser = argparse.ArgumentParser( |
| description="Integrate GAN synthetic data into YOLO training dataset" |
| ) |
| parser.add_argument( |
| "--num-synthetic", |
| type=int, |
| default=100, |
| help="Number of synthetic images per class (default: 100)" |
| ) |
| |
| args = parser.parse_args() |
| |
| integrate_synthetic_data(args.num_synthetic) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|