deepfake-api / src /train_video_model.py
piyushnaula's picture
Initial commit - Deepfake Detector API
45742a7
import os
import sys
import matplotlib.pyplot as plt
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.metrics import AUC
import tensorflow as tf
# --- 1. Import our project files ---
try:
import config
# Import our new generator and model
from video_data_generator import VideoDataGenerator
from video_model import build_video_model
except ImportError:
print("Error: Could not import config.py, video_data_generator.py, or video_model.py.")
print("Make sure they are all in the 'src/' directory.")
sys.exit(1)
# --- 2. BEGIN ROBUST GPU FIX ---
# (Same as your original train.py)
print("Applying robust GPU configuration...")
try:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
print(f" > Enabled memory growth for: {gpu.name}")
else:
print(" > No GPUs found by TensorFlow. Will run on CPU.")
except Exception as e:
print(f" > Error applying GPU configuration: {e}")
# --- END ROBUST GPU FIX ---
# --- 3. Plot History Function ---
# (Copied from train.py)
def plot_history(history, save_path):
"""
Plots the training history (accuracy and loss) and saves it to a file.
"""
print(f"Saving training history plot to {save_path}...")
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
# Plot training & validation accuracy values
ax1.plot(history.history['accuracy'])
ax1.plot(history.history['val_accuracy'])
ax1.set_title('Model Accuracy')
ax1.set_ylabel('Accuracy')
ax1.set_xlabel('Epoch')
ax1.legend(['Train', 'Validation'], loc='upper left')
# Plot training & validation loss values
ax2.plot(history.history['loss'])
ax2.plot(history.history['val_loss'])
ax2.set_title('Model Loss')
ax2.set_ylabel('Loss')
ax2.set_xlabel('Epoch')
ax2.legend(['Train', 'Validation'], loc='upper left')
plt.tight_layout()
plt.savefig(save_path)
print("History plot saved.")
# --- 4. Main Training Function ---
def main():
"""
Main training function for the new CNN-LSTM video model.
"""
print("--- Phase 1: Starting CNN-LSTM Model Training ---")
# 1. Instantiate Data Generators
# We use our new VideoDataGenerator and the new paths
print("Initializing data generators...")
train_gen = VideoDataGenerator(
data_dir=config.TRAIN_SEQ_DIR,
batch_size=config.VIDEO_BATCH_SIZE, # Using the new, smaller batch size
sequence_length=config.SEQUENCE_LENGTH,
img_size=config.TARGET_IMAGE_SIZE,
shuffle=True
)
val_gen = VideoDataGenerator(
data_dir=config.TEST_SEQ_DIR, # Using the TEST set for validation
batch_size=config.VIDEO_BATCH_SIZE,
sequence_length=config.SEQUENCE_LENGTH,
img_size=config.TARGET_IMAGE_SIZE,
shuffle=False # No need to shuffle validation data
)
# 2. Build the model
print("Building model...")
model = build_video_model()
# 3. Compile the model
# We are training the new LSTM head
model.compile(
optimizer=Adam(learning_rate=config.LEARNING_RATE),
loss='binary_crossentropy',
metrics=['accuracy', AUC(name='auc')] # AUC is a great metric
)
model.summary()
# 4. Define Callbacks
# Save the new model with a new name
checkpoint_path = os.path.join(config.MODEL_DIR, "cnn_lstm_video_model.h5")
model_checkpoint = ModelCheckpoint(
filepath=checkpoint_path,
save_best_only=True,
monitor='val_auc', # Monitor our best metric
mode='max',
verbose=1
)
early_stopping = EarlyStopping(
monitor='val_auc',
mode='max',
patience=7, # Stop after 7 epochs of no improvement
verbose=1,
restore_best_weights=True # Restore the best model
)
# 5. Start Training
print("Starting model training...")
# We set epochs high and let EarlyStopping find the best one
NUM_EPOCHS = 50
history = model.fit(
train_gen,
validation_data=val_gen,
epochs=NUM_EPOCHS,
callbacks=[model_checkpoint, early_stopping]
# No 'steps_per_epoch' needed!
# Keras automatically knows the length from our
# generator's `__len__` method.
)
print("Training complete.")
# 6. Save history plot with a new name
plot_path = os.path.join(config.RESULTS_DIR, "cnn_lstm_training_history.png")
plot_history(history, plot_path)
print("\n--- CNN-LSTM Model Training Finished ---")
print(f"Best video model saved to: {checkpoint_path}")
# --- 5. Run the Script ---
if __name__ == "__main__":
# --- CRITICAL FIX ---
# Create models/ and results/ directories if they don't exist
# This prevents the FileNotFoundError
os.makedirs(config.MODEL_DIR, exist_ok=True)
os.makedirs(config.RESULTS_DIR, exist_ok=True)
main()