arm-model / model /gan_train.py
pragadeeshv23's picture
Upload folder using huggingface_hub
5b86813 verified
#!/usr/bin/env python3
"""
GAN Training for Road Anomaly Patch Generation.
Trains a GAN on extracted patches from gan_dataset/ to generate synthetic
training data for improving YOLO model performance through augmentation.
Usage:
python gan_train.py [--epochs 50] [--batch-size 32] [--class pothole]
"""
import os
import sys
import argparse
from pathlib import Path
import numpy as np
import cv2
from collections import defaultdict
# Configuration
GAN_DATASET_ROOT = Path("/home/pragadeesh/ARM/model/gan_dataset")
OUTPUT_DIR = Path("/home/pragadeesh/ARM/model/gan_output")
SYNTHETIC_DIR = Path("/home/pragadeesh/ARM/model/dataset/synthetic")
PATCH_SIZE = 64
LATENT_DIM = 100
class SimpleGAN:
"""Simple GAN for generating 64x64 patches using NumPy."""
def __init__(self, latent_dim=LATENT_DIM, patch_size=PATCH_SIZE):
"""Initialize GAN with generator and discriminator."""
self.latent_dim = latent_dim
self.patch_size = patch_size
self.channels = 3
# Generator weights
self.gen_dense1_w = np.random.randn(latent_dim, 256) * 0.02
self.gen_dense1_b = np.zeros(256)
self.gen_dense2_w = np.random.randn(256, 512) * 0.02
self.gen_dense2_b = np.zeros(512)
self.gen_dense3_w = np.random.randn(512, patch_size * patch_size * self.channels) * 0.02
self.gen_dense3_b = np.zeros(patch_size * patch_size * self.channels)
# Discriminator weights
self.dis_dense1_w = np.random.randn(patch_size * patch_size * self.channels, 512) * 0.02
self.dis_dense1_b = np.zeros(512)
self.dis_dense2_w = np.random.randn(512, 256) * 0.02
self.dis_dense2_b = np.zeros(256)
self.dis_dense3_w = np.random.randn(256, 1) * 0.02
self.dis_dense3_b = np.zeros(1)
self.learning_rate = 0.0002
@staticmethod
def relu(x):
"""ReLU activation."""
return np.maximum(0, x)
@staticmethod
def relu_derivative(x):
"""ReLU derivative."""
return (x > 0).astype(float)
@staticmethod
def sigmoid(x):
"""Sigmoid activation."""
return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
@staticmethod
def tanh(x):
"""Tanh activation."""
return np.tanh(x)
@staticmethod
def tanh_derivative(x):
"""Tanh derivative."""
return 1 - np.tanh(x) ** 2
def generate(self, batch_size):
"""Generate synthetic patches."""
z = np.random.randn(batch_size, self.latent_dim)
# Generator forward pass
h1 = self.relu(np.dot(z, self.gen_dense1_w) + self.gen_dense1_b)
h2 = self.relu(np.dot(h1, self.gen_dense2_w) + self.gen_dense2_b)
output = self.tanh(np.dot(h2, self.gen_dense3_w) + self.gen_dense3_b)
# Reshape to image format
images = output.reshape(batch_size, self.patch_size, self.patch_size, self.channels)
return images, z, h1, h2
def discriminate(self, images):
"""Discriminate real vs fake images."""
batch_size = images.shape[0]
flat = images.reshape(batch_size, -1)
# Discriminator forward pass
h1 = self.relu(np.dot(flat, self.dis_dense1_w) + self.dis_dense1_b)
h2 = self.relu(np.dot(h1, self.dis_dense2_w) + self.dis_dense2_b)
output = self.sigmoid(np.dot(h2, self.dis_dense3_w) + self.dis_dense3_b)
return output, h1, h2, flat
def train_discriminator(self, real_images, batch_size):
"""Train discriminator on real and fake images."""
# Generate fake images
fake_images, _, _, _ = self.generate(batch_size)
# Discriminator predictions
real_preds, _, _, real_flat = self.discriminate(real_images)
fake_preds, _, _, fake_flat = self.discriminate(fake_images)
# Simple loss: Binary cross-entropy
real_loss = -np.mean(np.log(real_preds + 1e-8))
fake_loss = -np.mean(np.log(1 - fake_preds + 1e-8))
total_loss = real_loss + fake_loss
return total_loss, real_loss, fake_loss
def train_generator(self, batch_size):
"""Train generator to fool discriminator."""
fake_images, _, _, _ = self.generate(batch_size)
fake_preds, _, _, _ = self.discriminate(fake_images)
# Loss: How well generator fools discriminator
gen_loss = -np.mean(np.log(fake_preds + 1e-8))
return gen_loss
def load_patches(class_name):
"""Load all patches for a class."""
class_dir = GAN_DATASET_ROOT / class_name
if not class_dir.exists():
print(f"✗ Class directory not found: {class_dir}")
return None
patches = []
patch_files = sorted(class_dir.glob("*.jpg"))
print(f"Loading {len(patch_files)} patches for {class_name}...")
for patch_file in patch_files:
patch = cv2.imread(str(patch_file))
if patch is not None:
# Normalize to [-1, 1]
patch = patch.astype(np.float32) / 127.5 - 1.0
patches.append(patch)
return np.array(patches) if patches else None
def save_sample_images(gan, epoch, class_name):
"""Save sample generated images."""
output_class_dir = OUTPUT_DIR / class_name / "samples"
output_class_dir.mkdir(parents=True, exist_ok=True)
# Generate samples
fake_images, _, _, _ = gan.generate(4)
for i, img in enumerate(fake_images):
# Denormalize from [-1, 1] to [0, 255]
img_uint8 = ((img + 1.0) * 127.5).astype(np.uint8)
output_path = output_class_dir / f"epoch_{epoch:04d}_sample_{i}.jpg"
cv2.imwrite(str(output_path), img_uint8)
def train_gan(class_name, epochs, batch_size):
"""Train GAN for a specific class."""
print(f"\n{'='*70}")
print(f"Training GAN for: {class_name.upper()}")
print(f"{'='*70}")
# Load patches
patches = load_patches(class_name)
if patches is None or len(patches) == 0:
print(f"✗ No patches found for {class_name}")
return
print(f"✓ Loaded {len(patches)} patches")
print(f" Shape: {patches.shape}")
print(f" Range: [{patches.min():.2f}, {patches.max():.2f}]")
# Initialize GAN
gan = SimpleGAN(latent_dim=LATENT_DIM, patch_size=PATCH_SIZE)
# Training loop
print(f"\nTraining for {epochs} epochs...")
for epoch in range(epochs):
# Shuffle patches
indices = np.random.permutation(len(patches))
epoch_d_loss = 0
epoch_g_loss = 0
num_batches = len(patches) // batch_size
for batch_idx in range(num_batches):
# Get batch
batch_indices = indices[batch_idx * batch_size:(batch_idx + 1) * batch_size]
real_batch = patches[batch_indices]
# Train discriminator
d_loss, d_real_loss, d_fake_loss = gan.train_discriminator(real_batch, batch_size)
epoch_d_loss += d_loss
# Train generator
g_loss = gan.train_generator(batch_size)
epoch_g_loss += g_loss
# Average losses
epoch_d_loss /= num_batches
epoch_g_loss /= num_batches
# Print progress
if (epoch + 1) % 5 == 0:
print(f"Epoch {epoch + 1}/{epochs} | D Loss: {epoch_d_loss:.4f} | G Loss: {epoch_g_loss:.4f}")
# Save samples
if (epoch + 1) % 10 == 0:
save_sample_images(gan, epoch + 1, class_name)
print(f"✓ Training complete for {class_name}")
# Generate synthetic data
print(f"\nGenerating synthetic patches for {class_name}...")
num_synthetic = len(patches) # Generate same number as originals
synthetic_dir = SYNTHETIC_DIR / class_name
synthetic_dir.mkdir(parents=True, exist_ok=True)
# Generate in batches
num_batches = (num_synthetic + batch_size - 1) // batch_size
saved_count = 0
for batch_idx in range(num_batches):
batch_count = min(batch_size, num_synthetic - batch_idx * batch_size)
fake_images, _, _, _ = gan.generate(batch_count)
for i, img in enumerate(fake_images):
# Denormalize
img_uint8 = ((img + 1.0) * 127.5).astype(np.uint8)
output_path = synthetic_dir / f"synthetic_{saved_count:06d}.jpg"
cv2.imwrite(str(output_path), img_uint8)
saved_count += 1
print(f"✓ Generated {saved_count} synthetic patches for {class_name}")
print(f" Saved to: {synthetic_dir}")
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Train GAN on road anomaly patches"
)
parser.add_argument(
"--epochs",
type=int,
default=50,
help="Number of training epochs (default: 50)"
)
parser.add_argument(
"--batch-size",
type=int,
default=32,
help="Batch size (default: 32)"
)
parser.add_argument(
"--class",
dest="class_name",
choices=["pothole", "cracks", "open_manhole", "all"],
default="all",
help="Which class to train (default: all)"
)
args = parser.parse_args()
print("\n" + "="*70)
print("GAN TRAINING FOR ROAD ANOMALY PATCH GENERATION")
print("="*70)
print(f"Dataset: {GAN_DATASET_ROOT}")
print(f"Output: {OUTPUT_DIR}")
print(f"Synthetic: {SYNTHETIC_DIR}")
print(f"Epochs: {args.epochs}")
print(f"Batch size: {args.batch_size}")
# Check dataset exists
if not GAN_DATASET_ROOT.exists():
print(f"\n✗ GAN dataset not found: {GAN_DATASET_ROOT}")
print(" Run 'python gan.py' first to extract patches")
sys.exit(1)
# Train GANs
classes = ["pothole", "cracks", "open_manhole"] if args.class_name == "all" else [args.class_name]
for class_name in classes:
train_gan(class_name, args.epochs, args.batch_size)
print("\n" + "="*70)
print("SYNTHETIC DATA GENERATION COMPLETE")
print("="*70)
print(f"\nTo augment YOLO dataset:")
print(f" 1. Synthetic patches saved to: {SYNTHETIC_DIR}")
print(f" 2. Copy to dataset/train/images/ for augmentation")
print(f" 3. Run: python train_road_anomaly_model.py")
print("="*70 + "\n")
if __name__ == "__main__":
main()