Spaces:
Running
Running
| """Memory-efficient training using TensorFlow data generators.""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| import tensorflow as tf | |
| from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score | |
| from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau | |
| from src.ai_image_detector.config import ( | |
| ARTIFACTS_DIR, | |
| IMAGE_SIZE, | |
| METRICS_PATH, | |
| MODEL_PATH, | |
| PROCESSED_DATA_DIR, | |
| SEED, | |
| THRESHOLD_PATH, | |
| TRAINING_PLOT_PATH, | |
| ) | |
| from src.ai_image_detector.model import build_model, unfreeze_for_fine_tuning | |
| def get_env_int(name: str, default: int) -> int: | |
| value = os.getenv(name) | |
| if value is None: | |
| return default | |
| try: | |
| parsed = int(value) | |
| except ValueError: | |
| return default | |
| return parsed if parsed > 0 else default | |
| def create_dataset( | |
| data_dir: Path, | |
| batch_size: int, | |
| augment: bool = False, | |
| shuffle: bool = False, | |
| subset: str | None = None, | |
| validation_split: float = 0.0, | |
| seed: int = SEED, | |
| ) -> tf.data.Dataset: | |
| """Create a TensorFlow dataset from directory with streaming.""" | |
| def parse_image(file_path, label): | |
| # Read and decode image | |
| img = tf.io.read_file(file_path) | |
| img = tf.image.decode_image(img, channels=3, expand_animations=False) | |
| img = tf.image.resize(img, IMAGE_SIZE) | |
| img = tf.cast(img, tf.float32) | |
| # MobileNetV2 preprocessing | |
| img = tf.keras.applications.mobilenet_v2.preprocess_input(img) | |
| return img, label | |
| def augment_image(image, label): | |
| image = tf.image.random_flip_left_right(image) | |
| image = tf.image.random_brightness(image, 0.1) | |
| image = tf.image.random_contrast(image, 0.9, 1.1) | |
| image = tf.clip_by_value(image, -1.0, 1.0) # Keep in MobileNetV2 range | |
| return image, label | |
| # Get file paths and labels | |
| real_dir = data_dir / "real" | |
| fake_dir = data_dir / "fake" | |
| real_files = [str(p) for p in real_dir.glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}] | |
| fake_files = [str(p) for p in fake_dir.glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}] | |
| file_paths = real_files + fake_files | |
| labels = [0] * len(real_files) + [1] * len(fake_files) | |
| print(f"Found {len(real_files)} real images") | |
| print(f"Found {len(fake_files)} fake images") | |
| print(f"Total: {len(file_paths)} images") | |
| # Create dataset | |
| dataset = tf.data.Dataset.from_tensor_slices((file_paths, labels)) | |
| if shuffle: | |
| dataset = dataset.shuffle(buffer_size=min(len(file_paths), 10000), seed=seed) | |
| dataset = dataset.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE) | |
| if augment: | |
| dataset = dataset.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE) | |
| dataset = dataset.batch(batch_size) | |
| dataset = dataset.prefetch(tf.data.AUTOTUNE) | |
| return dataset, len(file_paths) | |
| def split_dataset( | |
| data_dir: Path, | |
| batch_size: int, | |
| validation_split: float = 0.3, | |
| test_split: float = 0.15, | |
| seed: int = SEED, | |
| ) -> tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int]: | |
| """Split dataset into train/val/test.""" | |
| def parse_image(file_path, label): | |
| img = tf.io.read_file(file_path) | |
| img = tf.image.decode_image(img, channels=3, expand_animations=False) | |
| img = tf.image.resize(img, IMAGE_SIZE) | |
| img = tf.cast(img, tf.float32) | |
| img = tf.keras.applications.mobilenet_v2.preprocess_input(img) | |
| return img, label | |
| def augment_image(image, label): | |
| image = tf.image.random_flip_left_right(image) | |
| image = tf.image.random_brightness(image, 0.1) | |
| image = tf.image.random_contrast(image, 0.9, 1.1) | |
| image = tf.clip_by_value(image, -1.0, 1.0) # Keep in MobileNetV2 range | |
| return image, label | |
| # Get file paths and labels | |
| real_dir = data_dir / "real" | |
| fake_dir = data_dir / "fake" | |
| real_files = sorted([str(p) for p in real_dir.glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}]) | |
| fake_files = sorted([str(p) for p in fake_dir.glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}]) | |
| # Balance and shuffle | |
| np.random.seed(seed) | |
| min_count = min(len(real_files), len(fake_files)) | |
| real_files = np.random.choice(real_files, min_count, replace=False).tolist() | |
| fake_files = np.random.choice(fake_files, min_count, replace=False).tolist() | |
| file_paths = real_files + fake_files | |
| labels = [0] * len(real_files) + [1] * len(fake_files) | |
| # Shuffle together | |
| indices = np.random.permutation(len(file_paths)) | |
| file_paths = [file_paths[i] for i in indices] | |
| labels = [labels[i] for i in indices] | |
| # Calculate split indices | |
| total = len(file_paths) | |
| test_count = int(total * test_split) | |
| val_count = int(total * validation_split) | |
| train_count = total - val_count - test_count | |
| train_files = file_paths[:train_count] | |
| train_labels = labels[:train_count] | |
| val_files = file_paths[train_count:train_count + val_count] | |
| val_labels = labels[train_count:train_count + val_count] | |
| test_files = file_paths[train_count + val_count:] | |
| test_labels = labels[train_count + val_count:] | |
| print(f"Train: {len(train_files)} | Val: {len(val_files)} | Test: {len(test_files)}") | |
| # Create datasets | |
| train_ds = tf.data.Dataset.from_tensor_slices((train_files, train_labels)) | |
| train_ds = train_ds.shuffle(buffer_size=min(len(train_files), 5000), seed=seed) | |
| train_ds = train_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE) | |
| train_ds = train_ds.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE) | |
| train_ds = train_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) | |
| val_ds = tf.data.Dataset.from_tensor_slices((val_files, val_labels)) | |
| val_ds = val_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE) | |
| val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) | |
| test_ds = tf.data.Dataset.from_tensor_slices((test_files, test_labels)) | |
| test_ds = test_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE) | |
| test_ds = test_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE) | |
| return train_ds, val_ds, test_ds, len(val_files), len(test_files) | |
| def save_training_plot(history) -> None: | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) | |
| axes[0].plot(history.history["accuracy"], label="Train") | |
| axes[0].plot(history.history["val_accuracy"], label="Validation") | |
| axes[0].set_title("Accuracy") | |
| axes[0].set_xlabel("Epoch") | |
| axes[0].set_ylabel("Accuracy") | |
| axes[0].legend() | |
| axes[1].plot(history.history["loss"], label="Train") | |
| axes[1].plot(history.history["val_loss"], label="Validation") | |
| axes[1].set_title("Loss") | |
| axes[1].set_xlabel("Epoch") | |
| axes[1].set_ylabel("Loss") | |
| axes[1].legend() | |
| fig.tight_layout() | |
| fig.savefig(TRAINING_PLOT_PATH, dpi=150) | |
| plt.close(fig) | |
| print(f"Saved training plot to {TRAINING_PLOT_PATH}") | |
| def evaluate_model(model, test_ds, test_count, threshold=0.5): | |
| """Evaluate model on test set.""" | |
| # Collect predictions | |
| y_true = [] | |
| y_pred = [] | |
| y_probs = [] | |
| for images, labels in test_ds: | |
| probs = model.predict(images, verbose=0) | |
| y_probs.extend(probs.flatten().tolist()) | |
| y_pred.extend((probs >= threshold).flatten().astype(int).tolist()) | |
| y_true.extend(labels.numpy().tolist()) | |
| y_true = np.array(y_true) | |
| y_pred = np.array(y_pred) | |
| y_probs = np.array(y_probs) | |
| acc = accuracy_score(y_true, y_pred) | |
| f1 = f1_score(y_true, y_pred, pos_label=1, zero_division=0) | |
| cm = confusion_matrix(y_true, y_pred).tolist() | |
| report = classification_report(y_true, y_pred, target_names=["real", "fake"], output_dict=True, zero_division=0) | |
| metrics = { | |
| "test_accuracy": float(acc), | |
| "test_f1_fake": float(f1), | |
| "threshold": float(threshold), | |
| "confusion_matrix": cm, | |
| "classification_report": report, | |
| } | |
| METRICS_PATH.write_text(json.dumps(metrics, indent=2), encoding="utf-8") | |
| print(f"\nTest Accuracy: {acc:.4f}") | |
| print(f"Test F1 (fake): {f1:.4f}") | |
| print(f"Confusion Matrix:\n{cm}") | |
| return metrics | |
| def main(): | |
| ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True) | |
| if not PROCESSED_DATA_DIR.exists(): | |
| raise FileNotFoundError(f"Dataset not found at {PROCESSED_DATA_DIR}") | |
| batch_size = get_env_int("BATCH_SIZE", 32) | |
| frozen_epochs = get_env_int("FROZEN_EPOCHS", 10) | |
| finetune_epochs = get_env_int("FINETUNE_EPOCHS", 15) | |
| print("Creating datasets...") | |
| train_ds, val_ds, test_ds, val_count, test_count = split_dataset( | |
| PROCESSED_DATA_DIR, batch_size=batch_size | |
| ) | |
| print(f"\nBuilding model...") | |
| model = build_model() | |
| # Stage 1: Train with frozen base | |
| print(f"\n{'='*50}") | |
| print("Stage 1: Training with frozen base") | |
| print(f"{'='*50}") | |
| callbacks_frozen = [ | |
| EarlyStopping(monitor="val_auc", mode="max", patience=4, restore_best_weights=True), | |
| ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2), | |
| ModelCheckpoint(str(MODEL_PATH), monitor="val_auc", mode="max", save_best_only=True), | |
| ] | |
| history1 = model.fit( | |
| train_ds, | |
| validation_data=val_ds, | |
| epochs=frozen_epochs, | |
| callbacks=callbacks_frozen, | |
| verbose=1, | |
| ) | |
| # Stage 2: Fine-tune | |
| print(f"\n{'='*50}") | |
| print("Stage 2: Fine-tuning") | |
| print(f"{'='*50}") | |
| model = tf.keras.models.load_model(str(MODEL_PATH)) | |
| unfreeze_for_fine_tuning(model, trainable_layers=45) | |
| callbacks_finetune = [ | |
| EarlyStopping(monitor="val_auc", mode="max", patience=5, restore_best_weights=True), | |
| ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2), | |
| ModelCheckpoint(str(MODEL_PATH), monitor="val_auc", mode="max", save_best_only=True), | |
| ] | |
| history2 = model.fit( | |
| train_ds, | |
| validation_data=val_ds, | |
| epochs=finetune_epochs, | |
| callbacks=callbacks_finetune, | |
| verbose=1, | |
| ) | |
| # Evaluate | |
| print(f"\n{'='*50}") | |
| print("Final Evaluation") | |
| print(f"{'='*50}") | |
| model = tf.keras.models.load_model(str(MODEL_PATH)) | |
| evaluate_model(model, test_ds, test_count) | |
| # Save plots | |
| class CombinedHistory: | |
| def __init__(self, h1, h2): | |
| self.history = {} | |
| for key in h1.history: | |
| self.history[key] = h1.history[key] + h2.history[key] | |
| save_training_plot(CombinedHistory(history1, history2)) | |
| print(f"\n{'='*50}") | |
| print("Training complete!") | |
| print(f"Model saved to: {MODEL_PATH}") | |
| print(f"Metrics saved to: {METRICS_PATH}") | |
| print(f"{'='*50}") | |
| if __name__ == "__main__": | |
| main() | |