SDK-Docker / train.py
Lucifer9907's picture
Prepare Hugging Face Docker Space
ff0c419
from __future__ import annotations
import json
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from sklearn.model_selection import train_test_split
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from src.ai_image_detector.config import (
ARTIFACTS_DIR,
METRICS_PATH,
MODEL_PATH,
PROCESSED_DATA_DIR,
SEED,
THRESHOLD_PATH,
TRAINING_PLOT_PATH,
)
from src.ai_image_detector.data import load_dataset
from src.ai_image_detector.model import build_model, unfreeze_for_fine_tuning
def get_env_int(name: str, default: int) -> int:
value = os.getenv(name)
if value is None:
return default
try:
parsed = int(value)
except ValueError:
return default
return parsed if parsed > 0 else default
def make_datasets(
x: np.ndarray,
y: np.ndarray,
batch_size: int,
) -> tuple[tf.data.Dataset, tf.data.Dataset, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
x_train, x_temp, y_train, y_temp = train_test_split(
x,
y,
test_size=0.3,
random_state=SEED,
stratify=y,
)
x_val, x_test, y_val, y_test = train_test_split(
x_temp,
y_temp,
test_size=0.5,
random_state=SEED,
stratify=y_temp,
)
augmenter = tf.keras.Sequential(
[
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(0.05),
tf.keras.layers.RandomZoom(0.1),
tf.keras.layers.RandomContrast(0.1),
]
)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(len(x_train), seed=SEED)
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.map(
lambda images, labels: (augmenter(images, training=True), labels),
num_parallel_calls=tf.data.AUTOTUNE,
)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return train_ds, val_ds, x_val, y_val, x_test, y_test
def combine_histories(
first_history: tf.keras.callbacks.History,
second_history: tf.keras.callbacks.History,
) -> dict[str, list[float]]:
combined: dict[str, list[float]] = {}
for history in (first_history.history, second_history.history):
for key, values in history.items():
combined.setdefault(key, []).extend(values)
return combined
def save_training_plot(history_data: dict[str, list[float]]) -> None:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history_data["accuracy"], label="Train")
axes[0].plot(history_data["val_accuracy"], label="Validation")
axes[0].set_title("Accuracy")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Accuracy")
axes[0].legend()
axes[1].plot(history_data["loss"], label="Train")
axes[1].plot(history_data["val_loss"], label="Validation")
axes[1].set_title("Loss")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")
axes[1].legend()
fig.tight_layout()
fig.savefig(TRAINING_PLOT_PATH, dpi=150)
plt.close(fig)
def predict_probabilities(
model: tf.keras.Model,
x: np.ndarray,
batch_size: int = 32,
) -> np.ndarray:
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return model.predict(dataset, verbose=0).ravel()
def evaluate_and_save_metrics(
model: tf.keras.Model,
x_test: np.ndarray,
y_test: np.ndarray,
threshold_info: dict[str, float],
) -> None:
threshold = float(threshold_info["threshold"])
results = model.evaluate(x_test, y_test, verbose=0, return_dict=True)
predictions = predict_probabilities(model, x_test, batch_size=32)
predicted_classes = (predictions >= threshold).astype(int)
predicted_classes_default = (predictions >= 0.5).astype(int)
report = classification_report(
y_test,
predicted_classes,
target_names=["real", "fake"],
output_dict=True,
zero_division=0,
)
matrix = confusion_matrix(y_test, predicted_classes).tolist()
metrics = {
"evaluation": {key: float(value) for key, value in results.items()},
"thresholding": {
"default_threshold": 0.5,
"calibrated_threshold": threshold,
"test_accuracy_default": float(accuracy_score(y_test, predicted_classes_default)),
"test_accuracy_calibrated": float(accuracy_score(y_test, predicted_classes)),
"test_f1_fake_calibrated": float(f1_score(y_test, predicted_classes, pos_label=1)),
},
"confusion_matrix": matrix,
"classification_report": report,
}
METRICS_PATH.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
def calibrate_threshold(y_val: np.ndarray, val_probs: np.ndarray) -> dict[str, float]:
thresholds = np.linspace(0.2, 0.8, 241)
best_acc = -1.0
best_f1 = -1.0
best_threshold = 0.5
default_acc = float(accuracy_score(y_val, (val_probs >= 0.5).astype(int)))
for threshold in thresholds:
predicted = (val_probs >= threshold).astype(int)
acc = accuracy_score(y_val, predicted)
f1_fake = f1_score(y_val, predicted, pos_label=1, zero_division=0)
if acc > best_acc or (acc == best_acc and f1_fake > best_f1):
best_acc = acc
best_f1 = f1_fake
best_threshold = float(threshold)
if best_acc < default_acc + 0.02:
best_threshold = 0.5
best_threshold = float(np.clip(best_threshold, 0.35, 0.65))
margin = 0.10
uncertain_low = float(np.clip(best_threshold - margin, 0.0, 1.0))
uncertain_high = float(np.clip(best_threshold + margin, 0.0, 1.0))
return {
"threshold": best_threshold,
"uncertain_low": uncertain_low,
"uncertain_high": uncertain_high,
"validation_accuracy_default_0_5": default_acc,
"validation_accuracy": float(accuracy_score(y_val, (val_probs >= best_threshold).astype(int))),
"validation_f1_fake": float(
f1_score(y_val, (val_probs >= best_threshold).astype(int), pos_label=1, zero_division=0)
),
}
def main() -> None:
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
if not PROCESSED_DATA_DIR.exists():
raise FileNotFoundError(
f"Dataset folder not found at {PROCESSED_DATA_DIR}. "
"Create data/processed/real and data/processed/fake and put images there."
)
batch_size = get_env_int("BATCH_SIZE", 32)
frozen_epochs = get_env_int("FROZEN_EPOCHS", 10)
finetune_epochs = get_env_int("FINETUNE_EPOCHS", 18)
x, y, _ = load_dataset(PROCESSED_DATA_DIR)
train_ds, val_ds, x_val, y_val, x_test, y_test = make_datasets(x, y, batch_size=batch_size)
model = build_model()
callbacks_frozen = [
EarlyStopping(
monitor="val_auc",
mode="max",
patience=4,
restore_best_weights=True,
),
ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2),
ModelCheckpoint(
MODEL_PATH,
monitor="val_auc",
mode="max",
save_best_only=True,
),
]
frozen_history = model.fit(
train_ds,
validation_data=val_ds,
epochs=frozen_epochs,
callbacks=callbacks_frozen,
verbose=1,
)
model = tf.keras.models.load_model(MODEL_PATH)
unfreeze_for_fine_tuning(model, trainable_layers=45)
callbacks_finetune = [
EarlyStopping(
monitor="val_auc",
mode="max",
patience=5,
restore_best_weights=True,
),
ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2),
ModelCheckpoint(
MODEL_PATH,
monitor="val_auc",
mode="max",
save_best_only=True,
),
]
finetune_history = model.fit(
train_ds,
validation_data=val_ds,
epochs=finetune_epochs,
callbacks=callbacks_finetune,
verbose=1,
)
model = tf.keras.models.load_model(MODEL_PATH)
val_predictions = predict_probabilities(model, x_val, batch_size=32)
threshold_info = calibrate_threshold(y_val, val_predictions)
THRESHOLD_PATH.write_text(json.dumps(threshold_info, indent=2), encoding="utf-8")
save_training_plot(combine_histories(frozen_history, finetune_history))
evaluate_and_save_metrics(model, x_test, y_test, threshold_info)
print(f"Training complete. Model saved to: {MODEL_PATH}")
print(f"Threshold config saved to: {THRESHOLD_PATH}")
print(f"Metrics saved to: {METRICS_PATH}")
print(f"Training plot saved to: {TRAINING_PLOT_PATH}")
if __name__ == "__main__":
main()