File size: 10,484 Bytes
5b86813
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
#!/usr/bin/env python3
"""
GAN Training for Road Anomaly Patch Generation.

Trains a GAN on extracted patches from gan_dataset/ to generate synthetic
training data for improving YOLO model performance through augmentation.

Usage:
    python gan_train.py [--epochs 50] [--batch-size 32] [--class pothole]
"""

import os
import sys
import argparse
from pathlib import Path
import numpy as np
import cv2
from collections import defaultdict

# Configuration
GAN_DATASET_ROOT = Path("/home/pragadeesh/ARM/model/gan_dataset")
OUTPUT_DIR = Path("/home/pragadeesh/ARM/model/gan_output")
SYNTHETIC_DIR = Path("/home/pragadeesh/ARM/model/dataset/synthetic")

PATCH_SIZE = 64
LATENT_DIM = 100


class SimpleGAN:
    """Simple GAN for generating 64x64 patches using NumPy."""
    
    def __init__(self, latent_dim=LATENT_DIM, patch_size=PATCH_SIZE):
        """Initialize GAN with generator and discriminator."""
        self.latent_dim = latent_dim
        self.patch_size = patch_size
        self.channels = 3
        
        # Generator weights
        self.gen_dense1_w = np.random.randn(latent_dim, 256) * 0.02
        self.gen_dense1_b = np.zeros(256)
        self.gen_dense2_w = np.random.randn(256, 512) * 0.02
        self.gen_dense2_b = np.zeros(512)
        self.gen_dense3_w = np.random.randn(512, patch_size * patch_size * self.channels) * 0.02
        self.gen_dense3_b = np.zeros(patch_size * patch_size * self.channels)
        
        # Discriminator weights
        self.dis_dense1_w = np.random.randn(patch_size * patch_size * self.channels, 512) * 0.02
        self.dis_dense1_b = np.zeros(512)
        self.dis_dense2_w = np.random.randn(512, 256) * 0.02
        self.dis_dense2_b = np.zeros(256)
        self.dis_dense3_w = np.random.randn(256, 1) * 0.02
        self.dis_dense3_b = np.zeros(1)
        
        self.learning_rate = 0.0002
    
    @staticmethod
    def relu(x):
        """ReLU activation."""
        return np.maximum(0, x)
    
    @staticmethod
    def relu_derivative(x):
        """ReLU derivative."""
        return (x > 0).astype(float)
    
    @staticmethod
    def sigmoid(x):
        """Sigmoid activation."""
        return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
    
    @staticmethod
    def tanh(x):
        """Tanh activation."""
        return np.tanh(x)
    
    @staticmethod
    def tanh_derivative(x):
        """Tanh derivative."""
        return 1 - np.tanh(x) ** 2
    
    def generate(self, batch_size):
        """Generate synthetic patches."""
        z = np.random.randn(batch_size, self.latent_dim)
        
        # Generator forward pass
        h1 = self.relu(np.dot(z, self.gen_dense1_w) + self.gen_dense1_b)
        h2 = self.relu(np.dot(h1, self.gen_dense2_w) + self.gen_dense2_b)
        output = self.tanh(np.dot(h2, self.gen_dense3_w) + self.gen_dense3_b)
        
        # Reshape to image format
        images = output.reshape(batch_size, self.patch_size, self.patch_size, self.channels)
        return images, z, h1, h2
    
    def discriminate(self, images):
        """Discriminate real vs fake images."""
        batch_size = images.shape[0]
        flat = images.reshape(batch_size, -1)
        
        # Discriminator forward pass
        h1 = self.relu(np.dot(flat, self.dis_dense1_w) + self.dis_dense1_b)
        h2 = self.relu(np.dot(h1, self.dis_dense2_w) + self.dis_dense2_b)
        output = self.sigmoid(np.dot(h2, self.dis_dense3_w) + self.dis_dense3_b)
        
        return output, h1, h2, flat
    
    def train_discriminator(self, real_images, batch_size):
        """Train discriminator on real and fake images."""
        # Generate fake images
        fake_images, _, _, _ = self.generate(batch_size)
        
        # Discriminator predictions
        real_preds, _, _, real_flat = self.discriminate(real_images)
        fake_preds, _, _, fake_flat = self.discriminate(fake_images)
        
        # Simple loss: Binary cross-entropy
        real_loss = -np.mean(np.log(real_preds + 1e-8))
        fake_loss = -np.mean(np.log(1 - fake_preds + 1e-8))
        total_loss = real_loss + fake_loss
        
        return total_loss, real_loss, fake_loss
    
    def train_generator(self, batch_size):
        """Train generator to fool discriminator."""
        fake_images, _, _, _ = self.generate(batch_size)
        fake_preds, _, _, _ = self.discriminate(fake_images)
        
        # Loss: How well generator fools discriminator
        gen_loss = -np.mean(np.log(fake_preds + 1e-8))
        
        return gen_loss


def load_patches(class_name):
    """Load all patches for a class."""
    class_dir = GAN_DATASET_ROOT / class_name
    
    if not class_dir.exists():
        print(f"✗ Class directory not found: {class_dir}")
        return None
    
    patches = []
    patch_files = sorted(class_dir.glob("*.jpg"))
    
    print(f"Loading {len(patch_files)} patches for {class_name}...")
    
    for patch_file in patch_files:
        patch = cv2.imread(str(patch_file))
        if patch is not None:
            # Normalize to [-1, 1]
            patch = patch.astype(np.float32) / 127.5 - 1.0
            patches.append(patch)
    
    return np.array(patches) if patches else None


