Spaces:
Running
Running
| """ | |
| scripts/train.py | |
| Full two-phase training pipeline for waste classifier. | |
| Usage: | |
| python scripts/train.py --data_dir data/processed --output_dir models | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| from pathlib import Path | |
| PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| MPL_CONFIG_DIR = os.path.join(PROJECT_ROOT, ".cache", "matplotlib") | |
| os.makedirs(MPL_CONFIG_DIR, exist_ok=True) | |
| os.environ.setdefault("MPLCONFIGDIR", MPL_CONFIG_DIR) | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import tensorflow as tf | |
| from sklearn.metrics import classification_report, confusion_matrix | |
| from tensorflow.keras import Model, layers | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| CLASS_NAMES = ["plastic", "paper", "organic", "metal", "glass"] | |
| INPUT_SIZE = (224, 224) | |
| BATCH_SIZE = 32 | |
| SEED = 42 | |
| PREPROCESS_INPUT = tf.keras.applications.mobilenet_v2.preprocess_input | |
| def build_train_dataframe(data_dir: str) -> pd.DataFrame: | |
| rows = [] | |
| train_root = Path(data_dir) / "train" | |
| for class_name in CLASS_NAMES: | |
| class_dir = train_root / class_name | |
| for image_path in class_dir.glob("*"): | |
| if image_path.is_file(): | |
| rows.append({"filepath": str(image_path.resolve()), "class": class_name}) | |
| train_df = pd.DataFrame(rows) | |
| if train_df.empty: | |
| raise ValueError(f"No training images found under {train_root}") | |
| class_counts = train_df["class"].value_counts() | |
| target_count = int(class_counts.max()) | |
| balanced_parts = [] | |
| for class_name in CLASS_NAMES: | |
| class_rows = train_df[train_df["class"] == class_name] | |
| replace = len(class_rows) < target_count | |
| sampled = class_rows.sample( | |
| n=target_count, | |
| replace=replace, | |
| random_state=SEED, | |
| ) | |
| balanced_parts.append(sampled) | |
| balanced_df = pd.concat(balanced_parts, ignore_index=True) | |
| return balanced_df.sample(frac=1.0, random_state=SEED).reset_index(drop=True) | |
| def build_generators(data_dir: str, balance_strategy: str): | |
| train_datagen = ImageDataGenerator( | |
| preprocessing_function=PREPROCESS_INPUT, | |
| rotation_range=20, | |
| width_shift_range=0.1, | |
| height_shift_range=0.1, | |
| horizontal_flip=True, | |
| zoom_range=0.1, | |
| brightness_range=[0.7, 1.3], | |
| shear_range=0.1, | |
| ) | |
| eval_datagen = ImageDataGenerator(preprocessing_function=PREPROCESS_INPUT) | |
| if balance_strategy == "oversample": | |
| train_df = build_train_dataframe(data_dir) | |
| train_gen = train_datagen.flow_from_dataframe( | |
| train_df, | |
| x_col="filepath", | |
| y_col="class", | |
| target_size=INPUT_SIZE, | |
| batch_size=BATCH_SIZE, | |
| class_mode="categorical", | |
| classes=CLASS_NAMES, | |
| seed=SEED, | |
| shuffle=True, | |
| ) | |
| else: | |
| train_gen = train_datagen.flow_from_directory( | |
| os.path.join(data_dir, "train"), | |
| target_size=INPUT_SIZE, | |
| batch_size=BATCH_SIZE, | |
| class_mode="categorical", | |
| classes=CLASS_NAMES, | |
| seed=SEED, | |
| ) | |
| val_gen = eval_datagen.flow_from_directory( | |
| os.path.join(data_dir, "val"), | |
| target_size=INPUT_SIZE, | |
| batch_size=BATCH_SIZE, | |
| class_mode="categorical", | |
| classes=CLASS_NAMES, | |
| seed=SEED, | |
| shuffle=False, | |
| ) | |
| test_gen = eval_datagen.flow_from_directory( | |
| os.path.join(data_dir, "test"), | |
| target_size=INPUT_SIZE, | |
| batch_size=BATCH_SIZE, | |
| class_mode="categorical", | |
| classes=CLASS_NAMES, | |
| shuffle=False, | |
| ) | |
| return train_gen, val_gen, test_gen | |
| def build_class_weights(train_gen) -> dict[int, float] | None: | |
| classes = getattr(train_gen, "classes", None) | |
| if classes is None: | |
| return None | |
| counts = np.bincount(classes) | |
| total = counts.sum() | |
| num_classes = len(counts) | |
| return { | |
| index: float(total / (num_classes * count)) | |
| for index, count in enumerate(counts) | |
| if count > 0 | |
| } | |
| def build_model(num_classes: int = 5) -> Model: | |
| base = tf.keras.applications.MobileNetV2( | |
| input_shape=(224, 224, 3), | |
| include_top=False, | |
| weights="imagenet", | |
| ) | |
| base.trainable = False | |
| inputs = tf.keras.Input(shape=(224, 224, 3)) | |
| x = base(inputs, training=False) | |
| x = layers.GlobalAveragePooling2D()(x) | |
| x = layers.Dropout(0.3)(x) | |
| x = layers.Dense(256, activation="relu")(x) | |
| x = layers.BatchNormalization()(x) | |
| outputs = layers.Dense(num_classes, activation="softmax")(x) | |
| return Model(inputs, outputs, name="waste_classifier") | |
| def phase1(model, train_gen, val_gen, output_dir: str, epochs: int, class_weights: dict[int, float] | None): | |
| """Train the classification head while the backbone stays frozen.""" | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(1e-3), | |
| loss="categorical_crossentropy", | |
| metrics=["accuracy"], | |
| ) | |
| callbacks = [ | |
| tf.keras.callbacks.EarlyStopping( | |
| monitor="val_accuracy", patience=3, restore_best_weights=True | |
| ), | |
| tf.keras.callbacks.ModelCheckpoint( | |
| os.path.join(output_dir, "phase1_best.h5"), | |
| save_best_only=True, | |
| monitor="val_accuracy", | |
| ), | |
| ] | |
| print("\nPhase 1: training head (backbone frozen)") | |
| history = model.fit( | |
| train_gen, | |
| epochs=epochs, | |
| validation_data=val_gen, | |
| callbacks=callbacks, | |
| class_weight=class_weights, | |
| ) | |
| return history | |
| def phase2(model, train_gen, val_gen, output_dir: str, epochs: int, class_weights: dict[int, float] | None): | |
| """Unfreeze the top MobileNetV2 layers and fine-tune end to end.""" | |
| backbone = model.layers[1] | |
| backbone.trainable = True | |
| for layer in backbone.layers[:-30]: | |
| layer.trainable = False | |
| model.compile( | |
| optimizer=tf.keras.optimizers.Adam(1e-5), | |
| loss="categorical_crossentropy", | |
| metrics=["accuracy"], | |
| ) | |
| callbacks = [ | |
| tf.keras.callbacks.EarlyStopping( | |
| monitor="val_accuracy", patience=5, restore_best_weights=True | |
| ), | |
| tf.keras.callbacks.ReduceLROnPlateau( | |
| monitor="val_loss", factor=0.3, patience=3, min_lr=1e-7 | |
| ), | |
| tf.keras.callbacks.ModelCheckpoint( | |
| os.path.join(output_dir, "phase2_best.h5"), | |
| save_best_only=True, | |
| monitor="val_accuracy", | |
| ), | |
| ] | |
| print("\nPhase 2: fine-tuning top-30 layers") | |
| history = model.fit( | |
| train_gen, | |
| epochs=epochs, | |
| validation_data=val_gen, | |
| callbacks=callbacks, | |
| class_weight=class_weights, | |
| ) | |
| return history | |
| def evaluate(model, test_gen, output_dir: str): | |
| predictions = model.predict(test_gen) | |
| predicted_classes = np.argmax(predictions, axis=1) | |
| true_classes = test_gen.classes | |
| report = classification_report( | |
| true_classes, predicted_classes, target_names=CLASS_NAMES, output_dict=True | |
| ) | |
| cm = confusion_matrix(true_classes, predicted_classes, normalize="true") | |
| print("\nClassification Report") | |
| print(classification_report(true_classes, predicted_classes, target_names=CLASS_NAMES)) | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| im = ax.imshow(cm, cmap="Greens") | |
| ax.set_xticks(range(5)) | |
| ax.set_yticks(range(5)) | |
| ax.set_xticklabels(CLASS_NAMES, rotation=45, ha="right") | |
| ax.set_yticklabels(CLASS_NAMES) | |
| plt.colorbar(im, ax=ax) | |
| for i in range(5): | |
| for j in range(5): | |
| ax.text( | |
| j, | |
| i, | |
| f"{cm[i, j]:.2f}", | |
| ha="center", | |
| va="center", | |
| fontsize=8, | |
| color="white" if cm[i, j] > 0.5 else "black", | |
| ) | |
| ax.set_title("Confusion Matrix (normalized)") | |
| plt.tight_layout() | |
| plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=150) | |
| plt.close() | |
| with open(os.path.join(output_dir, "metrics.json"), "w", encoding="utf-8") as file: | |
| json.dump(report, file, indent=2) | |
| print(f"Confusion matrix saved -> {output_dir}/confusion_matrix.png") | |
| print(f"Metrics JSON saved -> {output_dir}/metrics.json") | |
| return report | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train waste classifier") | |
| parser.add_argument("--data_dir", default="data/processed") | |
| parser.add_argument("--output_dir", default="models") | |
| parser.add_argument("--phase1_epochs", type=int, default=10) | |
| parser.add_argument("--phase2_epochs", type=int, default=20) | |
| parser.add_argument( | |
| "--balance_strategy", | |
| choices=["class_weight", "oversample"], | |
| default="class_weight", | |
| ) | |
| args = parser.parse_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| train_gen, val_gen, test_gen = build_generators(args.data_dir, args.balance_strategy) | |
| class_weights = None if args.balance_strategy == "oversample" else build_class_weights(train_gen) | |
| model = build_model(num_classes=5) | |
| model.summary() | |
| phase1(model, train_gen, val_gen, args.output_dir, args.phase1_epochs, class_weights) | |
| phase2(model, train_gen, val_gen, args.output_dir, args.phase2_epochs, class_weights) | |
| print("\nFinal evaluation on held-out test set") | |
| evaluate(model, test_gen, args.output_dir) | |
| saved_path = os.path.join(args.output_dir, "waste_classifier_v1") | |
| model.export(saved_path) | |
| print(f"\nSavedModel exported -> {saved_path}") | |
| if __name__ == "__main__": | |
| main() | |