Spaces:
Runtime error
Runtime error
| #src/train_optimize.py | |
| import tensorflow as tf | |
| from dataloader import create_dataset | |
| from siamese_model import build_siamese_model | |
| import matplotlib.pyplot as plt | |
| import os | |
| BASE_DIR = os.path.dirname(os.path.dirname(__file__)) | |
| CSV_FILE = os.path.join(BASE_DIR, "pairs", "iris_pairs.csv") | |
| MODEL_NAME = os.path.join(BASE_DIR, "models", "iris_siamese.keras") | |
| # CSV_FILE = "pairs/iris_pairs.csv" | |
| # MODEL_NAME = "models/iris_siamese.h5" | |
| # 1. Load Data | |
| train_ds, val_ds = create_dataset( | |
| CSV_FILE, | |
| batch_size=16, | |
| validation_split=0.2 | |
| ) | |
| # 3. Build & Train | |
| model = build_siamese_model(visualize=True) | |
| #model = build_siamese_model() | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(1e-4), | |
| loss="binary_crossentropy", | |
| metrics=[tf.keras.metrics.AUC(name="auc")] | |
| ) | |
| # model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss="binary_crossentropy", metrics=["accuracy", tf.keras.metrics.AUC(name="auc")]) | |
| # history = model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks) | |
| callbacks = [ | |
| tf.keras.callbacks.ModelCheckpoint( | |
| MODEL_NAME, | |
| save_best_only=True, | |
| monitor="val_loss" | |
| ), | |
| tf.keras.callbacks.EarlyStopping( | |
| monitor="val_loss", | |
| patience=3, | |
| restore_best_weights=True | |
| ), | |
| tf.keras.callbacks.ReduceLROnPlateau( | |
| monitor="val_loss", | |
| factor=0.5, | |
| patience=2, | |
| min_lr=1e-6, | |
| verbose=1 | |
| ) | |
| ] | |
| history = model.fit( | |
| train_ds, | |
| validation_data=val_ds, | |
| epochs=3, # 8, 20 | |
| callbacks=callbacks | |
| ) | |
| # 4. Plot Training History | |
| def plot_history(history): | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) | |
| ax1.plot(history.history['loss'], label='Train Loss') | |
| ax1.plot(history.history['val_loss'], label='Val Loss') | |
| ax1.set_title('Loss') | |
| ax1.legend() | |
| ax2.plot(history.history['auc'], label='Train AUC') | |
| ax2.plot(history.history['val_auc'], label='Val AUC') | |
| ax2.set_title('AUC Score') | |
| ax2.legend() | |
| plt.show() | |
| plot_history(history) | |
| print("✅ Training completed") | |