Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import tensorflow as tf | |
| from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score | |
| from sklearn.model_selection import train_test_split | |
| from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau | |
| from src.ai_image_detector.config import ( | |
| ARTIFACTS_DIR, | |
| METRICS_PATH, | |
| MODEL_PATH, | |
| PROCESSED_DATA_DIR, | |
| SEED, | |
| THRESHOLD_PATH, | |
| TRAINING_PLOT_PATH, | |
| ) | |
| from src.ai_image_detector.data import load_dataset | |
| from src.ai_image_detector.model import build_model, unfreeze_for_fine_tuning | |
| def get_env_int(name: str, default: int) -> int: | |
| value = os.getenv(name) | |
| if value is None: | |
| return default | |
| try: | |
| parsed = int(value) | |
| except ValueError: | |
| return default | |
| return parsed if parsed > 0 else default | |
| def make_datasets( | |
| x: np.ndarray, | |
| y: np.ndarray, | |
| batch_size: int, | |
| ) -> tuple[tf.data.Dataset, tf.data.Dataset, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| x_train, x_temp, y_train, y_temp = train_test_split( | |
| x, | |
| y, | |
| test_size=0.3, | |
| random_state=SEED, | |
| stratify=y, | |
| ) | |
| x_val, x_test, y_val, y_test = train_test_split( | |
| x_temp, | |
| y_temp, | |
| test_size=0.5, | |
| random_state=SEED, | |
| stratify=y_temp, | |
| ) | |
| augmenter = tf.keras.Sequential( | |
| [ | |
| tf.keras.layers.RandomFlip("horizontal"), | |
| tf.keras.layers.RandomRotation(0.05), | |
| tf.keras.layers.RandomZoom(0.1), | |
| tf.keras.layers.RandomContrast(0.1), | |
| ] | |
| ) | |
| train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)) | |
| train_ds = train_ds.shuffle(len(x_train), seed=SEED) | |
| train_ds = train_ds.batch(batch_size) | |
| train_ds = train_ds.map( | |
| lambda images, labels: (augmenter(images, training=True), labels), | |
| num_parallel_calls=tf.data.AUTOTUNE, | |
| ) | |
| train_ds = train_ds.prefetch(tf.data.AUTOTUNE) | |
| val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)) | |
| val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) | |
| return train_ds, val_ds, x_val, y_val, x_test, y_test | |
| def combine_histories( | |
| first_history: tf.keras.callbacks.History, | |
| second_history: tf.keras.callbacks.History, | |
| ) -> dict[str, list[float]]: | |
| combined: dict[str, list[float]] = {} | |
| for history in (first_history.history, second_history.history): | |
| for key, values in history.items(): | |
| combined.setdefault(key, []).extend(values) | |
| return combined | |
| def save_training_plot(history_data: dict[str, list[float]]) -> None: | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) | |
| axes[0].plot(history_data["accuracy"], label="Train") | |
| axes[0].plot(history_data["val_accuracy"], label="Validation") | |
| axes[0].set_title("Accuracy") | |
| axes[0].set_xlabel("Epoch") | |
| axes[0].set_ylabel("Accuracy") | |
| axes[0].legend() | |
| axes[1].plot(history_data["loss"], label="Train") | |
| axes[1].plot(history_data["val_loss"], label="Validation") | |
| axes[1].set_title("Loss") | |
| axes[1].set_xlabel("Epoch") | |
| axes[1].set_ylabel("Loss") | |
| axes[1].legend() | |
| fig.tight_layout() | |
| fig.savefig(TRAINING_PLOT_PATH, dpi=150) | |
| plt.close(fig) | |
| def predict_probabilities( | |
| model: tf.keras.Model, | |
| x: np.ndarray, | |
| batch_size: int = 32, | |
| ) -> np.ndarray: | |
| dataset = tf.data.Dataset.from_tensor_slices(x) | |
| dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE) | |
| return model.predict(dataset, verbose=0).ravel() | |
| def evaluate_and_save_metrics( | |
| model: tf.keras.Model, | |
| x_test: np.ndarray, | |
| y_test: np.ndarray, | |
| threshold_info: dict[str, float], | |
| ) -> None: | |
| threshold = float(threshold_info["threshold"]) | |
| results = model.evaluate(x_test, y_test, verbose=0, return_dict=True) | |
| predictions = predict_probabilities(model, x_test, batch_size=32) | |
| predicted_classes = (predictions >= threshold).astype(int) | |
| predicted_classes_default = (predictions >= 0.5).astype(int) | |
| report = classification_report( | |
| y_test, | |
| predicted_classes, | |
| target_names=["real", "fake"], | |
| output_dict=True, | |
| zero_division=0, | |
| ) | |
| matrix = confusion_matrix(y_test, predicted_classes).tolist() | |
| metrics = { | |
| "evaluation": {key: float(value) for key, value in results.items()}, | |
| "thresholding": { | |
| "default_threshold": 0.5, | |
| "calibrated_threshold": threshold, | |
| "test_accuracy_default": float(accuracy_score(y_test, predicted_classes_default)), | |
| "test_accuracy_calibrated": float(accuracy_score(y_test, predicted_classes)), | |
| "test_f1_fake_calibrated": float(f1_score(y_test, predicted_classes, pos_label=1)), | |
| }, | |
| "confusion_matrix": matrix, | |
| "classification_report": report, | |
| } | |
| METRICS_PATH.write_text(json.dumps(metrics, indent=2), encoding="utf-8") | |
| def calibrate_threshold(y_val: np.ndarray, val_probs: np.ndarray) -> dict[str, float]: | |
| thresholds = np.linspace(0.2, 0.8, 241) | |
| best_acc = -1.0 | |
| best_f1 = -1.0 | |
| best_threshold = 0.5 | |
| default_acc = float(accuracy_score(y_val, (val_probs >= 0.5).astype(int))) | |
| for threshold in thresholds: | |
| predicted = (val_probs >= threshold).astype(int) | |
| acc = accuracy_score(y_val, predicted) | |
| f1_fake = f1_score(y_val, predicted, pos_label=1, zero_division=0) | |
| if acc > best_acc or (acc == best_acc and f1_fake > best_f1): | |
| best_acc = acc | |
| best_f1 = f1_fake | |
| best_threshold = float(threshold) | |
| if best_acc < default_acc + 0.02: | |
| best_threshold = 0.5 | |
| best_threshold = float(np.clip(best_threshold, 0.35, 0.65)) | |
| margin = 0.10 | |
| uncertain_low = float(np.clip(best_threshold - margin, 0.0, 1.0)) | |
| uncertain_high = float(np.clip(best_threshold + margin, 0.0, 1.0)) | |
| return { | |
| "threshold": best_threshold, | |
| "uncertain_low": uncertain_low, | |
| "uncertain_high": uncertain_high, | |
| "validation_accuracy_default_0_5": default_acc, | |
| "validation_accuracy": float(accuracy_score(y_val, (val_probs >= best_threshold).astype(int))), | |
| "validation_f1_fake": float( | |
| f1_score(y_val, (val_probs >= best_threshold).astype(int), pos_label=1, zero_division=0) | |
| ), | |
| } | |
| def main() -> None: | |
| ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True) | |
| if not PROCESSED_DATA_DIR.exists(): | |
| raise FileNotFoundError( | |
| f"Dataset folder not found at {PROCESSED_DATA_DIR}. " | |
| "Create data/processed/real and data/processed/fake and put images there." | |
| ) | |
| batch_size = get_env_int("BATCH_SIZE", 32) | |
| frozen_epochs = get_env_int("FROZEN_EPOCHS", 10) | |
| finetune_epochs = get_env_int("FINETUNE_EPOCHS", 18) | |
| x, y, _ = load_dataset(PROCESSED_DATA_DIR) | |
| train_ds, val_ds, x_val, y_val, x_test, y_test = make_datasets(x, y, batch_size=batch_size) | |
| model = build_model() | |
| callbacks_frozen = [ | |
| EarlyStopping( | |
| monitor="val_auc", | |
| mode="max", | |
| patience=4, | |
| restore_best_weights=True, | |
| ), | |
| ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2), | |
| ModelCheckpoint( | |
| MODEL_PATH, | |
| monitor="val_auc", | |
| mode="max", | |
| save_best_only=True, | |
| ), | |
| ] | |
| frozen_history = model.fit( | |
| train_ds, | |
| validation_data=val_ds, | |
| epochs=frozen_epochs, | |
| callbacks=callbacks_frozen, | |
| verbose=1, | |
| ) | |
| model = tf.keras.models.load_model(MODEL_PATH) | |
| unfreeze_for_fine_tuning(model, trainable_layers=45) | |
| callbacks_finetune = [ | |
| EarlyStopping( | |
| monitor="val_auc", | |
| mode="max", | |
| patience=5, | |
| restore_best_weights=True, | |
| ), | |
| ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2), | |
| ModelCheckpoint( | |
| MODEL_PATH, | |
| monitor="val_auc", | |
| mode="max", | |
| save_best_only=True, | |
| ), | |
| ] | |
| finetune_history = model.fit( | |
| train_ds, | |
| validation_data=val_ds, | |
| epochs=finetune_epochs, | |
| callbacks=callbacks_finetune, | |
| verbose=1, | |
| ) | |
| model = tf.keras.models.load_model(MODEL_PATH) | |
| val_predictions = predict_probabilities(model, x_val, batch_size=32) | |
| threshold_info = calibrate_threshold(y_val, val_predictions) | |
| THRESHOLD_PATH.write_text(json.dumps(threshold_info, indent=2), encoding="utf-8") | |
| save_training_plot(combine_histories(frozen_history, finetune_history)) | |
| evaluate_and_save_metrics(model, x_test, y_test, threshold_info) | |
| print(f"Training complete. Model saved to: {MODEL_PATH}") | |
| print(f"Threshold config saved to: {THRESHOLD_PATH}") | |
| print(f"Metrics saved to: {METRICS_PATH}") | |
| print(f"Training plot saved to: {TRAINING_PLOT_PATH}") | |
| if __name__ == "__main__": | |
| main() | |