File size: 8,639 Bytes
ecc16d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#!/usr/bin/env python3
"""

CNN Model Training Script

========================



Standalone script to train the CNN deblurring model with comprehensive options.

"""

import os
import sys
import argparse
import logging
from datetime import datetime

# Add modules to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from modules.cnn_deblurring import CNNDeblurModel, train_new_model, quick_train, full_train

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f'training_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger(__name__)

def main():
    """Main training function with comprehensive options"""
    
    parser = argparse.ArgumentParser(
        description='Train CNN Deblurring Model',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog='''

Examples:

  python train_cnn_model.py --quick              # Quick training (500 samples, 10 epochs)

  python train_cnn_model.py --full               # Full training (2000 samples, 30 epochs)

  python train_cnn_model.py --samples 1500       # Custom samples with default epochs

  python train_cnn_model.py --samples 1000 --epochs 25  # Custom training

  python train_cnn_model.py --test               # Test existing model

        '''
    )
    
    # Training modes
    mode_group = parser.add_mutually_exclusive_group(required=True)
    mode_group.add_argument('--quick', action='store_true', 
                           help='Quick training (500 samples, 10 epochs)')
    mode_group.add_argument('--full', action='store_true',
                           help='Full training (2000 samples, 30 epochs)')
    mode_group.add_argument('--custom', action='store_true',
                           help='Custom training (specify --samples and --epochs)')
    mode_group.add_argument('--test', action='store_true',
                           help='Test existing model performance')
    
    # Training parameters
    parser.add_argument('--samples', type=int, default=1000,
                       help='Number of training samples (default: 1000)')
    parser.add_argument('--epochs', type=int, default=20,
                       help='Number of training epochs (default: 20)')
    parser.add_argument('--batch-size', type=int, default=16,
                       help='Training batch size (default: 16)')
    parser.add_argument('--validation-split', type=float, default=0.2,
                       help='Validation data split (default: 0.2)')
    
    # Model parameters
    parser.add_argument('--image-size', type=int, default=256,
                       help='Input image size (default: 256x256)')
    
    # Data options
    parser.add_argument('--use-existing-dataset', action='store_true', default=True,
                       help='Use existing dataset if available (default: True)')
    parser.add_argument('--force-new-dataset', action='store_true',
                       help='Force creation of new dataset')
    
    args = parser.parse_args()
    
    # Print banner
    print("🎯 CNN Deblurring Model Training")
    print("=" * 40)
    print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print()
    
    # Ensure directories exist
    os.makedirs("models", exist_ok=True)
    os.makedirs("data/training_dataset", exist_ok=True)
    
    try:
        if args.test:
            # Test existing model
            print("πŸ§ͺ Testing Existing Model")
            print("-" * 30)
            
            model = CNNDeblurModel()
            
            if model.load_model(model.model_path):
                print("βœ… Successfully loaded trained model")
                
                # Evaluate model
                print("πŸ“Š Evaluating model performance...")
                metrics = model.evaluate_model()
                
                if metrics:
                    print("\nπŸ“ˆ Model Performance Metrics:")
                    print(f"   Loss: {metrics['loss']:.4f}")
                    print(f"   Mean Absolute Error: {metrics['mae']:.4f}")
                    print(f"   Mean Squared Error: {metrics['mse']:.4f}")
                    
                    # Performance interpretation
                    if metrics['loss'] < 0.01:
                        print("🌟 Excellent performance!")
                    elif metrics['loss'] < 0.05:
                        print("πŸ‘ Good performance")
                    elif metrics['loss'] < 0.1:
                        print("⚠️ Fair performance - consider more training")
                    else:
                        print("πŸ”„ Poor performance - retrain recommended")
                else:
                    print("❌ Failed to evaluate model")
            else:
                print("❌ No trained model found. Train a model first:")
                print("   python train_cnn_model.py --quick")
                return False
                
        elif args.quick:
            # Quick training
            print("πŸš€ Quick Training Mode")
            print("-" * 30)
            print("Configuration:")
            print(f"   Samples: 500")
            print(f"   Epochs: 10")
            print(f"   Expected time: ~10-15 minutes")
            print()
            
            model = quick_train()
            
        elif args.full:
            # Full training
            print("πŸš€ Full Training Mode")
            print("-" * 30)
            print("Configuration:")
            print(f"   Samples: 2000")
            print(f"   Epochs: 30")
            print(f"   Expected time: ~45-60 minutes")
            print()
            
            model = full_train()
            
        elif args.custom:
            # Custom training
            print("πŸš€ Custom Training Mode")
            print("-" * 30)
            print("Configuration:")
            print(f"   Samples: {args.samples}")
            print(f"   Epochs: {args.epochs}")
            print(f"   Batch Size: {args.batch_size}")
            print(f"   Validation Split: {args.validation_split}")
            print(f"   Image Size: {args.image_size}x{args.image_size}")
            print(f"   Use Existing Dataset: {not args.force_new_dataset}")
            
            # Estimate training time
            estimated_minutes = (args.samples * args.epochs) / 1000
            print(f"   Estimated time: ~{estimated_minutes:.1f} minutes")
            print()
            
            # Initialize model with custom parameters
            input_shape = (args.image_size, args.image_size, 3)
            model = CNNDeblurModel(input_shape=input_shape)
            
            # Train with custom parameters
            success = model.train_model(
                epochs=args.epochs,
                batch_size=args.batch_size,
                validation_split=args.validation_split,
                use_existing_dataset=not args.force_new_dataset,
                num_training_samples=args.samples
            )
            
            if success:
                print("βœ… Custom training completed successfully!")
                
                # Evaluate model
                metrics = model.evaluate_model()
                if metrics:
                    print(f"πŸ“Š Final Model Performance:")
                    print(f"   Loss: {metrics['loss']:.4f}")
                    print(f"   MAE: {metrics['mae']:.4f}")
                    print(f"   MSE: {metrics['mse']:.4f}")
            else:
                print("❌ Custom training failed!")
                return False
        
        # Final message
        if not args.test:
            print("\nπŸŽ‰ Training Process Completed!")
            print(f"πŸ“ Model saved to: models/cnn_deblur_model.h5")
            print(f"πŸ“ Dataset saved to: data/training_dataset/")
            print(f"πŸ“ Training log: training_log_*.log")
            print("\nπŸš€ You can now use the trained model in the main application!")
        
        return True
        
    except KeyboardInterrupt:
        print("\n⚠️ Training interrupted by user")
        return False
    except Exception as e:
        logger.error(f"Training failed with error: {e}")
        print(f"\n❌ Training failed: {e}")
        return False

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)