Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import matplotlib.pyplot as plt | |
| from tensorflow.keras.optimizers import Adam | |
| from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping | |
| from tensorflow.keras.metrics import AUC | |
| import tensorflow as tf | |
| # --- 1. Import our project files --- | |
| try: | |
| import config | |
| # Import our new generator and model | |
| from video_data_generator import VideoDataGenerator | |
| from video_model import build_video_model | |
| except ImportError: | |
| print("Error: Could not import config.py, video_data_generator.py, or video_model.py.") | |
| print("Make sure they are all in the 'src/' directory.") | |
| sys.exit(1) | |
| # --- 2. BEGIN ROBUST GPU FIX --- | |
| # (Same as your original train.py) | |
| print("Applying robust GPU configuration...") | |
| try: | |
| gpus = tf.config.list_physical_devices('GPU') | |
| if gpus: | |
| for gpu in gpus: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| print(f" > Enabled memory growth for: {gpu.name}") | |
| else: | |
| print(" > No GPUs found by TensorFlow. Will run on CPU.") | |
| except Exception as e: | |
| print(f" > Error applying GPU configuration: {e}") | |
| # --- END ROBUST GPU FIX --- | |
| # --- 3. Plot History Function --- | |
| # (Copied from train.py) | |
| def plot_history(history, save_path): | |
| """ | |
| Plots the training history (accuracy and loss) and saves it to a file. | |
| """ | |
| print(f"Saving training history plot to {save_path}...") | |
| fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10)) | |
| # Plot training & validation accuracy values | |
| ax1.plot(history.history['accuracy']) | |
| ax1.plot(history.history['val_accuracy']) | |
| ax1.set_title('Model Accuracy') | |
| ax1.set_ylabel('Accuracy') | |
| ax1.set_xlabel('Epoch') | |
| ax1.legend(['Train', 'Validation'], loc='upper left') | |
| # Plot training & validation loss values | |
| ax2.plot(history.history['loss']) | |
| ax2.plot(history.history['val_loss']) | |
| ax2.set_title('Model Loss') | |
| ax2.set_ylabel('Loss') | |
| ax2.set_xlabel('Epoch') | |
| ax2.legend(['Train', 'Validation'], loc='upper left') | |
| plt.tight_layout() | |
| plt.savefig(save_path) | |
| print("History plot saved.") | |
| # --- 4. Main Training Function --- | |
| def main(): | |
| """ | |
| Main training function for the new CNN-LSTM video model. | |
| """ | |
| print("--- Phase 1: Starting CNN-LSTM Model Training ---") | |
| # 1. Instantiate Data Generators | |
| # We use our new VideoDataGenerator and the new paths | |
| print("Initializing data generators...") | |
| train_gen = VideoDataGenerator( | |
| data_dir=config.TRAIN_SEQ_DIR, | |
| batch_size=config.VIDEO_BATCH_SIZE, # Using the new, smaller batch size | |
| sequence_length=config.SEQUENCE_LENGTH, | |
| img_size=config.TARGET_IMAGE_SIZE, | |
| shuffle=True | |
| ) | |
| val_gen = VideoDataGenerator( | |
| data_dir=config.TEST_SEQ_DIR, # Using the TEST set for validation | |
| batch_size=config.VIDEO_BATCH_SIZE, | |
| sequence_length=config.SEQUENCE_LENGTH, | |
| img_size=config.TARGET_IMAGE_SIZE, | |
| shuffle=False # No need to shuffle validation data | |
| ) | |
| # 2. Build the model | |
| print("Building model...") | |
| model = build_video_model() | |
| # 3. Compile the model | |
| # We are training the new LSTM head | |
| model.compile( | |
| optimizer=Adam(learning_rate=config.LEARNING_RATE), | |
| loss='binary_crossentropy', | |
| metrics=['accuracy', AUC(name='auc')] # AUC is a great metric | |
| ) | |
| model.summary() | |
| # 4. Define Callbacks | |
| # Save the new model with a new name | |
| checkpoint_path = os.path.join(config.MODEL_DIR, "cnn_lstm_video_model.h5") | |
| model_checkpoint = ModelCheckpoint( | |
| filepath=checkpoint_path, | |
| save_best_only=True, | |
| monitor='val_auc', # Monitor our best metric | |
| mode='max', | |
| verbose=1 | |
| ) | |
| early_stopping = EarlyStopping( | |
| monitor='val_auc', | |
| mode='max', | |
| patience=7, # Stop after 7 epochs of no improvement | |
| verbose=1, | |
| restore_best_weights=True # Restore the best model | |
| ) | |
| # 5. Start Training | |
| print("Starting model training...") | |
| # We set epochs high and let EarlyStopping find the best one | |
| NUM_EPOCHS = 50 | |
| history = model.fit( | |
| train_gen, | |
| validation_data=val_gen, | |
| epochs=NUM_EPOCHS, | |
| callbacks=[model_checkpoint, early_stopping] | |
| # No 'steps_per_epoch' needed! | |
| # Keras automatically knows the length from our | |
| # generator's `__len__` method. | |
| ) | |
| print("Training complete.") | |
| # 6. Save history plot with a new name | |
| plot_path = os.path.join(config.RESULTS_DIR, "cnn_lstm_training_history.png") | |
| plot_history(history, plot_path) | |
| print("\n--- CNN-LSTM Model Training Finished ---") | |
| print(f"Best video model saved to: {checkpoint_path}") | |
| # --- 5. Run the Script --- | |
| if __name__ == "__main__": | |
| # --- CRITICAL FIX --- | |
| # Create models/ and results/ directories if they don't exist | |
| # This prevents the FileNotFoundError | |
| os.makedirs(config.MODEL_DIR, exist_ok=True) | |
| os.makedirs(config.RESULTS_DIR, exist_ok=True) | |
| main() |