def save_sample_images(gan, epoch, class_name):
    """Save sample generated images."""
    output_class_dir = OUTPUT_DIR / class_name / "samples"
    output_class_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate samples
    fake_images, _, _, _ = gan.generate(4)
    
    for i, img in enumerate(fake_images):
        # Denormalize from [-1, 1] to [0, 255]
        img_uint8 = ((img + 1.0) * 127.5).astype(np.uint8)
        
        output_path = output_class_dir / f"epoch_{epoch:04d}_sample_{i}.jpg"
        cv2.imwrite(str(output_path), img_uint8)


def train_gan(class_name, epochs, batch_size):
    """Train GAN for a specific class."""
    print(f"\n{'='*70}")
    print(f"Training GAN for: {class_name.upper()}")
    print(f"{'='*70}")
    
    # Load patches
    patches = load_patches(class_name)
    if patches is None or len(patches) == 0:
        print(f"✗ No patches found for {class_name}")
        return
    
    print(f"✓ Loaded {len(patches)} patches")
    print(f"  Shape: {patches.shape}")
    print(f"  Range: [{patches.min():.2f}, {patches.max():.2f}]")
    
    # Initialize GAN
    gan = SimpleGAN(latent_dim=LATENT_DIM, patch_size=PATCH_SIZE)
    
    # Training loop
    print(f"\nTraining for {epochs} epochs...")
    
    for epoch in range(epochs):
        # Shuffle patches
        indices = np.random.permutation(len(patches))
        epoch_d_loss = 0
        epoch_g_loss = 0
        num_batches = len(patches) // batch_size
        
        for batch_idx in range(num_batches):
            # Get batch
            batch_indices = indices[batch_idx * batch_size:(batch_idx + 1) * batch_size]
            real_batch = patches[batch_indices]
            
            # Train discriminator
            d_loss, d_real_loss, d_fake_loss = gan.train_discriminator(real_batch, batch_size)
            epoch_d_loss += d_loss
            
            # Train generator
            g_loss = gan.train_generator(batch_size)
            epoch_g_loss += g_loss
        
        # Average losses
        epoch_d_loss /= num_batches
        epoch_g_loss /= num_batches
        
        # Print progress
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch + 1}/{epochs} | D Loss: {epoch_d_loss:.4f} | G Loss: {epoch_g_loss:.4f}")
        
        # Save samples
        if (epoch + 1) % 10 == 0:
            save_sample_images(gan, epoch + 1, class_name)
    
    print(f"✓ Training complete for {class_name}")
    
    # Generate synthetic data
    print(f"\nGenerating synthetic patches for {class_name}...")
    num_synthetic = len(patches)  # Generate same number as originals
    
    synthetic_dir = SYNTHETIC_DIR / class_name
    synthetic_dir.mkdir(parents=True, exist_ok=True)
    
    # Generate in batches
    num_batches = (num_synthetic + batch_size - 1) // batch_size
    saved_count = 0
    
    for batch_idx in range(num_batches):
        batch_count = min(batch_size, num_synthetic - batch_idx * batch_size)
        fake_images, _, _, _ = gan.generate(batch_count)
        
        for i, img in enumerate(fake_images):
            # Denormalize
            img_uint8 = ((img + 1.0) * 127.5).astype(np.uint8)
            
            output_path = synthetic_dir / f"synthetic_{saved_count:06d}.jpg"
            cv2.imwrite(str(output_path), img_uint8)
            saved_count += 1
    
    print(f"✓ Generated {saved_count} synthetic patches for {class_name}")
    print(f"  Saved to: {synthetic_dir}")


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description="Train GAN on road anomaly patches"
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=50,
        help="Number of training epochs (default: 50)"
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size (default: 32)"
    )
    parser.add_argument(
        "--class",
        dest="class_name",
        choices=["pothole", "cracks", "open_manhole", "all"],
        default="all",
        help="Which class to train (default: all)"
    )
    
    args = parser.parse_args()
    
    print("\n" + "="*70)
    print("GAN TRAINING FOR ROAD ANOMALY PATCH GENERATION")
    print("="*70)
    print(f"Dataset: {GAN_DATASET_ROOT}")
    print(f"Output: {OUTPUT_DIR}")
    print(f"Synthetic: {SYNTHETIC_DIR}")
    print(f"Epochs: {args.epochs}")
    print(f"Batch size: {args.batch_size}")
    
    # Check dataset exists
    if not GAN_DATASET_ROOT.exists():
        print(f"\n✗ GAN dataset not found: {GAN_DATASET_ROOT}")
        print("  Run 'python gan.py' first to extract patches")
        sys.exit(1)
    
    # Train GANs
    classes = ["pothole", "cracks", "open_manhole"] if args.class_name == "all" else [args.class_name]
    
    for class_name in classes:
        train_gan(class_name, args.epochs, args.batch_size)
    
    print("\n" + "="*70)
    print("SYNTHETIC DATA GENERATION COMPLETE")
    print("="*70)
    print(f"\nTo augment YOLO dataset:")
    print(f"  1. Synthetic patches saved to: {SYNTHETIC_DIR}")
    print(f"  2. Copy to dataset/train/images/ for augmentation")
    print(f"  3. Run: python train_road_anomaly_model.py")
    print("="*70 + "\n")


if __name__ == "__main__":
    main()