Spaces:
Sleeping
Sleeping
| """ | |
| src/train_model.py | |
| Unified training script for DeepFake detection models. | |
| Supports training: | |
| - Xception model (for images) | |
| - Hybrid model (Xception + EfficientNetB4 + ResNet50 for images) | |
| - Video sequence model (CNN-RNN for videos) | |
| Usage: | |
| # Train Xception model | |
| python src/train_model.py --model xception | |
| # Train Hybrid model | |
| python src/train_model.py --model hybrid | |
| # Train Video model | |
| python src/train_model.py --model video | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| # Ensure local src/ package is importable even if a global 'src' package exists | |
| ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| SRC_DIR = os.path.join(ROOT_DIR, "src") | |
| if SRC_DIR not in sys.path: | |
| sys.path.insert(0, SRC_DIR) | |
| import model as model_module | |
| def train_xception_model(): | |
| """Train Xception-based image model.""" | |
| print("=" * 70) | |
| print("DeepFake Detection Model Training") | |
| print("Using Xception Transfer Learning for 90+ Accuracy") | |
| print("=" * 70) | |
| # Configuration | |
| data_folder = "data/image_data" | |
| sample_size = 16000 # 8000 per class | |
| batch_size = 32 | |
| initial_epochs = 5 # Epochs with frozen base | |
| fine_tune_epochs = 10 # Epochs for fine-tuning | |
| checkpoint_path = "xception_deepfake_model.h5" | |
| # Step 1: Load dataset | |
| print("\n[Step 1] Loading dataset from folder...") | |
| try: | |
| (X_train, y_train), (X_val, y_val), (X_test, y_test) = model_module.load_dataset_from_folder( | |
| data_folder=data_folder, | |
| sample_size=sample_size, | |
| random_state=42 | |
| ) | |
| print(f"β Dataset loaded successfully!") | |
| print(f" Training: {X_train.shape[0]} samples") | |
| print(f" Validation: {X_val.shape[0]} samples") | |
| print(f" Test: {X_test.shape[0]} samples") | |
| except Exception as e: | |
| print(f"β Error loading dataset: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 2: Build Xception model | |
| print("\n[Step 2] Building Xception model with transfer learning...") | |
| try: | |
| model, base_model = model_module.build_xception_model( | |
| input_shape=(224, 224, 3), | |
| num_classes=1, | |
| use_binary=True | |
| ) | |
| print(f"β Model built successfully!") | |
| print(f" Model name: {model.name}") | |
| print(f" Total parameters: {model.count_params():,}") | |
| model.summary() | |
| except Exception as e: | |
| print(f"β Error building model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 3: Train model | |
| print("\n[Step 3] Training model...") | |
| try: | |
| history1, history2, trained_model = model_module.train_model_with_dataset( | |
| model=model, | |
| X_train=X_train, | |
| y_train=y_train, | |
| X_val=X_val, | |
| y_val=y_val, | |
| epochs=initial_epochs, | |
| batch_size=batch_size, | |
| use_callbacks=True, | |
| checkpoint_path=checkpoint_path, | |
| fine_tune_epochs=fine_tune_epochs, | |
| unfreeze_from_layer=56, | |
| base_model=base_model | |
| ) | |
| print(f"β Training completed!") | |
| # Update model reference | |
| model = trained_model | |
| except Exception as e: | |
| print(f"β Error during training: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 4: Evaluate on test set | |
| print("\n[Step 4] Evaluating on test set...") | |
| try: | |
| test_dataset = model_module.prepare_tf_dataset( | |
| X_test, y_test, batch_size=batch_size, shuffle=False, use_xception_preprocess=True | |
| ) | |
| test_results = model.evaluate(test_dataset, verbose=1) | |
| print(f"β Test evaluation completed!") | |
| print(f" Test Loss: {test_results[0]:.4f}") | |
| print(f" Test Accuracy: {test_results[1]:.4f} ({test_results[1]*100:.2f}%)") | |
| if test_results[1] >= 0.90: | |
| print(f"\nπ SUCCESS! Model achieved 90+ accuracy: {test_results[1]*100:.2f}%") | |
| else: | |
| print(f"\nβ Model accuracy is {test_results[1]*100:.2f}%. Consider:") | |
| print(f" - Training for more epochs") | |
| print(f" - Increasing sample_size") | |
| print(f" - Adjusting learning rates") | |
| except Exception as e: | |
| print(f"β Error during evaluation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 5: Save final model | |
| print(f"\n[Step 5] Saving final model to {checkpoint_path}...") | |
| try: | |
| model.save(checkpoint_path) | |
| print(f"β Model saved successfully!") | |
| print(f"\nπ Model saved at: {os.path.abspath(checkpoint_path)}") | |
| print(f"\nTo use this model in deployment, set MODEL_PATH environment variable:") | |
| print(f" export MODEL_PATH={os.path.abspath(checkpoint_path)}") | |
| print(f" Or update app/main.py to use: MODEL_PATH = '{checkpoint_path}'") | |
| except Exception as e: | |
| print(f"β Error saving model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| print("\n" + "=" * 70) | |
| print("Training pipeline completed successfully!") | |
| print("=" * 70) | |
| def train_hybrid_model(): | |
| """Train Hybrid model (Xception + EfficientNetB4 + ResNet50).""" | |
| print("=" * 70) | |
| print("DeepFake Detection - HYBRID MODEL Training") | |
| print("Combining Xception + EfficientNetB4 + ResNet50") | |
| print("Target: 99%+ Accuracy") | |
| print("=" * 70) | |
| # Configuration | |
| data_folder = "data/image_data" | |
| sample_size = 8000 # Reduced to avoid memory issues (4000 per class) - can increase if more RAM available | |
| batch_size = 8 # Smaller batch size to reduce memory usage | |
| initial_epochs = 8 # More epochs with frozen base | |
| fine_tune_epochs = 15 # More epochs for fine-tuning | |
| checkpoint_path = "hybrid_deepfake_model.h5" | |
| # Step 1: Load dataset | |
| print("\n[Step 1] Loading dataset from folder...") | |
| try: | |
| (X_train, y_train), (X_val, y_val), (X_test, y_test) = model_module.load_dataset_from_folder( | |
| data_folder=data_folder, | |
| sample_size=sample_size, | |
| random_state=42 | |
| ) | |
| print(f"β Dataset loaded successfully!") | |
| print(f" Training: {X_train.shape[0]} samples") | |
| print(f" Validation: {X_val.shape[0]} samples") | |
| print(f" Test: {X_test.shape[0]} samples") | |
| except Exception as e: | |
| print(f"β Error loading dataset: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 2: Check for existing checkpoint and load or build model | |
| print("\n[Step 2] Checking for existing checkpoint...") | |
| checkpoint_exists = os.path.exists(checkpoint_path) | |
| if checkpoint_exists: | |
| print(f"β Found existing checkpoint: {checkpoint_path}") | |
| print(" Loading model from checkpoint to resume training...") | |
| try: | |
| model = model_module.load_model_from_checkpoint(checkpoint_path) | |
| print(f"β Model loaded successfully from checkpoint!") | |
| print(f" Model name: {model.name}") | |
| print(f" Total parameters: {model.count_params():,}") | |
| # Rebuild base_models_dict for fine-tuning (needed for unfreeze) | |
| print(" Rebuilding base models for fine-tuning...") | |
| _, base_models_dict = model_module.build_hybrid_model( | |
| input_shape=(224, 224, 3), | |
| num_classes=1, | |
| use_binary=True | |
| ) | |
| print("β Ready to resume training from checkpoint!") | |
| resume_training = True | |
| except Exception as e: | |
| print(f"β Warning: Could not load checkpoint ({e})") | |
| print(" Building new model instead...") | |
| resume_training = False | |
| model, base_models_dict = model_module.build_hybrid_model( | |
| input_shape=(224, 224, 3), | |
| num_classes=1, | |
| use_binary=True | |
| ) | |
| print(f"β Hybrid model built successfully!") | |
| else: | |
| print(" No checkpoint found. Building new model...") | |
| resume_training = False | |
| try: | |
| model, base_models_dict = model_module.build_hybrid_model( | |
| input_shape=(224, 224, 3), | |
| num_classes=1, | |
| use_binary=True | |
| ) | |
| print(f"β Hybrid model built successfully!") | |
| print(f" Model name: {model.name}") | |
| print(f" Total parameters: {model.count_params():,}") | |
| print(f" Base models: {list(base_models_dict.keys())}") | |
| print("\nModel Architecture Summary:") | |
| model.summary() | |
| except Exception as e: | |
| print(f"β Error building hybrid model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 3: Train model | |
| print("\n[Step 3] Training hybrid model...") | |
| if resume_training: | |
| print(" β Resuming training from checkpoint!") | |
| print(" Note: Will skip Phase 1 and proceed directly to Phase 2 (Fine-tuning)") | |
| print(" If you want to restart from Phase 1, delete the checkpoint file first.") | |
| try: | |
| history1, history2, trained_model = model_module.train_model_with_dataset( | |
| model=model, | |
| X_train=X_train, | |
| y_train=y_train, | |
| X_val=X_val, | |
| y_val=y_val, | |
| epochs=initial_epochs if not resume_training else 0, # Skip Phase 1 if resuming | |
| batch_size=batch_size, | |
| use_callbacks=True, | |
| checkpoint_path=checkpoint_path, | |
| fine_tune_epochs=fine_tune_epochs, | |
| unfreeze_from_layer=100, | |
| base_models_dict=base_models_dict, | |
| resume_from_checkpoint=resume_training | |
| ) | |
| print(f"β Training completed!") | |
| # Update model reference | |
| model = trained_model | |
| # Print training history | |
| if history1: | |
| print(f"\nPhase 1 (Frozen Base) - Final Accuracy: {max(history1.history.get('val_accuracy', history1.history.get('accuracy', [0]))):.4f}") | |
| if history2: | |
| print(f"Phase 2 (Fine-tuning) - Final Accuracy: {max(history2.history.get('val_accuracy', history2.history.get('accuracy', [0]))):.4f}") | |
| except Exception as e: | |
| print(f"β Error during training: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 4: Evaluate on test set | |
| print("\n[Step 4] Evaluating hybrid model on test set...") | |
| try: | |
| test_dataset = model_module.prepare_tf_dataset( | |
| X_test, y_test, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| use_hybrid=True | |
| ) | |
| test_results = model.evaluate(test_dataset, verbose=1) | |
| print(f"β Test evaluation completed!") | |
| print(f" Test Loss: {test_results[0]:.4f}") | |
| print(f" Test Accuracy: {test_results[1]:.4f} ({test_results[1]*100:.2f}%)") | |
| if len(test_results) > 2: | |
| print(f" Test Precision: {test_results[2]:.4f}") | |
| print(f" Test Recall: {test_results[3]:.4f}") | |
| if test_results[1] >= 0.99: | |
| print(f"\nπππ EXCELLENT! Hybrid model achieved 99%+ accuracy: {test_results[1]*100:.2f}% πππ") | |
| elif test_results[1] >= 0.95: | |
| print(f"\nπ GREAT! Hybrid model achieved 95%+ accuracy: {test_results[1]*100:.2f}%") | |
| print(f" Consider training for more epochs or increasing sample_size for 99%+") | |
| elif test_results[1] >= 0.90: | |
| print(f"\nβ Good! Model accuracy is {test_results[1]*100:.2f}%") | |
| print(f" To reach 99%+, try:") | |
| print(f" - Training for more epochs") | |
| print(f" - Increasing sample_size (currently {sample_size})") | |
| print(f" - Adjusting learning rates") | |
| else: | |
| print(f"\nβ Model accuracy is {test_results[1]*100:.2f}%") | |
| print(f" Consider:") | |
| print(f" - Training for more epochs") | |
| print(f" - Increasing sample_size") | |
| print(f" - Checking data quality") | |
| except Exception as e: | |
| print(f"β Error during evaluation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 5: Save final model | |
| print(f"\n[Step 5] Saving hybrid model to {checkpoint_path}...") | |
| try: | |
| model.save(checkpoint_path) | |
| print(f"β Hybrid model saved successfully!") | |
| print(f"\nπ Model saved at: {os.path.abspath(checkpoint_path)}") | |
| print(f"\nTo use this model in deployment, set MODEL_PATH environment variable:") | |
| print(f" export MODEL_PATH={os.path.abspath(checkpoint_path)}") | |
| print(f" Or update app/main.py to use: MODEL_PATH = '{checkpoint_path}'") | |
| print(f"\nπ‘ The hybrid model will automatically be detected and used correctly!") | |
| except Exception as e: | |
| print(f"β Error saving model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| print("\n" + "=" * 70) | |
| print("Hybrid model training pipeline completed successfully!") | |
| print("=" * 70) | |
| print("\nπ Model Performance Summary:") | |
| print(f" - Architecture: Xception + EfficientNetB4 + ResNet50") | |
| print(f" - Parameters: {model.count_params():,}") | |
| print(f" - Test Accuracy: {test_results[1]*100:.2f}%") | |
| print(f" - Model saved: {checkpoint_path}") | |
| print("\nπ Your hybrid model is ready for deployment!") | |
| def train_video_model(): | |
| """Train CNN-RNN video sequence model.""" | |
| print("=" * 70) | |
| print("DeepFake Video Detection Model Training") | |
| print("Using CNN-RNN Architecture for Video Classification") | |
| print("=" * 70) | |
| # Configuration | |
| data_folder = "data/videos_data/train_sample_videos" | |
| metadata_file = "metadata.json" | |
| sample_size = 200 # Number of videos to use (None = all) | |
| batch_size = 8 | |
| epochs = 10 | |
| max_seq_length = 20 | |
| num_features = 2048 # InceptionV3 output features | |
| img_size = 224 | |
| checkpoint_path = "video_deepfake_model.h5" | |
| # Step 1: Load video dataset | |
| print("\n[Step 1] Loading video dataset from folder...") | |
| try: | |
| (train_paths, y_train), (val_paths, y_val), (test_paths, y_test) = model_module.load_video_dataset_from_folder( | |
| data_folder=data_folder, | |
| metadata_file=metadata_file, | |
| sample_size=sample_size, | |
| random_state=42 | |
| ) | |
| print(f"β Dataset loaded successfully!") | |
| print(f" Training: {len(train_paths)} videos") | |
| print(f" Validation: {len(val_paths)} videos") | |
| print(f" Test: {len(test_paths)} videos") | |
| except Exception as e: | |
| print(f"β Error loading dataset: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 2: Build feature extractor | |
| print("\n[Step 2] Building InceptionV3 feature extractor...") | |
| try: | |
| feature_extractor = model_module.build_video_feature_extractor( | |
| input_shape=(img_size, img_size, 3) | |
| ) | |
| print(f"β Feature extractor built successfully!") | |
| print(f" Output features: {feature_extractor.output_shape[-1]}") | |
| except Exception as e: | |
| print(f"β Error building feature extractor: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 3: Extract features from videos | |
| print("\n[Step 3] Extracting features from training videos...") | |
| print(" This may take a while...") | |
| try: | |
| train_data, train_labels = model_module.prepare_all_videos_for_training( | |
| train_paths, | |
| y_train, | |
| feature_extractor, | |
| max_seq_length=max_seq_length, | |
| img_size=img_size | |
| ) | |
| print(f"β Training features extracted!") | |
| print(f" Frame features shape: {train_data[0].shape}") | |
| print(f" Frame masks shape: {train_data[1].shape}") | |
| except Exception as e: | |
| print(f"β Error extracting training features: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| print("\n[Step 4] Extracting features from validation videos...") | |
| try: | |
| val_data, val_labels = model_module.prepare_all_videos_for_training( | |
| val_paths, | |
| y_val, | |
| feature_extractor, | |
| max_seq_length=max_seq_length, | |
| img_size=img_size | |
| ) | |
| print(f"β Validation features extracted!") | |
| except Exception as e: | |
| print(f"β Error extracting validation features: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 4: Build video sequence model | |
| print("\n[Step 5] Building CNN-RNN video sequence model...") | |
| try: | |
| video_model = model_module.build_video_sequence_model( | |
| max_seq_length=max_seq_length, | |
| num_features=num_features, | |
| num_classes=1, | |
| use_binary=True | |
| ) | |
| print(f"β Video model built successfully!") | |
| video_model.summary() | |
| except Exception as e: | |
| print(f"β Error building video model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 5: Train model | |
| print("\n[Step 6] Training video model...") | |
| try: | |
| from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint | |
| callbacks = [ | |
| EarlyStopping( | |
| monitor='val_accuracy', | |
| patience=5, | |
| restore_best_weights=True, | |
| verbose=1 | |
| ), | |
| ReduceLROnPlateau( | |
| monitor='val_accuracy', | |
| factor=0.5, | |
| patience=3, | |
| min_lr=1e-7, | |
| verbose=1 | |
| ), | |
| ModelCheckpoint( | |
| checkpoint_path, | |
| monitor='val_accuracy', | |
| save_best_only=True, | |
| verbose=1 | |
| ) | |
| ] | |
| history = video_model.fit( | |
| [train_data[0], train_data[1]], | |
| train_labels, | |
| validation_data=([val_data[0], val_data[1]], val_labels), | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| callbacks=callbacks, | |
| verbose=1 | |
| ) | |
| print(f"β Training completed!") | |
| except Exception as e: | |
| print(f"β Error during training: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 6: Evaluate on test set | |
| print("\n[Step 7] Evaluating on test set...") | |
| try: | |
| print(" Extracting test features...") | |
| test_data, test_labels = model_module.prepare_all_videos_for_training( | |
| test_paths, | |
| y_test, | |
| feature_extractor, | |
| max_seq_length=max_seq_length, | |
| img_size=img_size | |
| ) | |
| print(" Evaluating model...") | |
| test_results = video_model.evaluate( | |
| [test_data[0], test_data[1]], | |
| test_labels, | |
| verbose=1 | |
| ) | |
| print(f"β Test evaluation completed!") | |
| print(f" Test Loss: {test_results[0]:.4f}") | |
| print(f" Test Accuracy: {test_results[1]:.4f}") | |
| except Exception as e: | |
| print(f"β Error during evaluation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| # Step 7: Save model | |
| print(f"\n[Step 8] Saving model to {checkpoint_path}...") | |
| try: | |
| video_model.save(checkpoint_path) | |
| print(f"β Model saved successfully!") | |
| print(f"\nπ Model saved at: {os.path.abspath(checkpoint_path)}") | |
| print(f"\nTo use this model in deployment, set VIDEO_MODEL_PATH environment variable:") | |
| print(f" export VIDEO_MODEL_PATH={os.path.abspath(checkpoint_path)}") | |
| except Exception as e: | |
| print(f"β Error saving model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return | |
| print("\n" + "=" * 70) | |
| print("Training completed successfully!") | |
| print(f"Model saved to: {checkpoint_path}") | |
| print("=" * 70) | |
| def main(): | |
| """Main function with argument parsing.""" | |
| parser = argparse.ArgumentParser( | |
| description='Train DeepFake detection models', | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python src/train_model.py --model xception # Train Xception image model | |
| python src/train_model.py --model hybrid # Train Hybrid image model | |
| python src/train_model.py --model video # Train Video sequence model | |
| """ | |
| ) | |
| parser.add_argument( | |
| '--model', | |
| type=str, | |
| choices=['xception', 'hybrid', 'video'], | |
| default='xception', | |
| help='Model type to train (default: xception)' | |
| ) | |
| args = parser.parse_args() | |
| if args.model == 'xception': | |
| train_xception_model() | |
| elif args.model == 'hybrid': | |
| train_hybrid_model() | |
| elif args.model == 'video': | |
| train_video_model() | |
| else: | |
| print(f"Unknown model type: {args.model}") | |
| print("Available options: xception, hybrid, video") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |