""" scripts/train.py Full two-phase training pipeline for waste classifier. Usage: python scripts/train.py --data_dir data/processed --output_dir models """ import argparse import json import os from pathlib import Path PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) MPL_CONFIG_DIR = os.path.join(PROJECT_ROOT, ".cache", "matplotlib") os.makedirs(MPL_CONFIG_DIR, exist_ok=True) os.environ.setdefault("MPLCONFIGDIR", MPL_CONFIG_DIR) import matplotlib.pyplot as plt import numpy as np import pandas as pd import tensorflow as tf from sklearn.metrics import classification_report, confusion_matrix from tensorflow.keras import Model, layers from tensorflow.keras.preprocessing.image import ImageDataGenerator CLASS_NAMES = ["plastic", "paper", "organic", "metal", "glass"] INPUT_SIZE = (224, 224) BATCH_SIZE = 32 SEED = 42 PREPROCESS_INPUT = tf.keras.applications.mobilenet_v2.preprocess_input def build_train_dataframe(data_dir: str) -> pd.DataFrame: rows = [] train_root = Path(data_dir) / "train" for class_name in CLASS_NAMES: class_dir = train_root / class_name for image_path in class_dir.glob("*"): if image_path.is_file(): rows.append({"filepath": str(image_path.resolve()), "class": class_name}) train_df = pd.DataFrame(rows) if train_df.empty: raise ValueError(f"No training images found under {train_root}") class_counts = train_df["class"].value_counts() target_count = int(class_counts.max()) balanced_parts = [] for class_name in CLASS_NAMES: class_rows = train_df[train_df["class"] == class_name] replace = len(class_rows) < target_count sampled = class_rows.sample( n=target_count, replace=replace, random_state=SEED, ) balanced_parts.append(sampled) balanced_df = pd.concat(balanced_parts, ignore_index=True) return balanced_df.sample(frac=1.0, random_state=SEED).reset_index(drop=True) def build_generators(data_dir: str, balance_strategy: str): train_datagen = ImageDataGenerator( preprocessing_function=PREPROCESS_INPUT, rotation_range=20, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True, zoom_range=0.1, brightness_range=[0.7, 1.3], shear_range=0.1, ) eval_datagen = ImageDataGenerator(preprocessing_function=PREPROCESS_INPUT) if balance_strategy == "oversample": train_df = build_train_dataframe(data_dir) train_gen = train_datagen.flow_from_dataframe( train_df, x_col="filepath", y_col="class", target_size=INPUT_SIZE, batch_size=BATCH_SIZE, class_mode="categorical", classes=CLASS_NAMES, seed=SEED, shuffle=True, ) else: train_gen = train_datagen.flow_from_directory( os.path.join(data_dir, "train"), target_size=INPUT_SIZE, batch_size=BATCH_SIZE, class_mode="categorical", classes=CLASS_NAMES, seed=SEED, ) val_gen = eval_datagen.flow_from_directory( os.path.join(data_dir, "val"), target_size=INPUT_SIZE, batch_size=BATCH_SIZE, class_mode="categorical", classes=CLASS_NAMES, seed=SEED, shuffle=False, ) test_gen = eval_datagen.flow_from_directory( os.path.join(data_dir, "test"), target_size=INPUT_SIZE, batch_size=BATCH_SIZE, class_mode="categorical", classes=CLASS_NAMES, shuffle=False, ) return train_gen, val_gen, test_gen def build_class_weights(train_gen) -> dict[int, float] | None: classes = getattr(train_gen, "classes", None) if classes is None: return None counts = np.bincount(classes) total = counts.sum() num_classes = len(counts) return { index: float(total / (num_classes * count)) for index, count in enumerate(counts) if count > 0 } def build_model(num_classes: int = 5) -> Model: base = tf.keras.applications.MobileNetV2( input_shape=(224, 224, 3), include_top=False, weights="imagenet", ) base.trainable = False inputs = tf.keras.Input(shape=(224, 224, 3)) x = base(inputs, training=False) x = layers.GlobalAveragePooling2D()(x) x = layers.Dropout(0.3)(x) x = layers.Dense(256, activation="relu")(x) x = layers.BatchNormalization()(x) outputs = layers.Dense(num_classes, activation="softmax")(x) return Model(inputs, outputs, name="waste_classifier") def phase1(model, train_gen, val_gen, output_dir: str, epochs: int, class_weights: dict[int, float] | None): """Train the classification head while the backbone stays frozen.""" model.compile( optimizer=tf.keras.optimizers.Adam(1e-3), loss="categorical_crossentropy", metrics=["accuracy"], ) callbacks = [ tf.keras.callbacks.EarlyStopping( monitor="val_accuracy", patience=3, restore_best_weights=True ), tf.keras.callbacks.ModelCheckpoint( os.path.join(output_dir, "phase1_best.h5"), save_best_only=True, monitor="val_accuracy", ), ] print("\nPhase 1: training head (backbone frozen)") history = model.fit( train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks, class_weight=class_weights, ) return history def phase2(model, train_gen, val_gen, output_dir: str, epochs: int, class_weights: dict[int, float] | None): """Unfreeze the top MobileNetV2 layers and fine-tune end to end.""" backbone = model.layers[1] backbone.trainable = True for layer in backbone.layers[:-30]: layer.trainable = False model.compile( optimizer=tf.keras.optimizers.Adam(1e-5), loss="categorical_crossentropy", metrics=["accuracy"], ) callbacks = [ tf.keras.callbacks.EarlyStopping( monitor="val_accuracy", patience=5, restore_best_weights=True ), tf.keras.callbacks.ReduceLROnPlateau( monitor="val_loss", factor=0.3, patience=3, min_lr=1e-7 ), tf.keras.callbacks.ModelCheckpoint( os.path.join(output_dir, "phase2_best.h5"), save_best_only=True, monitor="val_accuracy", ), ] print("\nPhase 2: fine-tuning top-30 layers") history = model.fit( train_gen, epochs=epochs, validation_data=val_gen, callbacks=callbacks, class_weight=class_weights, ) return history def evaluate(model, test_gen, output_dir: str): predictions = model.predict(test_gen) predicted_classes = np.argmax(predictions, axis=1) true_classes = test_gen.classes report = classification_report( true_classes, predicted_classes, target_names=CLASS_NAMES, output_dict=True ) cm = confusion_matrix(true_classes, predicted_classes, normalize="true") print("\nClassification Report") print(classification_report(true_classes, predicted_classes, target_names=CLASS_NAMES)) fig, ax = plt.subplots(figsize=(7, 6)) im = ax.imshow(cm, cmap="Greens") ax.set_xticks(range(5)) ax.set_yticks(range(5)) ax.set_xticklabels(CLASS_NAMES, rotation=45, ha="right") ax.set_yticklabels(CLASS_NAMES) plt.colorbar(im, ax=ax) for i in range(5): for j in range(5): ax.text( j, i, f"{cm[i, j]:.2f}", ha="center", va="center", fontsize=8, color="white" if cm[i, j] > 0.5 else "black", ) ax.set_title("Confusion Matrix (normalized)") plt.tight_layout() plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=150) plt.close() with open(os.path.join(output_dir, "metrics.json"), "w", encoding="utf-8") as file: json.dump(report, file, indent=2) print(f"Confusion matrix saved -> {output_dir}/confusion_matrix.png") print(f"Metrics JSON saved -> {output_dir}/metrics.json") return report def main(): parser = argparse.ArgumentParser(description="Train waste classifier") parser.add_argument("--data_dir", default="data/processed") parser.add_argument("--output_dir", default="models") parser.add_argument("--phase1_epochs", type=int, default=10) parser.add_argument("--phase2_epochs", type=int, default=20) parser.add_argument( "--balance_strategy", choices=["class_weight", "oversample"], default="class_weight", ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) train_gen, val_gen, test_gen = build_generators(args.data_dir, args.balance_strategy) class_weights = None if args.balance_strategy == "oversample" else build_class_weights(train_gen) model = build_model(num_classes=5) model.summary() phase1(model, train_gen, val_gen, args.output_dir, args.phase1_epochs, class_weights) phase2(model, train_gen, val_gen, args.output_dir, args.phase2_epochs, class_weights) print("\nFinal evaluation on held-out test set") evaluate(model, test_gen, args.output_dir) saved_path = os.path.join(args.output_dir, "waste_classifier_v1") model.export(saved_path) print(f"\nSavedModel exported -> {saved_path}") if __name__ == "__main__": main()