sasmithasarasi's picture
Upload 22 files
ddec2b7 verified
raw
history blame contribute delete
952 Bytes
import tensorflow as tf
import numpy as np
import os
print("=" * 60)
print("TRANSFER LEARNING TRAINING")
print("=" * 60)
print("Building MobileNetV3 with ImageNet weights...")
base_model = tf.keras.applications.MobileNetV3Small(weights="imagenet", include_top=False, input_shape=(224,224,3))
base_model.trainable = False
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(64, activation="relu"),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(2, activation="softmax")
])
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
print("Model built!")
X = np.random.random((50, 224, 224, 3))
y = tf.keras.utils.to_categorical(np.random.randint(0, 2, 50), 2)
print("Training...")
model.fit(X, y, epochs=3, batch_size=8)
os.makedirs("data/models", exist_ok=True)
model.save("data/models/transfer_learning_model.h5")
print("Model saved!")