File size: 2,153 Bytes
e619b9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#src/train_optimize.py
import tensorflow as tf
from dataloader import create_dataset
from siamese_model import build_siamese_model
import matplotlib.pyplot as plt

import os

BASE_DIR = os.path.dirname(os.path.dirname(__file__))

CSV_FILE = os.path.join(BASE_DIR, "pairs", "iris_pairs.csv")
MODEL_NAME = os.path.join(BASE_DIR, "models", "iris_siamese.keras")



# CSV_FILE = "pairs/iris_pairs.csv"
# MODEL_NAME = "models/iris_siamese.h5"

# 1. Load Data
train_ds, val_ds = create_dataset(
    CSV_FILE,
    batch_size=16,
    validation_split=0.2
)


# 3. Build & Train
model = build_siamese_model(visualize=True)
#model = build_siamese_model()

model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss="binary_crossentropy",
    metrics=[tf.keras.metrics.AUC(name="auc")]
)

# model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss="binary_crossentropy", metrics=["accuracy", tf.keras.metrics.AUC(name="auc")])
# history = model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks)

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        MODEL_NAME,
        save_best_only=True,
        monitor="val_loss"
    ),
    tf.keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=3,
        restore_best_weights=True
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=2,
        min_lr=1e-6,
        verbose=1
    )
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=3,  # 8, 20
    callbacks=callbacks
)


# 4. Plot Training History
def plot_history(history):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    ax1.plot(history.history['loss'], label='Train Loss')
    ax1.plot(history.history['val_loss'], label='Val Loss')
    ax1.set_title('Loss')
    ax1.legend()
    
    ax2.plot(history.history['auc'], label='Train AUC')
    ax2.plot(history.history['val_auc'], label='Val AUC')
    ax2.set_title('AUC Score')
    ax2.legend()
    plt.show()

plot_history(history)


print("✅ Training completed")