| |
| """ |
| GAN Training for Road Anomaly Patch Generation. |
| |
| Trains a GAN on extracted patches from gan_dataset/ to generate synthetic |
| training data for improving YOLO model performance through augmentation. |
| |
| Usage: |
| python gan_train.py [--epochs 50] [--batch-size 32] [--class pothole] |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| from pathlib import Path |
| import numpy as np |
| import cv2 |
| from collections import defaultdict |
|
|
| |
| GAN_DATASET_ROOT = Path("/home/pragadeesh/ARM/model/gan_dataset") |
| OUTPUT_DIR = Path("/home/pragadeesh/ARM/model/gan_output") |
| SYNTHETIC_DIR = Path("/home/pragadeesh/ARM/model/dataset/synthetic") |
|
|
| PATCH_SIZE = 64 |
| LATENT_DIM = 100 |
|
|
|
|
| class SimpleGAN: |
| """Simple GAN for generating 64x64 patches using NumPy.""" |
| |
| def __init__(self, latent_dim=LATENT_DIM, patch_size=PATCH_SIZE): |
| """Initialize GAN with generator and discriminator.""" |
| self.latent_dim = latent_dim |
| self.patch_size = patch_size |
| self.channels = 3 |
| |
| |
| self.gen_dense1_w = np.random.randn(latent_dim, 256) * 0.02 |
| self.gen_dense1_b = np.zeros(256) |
| self.gen_dense2_w = np.random.randn(256, 512) * 0.02 |
| self.gen_dense2_b = np.zeros(512) |
| self.gen_dense3_w = np.random.randn(512, patch_size * patch_size * self.channels) * 0.02 |
| self.gen_dense3_b = np.zeros(patch_size * patch_size * self.channels) |
| |
| |
| self.dis_dense1_w = np.random.randn(patch_size * patch_size * self.channels, 512) * 0.02 |
| self.dis_dense1_b = np.zeros(512) |
| self.dis_dense2_w = np.random.randn(512, 256) * 0.02 |
| self.dis_dense2_b = np.zeros(256) |
| self.dis_dense3_w = np.random.randn(256, 1) * 0.02 |
| self.dis_dense3_b = np.zeros(1) |
| |
| self.learning_rate = 0.0002 |
| |
| @staticmethod |
| def relu(x): |
| """ReLU activation.""" |
| return np.maximum(0, x) |
| |
| @staticmethod |
| def relu_derivative(x): |
| """ReLU derivative.""" |
| return (x > 0).astype(float) |
| |
| @staticmethod |
| def sigmoid(x): |
| """Sigmoid activation.""" |
| return 1 / (1 + np.exp(-np.clip(x, -500, 500))) |
| |
| @staticmethod |
| def tanh(x): |
| """Tanh activation.""" |
| return np.tanh(x) |
| |
| @staticmethod |
| def tanh_derivative(x): |
| """Tanh derivative.""" |
| return 1 - np.tanh(x) ** 2 |
| |
| def generate(self, batch_size): |
| """Generate synthetic patches.""" |
| z = np.random.randn(batch_size, self.latent_dim) |
| |
| |
| h1 = self.relu(np.dot(z, self.gen_dense1_w) + self.gen_dense1_b) |
| h2 = self.relu(np.dot(h1, self.gen_dense2_w) + self.gen_dense2_b) |
| output = self.tanh(np.dot(h2, self.gen_dense3_w) + self.gen_dense3_b) |
| |
| |
| images = output.reshape(batch_size, self.patch_size, self.patch_size, self.channels) |
| return images, z, h1, h2 |
| |
| def discriminate(self, images): |
| """Discriminate real vs fake images.""" |
| batch_size = images.shape[0] |
| flat = images.reshape(batch_size, -1) |
| |
| |
| h1 = self.relu(np.dot(flat, self.dis_dense1_w) + self.dis_dense1_b) |
| h2 = self.relu(np.dot(h1, self.dis_dense2_w) + self.dis_dense2_b) |
| output = self.sigmoid(np.dot(h2, self.dis_dense3_w) + self.dis_dense3_b) |
| |
| return output, h1, h2, flat |
| |
| def train_discriminator(self, real_images, batch_size): |
| """Train discriminator on real and fake images.""" |
| |
| fake_images, _, _, _ = self.generate(batch_size) |
| |
| |
| real_preds, _, _, real_flat = self.discriminate(real_images) |
| fake_preds, _, _, fake_flat = self.discriminate(fake_images) |
| |
| |
| real_loss = -np.mean(np.log(real_preds + 1e-8)) |
| fake_loss = -np.mean(np.log(1 - fake_preds + 1e-8)) |
| total_loss = real_loss + fake_loss |
| |
| return total_loss, real_loss, fake_loss |
| |
| def train_generator(self, batch_size): |
| """Train generator to fool discriminator.""" |
| fake_images, _, _, _ = self.generate(batch_size) |
| fake_preds, _, _, _ = self.discriminate(fake_images) |
| |
| |
| gen_loss = -np.mean(np.log(fake_preds + 1e-8)) |
| |
| return gen_loss |
|
|
|
|
| def load_patches(class_name): |
| """Load all patches for a class.""" |
| class_dir = GAN_DATASET_ROOT / class_name |
| |
| if not class_dir.exists(): |
| print(f"✗ Class directory not found: {class_dir}") |
| return None |
| |
| patches = [] |
| patch_files = sorted(class_dir.glob("*.jpg")) |
| |
| print(f"Loading {len(patch_files)} patches for {class_name}...") |
| |
| for patch_file in patch_files: |
| patch = cv2.imread(str(patch_file)) |
| if patch is not None: |
| |
| patch = patch.astype(np.float32) / 127.5 - 1.0 |
| patches.append(patch) |
| |
| return np.array(patches) if patches else None |
|
|
|
|
| def save_sample_images(gan, epoch, class_name): |
| """Save sample generated images.""" |
| output_class_dir = OUTPUT_DIR / class_name / "samples" |
| output_class_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| fake_images, _, _, _ = gan.generate(4) |
| |
| for i, img in enumerate(fake_images): |
| |
| img_uint8 = ((img + 1.0) * 127.5).astype(np.uint8) |
| |
| output_path = output_class_dir / f"epoch_{epoch:04d}_sample_{i}.jpg" |
| cv2.imwrite(str(output_path), img_uint8) |
|
|
|
|
| def train_gan(class_name, epochs, batch_size): |
| """Train GAN for a specific class.""" |
| print(f"\n{'='*70}") |
| print(f"Training GAN for: {class_name.upper()}") |
| print(f"{'='*70}") |
| |
| |
| patches = load_patches(class_name) |
| if patches is None or len(patches) == 0: |
| print(f"✗ No patches found for {class_name}") |
| return |
| |
| print(f"✓ Loaded {len(patches)} patches") |
| print(f" Shape: {patches.shape}") |
| print(f" Range: [{patches.min():.2f}, {patches.max():.2f}]") |
| |
| |
| gan = SimpleGAN(latent_dim=LATENT_DIM, patch_size=PATCH_SIZE) |
| |
| |
| print(f"\nTraining for {epochs} epochs...") |
| |
| for epoch in range(epochs): |
| |
| indices = np.random.permutation(len(patches)) |
| epoch_d_loss = 0 |
| epoch_g_loss = 0 |
| num_batches = len(patches) // batch_size |
| |
| for batch_idx in range(num_batches): |
| |
| batch_indices = indices[batch_idx * batch_size:(batch_idx + 1) * batch_size] |
| real_batch = patches[batch_indices] |
| |
| |
| d_loss, d_real_loss, d_fake_loss = gan.train_discriminator(real_batch, batch_size) |
| epoch_d_loss += d_loss |
| |
| |
| g_loss = gan.train_generator(batch_size) |
| epoch_g_loss += g_loss |
| |
| |
| epoch_d_loss /= num_batches |
| epoch_g_loss /= num_batches |
| |
| |
| if (epoch + 1) % 5 == 0: |
| print(f"Epoch {epoch + 1}/{epochs} | D Loss: {epoch_d_loss:.4f} | G Loss: {epoch_g_loss:.4f}") |
| |
| |
| if (epoch + 1) % 10 == 0: |
| save_sample_images(gan, epoch + 1, class_name) |
| |
| print(f"✓ Training complete for {class_name}") |
| |
| |
| print(f"\nGenerating synthetic patches for {class_name}...") |
| num_synthetic = len(patches) |
| |
| synthetic_dir = SYNTHETIC_DIR / class_name |
| synthetic_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| num_batches = (num_synthetic + batch_size - 1) // batch_size |
| saved_count = 0 |
| |
| for batch_idx in range(num_batches): |
| batch_count = min(batch_size, num_synthetic - batch_idx * batch_size) |
| fake_images, _, _, _ = gan.generate(batch_count) |
| |
| for i, img in enumerate(fake_images): |
| |
| img_uint8 = ((img + 1.0) * 127.5).astype(np.uint8) |
| |
| output_path = synthetic_dir / f"synthetic_{saved_count:06d}.jpg" |
| cv2.imwrite(str(output_path), img_uint8) |
| saved_count += 1 |
| |
| print(f"✓ Generated {saved_count} synthetic patches for {class_name}") |
| print(f" Saved to: {synthetic_dir}") |
|
|
|
|
| def main(): |
| """Main entry point.""" |
| parser = argparse.ArgumentParser( |
| description="Train GAN on road anomaly patches" |
| ) |
| parser.add_argument( |
| "--epochs", |
| type=int, |
| default=50, |
| help="Number of training epochs (default: 50)" |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=32, |
| help="Batch size (default: 32)" |
| ) |
| parser.add_argument( |
| "--class", |
| dest="class_name", |
| choices=["pothole", "cracks", "open_manhole", "all"], |
| default="all", |
| help="Which class to train (default: all)" |
| ) |
| |
| args = parser.parse_args() |
| |
| print("\n" + "="*70) |
| print("GAN TRAINING FOR ROAD ANOMALY PATCH GENERATION") |
| print("="*70) |
| print(f"Dataset: {GAN_DATASET_ROOT}") |
| print(f"Output: {OUTPUT_DIR}") |
| print(f"Synthetic: {SYNTHETIC_DIR}") |
| print(f"Epochs: {args.epochs}") |
| print(f"Batch size: {args.batch_size}") |
| |
| |
| if not GAN_DATASET_ROOT.exists(): |
| print(f"\n✗ GAN dataset not found: {GAN_DATASET_ROOT}") |
| print(" Run 'python gan.py' first to extract patches") |
| sys.exit(1) |
| |
| |
| classes = ["pothole", "cracks", "open_manhole"] if args.class_name == "all" else [args.class_name] |
| |
| for class_name in classes: |
| train_gan(class_name, args.epochs, args.batch_size) |
| |
| print("\n" + "="*70) |
| print("SYNTHETIC DATA GENERATION COMPLETE") |
| print("="*70) |
| print(f"\nTo augment YOLO dataset:") |
| print(f" 1. Synthetic patches saved to: {SYNTHETIC_DIR}") |
| print(f" 2. Copy to dataset/train/images/ for augmentation") |
| print(f" 3. Run: python train_road_anomaly_model.py") |
| print("="*70 + "\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|