arm-model / model /integrate_gan.py
pragadeeshv23's picture
Upload folder using huggingface_hub
5b86813 verified
#!/usr/bin/env python3
"""
Integrate GAN synthetic data into YOLO training dataset.
This script converts synthetic patches from GAN training into complete
YOLO-format images with annotations for dataset augmentation.
Usage:
python integrate_gan.py [--num-synthetic 1000]
"""
import os
import sys
import argparse
import random
from pathlib import Path
import cv2
import numpy as np
# Paths
SYNTHETIC_ROOT = Path("/home/pragadeesh/ARM/model/dataset/synthetic")
DATASET_ROOT = Path("/home/pragadeesh/ARM/model/dataset")
TRAIN_IMAGES = DATASET_ROOT / "train" / "images"
TRAIN_LABELS = DATASET_ROOT / "train" / "labels"
# Class mapping
CLASS_NAMES = {
"pothole": 0,
"cracks": 1,
"open_manhole": 2
}
PATCH_SIZE = 64
SYNTHETIC_IMAGE_SIZE = 416 # YOLO training size
def create_synthetic_yolo_image(patches, class_id, image_size=SYNTHETIC_IMAGE_SIZE):
"""
Create a YOLO-format image by composing GAN patches.
Args:
patches: List of patch arrays
class_id: Class ID for the patches
image_size: Output image size
Returns:
Tuple of (image, label_data)
"""
image = np.ones((image_size, image_size, 3), dtype=np.uint8) * 128 # Gray background
label_bboxes = []
pixel_bboxes = []
# Place patches randomly on image
num_patches = random.randint(2, 5)
selected_patches = random.sample(patches, min(num_patches, len(patches)))
for patch_idx, patch in enumerate(selected_patches):
# Random position
max_x = image_size - PATCH_SIZE
max_y = image_size - PATCH_SIZE
x = random.randint(0, max(1, max_x))
y = random.randint(0, max(1, max_y))
# Ensure no overlap with previous patches
overlaps = False
for bx1, by1, bx2, by2 in pixel_bboxes:
if not (x + PATCH_SIZE < bx1 or x > bx2 or
y + PATCH_SIZE < by1 or y > by2):
overlaps = True
break
if overlaps and len(pixel_bboxes) > 0:
continue
# Place patch
image[y:y+PATCH_SIZE, x:x+PATCH_SIZE] = patch
# Record bounding box (center coordinates, normalized)
x_center = (x + PATCH_SIZE / 2) / image_size
y_center = (y + PATCH_SIZE / 2) / image_size
width = PATCH_SIZE / image_size
height = PATCH_SIZE / image_size
label_bboxes.append((class_id, x_center, y_center, width, height))
pixel_bboxes.append((x, y, x + PATCH_SIZE, y + PATCH_SIZE))
return image, label_bboxes
def integrate_synthetic_data(num_synthetic_per_class):
"""Integrate synthetic data into training dataset."""
print("\n" + "="*70)
print("INTEGRATING GAN SYNTHETIC DATA INTO YOLO DATASET")
print("="*70)
# Check directories
if not SYNTHETIC_ROOT.exists():
print(f"✗ Synthetic data not found: {SYNTHETIC_ROOT}")
print(" Run 'python gan_train.py' first")
sys.exit(1)
if not TRAIN_IMAGES.exists():
print(f"✗ Training images directory not found: {TRAIN_IMAGES}")
sys.exit(1)
# Load synthetic patches per class
synthetic_patches = {}
for class_name, class_id in CLASS_NAMES.items():
class_dir = SYNTHETIC_ROOT / class_name
if not class_dir.exists():
print(f"⚠ No synthetic data for {class_name}")
continue
patches = []
patch_files = list(class_dir.glob("*.jpg"))
print(f"\nLoading synthetic patches for {class_name}...")
print(f" Found: {len(patch_files)} patches")
for patch_file in patch_files:
try:
patch = cv2.imread(str(patch_file))
if patch is not None:
patches.append(patch)
except:
pass
if patches:
synthetic_patches[class_name] = patches
print(f" Loaded: {len(patches)} patches")
if not synthetic_patches:
print("✗ No synthetic patches loaded")
sys.exit(1)
# Generate synthetic YOLO images
print(f"\n{'='*70}")
print("GENERATING SYNTHETIC YOLO IMAGES")
print(f"{'='*70}")
# Get current image count for naming
existing_images = list(TRAIN_IMAGES.glob("*.jpg")) + list(TRAIN_IMAGES.glob("*.png"))
next_id = max([int(f.stem) for f in existing_images if f.stem.isdigit()] + [0]) + 1
total_generated = 0
for class_name, class_id in CLASS_NAMES.items():
if class_name not in synthetic_patches:
continue
patches = synthetic_patches[class_name]
print(f"\nGenerating for {class_name}...")
for i in range(num_synthetic_per_class):
# Create synthetic YOLO image
image, bboxes = create_synthetic_yolo_image(patches, class_id)
# Save image
image_id = f"{next_id:06d}"
image_path = TRAIN_IMAGES / f"{image_id}.jpg"
cv2.imwrite(str(image_path), image)
# Save label
label_path = TRAIN_LABELS / f"{image_id}.txt"
with open(label_path, 'w') as f:
for class_id, x_center, y_center, width, height in bboxes:
f.write(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
next_id += 1
total_generated += 1
if (i + 1) % 100 == 0:
print(f" Progress: {i+1}/{num_synthetic_per_class}")
print(f" ✓ Generated {num_synthetic_per_class} images for {class_name}")
print(f"\n{'='*70}")
print("INTEGRATION COMPLETE")
print(f"{'='*70}")
print(f"Total synthetic images generated: {total_generated}")
print(f"New training set size: {len(list(TRAIN_IMAGES.glob('*')))}")
print(f"\nTo train with synthetic data:")
print(f" python train_road_anomaly_model.py")
print(f"{'='*70}\n")
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Integrate GAN synthetic data into YOLO training dataset"
)
parser.add_argument(
"--num-synthetic",
type=int,
default=100,
help="Number of synthetic images per class (default: 100)"
)
args = parser.parse_args()
integrate_synthetic_data(args.num_synthetic)
if __name__ == "__main__":
main()