SDK-Docker / train_generator.py
Lucifer9907's picture
Prepare Hugging Face Docker Space
ff0c419
"""Memory-efficient training using TensorFlow data generators."""
from __future__ import annotations
import json
import os
from pathlib import Path
import numpy as np
import tensorflow as tf
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from src.ai_image_detector.config import (
ARTIFACTS_DIR,
IMAGE_SIZE,
METRICS_PATH,
MODEL_PATH,
PROCESSED_DATA_DIR,
SEED,
THRESHOLD_PATH,
TRAINING_PLOT_PATH,
)
from src.ai_image_detector.model import build_model, unfreeze_for_fine_tuning
def get_env_int(name: str, default: int) -> int:
value = os.getenv(name)
if value is None:
return default
try:
parsed = int(value)
except ValueError:
return default
return parsed if parsed > 0 else default
def create_dataset(
data_dir: Path,
batch_size: int,
augment: bool = False,
shuffle: bool = False,
subset: str | None = None,
validation_split: float = 0.0,
seed: int = SEED,
) -> tf.data.Dataset:
"""Create a TensorFlow dataset from directory with streaming."""
def parse_image(file_path, label):
# Read and decode image
img = tf.io.read_file(file_path)
img = tf.image.decode_image(img, channels=3, expand_animations=False)
img = tf.image.resize(img, IMAGE_SIZE)
img = tf.cast(img, tf.float32)
# MobileNetV2 preprocessing
img = tf.keras.applications.mobilenet_v2.preprocess_input(img)
return img, label
def augment_image(image, label):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, 0.1)
image = tf.image.random_contrast(image, 0.9, 1.1)
image = tf.clip_by_value(image, -1.0, 1.0) # Keep in MobileNetV2 range
return image, label
# Get file paths and labels
real_dir = data_dir / "real"
fake_dir = data_dir / "fake"
real_files = [str(p) for p in real_dir.glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}]
fake_files = [str(p) for p in fake_dir.glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}]
file_paths = real_files + fake_files
labels = [0] * len(real_files) + [1] * len(fake_files)
print(f"Found {len(real_files)} real images")
print(f"Found {len(fake_files)} fake images")
print(f"Total: {len(file_paths)} images")
# Create dataset
dataset = tf.data.Dataset.from_tensor_slices((file_paths, labels))
if shuffle:
dataset = dataset.shuffle(buffer_size=min(len(file_paths), 10000), seed=seed)
dataset = dataset.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
if augment:
dataset = dataset.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset, len(file_paths)
def split_dataset(
data_dir: Path,
batch_size: int,
validation_split: float = 0.3,
test_split: float = 0.15,
seed: int = SEED,
) -> tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, int, int]:
"""Split dataset into train/val/test."""
def parse_image(file_path, label):
img = tf.io.read_file(file_path)
img = tf.image.decode_image(img, channels=3, expand_animations=False)
img = tf.image.resize(img, IMAGE_SIZE)
img = tf.cast(img, tf.float32)
img = tf.keras.applications.mobilenet_v2.preprocess_input(img)
return img, label
def augment_image(image, label):
image = tf.image.random_flip_left_right(image)
image = tf.image.random_brightness(image, 0.1)
image = tf.image.random_contrast(image, 0.9, 1.1)
image = tf.clip_by_value(image, -1.0, 1.0) # Keep in MobileNetV2 range
return image, label
# Get file paths and labels
real_dir = data_dir / "real"
fake_dir = data_dir / "fake"
real_files = sorted([str(p) for p in real_dir.glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}])
fake_files = sorted([str(p) for p in fake_dir.glob("*") if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".bmp", ".webp"}])
# Balance and shuffle
np.random.seed(seed)
min_count = min(len(real_files), len(fake_files))
real_files = np.random.choice(real_files, min_count, replace=False).tolist()
fake_files = np.random.choice(fake_files, min_count, replace=False).tolist()
file_paths = real_files + fake_files
labels = [0] * len(real_files) + [1] * len(fake_files)
# Shuffle together
indices = np.random.permutation(len(file_paths))
file_paths = [file_paths[i] for i in indices]
labels = [labels[i] for i in indices]
# Calculate split indices
total = len(file_paths)
test_count = int(total * test_split)
val_count = int(total * validation_split)
train_count = total - val_count - test_count
train_files = file_paths[:train_count]
train_labels = labels[:train_count]
val_files = file_paths[train_count:train_count + val_count]
val_labels = labels[train_count:train_count + val_count]
test_files = file_paths[train_count + val_count:]
test_labels = labels[train_count + val_count:]
print(f"Train: {len(train_files)} | Val: {len(val_files)} | Test: {len(test_files)}")
# Create datasets
train_ds = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
train_ds = train_ds.shuffle(buffer_size=min(len(train_files), 5000), seed=seed)
train_ds = train_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
val_ds = tf.data.Dataset.from_tensor_slices((val_files, val_labels))
val_ds = val_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_ds = tf.data.Dataset.from_tensor_slices((test_files, test_labels))
test_ds = test_ds.map(parse_image, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
return train_ds, val_ds, test_ds, len(val_files), len(test_files)
def save_training_plot(history) -> None:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].plot(history.history["accuracy"], label="Train")
axes[0].plot(history.history["val_accuracy"], label="Validation")
axes[0].set_title("Accuracy")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Accuracy")
axes[0].legend()
axes[1].plot(history.history["loss"], label="Train")
axes[1].plot(history.history["val_loss"], label="Validation")
axes[1].set_title("Loss")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Loss")
axes[1].legend()
fig.tight_layout()
fig.savefig(TRAINING_PLOT_PATH, dpi=150)
plt.close(fig)
print(f"Saved training plot to {TRAINING_PLOT_PATH}")
def evaluate_model(model, test_ds, test_count, threshold=0.5):
"""Evaluate model on test set."""
# Collect predictions
y_true = []
y_pred = []
y_probs = []
for images, labels in test_ds:
probs = model.predict(images, verbose=0)
y_probs.extend(probs.flatten().tolist())
y_pred.extend((probs >= threshold).flatten().astype(int).tolist())
y_true.extend(labels.numpy().tolist())
y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_probs = np.array(y_probs)
acc = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, pos_label=1, zero_division=0)
cm = confusion_matrix(y_true, y_pred).tolist()
report = classification_report(y_true, y_pred, target_names=["real", "fake"], output_dict=True, zero_division=0)
metrics = {
"test_accuracy": float(acc),
"test_f1_fake": float(f1),
"threshold": float(threshold),
"confusion_matrix": cm,
"classification_report": report,
}
METRICS_PATH.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
print(f"\nTest Accuracy: {acc:.4f}")
print(f"Test F1 (fake): {f1:.4f}")
print(f"Confusion Matrix:\n{cm}")
return metrics
def main():
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)
if not PROCESSED_DATA_DIR.exists():
raise FileNotFoundError(f"Dataset not found at {PROCESSED_DATA_DIR}")
batch_size = get_env_int("BATCH_SIZE", 32)
frozen_epochs = get_env_int("FROZEN_EPOCHS", 10)
finetune_epochs = get_env_int("FINETUNE_EPOCHS", 15)
print("Creating datasets...")
train_ds, val_ds, test_ds, val_count, test_count = split_dataset(
PROCESSED_DATA_DIR, batch_size=batch_size
)
print(f"\nBuilding model...")
model = build_model()
# Stage 1: Train with frozen base
print(f"\n{'='*50}")
print("Stage 1: Training with frozen base")
print(f"{'='*50}")
callbacks_frozen = [
EarlyStopping(monitor="val_auc", mode="max", patience=4, restore_best_weights=True),
ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2),
ModelCheckpoint(str(MODEL_PATH), monitor="val_auc", mode="max", save_best_only=True),
]
history1 = model.fit(
train_ds,
validation_data=val_ds,
epochs=frozen_epochs,
callbacks=callbacks_frozen,
verbose=1,
)
# Stage 2: Fine-tune
print(f"\n{'='*50}")
print("Stage 2: Fine-tuning")
print(f"{'='*50}")
model = tf.keras.models.load_model(str(MODEL_PATH))
unfreeze_for_fine_tuning(model, trainable_layers=45)
callbacks_finetune = [
EarlyStopping(monitor="val_auc", mode="max", patience=5, restore_best_weights=True),
ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=2),
ModelCheckpoint(str(MODEL_PATH), monitor="val_auc", mode="max", save_best_only=True),
]
history2 = model.fit(
train_ds,
validation_data=val_ds,
epochs=finetune_epochs,
callbacks=callbacks_finetune,
verbose=1,
)
# Evaluate
print(f"\n{'='*50}")
print("Final Evaluation")
print(f"{'='*50}")
model = tf.keras.models.load_model(str(MODEL_PATH))
evaluate_model(model, test_ds, test_count)
# Save plots
class CombinedHistory:
def __init__(self, h1, h2):
self.history = {}
for key in h1.history:
self.history[key] = h1.history[key] + h2.history[key]
save_training_plot(CombinedHistory(history1, history2))
print(f"\n{'='*50}")
print("Training complete!")
print(f"Model saved to: {MODEL_PATH}")
print(f"Metrics saved to: {METRICS_PATH}")
print(f"{'='*50}")
if __name__ == "__main__":
main()