MilkSpoilageClassifier / scripts /train_variants.py
Ubuntu
Add multi-variant FastAPI support (code and configs)
59b15f5
"""
Train Multiple Model Variants for Milk Spoilage Classification
This script:
1. Loads training data from CSV files
2. Trains 10 RandomForest model variants with different feature subsets
3. Exports all model artifacts (*.joblib, variants_config.json)
"""
import json
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
import joblib
import os
from pathlib import Path
# Define all model variants with feature subsets
MODEL_VARIANTS = {
'baseline': {
'name': 'Baseline (All Features)',
'description': 'Uses all 6 microbiological measurements across all time points',
'features': ['SPC_D7', 'SPC_D14', 'SPC_D21', 'TGN_D7', 'TGN_D14', 'TGN_D21']
},
'scenario_1_days14_21': {
'name': 'Days 14 & 21',
'description': 'Uses measurements from days 14 and 21 only',
'features': ['SPC_D14', 'SPC_D21', 'TGN_D14', 'TGN_D21']
},
'scenario_2_days7_14': {
'name': 'Days 7 & 14',
'description': 'Uses measurements from days 7 and 14 only',
'features': ['SPC_D7', 'SPC_D14', 'TGN_D7', 'TGN_D14']
},
'scenario_3_day21': {
'name': 'Day 21 Only',
'description': 'Uses only day 21 measurements',
'features': ['SPC_D21', 'TGN_D21']
},
'scenario_4_day14': {
'name': 'Day 14 Only',
'description': 'Uses only day 14 measurements',
'features': ['SPC_D14', 'TGN_D14']
},
'scenario_5_day7': {
'name': 'Day 7 Only',
'description': 'Uses only day 7 measurements',
'features': ['SPC_D7', 'TGN_D7']
},
'scenario_6_spc_all': {
'name': 'SPC Only (All Days)',
'description': 'Uses only Standard Plate Count measurements across all days',
'features': ['SPC_D7', 'SPC_D14', 'SPC_D21']
},
'scenario_7_tgn_all': {
'name': 'TGN Only (All Days)',
'description': 'Uses only Total Gram-Negative measurements across all days',
'features': ['TGN_D7', 'TGN_D14', 'TGN_D21']
},
'scenario_8_spc_7_14': {
'name': 'SPC Days 7 & 14',
'description': 'Uses only SPC measurements from days 7 and 14',
'features': ['SPC_D7', 'SPC_D14']
},
'scenario_9_tgn_7_14': {
'name': 'TGN Days 7 & 14',
'description': 'Uses only TGN measurements from days 7 and 14',
'features': ['TGN_D7', 'TGN_D14']
}
}
# Feature descriptions (constant across all variants)
FEATURE_DESCRIPTIONS = {
"SPC_D7": "Standard Plate Count at Day 7 (log CFU/mL)",
"SPC_D14": "Standard Plate Count at Day 14 (log CFU/mL)",
"SPC_D21": "Standard Plate Count at Day 21 (log CFU/mL)",
"TGN_D7": "Total Gram-Negative count at Day 7 (log CFU/mL)",
"TGN_D14": "Total Gram-Negative count at Day 14 (log CFU/mL)",
"TGN_D21": "Total Gram-Negative count at Day 21 (log CFU/mL)"
}
# Class descriptions (constant across all variants)
CLASS_DESCRIPTIONS = {
"PPC": "Post-Pasteurization Contamination",
"no spoilage": "No spoilage detected",
"spore spoilage": "Spore-forming bacteria spoilage"
}
def load_data():
"""Load training and test data from CSV files."""
print("Loading data...")
# Adjust path based on where script is run from
data_dir = Path(__file__).parent.parent / "data"
if not data_dir.exists():
# Try alternate path
data_dir = Path("data")
train_df = pd.read_csv(data_dir / "train_df.csv")
test_df = pd.read_csv(data_dir / "test_df.csv")
print(f"✓ Loaded {len(train_df)} training samples and {len(test_df)} test samples")
return train_df, test_df
def train_model(X_train, y_train):
"""Train RandomForest model with best hyperparameters from notebook."""
# Best hyperparameters from GridSearchCV in original notebook
model = RandomForestClassifier(
n_estimators=100,
max_depth=None,
min_samples_split=5,
min_samples_leaf=1,
random_state=42
)
model.fit(X_train, y_train)
return model
def evaluate_model(model, X_test, y_test):
"""Evaluate model performance on test set."""
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred, output_dict=True)
return accuracy, report
def train_all_variants(train_df, test_df, output_dir="model/variants"):
"""Train all model variants and save artifacts."""
# Create output directory
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
target_col = 'spoilagetype'
variants_metadata = {}
print("\n" + "=" * 70)
print("Training All Model Variants")
print("=" * 70)
for variant_id, variant_config in MODEL_VARIANTS.items():
print(f"\n{variant_id}")
print(f" Name: {variant_config['name']}")
print(f" Features: {', '.join(variant_config['features'])}")
# Prepare data for this variant
features = variant_config['features']
train_set = train_df[features + [target_col]].dropna()
test_set = test_df[features + [target_col]].dropna()
X_train = train_set[features]
y_train = train_set[target_col]
X_test = test_set[features]
y_test = test_set[target_col]
print(f" Training samples: {len(X_train)}, Test samples: {len(X_test)}")
# Train model
model = train_model(X_train, y_train)
# Evaluate
test_accuracy, test_report = evaluate_model(model, X_test, y_test)
train_accuracy = accuracy_score(y_train, model.predict(X_train))
print(f" Train accuracy: {train_accuracy:.4f}")
print(f" Test accuracy: {test_accuracy:.4f}")
# Save model
model_path = output_path / f"{variant_id}.joblib"
joblib.dump(model, model_path)
print(f" ✓ Saved to {model_path}")
# Store metadata
variants_metadata[variant_id] = {
'name': variant_config['name'],
'description': variant_config['description'],
'features': features,
'train_accuracy': float(train_accuracy),
'test_accuracy': float(test_accuracy),
'n_train_samples': len(X_train),
'n_test_samples': len(X_test),
'classes': list(model.classes_),
'class_metrics': {
cls: {
'precision': float(test_report[cls]['precision']),
'recall': float(test_report[cls]['recall']),
'f1-score': float(test_report[cls]['f1-score']),
'support': int(test_report[cls]['support'])
}
for cls in model.classes_
}
}
return variants_metadata
def create_variants_config(variants_metadata, output_dir="model/variants"):
"""Create comprehensive config file for all variants."""
config = {
'model_type': 'RandomForestClassifier',
'framework': 'sklearn',
'task': 'classification',
'hyperparameters': {
'n_estimators': 100,
'max_depth': None,
'min_samples_split': 5,
'min_samples_leaf': 1,
'random_state': 42
},
'feature_descriptions': FEATURE_DESCRIPTIONS,
'class_descriptions': CLASS_DESCRIPTIONS,
'variants': variants_metadata
}
config_path = Path(output_dir) / "variants_config.json"
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
print(f"\n✓ Variants config saved to {config_path}")
return config
def print_summary(variants_metadata):
"""Print summary of all trained variants."""
print("\n" + "=" * 70)
print("Training Summary")
print("=" * 70)
# Sort by test accuracy
sorted_variants = sorted(
variants_metadata.items(),
key=lambda x: x[1]['test_accuracy'],
reverse=True
)
print(f"\n{'Rank':<6} {'Variant':<30} {'Test Acc':<12} {'Features'}")
print("-" * 70)
for rank, (variant_id, metadata) in enumerate(sorted_variants, 1):
medal = ['🥇', '🥈', '🥉'][rank - 1] if rank <= 3 else ' '
features_str = ', '.join(metadata['features'][:2]) + (
'...' if len(metadata['features']) > 2 else ''
)
print(f"{medal} {rank:<4} {variant_id:<30} {metadata['test_accuracy']:.4f} {features_str}")
print("\n" + "=" * 70)
def main():
"""Main function to train all model variants."""
print("=" * 70)
print("Milk Spoilage Classification - Multi-Variant Training")
print("=" * 70)
# Load data
train_df, test_df = load_data()
# Train all variants
variants_metadata = train_all_variants(train_df, test_df)
# Create config
create_variants_config(variants_metadata)
# Print summary
print_summary(variants_metadata)
print("\n✓ All model variants trained successfully!")
print(f"\nGenerated files:")
print(f" - model/variants/*.joblib (10 model files)")
print(f" - model/variants/variants_config.json")
print(f"\nNext step: Update API to load all variants")
if __name__ == "__main__":
main()