smart-election-verification / src /train_optimize.py
selvaneyas's picture
Upload 7 files
e619b9a verified
#src/train_optimize.py
import tensorflow as tf
from dataloader import create_dataset
from siamese_model import build_siamese_model
import matplotlib.pyplot as plt
import os
BASE_DIR = os.path.dirname(os.path.dirname(__file__))
CSV_FILE = os.path.join(BASE_DIR, "pairs", "iris_pairs.csv")
MODEL_NAME = os.path.join(BASE_DIR, "models", "iris_siamese.keras")
# CSV_FILE = "pairs/iris_pairs.csv"
# MODEL_NAME = "models/iris_siamese.h5"
# 1. Load Data
train_ds, val_ds = create_dataset(
CSV_FILE,
batch_size=16,
validation_split=0.2
)
# 3. Build & Train
model = build_siamese_model(visualize=True)
#model = build_siamese_model()
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-4),
loss="binary_crossentropy",
metrics=[tf.keras.metrics.AUC(name="auc")]
)
# model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss="binary_crossentropy", metrics=["accuracy", tf.keras.metrics.AUC(name="auc")])
# history = model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks)
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
MODEL_NAME,
save_best_only=True,
monitor="val_loss"
),
tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
patience=3,
restore_best_weights=True
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss",
factor=0.5,
patience=2,
min_lr=1e-6,
verbose=1
)
]
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=3, # 8, 20
callbacks=callbacks
)
# 4. Plot Training History
def plot_history(history):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history.history['loss'], label='Train Loss')
ax1.plot(history.history['val_loss'], label='Val Loss')
ax1.set_title('Loss')
ax1.legend()
ax2.plot(history.history['auc'], label='Train AUC')
ax2.plot(history.history['val_auc'], label='Val AUC')
ax2.set_title('AUC Score')
ax2.legend()
plt.show()
plot_history(history)
print("✅ Training completed")