FairRelay / ops /Model /training.py
MouleeswaranM's picture
Upload folder using huggingface_hub
fcf8749 verified
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
print("TensorFlow:", tf.__version__)
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS_HEAD = 15
EPOCHS_FINE = 20
DATA_DIR = "/content/drive/MyDrive/Images"
AUTOTUNE = tf.data.AUTOTUNE
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset="training",
seed=42,
image_size=IMG_SIZE,
batch_size=BATCH_SIZE,
label_mode="binary"
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
DATA_DIR,
validation_split=0.2,
subset="validation",
seed=42,
image_size=IMG_SIZE,
batch_size=BATCH_SIZE,
label_mode="binary"
)
train_ds = train_ds.cache().shuffle(1000).prefetch(AUTOTUNE)
val_ds = val_ds.cache().prefetch(AUTOTUNE)
labels = np.concatenate([y.numpy() for x, y in train_ds])
neg, pos = np.bincount(labels.astype(int))
total = neg + pos
class_weight = {
0: total / (2 * neg),
1: total / (2 * pos)
}
print("Class weights:", class_weight)
data_augmentation = tf.keras.Sequential([
layers.RandomFlip("horizontal"),
layers.RandomRotation(0.05),
layers.RandomZoom(0.1),
layers.RandomContrast(0.1),
])
base_model = tf.keras.applications.MobileNetV2(
input_shape=IMG_SIZE + (3,),
include_top=False,
weights="imagenet"
)
base_model.trainable = False
inputs = layers.Input(shape=IMG_SIZE + (3,))
x = data_augmentation(inputs)
x = tf.keras.applications.mobilenet_v2.preprocess_input(x)
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.BatchNormalization()(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.Dropout(0.4)(x)
outputs = layers.Dense(
1,
activation="sigmoid"
)(x)
model = models.Model(inputs, outputs)
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-3),
loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.05),
metrics=["accuracy"]
)
model.summary()
history_head = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS_HEAD,
class_weight=class_weight,
callbacks=[
tf.keras.callbacks.EarlyStopping(
patience=4,
restore_best_weights=True
)
]
)
base_model.trainable = True
for layer in base_model.layers[:-40]:
layer.trainable = False
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-4),
loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.05),
metrics=["accuracy"]
)
history_fine = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS_FINE,
class_weight=class_weight,
callbacks=[
tf.keras.callbacks.EarlyStopping(
patience=5,
restore_best_weights=True
),
tf.keras.callbacks.ReduceLROnPlateau(
factor=0.3,
patience=3
)
]
)