File size: 5,103 Bytes
45742a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import sys
import matplotlib.pyplot as plt
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.metrics import AUC
import tensorflow as tf

# --- 1. Import our project files ---
try:
    import config
    # Import our new generator and model
    from video_data_generator import VideoDataGenerator
    from video_model import build_video_model
except ImportError:
    print("Error: Could not import config.py, video_data_generator.py, or video_model.py.")
    print("Make sure they are all in the 'src/' directory.")
    sys.exit(1)

# --- 2. BEGIN ROBUST GPU FIX ---
# (Same as your original train.py)
print("Applying robust GPU configuration...")
try:
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
            print(f"  > Enabled memory growth for: {gpu.name}")
    else:
        print("  > No GPUs found by TensorFlow. Will run on CPU.")
except Exception as e:
    print(f"  > Error applying GPU configuration: {e}")
# --- END ROBUST GPU FIX ---

# --- 3. Plot History Function ---
# (Copied from train.py)
def plot_history(history, save_path):
    """
    Plots the training history (accuracy and loss) and saves it to a file.
    """
    print(f"Saving training history plot to {save_path}...")
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))

    # Plot training & validation accuracy values
    ax1.plot(history.history['accuracy'])
    ax1.plot(history.history['val_accuracy'])
    ax1.set_title('Model Accuracy')
    ax1.set_ylabel('Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.legend(['Train', 'Validation'], loc='upper left')

    # Plot training & validation loss values
    ax2.plot(history.history['loss'])
    ax2.plot(history.history['val_loss'])
    ax2.set_title('Model Loss')
    ax2.set_ylabel('Loss')
    ax2.set_xlabel('Epoch')
    ax2.legend(['Train', 'Validation'], loc='upper left')
    
    plt.tight_layout()
    plt.savefig(save_path)
    print("History plot saved.")

# --- 4. Main Training Function ---
def main():
    """
    Main training function for the new CNN-LSTM video model.
    """
    print("--- Phase 1: Starting CNN-LSTM Model Training ---")
    
    # 1. Instantiate Data Generators
    # We use our new VideoDataGenerator and the new paths
    print("Initializing data generators...")
    train_gen = VideoDataGenerator(
        data_dir=config.TRAIN_SEQ_DIR,
        batch_size=config.VIDEO_BATCH_SIZE, # Using the new, smaller batch size
        sequence_length=config.SEQUENCE_LENGTH,
        img_size=config.TARGET_IMAGE_SIZE,
        shuffle=True
    )
    
    val_gen = VideoDataGenerator(
        data_dir=config.TEST_SEQ_DIR, # Using the TEST set for validation
        batch_size=config.VIDEO_BATCH_SIZE,
        sequence_length=config.SEQUENCE_LENGTH,
        img_size=config.TARGET_IMAGE_SIZE,
        shuffle=False # No need to shuffle validation data
    )
    
    # 2. Build the model
    print("Building model...")
    model = build_video_model()
    
    # 3. Compile the model
    # We are training the new LSTM head
    model.compile(
        optimizer=Adam(learning_rate=config.LEARNING_RATE),
        loss='binary_crossentropy',
        metrics=['accuracy', AUC(name='auc')] # AUC is a great metric
    )
    
    model.summary()
    
    # 4. Define Callbacks
    # Save the new model with a new name
    checkpoint_path = os.path.join(config.MODEL_DIR, "cnn_lstm_video_model.h5")
    
    model_checkpoint = ModelCheckpoint(
        filepath=checkpoint_path,
        save_best_only=True,
        monitor='val_auc', # Monitor our best metric
        mode='max',
        verbose=1
    )
    
    early_stopping = EarlyStopping(
        monitor='val_auc',
        mode='max',
        patience=7,  # Stop after 7 epochs of no improvement
        verbose=1,
        restore_best_weights=True # Restore the best model
    )
    
    # 5. Start Training
    print("Starting model training...")
    
    # We set epochs high and let EarlyStopping find the best one
    NUM_EPOCHS = 50 
    
    history = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=NUM_EPOCHS,
        callbacks=[model_checkpoint, early_stopping]
        # No 'steps_per_epoch' needed!
        # Keras automatically knows the length from our
        # generator's `__len__` method.
    )
    
    print("Training complete.")
    
    # 6. Save history plot with a new name
    plot_path = os.path.join(config.RESULTS_DIR, "cnn_lstm_training_history.png")
    plot_history(history, plot_path)
    
    print("\n--- CNN-LSTM Model Training Finished ---")
    print(f"Best video model saved to: {checkpoint_path}")

# --- 5. Run the Script ---
if __name__ == "__main__":
    
    # --- CRITICAL FIX ---
    # Create models/ and results/ directories if they don't exist
    # This prevents the FileNotFoundError
    os.makedirs(config.MODEL_DIR, exist_ok=True)
    os.makedirs(config.RESULTS_DIR, exist_ok=True)
    
    main()