Vinh Vu commited on
Commit
686e5bb
·
1 Parent(s): a879ae6

Update the training model

Browse files
Files changed (1) hide show
  1. 03-train_cnn.py +92 -18
03-train_cnn.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import pandas as pd
3
  import numpy as np
4
 
5
  # TensorFlow and tf.keras
@@ -16,22 +15,27 @@ from tensorflow.keras.preprocessing.image import ImageDataGenerator
16
  from tensorflow.keras.applications import EfficientNetB0
17
  from tensorflow.keras.applications.efficientnet import preprocess_input
18
  from tensorflow.keras.models import Sequential, load_model
19
- from tensorflow.keras.layers import Dense, Dropout
20
  from tensorflow.keras.optimizers import Adam
21
  from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
22
 
23
- input_size = 128
 
24
  batch_size_num = 32
25
  train_path = os.path.join(dataset_path, 'train')
26
  val_path = os.path.join(dataset_path, 'val')
27
  test_path = os.path.join(dataset_path, 'test')
28
 
29
  # preprocess_input scales pixels to [-1, 1] which EfficientNet expects
 
30
  train_datagen = ImageDataGenerator(
31
  preprocessing_function = preprocess_input,
32
- rotation_range = 10,
33
  horizontal_flip = True,
34
- zoom_range = 0.1,
 
 
 
35
  fill_mode = 'nearest'
36
  )
37
 
