Instructions to use ViettNguyen21/brain_msi_segmentation_effunet with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use ViettNguyen21/brain_msi_segmentation_effunet with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://ViettNguyen21/brain_msi_segmentation_effunet") - Notebooks
- Google Colab
- Kaggle
| import os | |
| os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices=false' | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| import glob | |
| libdevice_paths = glob.glob( | |
| '/root/miniconda3/envs/deeplab/lib/python3.11/site-packages' | |
| '/nvidia/cuda_nvcc/nvvm/libdevice/libdevice.10.bc' | |
| ) | |
| if libdevice_paths: | |
| cuda_dir = os.path.dirname(os.path.dirname(os.path.dirname(libdevice_paths[0]))) | |
| os.environ['XLA_FLAGS'] = f'--xla_gpu_cuda_data_dir={cuda_dir}' | |
| print(f"[OK] libdevice found: {libdevice_paths[0]}") | |
| else: | |
| # Fallback: tắt hoàn toàn XLA | |
| os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices=false' | |
| print("[WARN] libdevice not found, XLA disabled") | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, message=".*input_shape.*") | |
| import tensorflow as tf | |
| tf.random.set_seed(42) | |
| tf.config.optimizer.set_jit(False) | |
| tf.config.experimental.set_synchronous_execution(True) | |
| from tensorflow.keras import mixed_precision | |
| from tensorflow.keras.applications import MobileNetV2 | |
| gpus = tf.config.list_physical_devices('GPU') | |
| if gpus: | |
| try: | |
| for gpu in gpus: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| print("[OK] Đã kích hoạt Memory Growth") | |
| except RuntimeError as e: | |
| print(f"Lỗi khởi tạo GPU: {e}") | |
| mixed_precision.set_global_policy('float32') | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| from sklearn.model_selection import train_test_split | |
| from CLMR import CLMRCallback | |
| BATCH_SIZE = 8 | |
| EPOCHS_STAGE_1 = 20 | |
| EPOCHS_STAGE_2 = 80 | |
| TARGET_SIZE = (128, 128) | |
| NUM_CLASSES = 2 | |
| # --- 3. CÁC HÀM TIỀN XỬ LÝ DỮ LIỆU --- | |
| def parse_tfrecord(serialized): | |
| feature_desc = { | |
| 'image': tf.io.FixedLenFeature([], tf.string), | |
| 'mask': tf.io.FixedLenFeature([], tf.string), | |
| } | |
| parsed = tf.io.parse_single_example(serialized, feature_desc) | |
| image = tf.image.decode_jpeg(parsed['image'], channels=3) | |
| image = tf.cast(image, tf.float32) / 255.0 | |
| image = tf.image.resize(image, TARGET_SIZE) | |
| image.set_shape([TARGET_SIZE[0], TARGET_SIZE[1], 3]) | |
| mask = tf.image.decode_png(parsed['mask'], channels=1) | |
| mask = tf.cast(mask, tf.float32) | |
| mask = tf.image.resize(mask, TARGET_SIZE, | |
| method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) | |
| mask.set_shape([TARGET_SIZE[0], TARGET_SIZE[1], 1]) | |
| return image, mask | |
| def augment(image, mask): | |
| """50% crop ngẫu nhiên, 50% crop tập trung vào vùng có class hiếm""" | |
| def random_crop(image, mask): | |
| combined = tf.concat([image, mask], axis=-1) | |
| combined = tf.image.random_flip_left_right(combined) | |
| min_crop = tf.cast(TARGET_SIZE[0] * 0.5, tf.int32) | |
| crop_size = tf.random.uniform([], minval=min_crop, maxval=TARGET_SIZE[0], dtype=tf.int32) | |
| combined = tf.image.random_crop(combined, [crop_size, crop_size, 4]) | |
| combined = tf.image.resize(combined, TARGET_SIZE) | |
| return combined[..., :3], combined[..., 3:] | |
| def focused_crop(image, mask): | |
| """Crop vùng có chứa class 2 hoặc 3""" | |
| combined = tf.concat([image, mask], axis=-1) | |
| combined = tf.image.random_flip_left_right(combined) | |
| # Tìm vùng có class hiếm | |
| rare_mask = tf.logical_or( | |
| tf.equal(mask[..., 0], 2), | |
| tf.equal(mask[..., 0], 3) | |
| ) | |
| rare_indices = tf.where(rare_mask) | |
| # Nếu không có class hiếm thì crop ngẫu nhiên | |
| has_rare = tf.greater(tf.shape(rare_indices)[0], 0) | |
| def crop_around_rare(): | |
| # Lấy một điểm ngẫu nhiên trong vùng hiếm | |
| idx = tf.random.uniform( | |
| [], 0, tf.shape(rare_indices)[0], dtype=tf.int32 | |
| ) | |
| center = tf.cast(rare_indices[idx], tf.int32) | |
| cy, cx = center[0], center[1] | |
| h, w = TARGET_SIZE | |
| crop_size = tf.cast(h * 0.5, tf.int32) | |
| # Tính crop box | |
| y1 = tf.clip_by_value(cy - crop_size//2, 0, h - crop_size) | |
| x1 = tf.clip_by_value(cx - crop_size//2, 0, w - crop_size) | |
| cropped = combined[y1:y1+crop_size, x1:x1+crop_size, :] | |
| cropped = tf.image.resize(cropped, TARGET_SIZE) | |
| return cropped | |
| combined = tf.cond( | |
| has_rare, | |
| crop_around_rare, | |
| lambda: tf.image.resize(combined, TARGET_SIZE) | |
| ) | |
| return combined[..., :3], combined[..., 3:] | |
| # 50% random, 50% focused | |
| use_focused = tf.random.uniform([]) > 0.5 | |
| image, mask = tf.cond( | |
| use_focused, | |
| lambda: focused_crop(image, mask), | |
| lambda: random_crop(image, mask) | |
| ) | |
| # Color augmentation | |
| image = tf.image.random_brightness(image, max_delta=0.15) | |
| image = tf.image.random_contrast(image, lower=0.85, upper=1.15) | |
| image = tf.image.random_saturation(image, lower=0.85, upper=1.15) | |
| image = tf.clip_by_value(image, 0.0, 1.0) | |
| return image, mask | |
| def load_tfrecord_dataset(tfrecord_pattern, batch_size=4, training=False): | |
| files = tf.data.Dataset.list_files(tfrecord_pattern, shuffle=training) | |
| dataset = files.interleave( | |
| lambda f: tf.data.TFRecordDataset(f, compression_type='GZIP'), | |
| cycle_length=4, | |
| num_parallel_calls=tf.data.AUTOTUNE | |
| ) | |
| dataset = dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE) | |
| if training: | |
| dataset = dataset.shuffle(buffer_size=500) | |
| dataset = dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE) | |
| dataset = dataset.batch(batch_size, drop_remainder=True) | |
| dataset = dataset.prefetch(tf.data.AUTOTUNE) | |
| return dataset | |
| # --- 5. ĐỊNH NGHĨA MODEL --- | |
| def aspp_block(x, filters=128): | |
| b1 = tf.keras.layers.Conv2D(filters, 1, padding='same', use_bias=False)(x) | |
| b1 = tf.keras.layers.BatchNormalization()(b1) | |
| b1 = tf.keras.layers.Activation('relu')(b1) | |
| b2 = tf.keras.layers.Conv2D(filters, 3, padding='same', dilation_rate=6, use_bias=False)(x) | |
| b2 = tf.keras.layers.BatchNormalization()(b2) | |
| b2 = tf.keras.layers.Activation('relu')(b2) | |
| out = tf.keras.layers.Concatenate()([b1, b2]) | |
| out = tf.keras.layers.Conv2D(filters, 1, use_bias=False)(out) | |
| out = tf.keras.layers.BatchNormalization()(out) | |
| out = tf.keras.layers.Activation('relu')(out) | |
| return out | |
| def decodeBlock(prev_layer_input, skip_layer_input, n_filters=32): | |
| up = tf.keras.layers.Conv2DTranspose( | |
| n_filters, | |
| (3,3), # Kernel size | |
| strides=(2,2), | |
| padding='same', | |
| kernel_regularizer=tf.keras.regularizers.l2(1e-4) | |
| )(prev_layer_input) | |
| merge = tf.keras.layers.concatenate([up, skip_layer_input], axis=3) | |
| conv = tf.keras.layers.Conv2D(n_filters, | |
| 3, # Kernel size | |
| activation='relu', | |
| padding='same', | |
| kernel_initializer='he_normal', | |
| kernel_regularizer=tf.keras.regularizers.l2(1e-4) | |
| )(merge) | |
| conv = tf.keras.layers.BatchNormalization()(conv) | |
| conv = tf.keras.layers.Conv2D(n_filters, | |
| 3, # Kernel size | |
| activation='relu', | |
| padding='same', | |
| kernel_initializer='he_normal', | |
| kernel_regularizer=tf.keras.regularizers.l2(1e-4) | |
| )(conv) | |
| conv = tf.keras.layers.BatchNormalization()(conv) | |
| return conv | |
| from tensorflow.keras.applications import EfficientNetB1 | |
| def efficientnet_unet(input_size=(512, 512, 3), n_filter=16, num_classes=5): | |
| inputs = tf.keras.Input(shape=input_size) | |
| x = tf.keras.layers.Rescaling(scale=255.0)(inputs) | |
| backbone = EfficientNetB1( | |
| input_tensor=x, | |
| include_top=False, | |
| weights='imagenet' | |
| ) | |
| backbone.trainable = False | |
| s1 = backbone.get_layer('block2a_expand_activation').output # 256x256 | |
| s2 = backbone.get_layer('block3a_expand_activation').output # 128x128 | |
| s3 = backbone.get_layer('block4a_expand_activation').output # 64x64 | |
| s4 = backbone.get_layer('block6a_expand_activation').output # 32x32 | |
| s5 = backbone.get_layer('top_activation').output # 16x16 | |
| # Bridge ASPP | |
| bridge = aspp_block(s5,filters=n_filter * 8) | |
| # Decoder | |
| d1 = decodeBlock(bridge,s4,n_filter * 16) # 16→32 | |
| d2 = decodeBlock(d1,s3,n_filter * 8) # 32→64 | |
| d3 = decodeBlock(d2,s2,n_filter * 4) # 64→128 | |
| d4 = decodeBlock(d3,s1,n_filter * 2) # 128→256 | |
| x = tf.keras.layers.Conv2DTranspose(n_filter, 3, strides=2, padding='same')(d4) | |
| x = tf.keras.layers.BatchNormalization()(x) | |
| x = tf.keras.layers.Activation('relu')(x) | |
| outputs = tf.keras.layers.Conv2D( | |
| num_classes, 1, activation='softmax', dtype='float32' | |
| )(x) | |
| model = tf.keras.Model(inputs, outputs) | |
| model.backbone = backbone | |
| return model | |
| def deeplab_model(input_size=(512, 512, 3), n_filter=16, num_classes=5): | |
| inputs = tf.keras.Input(shape=input_size) | |
| x = tf.keras.layers.Rescaling(scale=2.0, offset=-1.0)(inputs) | |
| backbone = MobileNetV2( | |
| input_tensor=x, | |
| include_top=False, | |
| weights='imagenet' | |
| ) | |
| backbone.trainable = False | |
| low_level = backbone.get_layer('block_2_add').output # 128x128 | |
| high_level = backbone.get_layer('block_13_expand_relu').output # 32x32 | |
| aspp_b1 = tf.keras.layers.Conv2D(n_filter, (1,1), padding='same')(high_level) | |
| aspp_b1 = tf.keras.layers.BatchNormalization()(aspp_b1) | |
| aspp_b1 = tf.keras.layers.Activation('relu')(aspp_b1) | |
| aspp_b2 = tf.keras.layers.Conv2D(n_filter, (3,3), dilation_rate=6, padding='same')(high_level) | |
| aspp_b2 = tf.keras.layers.BatchNormalization()(aspp_b2) | |
| aspp_b2 = tf.keras.layers.Activation('relu')(aspp_b2) | |
| aspp_b3 = tf.keras.layers.Conv2D(n_filter, (3,3), dilation_rate=12, padding='same')(high_level) | |
| aspp_b3 = tf.keras.layers.BatchNormalization()(aspp_b3) | |
| aspp_b3 = tf.keras.layers.Activation('relu')(aspp_b3) | |
| aspp_b4 = tf.keras.layers.Conv2D(n_filter, (3,3), dilation_rate=18, padding='same')(high_level) | |
| aspp_b4 = tf.keras.layers.BatchNormalization()(aspp_b4) | |
| aspp_b4 = tf.keras.layers.Activation('relu')(aspp_b4) | |
| aspp_b5 = tf.keras.layers.GlobalAveragePooling2D(keepdims=True)(high_level) | |
| aspp_b5 = tf.keras.layers.Conv2D(n_filter, (1,1), padding='same')(aspp_b5) | |
| aspp_b5 = tf.keras.layers.BatchNormalization()(aspp_b5) | |
| aspp_b5 = tf.keras.layers.Activation('relu')(aspp_b5) | |
| #aspp_b5 = tf.keras.layers.UpSampling2D(size=(32, 32), interpolation='bilinear')(aspp_b5) | |
| pool_size = (high_level.shape[1], high_level.shape[2]) | |
| aspp_b5 = tf.keras.layers.Lambda(lambda x: tf.image.resize(x, pool_size, method='bilinear'))(aspp_b5) | |
| aspp_out = tf.keras.layers.Concatenate()([aspp_b1, aspp_b2, aspp_b3, aspp_b4, aspp_b5]) | |
| aspp_out = tf.keras.layers.Conv2D(n_filter, (1,1), padding='same')(aspp_out) | |
| aspp_out = tf.keras.layers.BatchNormalization()(aspp_out) | |
| aspp_out = tf.keras.layers.Activation('relu')(aspp_out) | |
| high_branch = tf.keras.layers.UpSampling2D(size=(4,4), interpolation='bilinear')(aspp_out) # 32→128 | |
| low_branch = tf.keras.layers.Conv2D(48, 1, use_bias=False)(low_level) | |
| low_branch = tf.keras.layers.BatchNormalization()(low_branch) | |
| low_branch = tf.keras.layers.Activation('relu')(low_branch) | |
| encode = tf.keras.layers.Concatenate()([low_branch, high_branch]) # 128x128 | |
| encode = tf.keras.layers.Conv2D(n_filter, (3,3), padding='same')(encode) | |
| encode = tf.keras.layers.BatchNormalization()(encode) | |
| encode = tf.keras.layers.Activation('relu')(encode) | |
| encode = tf.keras.layers.Conv2D(n_filter, (3,3), padding='same')(encode) | |
| encode = tf.keras.layers.BatchNormalization()(encode) | |
| encode = tf.keras.layers.Activation('relu')(encode) | |
| encode = tf.keras.layers.UpSampling2D(size=(4,4), interpolation='bilinear')(encode) # 128→512 | |
| outputs = tf.keras.layers.Conv2D(num_classes, 1, activation='softmax', dtype='float32')(encode) | |
| model = tf.keras.Model(inputs, outputs) | |
| model.backbone = backbone | |
| return model | |
| # --- 6. LOSS & METRICS & OPTIMIZER & CALLBACKS--- | |
| class SparseMeanIoU(tf.keras.metrics.MeanIoU): | |
| def update_state(self, y_true, y_pred, sample_weight=None): | |
| y_pred = tf.argmax(y_pred, axis=-1) | |
| if len(y_true.shape) == 4: | |
| y_true = tf.squeeze(y_true, axis=-1) | |
| y_true = tf.cast(y_true, tf.int64) | |
| y_pred = tf.cast(y_pred, tf.int64) | |
| y_true = tf.where(y_true == 5, tf.cast(1, tf.int64), y_true) | |
| y_true = tf.clip_by_value(y_true, 0, self.num_classes - 1) | |
| return super().update_state(y_true, y_pred, sample_weight) | |
| def tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3): | |
| y_true_sparse = tf.cast(y_true[..., 0], tf.int32) | |
| y_true_oh = tf.one_hot(y_true_sparse, depth=5) | |
| tp = tf.reduce_sum(y_true_oh * y_pred, axis=[1, 2]) | |
| fp = tf.reduce_sum((1 - y_true_oh) * y_pred, axis=[1, 2]) | |
| fn = tf.reduce_sum(y_true_oh * (1 - y_pred), axis=[1, 2]) | |
| tversky = (tp + 1e-6) / (tp + alpha * fn + beta * fp + 1e-6) | |
| return 1.0 - tf.reduce_mean(tversky) | |
| lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts( | |
| initial_learning_rate=1e-4, # LR ban đầu | |
| first_decay_steps=1000, # Số steps mỗi chu kỳ restart | |
| t_mul=2.0, # Mỗi chu kỳ sau dài gấp đôi | |
| m_mul=0.9, # LR đỉnh giảm 10% sau mỗi restart | |
| alpha=1e-6, # LR tối thiểu (không về 0) | |
| ) | |
| # --- 7. TRAINING --- | |
| os.makedirs('./model', exist_ok=True) | |
| # Mô Hình deeplabv3 | |
| print("Đang khởi tạo Dataset trên CPU...") | |
| train_data = load_tfrecord_dataset('./tfrecord/train_*.tfrecord', batch_size=4, training=True) | |
| val_data = load_tfrecord_dataset('./tfrecord/val_*.tfrecord', batch_size=4, training=False) | |
| test_data = load_tfrecord_dataset('./tfrecord/test_*.tfrecord', batch_size=4,training=False) | |
| print("[OK] Dataset đã sẵn sàng") | |
| model = deeplab_model(input_size=(128, 128, 3), n_filter=32) | |
| print("[OK] Đã dựng xong mô hình") | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam( | |
| learning_rate=lr_schedule, | |
| beta_1=0.9, # Momentum (default, giữ nguyên) | |
| beta_2=0.999, # RMSprop (default, giữ nguyên) | |
| epsilon=1e-7, # Tránh chia 0 (default) | |
| clipnorm=1.0, # Gradient clipping: tránh gradient explosion | |
| ), | |
| loss=tversky_loss, | |
| metrics=['accuracy', SparseMeanIoU(num_classes=2, name='mean_iou')] | |
| ) | |
| # Giai đoạn 1 | |
| print("Starting Stage 1...") | |
| history_1 = model.fit(train_data, epochs=EPOCHS_STAGE_1, validation_data=val_data) | |
| plt.figure(figsize=(12,6)) | |
| plt.subplot(1,3,1) | |
| plt.plot(history_1.history['accuracy'], label='accuracy') | |
| plt.plot(history_1.history['val_accuracy'], label='val_accuracy') | |
| plt.title('Model accuracy'); plt.xlabel('Epoch'); plt.ylabel('accuracy'); plt.legend() | |
| plt.subplot(1,3,2) | |
| plt.plot(history_1.history['mean_iou'], label='mean_iou') | |
| plt.plot(history_1.history['val_mean_iou'], label='val_mean_iou') | |
| plt.title('Model mean iou'); plt.xlabel('Epoch'); plt.ylabel('mean iou'); plt.legend() | |
| plt.subplot(1,3,3) | |
| plt.plot(history_1.history['loss'], label='loss') | |
| plt.plot(history_1.history['val_loss'], label='val_loss') | |
| plt.title('Model loss'); plt.xlabel('Epoch'); plt.ylabel('loss'); plt.legend() | |
| plt.savefig('history_stage1.png'); plt.close() | |
| # Giai đoạn 2 | |
| model.backbone.trainable = True | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam( | |
| learning_rate=lr_schedule, | |
| beta_1=0.9, # Momentum (default, giữ nguyên) | |
| beta_2=0.999, # RMSprop (default, giữ nguyên) | |
| epsilon=1e-7, # Tránh chia 0 (default) | |
| clipnorm=1.0, # Gradient clipping: tránh gradient explosion | |
| ), | |
| loss=tversky_loss, | |
| metrics=['accuracy', SparseMeanIoU(num_classes=2, name='mean_iou')] | |
| ) | |
| print("Starting Stage 2 (Fine-tuning)...") | |
| checkpoint = tf.keras.callbacks.ModelCheckpoint( | |
| './model/best_deeplab_eff.keras', | |
| save_best_only=True, monitor='val_mean_iou', mode='max' | |
| ) | |
| history_2 = model.fit(train_data, epochs=EPOCHS_STAGE_2, validation_data=val_data, callbacks=[checkpoint]) | |
| plt.figure(figsize=(12,6)) | |
| plt.subplot(1,3,1) | |
| plt.plot(history_2.history['accuracy'], label='accuracy') | |
| plt.plot(history_2.history['val_accuracy'], label='val_accuracy') | |
| plt.title('Model accuracy'); plt.xlabel('Epoch'); plt.ylabel('accuracy'); plt.legend() | |
| plt.subplot(1,3,2) | |
| plt.plot(history_2.history['mean_iou'], label='mean_iou') | |
| plt.plot(history_2.history['val_mean_iou'], label='val_mean_iou') | |
| plt.title('Model mean iou'); plt.xlabel('Epoch'); plt.ylabel('mean iou'); plt.legend() | |
| plt.subplot(1,3,3) | |
| plt.plot(history_2.history['loss'], label='loss') | |
| plt.plot(history_2.history['val_loss'], label='val_loss') | |
| plt.title('Model loss'); plt.xlabel('Epoch'); plt.ylabel('loss'); plt.legend() | |
| plt.savefig('./model/history_stage2_deeplab_mobilenet.png'); plt.close() | |
| model.save('./model/deeplab_mobilenet.h5') | |
| print("Training Complete. Model Deeplab saved.") | |
| #Mô Hình effnet | |
| model_eff = efficientnet_unet(input_size=(128, 128, 3), n_filter=32) | |
| print("[OK] Đã dựng xong mô hình effnet") | |
| model_eff.compile( | |
| optimizer=tf.keras.optimizers.Adam( | |
| learning_rate=lr_schedule, | |
| beta_1=0.9, # Momentum (default, giữ nguyên) | |
| beta_2=0.999, # RMSprop (default, giữ nguyên) | |
| epsilon=1e-7, # Tránh chia 0 (default) | |
| clipnorm=1.0, # Gradient clipping: tránh gradient explosion | |
| ), | |
| loss=tversky_loss, | |
| metrics=['accuracy', SparseMeanIoU(num_classes=2, name='mean_iou')] | |
| ) | |
| # Giai đoạn 1 | |
| print("Starting Stage 1...") | |
| history_1 = model_eff.fit(train_data, epochs=EPOCHS_STAGE_1, validation_data=val_data) | |
| plt.figure(figsize=(12,6)) | |
| plt.subplot(1,3,1) | |
| plt.plot(history_1.history['accuracy'], label='accuracy') | |
| plt.plot(history_1.history['val_accuracy'], label='val_accuracy') | |
| plt.title('Model accuracy'); plt.xlabel('Epoch'); plt.ylabel('accuracy'); plt.legend() | |
| plt.subplot(1,3,2) | |
| plt.plot(history_1.history['mean_iou'], label='mean_iou') | |
| plt.plot(history_1.history['val_mean_iou'], label='val_mean_iou') | |
| plt.title('Model mean iou'); plt.xlabel('Epoch'); plt.ylabel('mean iou'); plt.legend() | |
| plt.subplot(1,3,3) | |
| plt.plot(history_1.history['loss'], label='loss') | |
| plt.plot(history_1.history['val_loss'], label='val_loss') | |
| plt.title('Model loss'); plt.xlabel('Epoch'); plt.ylabel('loss'); plt.legend() | |
| plt.savefig('history_stage1(effnet).png'); plt.close() | |
| # Giai đoạn 2 | |
| model_eff.backbone.trainable = True | |
| model_eff.compile( | |
| optimizer=tf.keras.optimizers.Adam( | |
| learning_rate=lr_schedule, | |
| beta_1=0.9, # Momentum (default, giữ nguyên) | |
| beta_2=0.999, # RMSprop (default, giữ nguyên) | |
| epsilon=1e-7, # Tránh chia 0 (default) | |
| clipnorm=1.0, # Gradient clipping: tránh gradient explosion | |
| ), | |
| loss=tversky_loss, | |
| metrics=['accuracy', SparseMeanIoU(num_classes=2, name='mean_iou')] | |
| ) | |
| print("Starting Stage 2 (Fine-tuning)...") | |
| checkpoint = tf.keras.callbacks.ModelCheckpoint( | |
| './model/best_eff_net.keras', | |
| save_best_only=True, monitor='val_mean_iou', mode='max' | |
| ) | |
| history_2 = model_eff.fit(train_data, epochs=EPOCHS_STAGE_2, validation_data=val_data, callbacks=[checkpoint]) | |
| plt.figure(figsize=(12,6)) | |
| plt.subplot(1,3,1) | |
| plt.plot(history_2.history['accuracy'], label='accuracy') | |
| plt.plot(history_2.history['val_accuracy'], label='val_accuracy') | |
| plt.title('Model accuracy'); plt.xlabel('Epoch'); plt.ylabel('accuracy'); plt.legend() | |
| plt.subplot(1,3,2) | |
| plt.plot(history_2.history['mean_iou'], label='mean_iou') | |
| plt.plot(history_2.history['val_mean_iou'], label='val_mean_iou') | |
| plt.title('Model mean iou'); plt.xlabel('Epoch'); plt.ylabel('mean iou'); plt.legend() | |
| plt.subplot(1,3,3) | |
| plt.plot(history_2.history['loss'], label='loss') | |
| plt.plot(history_2.history['val_loss'], label='val_loss') | |
| plt.title('Model loss'); plt.xlabel('Epoch'); plt.ylabel('loss'); plt.legend() | |
| plt.savefig('./model/history_stage2_eff_net.png'); plt.close() | |
| model_eff.save('./model/eff_net.h5') | |
| print("Training Complete. Model Effnet saved.") | |
| loss, accuracy, mean_iou = model.evaluate(test_data) | |
| print(f"Loss DeepLabV3+ :{loss}") | |
| print(f"Accuracy DeepLabV3+ :{accuracy}") | |
| print(f"Mean IoU DeepLabV3+ :{mean_iou}") | |
| loss, accuracy, mean_iou = model_eff.evaluate(test_data) | |
| print(f"Loss Effnet :{loss}") | |
| print(f"Accuracy Effnet :{accuracy}") | |
| print(f"Mean IoU Effnet :{mean_iou}") |