#!/usr/bin/env python3 """ GAN Training for Road Anomaly Patch Generation. Trains a GAN on extracted patches from gan_dataset/ to generate synthetic training data for improving YOLO model performance through augmentation. Usage: python gan_train.py [--epochs 50] [--batch-size 32] [--class pothole] """ import os import sys import argparse from pathlib import Path import numpy as np import cv2 from collections import defaultdict # Configuration GAN_DATASET_ROOT = Path("/home/pragadeesh/ARM/model/gan_dataset") OUTPUT_DIR = Path("/home/pragadeesh/ARM/model/gan_output") SYNTHETIC_DIR = Path("/home/pragadeesh/ARM/model/dataset/synthetic") PATCH_SIZE = 64 LATENT_DIM = 100 class SimpleGAN: """Simple GAN for generating 64x64 patches using NumPy.""" def __init__(self, latent_dim=LATENT_DIM, patch_size=PATCH_SIZE): """Initialize GAN with generator and discriminator.""" self.latent_dim = latent_dim self.patch_size = patch_size self.channels = 3 # Generator weights self.gen_dense1_w = np.random.randn(latent_dim, 256) * 0.02 self.gen_dense1_b = np.zeros(256) self.gen_dense2_w = np.random.randn(256, 512) * 0.02 self.gen_dense2_b = np.zeros(512) self.gen_dense3_w = np.random.randn(512, patch_size * patch_size * self.channels) * 0.02 self.gen_dense3_b = np.zeros(patch_size * patch_size * self.channels) # Discriminator weights self.dis_dense1_w = np.random.randn(patch_size * patch_size * self.channels, 512) * 0.02 self.dis_dense1_b = np.zeros(512) self.dis_dense2_w = np.random.randn(512, 256) * 0.02 self.dis_dense2_b = np.zeros(256) self.dis_dense3_w = np.random.randn(256, 1) * 0.02 self.dis_dense3_b = np.zeros(1) self.learning_rate = 0.0002 @staticmethod def relu(x): """ReLU activation.""" return np.maximum(0, x) @staticmethod def relu_derivative(x): """ReLU derivative.""" return (x > 0).astype(float) @staticmethod def sigmoid(x): """Sigmoid activation.""" return 1 / (1 + np.exp(-np.clip(x, -500, 500))) @staticmethod def tanh(x): """Tanh activation.""" return np.tanh(x) @staticmethod def tanh_derivative(x): """Tanh derivative.""" return 1 - np.tanh(x) ** 2 def generate(self, batch_size): """Generate synthetic patches.""" z = np.random.randn(batch_size, self.latent_dim) # Generator forward pass h1 = self.relu(np.dot(z, self.gen_dense1_w) + self.gen_dense1_b) h2 = self.relu(np.dot(h1, self.gen_dense2_w) + self.gen_dense2_b) output = self.tanh(np.dot(h2, self.gen_dense3_w) + self.gen_dense3_b) # Reshape to image format images = output.reshape(batch_size, self.patch_size, self.patch_size, self.channels) return images, z, h1, h2 def discriminate(self, images): """Discriminate real vs fake images.""" batch_size = images.shape[0] flat = images.reshape(batch_size, -1) # Discriminator forward pass h1 = self.relu(np.dot(flat, self.dis_dense1_w) + self.dis_dense1_b) h2 = self.relu(np.dot(h1, self.dis_dense2_w) + self.dis_dense2_b) output = self.sigmoid(np.dot(h2, self.dis_dense3_w) + self.dis_dense3_b) return output, h1, h2, flat def train_discriminator(self, real_images, batch_size): """Train discriminator on real and fake images.""" # Generate fake images fake_images, _, _, _ = self.generate(batch_size) # Discriminator predictions real_preds, _, _, real_flat = self.discriminate(real_images) fake_preds, _, _, fake_flat = self.discriminate(fake_images) # Simple loss: Binary cross-entropy real_loss = -np.mean(np.log(real_preds + 1e-8)) fake_loss = -np.mean(np.log(1 - fake_preds + 1e-8)) total_loss = real_loss + fake_loss return total_loss, real_loss, fake_loss def train_generator(self, batch_size): """Train generator to fool discriminator.""" fake_images, _, _, _ = self.generate(batch_size) fake_preds, _, _, _ = self.discriminate(fake_images) # Loss: How well generator fools discriminator gen_loss = -np.mean(np.log(fake_preds + 1e-8)) return gen_loss def load_patches(class_name): """Load all patches for a class.""" class_dir = GAN_DATASET_ROOT / class_name if not class_dir.exists(): print(f"✗ Class directory not found: {class_dir}") return None patches = [] patch_files = sorted(class_dir.glob("*.jpg")) print(f"Loading {len(patch_files)} patches for {class_name}...") for patch_file in patch_files: patch = cv2.imread(str(patch_file)) if patch is not None: # Normalize to [-1, 1] patch = patch.astype(np.float32) / 127.5 - 1.0 patches.append(patch) return np.array(patches) if patches else None def save_sample_images(gan, epoch, class_name): """Save sample generated images.""" output_class_dir = OUTPUT_DIR / class_name / "samples" output_class_dir.mkdir(parents=True, exist_ok=True) # Generate samples fake_images, _, _, _ = gan.generate(4) for i, img in enumerate(fake_images): # Denormalize from [-1, 1] to [0, 255] img_uint8 = ((img + 1.0) * 127.5).astype(np.uint8) output_path = output_class_dir / f"epoch_{epoch:04d}_sample_{i}.jpg" cv2.imwrite(str(output_path), img_uint8) def train_gan(class_name, epochs, batch_size): """Train GAN for a specific class.""" print(f"\n{'='*70}") print(f"Training GAN for: {class_name.upper()}") print(f"{'='*70}") # Load patches patches = load_patches(class_name) if patches is None or len(patches) == 0: print(f"✗ No patches found for {class_name}") return print(f"✓ Loaded {len(patches)} patches") print(f" Shape: {patches.shape}") print(f" Range: [{patches.min():.2f}, {patches.max():.2f}]") # Initialize GAN gan = SimpleGAN(latent_dim=LATENT_DIM, patch_size=PATCH_SIZE) # Training loop print(f"\nTraining for {epochs} epochs...") for epoch in range(epochs): # Shuffle patches indices = np.random.permutation(len(patches)) epoch_d_loss = 0 epoch_g_loss = 0 num_batches = len(patches) // batch_size for batch_idx in range(num_batches): # Get batch batch_indices = indices[batch_idx * batch_size:(batch_idx + 1) * batch_size] real_batch = patches[batch_indices] # Train discriminator d_loss, d_real_loss, d_fake_loss = gan.train_discriminator(real_batch, batch_size) epoch_d_loss += d_loss # Train generator g_loss = gan.train_generator(batch_size) epoch_g_loss += g_loss # Average losses epoch_d_loss /= num_batches epoch_g_loss /= num_batches # Print progress if (epoch + 1) % 5 == 0: print(f"Epoch {epoch + 1}/{epochs} | D Loss: {epoch_d_loss:.4f} | G Loss: {epoch_g_loss:.4f}") # Save samples if (epoch + 1) % 10 == 0: save_sample_images(gan, epoch + 1, class_name) print(f"✓ Training complete for {class_name}") # Generate synthetic data print(f"\nGenerating synthetic patches for {class_name}...") num_synthetic = len(patches) # Generate same number as originals synthetic_dir = SYNTHETIC_DIR / class_name synthetic_dir.mkdir(parents=True, exist_ok=True) # Generate in batches num_batches = (num_synthetic + batch_size - 1) // batch_size saved_count = 0 for batch_idx in range(num_batches): batch_count = min(batch_size, num_synthetic - batch_idx * batch_size) fake_images, _, _, _ = gan.generate(batch_count) for i, img in enumerate(fake_images): # Denormalize img_uint8 = ((img + 1.0) * 127.5).astype(np.uint8) output_path = synthetic_dir / f"synthetic_{saved_count:06d}.jpg" cv2.imwrite(str(output_path), img_uint8) saved_count += 1 print(f"✓ Generated {saved_count} synthetic patches for {class_name}") print(f" Saved to: {synthetic_dir}") def main(): """Main entry point.""" parser = argparse.ArgumentParser( description="Train GAN on road anomaly patches" ) parser.add_argument( "--epochs", type=int, default=50, help="Number of training epochs (default: 50)" ) parser.add_argument( "--batch-size", type=int, default=32, help="Batch size (default: 32)" ) parser.add_argument( "--class", dest="class_name", choices=["pothole", "cracks", "open_manhole", "all"], default="all", help="Which class to train (default: all)" ) args = parser.parse_args() print("\n" + "="*70) print("GAN TRAINING FOR ROAD ANOMALY PATCH GENERATION") print("="*70) print(f"Dataset: {GAN_DATASET_ROOT}") print(f"Output: {OUTPUT_DIR}") print(f"Synthetic: {SYNTHETIC_DIR}") print(f"Epochs: {args.epochs}") print(f"Batch size: {args.batch_size}") # Check dataset exists if not GAN_DATASET_ROOT.exists(): print(f"\n✗ GAN dataset not found: {GAN_DATASET_ROOT}") print(" Run 'python gan.py' first to extract patches") sys.exit(1) # Train GANs classes = ["pothole", "cracks", "open_manhole"] if args.class_name == "all" else [args.class_name] for class_name in classes: train_gan(class_name, args.epochs, args.batch_size) print("\n" + "="*70) print("SYNTHETIC DATA GENERATION COMPLETE") print("="*70) print(f"\nTo augment YOLO dataset:") print(f" 1. Synthetic patches saved to: {SYNTHETIC_DIR}") print(f" 2. Copy to dataset/train/images/ for augmentation") print(f" 3. Run: python train_road_anomaly_model.py") print("="*70 + "\n") if __name__ == "__main__": main()