@@ -47,6 +51,16 @@ train_generator = train_datagen.flow_from_directory(
47
  print(f'Class mapping: {train_generator.class_indices}')
48
  print(f'Train samples - fake: {np.sum(train_generator.classes == 0)}, real: {np.sum(train_generator.classes == 1)}')
49
 
 
 
 
 
 
 
 
 
 
 
50
  val_datagen = ImageDataGenerator(
51
  preprocessing_function = preprocess_input
52
  )
@@ -74,29 +88,41 @@ test_generator = test_datagen.flow_from_directory(
74
  shuffle = False
75
  )
76
 
77
- # Build model - entire EfficientNetB0 is trainable
78
  efficient_net = EfficientNetB0(
79
  weights = 'imagenet',
80
  input_shape = (input_size, input_size, 3),
81
  include_top = False,
82
- pooling = 'max'
83
  )
84
 
 
 
 
85
  model = Sequential()
86
  model.add(efficient_net)
87
- model.add(Dense(units = 512, activation = 'relu'))
 
 
88
  model.add(Dropout(0.5))
89
- model.add(Dense(units = 128, activation = 'relu'))
90
  model.add(Dense(units = 1, activation = 'sigmoid'))
91
  model.summary()
92
 
93
- model.compile(optimizer = Adam(learning_rate=1e-4), loss='binary_crossentropy', metrics=['accuracy'])
94
-
95
  checkpoint_filepath = '.\\tmp_checkpoint'
96
  print('Creating Directory: ' + checkpoint_filepath)
97
  os.makedirs(checkpoint_filepath, exist_ok=True)
98
 
99
- callbacks = [
 
 
 
 
 
 
 
 
 
 
100
  EarlyStopping(
101
  monitor = 'val_loss',
102
  mode = 'min',
@@ -104,6 +130,52 @@ callbacks = [
104
  verbose = 1,
105
  restore_best_weights = True
106
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  ModelCheckpoint(
108
  filepath = os.path.join(checkpoint_filepath, 'best_model.keras'),
109
  monitor = 'val_loss',
@@ -120,20 +192,22 @@ callbacks = [
120
  )
121
  ]
122
 
123
- print('\n=== Training ===')
124
- num_epochs = 20
125
- history = model.fit(
126
  train_generator,
127
- epochs = num_epochs,
128
  steps_per_epoch = len(train_generator),
129
  validation_data = val_generator,
130
  validation_steps = len(val_generator),
131
- callbacks = callbacks
 
132
  )
133
 
134
- # Load the best model
135
  best_model = load_model(os.path.join(checkpoint_filepath, 'best_model.keras'))
136
 
 
 
 
137
  # Evaluate on test set
138
  print('\n=== Evaluation on Test Set ===')
139
  test_generator.reset()
 
1
  import os
 
2
  import numpy as np
3
 
4
  # TensorFlow and tf.keras
 
15
  from tensorflow.keras.applications import EfficientNetB0
16
  from tensorflow.keras.applications.efficientnet import preprocess_input
17
  from tensorflow.keras.models import Sequential, load_model
18
+ from tensorflow.keras.layers import Dense, Dropout, BatchNormalization, GlobalAveragePooling2D
19
  from tensorflow.keras.optimizers import Adam
20
  from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
21
 
22
+ # 224 is EfficientNetB0's native resolution — much better feature extraction than 128
23
+ input_size = 224
24
  batch_size_num = 32
25
  train_path = os.path.join(dataset_path, 'train')
26
  val_path = os.path.join(dataset_path, 'val')
27
  test_path = os.path.join(dataset_path, 'test')
28
 
29
  # preprocess_input scales pixels to [-1, 1] which EfficientNet expects
30
+ # Stronger augmentation for deepfake detection
31
  train_datagen = ImageDataGenerator(
32
  preprocessing_function = preprocess_input,
33
+ rotation_range = 15,
34
  horizontal_flip = True,
35
+ zoom_range = 0.15,
36
+ width_shift_range = 0.1,
37
+ height_shift_range = 0.1,
38
+ brightness_range = [0.8, 1.2],
39
  fill_mode = 'nearest'
40
  )
41
 
 
51
  print(f'Class mapping: {train_generator.class_indices}')
52
  print(f'Train samples - fake: {np.sum(train_generator.classes == 0)}, real: {np.sum(train_generator.classes == 1)}')
53
 
54
+ # Compute class weights to handle imbalance
55
+ num_fake = np.sum(train_generator.classes == 0)
56
+ num_real = np.sum(train_generator.classes == 1)
57
+ total = num_fake + num_real
58
+ class_weight = {
59
+ 0: total / (2.0 * num_fake),
60
+ 1: total / (2.0 * num_real)
61
+ }
62
+ print(f'Class weights: {class_weight}')
63
+
64
  val_datagen = ImageDataGenerator(
65
  preprocessing_function = preprocess_input
66
  )
 
88
  shuffle = False
89
  )
90
 
91
+ # Build model with frozen base for Phase 1
92
  efficient_net = EfficientNetB0(
93
  weights = 'imagenet',
94
  input_shape = (input_size, input_size, 3),
95
  include_top = False,
96
+ pooling = None # We'll add our own pooling
97
  )
98
 
99
+ # Freeze the base model for Phase 1
100
+ efficient_net.trainable = False
101
+
102
  model = Sequential()
103
  model.add(efficient_net)
104
+ model.add(GlobalAveragePooling2D())
105
+ model.add(BatchNormalization())
106
+ model.add(Dense(units = 256, activation = 'relu'))
107
  model.add(Dropout(0.5))
 
108
  model.add(Dense(units = 1, activation = 'sigmoid'))
109
  model.summary()
110
 
 
 
111
  checkpoint_filepath = '.\\tmp_checkpoint'
112
  print('Creating Directory: ' + checkpoint_filepath)
113
  os.makedirs(checkpoint_filepath, exist_ok=True)
114
 
115
+ # ============================================================
116
+ # Phase 1: Train head only (base frozen), higher learning rate
117
+ # ============================================================
118
+ print('\n=== Phase 1: Training head (base frozen) ===')
119
+ model.compile(
120
+ optimizer = Adam(learning_rate=1e-3),
121
+ loss='binary_crossentropy',
122
+ metrics=['accuracy']
123
+ )
124
+
125
+ phase1_callbacks = [
126
  EarlyStopping(
127
  monitor = 'val_loss',
128
  mode = 'min',
 
130
  verbose = 1,
131
  restore_best_weights = True
132
  ),
133
+ ModelCheckpoint(
134
+ filepath = os.path.join(checkpoint_filepath, 'best_model_phase1.keras'),
135
+ monitor = 'val_loss',
136
+ mode = 'min',
137
+ verbose = 1,
138
+ save_best_only = True
139
+ ),
140
+ ReduceLROnPlateau(
141
+ monitor = 'val_loss',
142
+ factor = 0.5,
143
+ patience = 2,
144
+ min_lr = 1e-5,
145
+ verbose = 1
146
+ )
147
+ ]
148
+
149
+ history_phase1 = model.fit(
150
+ train_generator,
151
+ epochs = 15,
152
+ steps_per_epoch = len(train_generator),
153
+ validation_data = val_generator,
154
+ validation_steps = len(val_generator),
155
+ class_weight = class_weight,
156
+ callbacks = phase1_callbacks
157
+ )
158
+
159
+ # ============================================================
160
+ # Phase 2: Unfreeze all layers, fine-tune with very low lr
161
+ # ============================================================
162
+ print('\n=== Phase 2: Fine-tuning entire model ===')
163
+ efficient_net.trainable = True
164
+
165
+ model.compile(
166
+ optimizer = Adam(learning_rate=1e-5),
167
+ loss='binary_crossentropy',
168
+ metrics=['accuracy']
169
+ )
170
+
171
+ phase2_callbacks = [
172
+ EarlyStopping(
173
+ monitor = 'val_loss',
174
+ mode = 'min',
175
+ patience = 7,
176
+ verbose = 1,
177
+ restore_best_weights = True
178
+ ),
179
  ModelCheckpoint(
180
  filepath = os.path.join(checkpoint_filepath, 'best_model.keras'),
181
  monitor = 'val_loss',
 
192
  )
193
  ]
194
 
195
+ history_phase2 = model.fit(
 
 
196
  train_generator,
197
+ epochs = 30,
198
  steps_per_epoch = len(train_generator),
199
  validation_data = val_generator,
200
  validation_steps = len(val_generator),
201
+ class_weight = class_weight,
202
+ callbacks = phase2_callbacks
203
  )
204
 
205
+ # Load the best model from Phase 2
206
  best_model = load_model(os.path.join(checkpoint_filepath, 'best_model.keras'))
207
 
208
+ # Also save a copy for the app
209
+ best_model.save('best_model.keras')
210
+
211
  # Evaluate on test set
212
  print('\n=== Evaluation on Test Set ===')
213
  test_generator.reset()