CoolWasteAI / scripts /train.py
Celvin
first commit
206d8b5
"""
scripts/train.py
Full two-phase training pipeline for waste classifier.
Usage:
python scripts/train.py --data_dir data/processed --output_dir models
"""
import argparse
import json
import os
from pathlib import Path
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
MPL_CONFIG_DIR = os.path.join(PROJECT_ROOT, ".cache", "matplotlib")
os.makedirs(MPL_CONFIG_DIR, exist_ok=True)
os.environ.setdefault("MPLCONFIGDIR", MPL_CONFIG_DIR)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.metrics import classification_report, confusion_matrix
from tensorflow.keras import Model, layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
CLASS_NAMES = ["plastic", "paper", "organic", "metal", "glass"]
INPUT_SIZE = (224, 224)
BATCH_SIZE = 32
SEED = 42
PREPROCESS_INPUT = tf.keras.applications.mobilenet_v2.preprocess_input
def build_train_dataframe(data_dir: str) -> pd.DataFrame:
rows = []
train_root = Path(data_dir) / "train"
for class_name in CLASS_NAMES:
class_dir = train_root / class_name
for image_path in class_dir.glob("*"):
if image_path.is_file():
rows.append({"filepath": str(image_path.resolve()), "class": class_name})
train_df = pd.DataFrame(rows)
if train_df.empty:
raise ValueError(f"No training images found under {train_root}")
class_counts = train_df["class"].value_counts()
target_count = int(class_counts.max())
balanced_parts = []
for class_name in CLASS_NAMES:
class_rows = train_df[train_df["class"] == class_name]
replace = len(class_rows) < target_count
sampled = class_rows.sample(
n=target_count,
replace=replace,
random_state=SEED,
)
balanced_parts.append(sampled)
balanced_df = pd.concat(balanced_parts, ignore_index=True)
return balanced_df.sample(frac=1.0, random_state=SEED).reset_index(drop=True)
def build_generators(data_dir: str, balance_strategy: str):
train_datagen = ImageDataGenerator(
preprocessing_function=PREPROCESS_INPUT,
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
zoom_range=0.1,
brightness_range=[0.7, 1.3],
shear_range=0.1,
)
eval_datagen = ImageDataGenerator(preprocessing_function=PREPROCESS_INPUT)
if balance_strategy == "oversample":
train_df = build_train_dataframe(data_dir)
train_gen = train_datagen.flow_from_dataframe(
train_df,
x_col="filepath",
y_col="class",
target_size=INPUT_SIZE,
batch_size=BATCH_SIZE,
class_mode="categorical",
classes=CLASS_NAMES,
seed=SEED,
shuffle=True,
)
else:
train_gen = train_datagen.flow_from_directory(
os.path.join(data_dir, "train"),
target_size=INPUT_SIZE,
batch_size=BATCH_SIZE,
class_mode="categorical",
classes=CLASS_NAMES,
seed=SEED,
)
val_gen = eval_datagen.flow_from_directory(
os.path.join(data_dir, "val"),
target_size=INPUT_SIZE,
batch_size=BATCH_SIZE,
class_mode="categorical",
classes=CLASS_NAMES,
seed=SEED,
shuffle=False,
)
test_gen = eval_datagen.flow_from_directory(
os.path.join(data_dir, "test"),
target_size=INPUT_SIZE,
batch_size=BATCH_SIZE,
class_mode="categorical",
classes=CLASS_NAMES,
shuffle=False,
)
return train_gen, val_gen, test_gen
def build_class_weights(train_gen) -> dict[int, float] | None:
classes = getattr(train_gen, "classes", None)
if classes is None:
return None
counts = np.bincount(classes)
total = counts.sum()
num_classes = len(counts)
return {
index: float(total / (num_classes * count))
for index, count in enumerate(counts)
if count > 0
}
def build_model(num_classes: int = 5) -> Model:
base = tf.keras.applications.MobileNetV2(
input_shape=(224, 224, 3),
include_top=False,
weights="imagenet",
)
base.trainable = False
inputs = tf.keras.Input(shape=(224, 224, 3))
x = base(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256, activation="relu")(x)
x = layers.BatchNormalization()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
return Model(inputs, outputs, name="waste_classifier")
def phase1(model, train_gen, val_gen, output_dir: str, epochs: int, class_weights: dict[int, float] | None):
"""Train the classification head while the backbone stays frozen."""
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-3),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor="val_accuracy", patience=3, restore_best_weights=True
),
tf.keras.callbacks.ModelCheckpoint(
os.path.join(output_dir, "phase1_best.h5"),
save_best_only=True,
monitor="val_accuracy",
),
]
print("\nPhase 1: training head (backbone frozen)")
history = model.fit(
train_gen,
epochs=epochs,
validation_data=val_gen,
callbacks=callbacks,
class_weight=class_weights,
)
return history
def phase2(model, train_gen, val_gen, output_dir: str, epochs: int, class_weights: dict[int, float] | None):
"""Unfreeze the top MobileNetV2 layers and fine-tune end to end."""
backbone = model.layers[1]
backbone.trainable = True
for layer in backbone.layers[:-30]:
layer.trainable = False
model.compile(
optimizer=tf.keras.optimizers.Adam(1e-5),
loss="categorical_crossentropy",
metrics=["accuracy"],
)
callbacks = [
tf.keras.callbacks.EarlyStopping(
monitor="val_accuracy", patience=5, restore_best_weights=True
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.3, patience=3, min_lr=1e-7
),
tf.keras.callbacks.ModelCheckpoint(
os.path.join(output_dir, "phase2_best.h5"),
save_best_only=True,
monitor="val_accuracy",
),
]
print("\nPhase 2: fine-tuning top-30 layers")
history = model.fit(
train_gen,
epochs=epochs,
validation_data=val_gen,
callbacks=callbacks,
class_weight=class_weights,
)
return history
def evaluate(model, test_gen, output_dir: str):
predictions = model.predict(test_gen)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = test_gen.classes
report = classification_report(
true_classes, predicted_classes, target_names=CLASS_NAMES, output_dict=True
)
cm = confusion_matrix(true_classes, predicted_classes, normalize="true")
print("\nClassification Report")
print(classification_report(true_classes, predicted_classes, target_names=CLASS_NAMES))
fig, ax = plt.subplots(figsize=(7, 6))
im = ax.imshow(cm, cmap="Greens")
ax.set_xticks(range(5))
ax.set_yticks(range(5))
ax.set_xticklabels(CLASS_NAMES, rotation=45, ha="right")
ax.set_yticklabels(CLASS_NAMES)
plt.colorbar(im, ax=ax)
for i in range(5):
for j in range(5):
ax.text(
j,
i,
f"{cm[i, j]:.2f}",
ha="center",
va="center",
fontsize=8,
color="white" if cm[i, j] > 0.5 else "black",
)
ax.set_title("Confusion Matrix (normalized)")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=150)
plt.close()
with open(os.path.join(output_dir, "metrics.json"), "w", encoding="utf-8") as file:
json.dump(report, file, indent=2)
print(f"Confusion matrix saved -> {output_dir}/confusion_matrix.png")
print(f"Metrics JSON saved -> {output_dir}/metrics.json")
return report
def main():
parser = argparse.ArgumentParser(description="Train waste classifier")
parser.add_argument("--data_dir", default="data/processed")
parser.add_argument("--output_dir", default="models")
parser.add_argument("--phase1_epochs", type=int, default=10)
parser.add_argument("--phase2_epochs", type=int, default=20)
parser.add_argument(
"--balance_strategy",
choices=["class_weight", "oversample"],
default="class_weight",
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
train_gen, val_gen, test_gen = build_generators(args.data_dir, args.balance_strategy)
class_weights = None if args.balance_strategy == "oversample" else build_class_weights(train_gen)
model = build_model(num_classes=5)
model.summary()
phase1(model, train_gen, val_gen, args.output_dir, args.phase1_epochs, class_weights)
phase2(model, train_gen, val_gen, args.output_dir, args.phase2_epochs, class_weights)
print("\nFinal evaluation on held-out test set")
evaluate(model, test_gen, args.output_dir)
saved_path = os.path.join(args.output_dir, "waste_classifier_v1")
model.export(saved_path)
print(f"\nSavedModel exported -> {saved_path}")
if __name__ == "__main__":
main()