Vinh Vu commited on
Commit
b06ef27
·
1 Parent(s): 4e15caf

Update train cnn

Browse files
Files changed (1) hide show
  1. 03-train_cnn.py +110 -34
03-train_cnn.py CHANGED
@@ -27,14 +27,16 @@ def get_filename_only(file_path):
27
  filename_only = file_basename.split('.')[0]
28
  return filename_only
29
 
 
 
30
  from tensorflow.keras.preprocessing.image import ImageDataGenerator
31
  from tensorflow.keras import applications
32
  from tensorflow.keras.applications import EfficientNetB0
33
- from tensorflow.keras.models import Sequential
 
34
  from tensorflow.keras.layers import Dense, Dropout
35
  from tensorflow.keras.optimizers import Adam
36
- from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
37
- from tensorflow.keras.models import load_model
38
 
39
  input_size = 128
40
  batch_size_num = 32
@@ -43,13 +45,15 @@ val_path = os.path.join(dataset_path, 'val')
43
  test_path = os.path.join(dataset_path, 'test')
44
 
45
  train_datagen = ImageDataGenerator(
46
- rescale = 1/255, #rescale the tensor values to [0,1]
47
- rotation_range = 10,
48
- width_shift_range = 0.1,
49
- height_shift_range = 0.1,
50
  shear_range = 0.2,
51
- zoom_range = 0.1,
52
  horizontal_flip = True,
 
 
53
  fill_mode = 'nearest'
54
  )
55
 
@@ -57,28 +61,33 @@ train_generator = train_datagen.flow_from_directory(
57
  directory = train_path,
58
  target_size = (input_size, input_size),
59
  color_mode = "rgb",
60
- class_mode = "binary", #"categorical", "binary", "sparse", "input"
61
  batch_size = batch_size_num,
62
  shuffle = True
63
- #save_to_dir = tmp_debug_path
64
  )
65
 
 
 
 
 
 
 
 
66
  val_datagen = ImageDataGenerator(
67
- rescale = 1/255 #rescale the tensor values to [0,1]
68
  )
69
 
70
  val_generator = val_datagen.flow_from_directory(
71
  directory = val_path,
72
  target_size = (input_size, input_size),
73
  color_mode = "rgb",
74
- class_mode = "binary", #"categorical", "binary", "sparse", "input"
75
  batch_size = batch_size_num,
76
  shuffle = True
77
- #save_to_dir = tmp_debug_path
78
  )
79
 
80
  test_datagen = ImageDataGenerator(
81
- rescale = 1/255 #rescale the tensor values to [0,1]
82
  )
83
 
84
  test_generator = test_datagen.flow_from_directory(
@@ -86,28 +95,29 @@ test_generator = test_datagen.flow_from_directory(
86
  classes=['fake', 'real'],
87
  target_size = (input_size, input_size),
88
  color_mode = "rgb",
89
- class_mode = None,
90
  batch_size = 1,
91
  shuffle = False
92
  )
93
 
94
- # Train a CNN classifier
95
  efficient_net = EfficientNetB0(
96
  weights = 'imagenet',
97
  input_shape = (input_size, input_size, 3),
98
  include_top = False,
99
  pooling = 'max'
100
  )
 
101
 
102
  model = Sequential()
103
  model.add(efficient_net)
104
  model.add(Dense(units = 512, activation = 'relu'))
105
  model.add(Dropout(0.5))
106
  model.add(Dense(units = 128, activation = 'relu'))
 
107
  model.add(Dense(units = 1, activation = 'sigmoid'))
108
  model.summary()
109
 
110
- # Compile model
111
  model.compile(optimizer = Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])
112
 
113
  checkpoint_filepath = '.\\tmp_checkpoint'
@@ -116,46 +126,112 @@ os.makedirs(checkpoint_filepath, exist_ok=True)
116
 
117
  custom_callbacks = [
118
  EarlyStopping(
119
- monitor = 'val_loss',
120
- mode = 'min',
121
  patience = 5,
122
- verbose = 1
 
123
  ),
124
  ModelCheckpoint(
125
- filepath = os.path.join(checkpoint_filepath, 'best_model.h5'),
126
- monitor = 'val_loss',
127
- mode = 'min',
128
  verbose = 1,
129
  save_best_only = True
 
 
 
 
 
 
 
 
130
  )
131
  ]
132
 
133
- # Train network
134
- num_epochs = 20
135
  history = model.fit(
136
  train_generator,
137
  epochs = num_epochs,
138
  steps_per_epoch = len(train_generator),
139
  validation_data = val_generator,
140
  validation_steps = len(val_generator),
141
- callbacks = custom_callbacks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  )
143
- print(history.history)
144
 
 
 
145
 
146
- # load the saved model that is considered the best
147
- best_model = load_model(os.path.join(checkpoint_filepath, 'best_model.h5'))
 
 
 
 
148
 
149
  # Generate predictions
150
  test_generator.reset()
 
 
 
151
 
152
- preds = best_model.predict(
153
- test_generator,
154
- verbose = 1
155
- )
 
156
 
157
  test_results = pd.DataFrame({
158
  "Filename": test_generator.filenames,
159
- "Prediction": preds.flatten()
 
 
160
  })
161
  print(test_results)
 
27
  filename_only = file_basename.split('.')[0]
28
  return filename_only
29
 
30
+ import numpy as np
31
+ from sklearn.utils.class_weight import compute_class_weight
32
  from tensorflow.keras.preprocessing.image import ImageDataGenerator
33
  from tensorflow.keras import applications
34
  from tensorflow.keras.applications import EfficientNetB0
35
+ from tensorflow.keras.applications.efficientnet import preprocess_input
36
+ from tensorflow.keras.models import Sequential, load_model
37
  from tensorflow.keras.layers import Dense, Dropout
38
  from tensorflow.keras.optimizers import Adam
39
+ from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
 
40
 
41
  input_size = 128
42
  batch_size_num = 32
 
45
  test_path = os.path.join(dataset_path, 'test')
46
 
47
  train_datagen = ImageDataGenerator(
48
+ preprocessing_function = preprocess_input,
49
+ rotation_range = 15,
50
+ width_shift_range = 0.15,
51
+ height_shift_range = 0.15,
52
  shear_range = 0.2,
53
+ zoom_range = 0.15,
54
  horizontal_flip = True,
55
+ brightness_range = [0.8, 1.2],
56
+ channel_shift_range = 30,
57
  fill_mode = 'nearest'
58
  )
59
 
 
61
  directory = train_path,
62
  target_size = (input_size, input_size),
63
  color_mode = "rgb",
64
+ class_mode = "binary",
65
  batch_size = batch_size_num,
66
  shuffle = True
 
67
  )
68
 
69
+ # Compute class weights to handle imbalance
70
+ class_weights = compute_class_weight('balanced', classes=np.unique(train_generator.classes), y=train_generator.classes)
71
+ class_weight_dict = dict(enumerate(class_weights))
72
+ print(f'Class mapping: {train_generator.class_indices}')
73
+ print(f'Class weights: {class_weight_dict}')
74
+ print(f'Train samples - fake: {np.sum(train_generator.classes == 0)}, real: {np.sum(train_generator.classes == 1)}')
75
+
76
  val_datagen = ImageDataGenerator(
77
+ preprocessing_function = preprocess_input
78
  )
79
 
80
  val_generator = val_datagen.flow_from_directory(
81
  directory = val_path,
82
  target_size = (input_size, input_size),
83
  color_mode = "rgb",
84
+ class_mode = "binary",
85
  batch_size = batch_size_num,
86
  shuffle = True
 
87
  )
88
 
89
  test_datagen = ImageDataGenerator(
90
+ preprocessing_function = preprocess_input
91
  )
92
 
93
  test_generator = test_datagen.flow_from_directory(
 
95
  classes=['fake', 'real'],
96
  target_size = (input_size, input_size),
97
  color_mode = "rgb",
98
+ class_mode = "binary",
99
  batch_size = 1,
100
  shuffle = False
101
  )
102
 
103
+ # --- Phase 1: Train with frozen base ---
104
  efficient_net = EfficientNetB0(
105
  weights = 'imagenet',
106
  input_shape = (input_size, input_size, 3),
107
  include_top = False,
108
  pooling = 'max'
109
  )
110
+ efficient_net.trainable = False # freeze base initially
111
 
112
  model = Sequential()
113
  model.add(efficient_net)
114
  model.add(Dense(units = 512, activation = 'relu'))
115
  model.add(Dropout(0.5))
116
  model.add(Dense(units = 128, activation = 'relu'))
117
+ model.add(Dropout(0.3))
118
  model.add(Dense(units = 1, activation = 'sigmoid'))
119
  model.summary()
120
 
 
121
  model.compile(optimizer = Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy'])
122
 
123
  checkpoint_filepath = '.\\tmp_checkpoint'
 
126
 
127
  custom_callbacks = [
128
  EarlyStopping(
129
+ monitor = 'val_accuracy',
130
+ mode = 'max',
131
  patience = 5,
132
+ verbose = 1,
133
+ restore_best_weights = True
134
  ),
135
  ModelCheckpoint(
136
+ filepath = os.path.join(checkpoint_filepath, 'best_model.keras'),
137
+ monitor = 'val_accuracy',
138
+ mode = 'max',
139
  verbose = 1,
140
  save_best_only = True
141
+ ),
142
+ ReduceLROnPlateau(
143
+ monitor = 'val_accuracy',
144
+ factor = 0.5,
145
+ patience = 3,
146
+ min_lr = 1e-7,
147
+ verbose = 1,
148
+ mode = 'max'
149
  )
150
  ]
151
 
152
+ print('\n=== Phase 1: Training with frozen base ===')
153
+ num_epochs = 15
154
  history = model.fit(
155
  train_generator,
156
  epochs = num_epochs,
157
  steps_per_epoch = len(train_generator),
158
  validation_data = val_generator,
159
  validation_steps = len(val_generator),
160
+ callbacks = custom_callbacks,
161
+ class_weight = class_weight_dict
162
+ )
163
+
164
+ # --- Phase 2: Fine-tune top layers of base model ---
165
+ print('\n=== Phase 2: Fine-tuning top layers ===')
166
+ efficient_net.trainable = True
167
+ # Freeze all layers except the last 30
168
+ for layer in efficient_net.layers[:-30]:
169
+ layer.trainable = False
170
+
171
+ model.compile(optimizer = Adam(learning_rate=1e-5), loss='binary_crossentropy', metrics=['accuracy'])
172
+
173
+ fine_tune_callbacks = [
174
+ EarlyStopping(
175
+ monitor = 'val_accuracy',
176
+ mode = 'max',
177
+ patience = 5,
178
+ verbose = 1,
179
+ restore_best_weights = True
180
+ ),
181
+ ModelCheckpoint(
182
+ filepath = os.path.join(checkpoint_filepath, 'best_model.keras'),
183
+ monitor = 'val_accuracy',
184
+ mode = 'max',
185
+ verbose = 1,
186
+ save_best_only = True
187
+ ),
188
+ ReduceLROnPlateau(
189
+ monitor = 'val_accuracy',
190
+ factor = 0.5,
191
+ patience = 3,
192
+ min_lr = 1e-8,
193
+ verbose = 1,
194
+ mode = 'max'
195
+ )
196
+ ]
197
+
198
+ fine_tune_epochs = 30
199
+ history_fine = model.fit(
200
+ train_generator,
201
+ epochs = fine_tune_epochs,
202
+ steps_per_epoch = len(train_generator),
203
+ validation_data = val_generator,
204
+ validation_steps = len(val_generator),
205
+ callbacks = fine_tune_callbacks,
206
+ class_weight = class_weight_dict
207
  )
 
208
 
209
+ # Load the best model
210
+ best_model = load_model(os.path.join(checkpoint_filepath, 'best_model.keras'))
211
 
212
+ # Evaluate on test set
213
+ print('\n=== Evaluation on Test Set ===')
214
+ test_generator.reset()
215
+ test_loss, test_accuracy = best_model.evaluate(test_generator, steps=len(test_generator), verbose=1)
216
+ print(f'Test Loss: {test_loss:.4f}')
217
+ print(f'Test Accuracy: {test_accuracy:.4f}')
218
 
219
  # Generate predictions
220
  test_generator.reset()
221
+ preds = best_model.predict(test_generator, verbose=1)
222
+ pred_labels = (preds.flatten() > 0.5).astype(int)
223
+ true_labels = test_generator.classes
224
 
225
+ from sklearn.metrics import classification_report, confusion_matrix
226
+ print('\nClassification Report:')
227
+ print(classification_report(true_labels, pred_labels, target_names=['fake', 'real']))
228
+ print('Confusion Matrix:')
229
+ print(confusion_matrix(true_labels, pred_labels))
230
 
231
  test_results = pd.DataFrame({
232
  "Filename": test_generator.filenames,
233
+ "Prediction": preds.flatten(),
234
+ "Predicted_Label": pred_labels,
235
+ "True_Label": true_labels
236
  })
237
  print(test_results)