|
|
""" |
|
|
Train Multiple Model Variants for Milk Spoilage Classification |
|
|
|
|
|
This script: |
|
|
1. Loads training data from CSV files |
|
|
2. Trains 10 RandomForest model variants with different feature subsets |
|
|
3. Exports all model artifacts (*.joblib, variants_config.json) |
|
|
""" |
|
|
|
|
|
import json |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
from sklearn.ensemble import RandomForestClassifier |
|
|
from sklearn.metrics import accuracy_score, classification_report |
|
|
import joblib |
|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
MODEL_VARIANTS = { |
|
|
'baseline': { |
|
|
'name': 'Baseline (All Features)', |
|
|
'description': 'Uses all 6 microbiological measurements across all time points', |
|
|
'features': ['SPC_D7', 'SPC_D14', 'SPC_D21', 'TGN_D7', 'TGN_D14', 'TGN_D21'] |
|
|
}, |
|
|
'scenario_1_days14_21': { |
|
|
'name': 'Days 14 & 21', |
|
|
'description': 'Uses measurements from days 14 and 21 only', |
|
|
'features': ['SPC_D14', 'SPC_D21', 'TGN_D14', 'TGN_D21'] |
|
|
}, |
|
|
'scenario_2_days7_14': { |
|
|
'name': 'Days 7 & 14', |
|
|
'description': 'Uses measurements from days 7 and 14 only', |
|
|
'features': ['SPC_D7', 'SPC_D14', 'TGN_D7', 'TGN_D14'] |
|
|
}, |
|
|
'scenario_3_day21': { |
|
|
'name': 'Day 21 Only', |
|
|
'description': 'Uses only day 21 measurements', |
|
|
'features': ['SPC_D21', 'TGN_D21'] |
|
|
}, |
|
|
'scenario_4_day14': { |
|
|
'name': 'Day 14 Only', |
|
|
'description': 'Uses only day 14 measurements', |
|
|
'features': ['SPC_D14', 'TGN_D14'] |
|
|
}, |
|
|
'scenario_5_day7': { |
|
|
'name': 'Day 7 Only', |
|
|
'description': 'Uses only day 7 measurements', |
|
|
'features': ['SPC_D7', 'TGN_D7'] |
|
|
}, |
|
|
'scenario_6_spc_all': { |
|
|
'name': 'SPC Only (All Days)', |
|
|
'description': 'Uses only Standard Plate Count measurements across all days', |
|
|
'features': ['SPC_D7', 'SPC_D14', 'SPC_D21'] |
|
|
}, |
|
|
'scenario_7_tgn_all': { |
|
|
'name': 'TGN Only (All Days)', |
|
|
'description': 'Uses only Total Gram-Negative measurements across all days', |
|
|
'features': ['TGN_D7', 'TGN_D14', 'TGN_D21'] |
|
|
}, |
|
|
'scenario_8_spc_7_14': { |
|
|
'name': 'SPC Days 7 & 14', |
|
|
'description': 'Uses only SPC measurements from days 7 and 14', |
|
|
'features': ['SPC_D7', 'SPC_D14'] |
|
|
}, |
|
|
'scenario_9_tgn_7_14': { |
|
|
'name': 'TGN Days 7 & 14', |
|
|
'description': 'Uses only TGN measurements from days 7 and 14', |
|
|
'features': ['TGN_D7', 'TGN_D14'] |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
FEATURE_DESCRIPTIONS = { |
|
|
"SPC_D7": "Standard Plate Count at Day 7 (log CFU/mL)", |
|
|
"SPC_D14": "Standard Plate Count at Day 14 (log CFU/mL)", |
|
|
"SPC_D21": "Standard Plate Count at Day 21 (log CFU/mL)", |
|
|
"TGN_D7": "Total Gram-Negative count at Day 7 (log CFU/mL)", |
|
|
"TGN_D14": "Total Gram-Negative count at Day 14 (log CFU/mL)", |
|
|
"TGN_D21": "Total Gram-Negative count at Day 21 (log CFU/mL)" |
|
|
} |
|
|
|
|
|
|
|
|
CLASS_DESCRIPTIONS = { |
|
|
"PPC": "Post-Pasteurization Contamination", |
|
|
"no spoilage": "No spoilage detected", |
|
|
"spore spoilage": "Spore-forming bacteria spoilage" |
|
|
} |
|
|
|
|
|
|
|
|
def load_data(): |
|
|
"""Load training and test data from CSV files.""" |
|
|
print("Loading data...") |
|
|
|
|
|
|
|
|
data_dir = Path(__file__).parent.parent / "data" |
|
|
if not data_dir.exists(): |
|
|
|
|
|
data_dir = Path("data") |
|
|
|
|
|
train_df = pd.read_csv(data_dir / "train_df.csv") |
|
|
test_df = pd.read_csv(data_dir / "test_df.csv") |
|
|
|
|
|
print(f"✓ Loaded {len(train_df)} training samples and {len(test_df)} test samples") |
|
|
|
|
|
return train_df, test_df |
|
|
|
|
|
|
|
|
def train_model(X_train, y_train): |
|
|
"""Train RandomForest model with best hyperparameters from notebook.""" |
|
|
|
|
|
model = RandomForestClassifier( |
|
|
n_estimators=100, |
|
|
max_depth=None, |
|
|
min_samples_split=5, |
|
|
min_samples_leaf=1, |
|
|
random_state=42 |
|
|
) |
|
|
|
|
|
model.fit(X_train, y_train) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def evaluate_model(model, X_test, y_test): |
|
|
"""Evaluate model performance on test set.""" |
|
|
y_pred = model.predict(X_test) |
|
|
accuracy = accuracy_score(y_test, y_pred) |
|
|
report = classification_report(y_test, y_pred, output_dict=True) |
|
|
|
|
|
return accuracy, report |
|
|
|
|
|
|
|
|
def train_all_variants(train_df, test_df, output_dir="model/variants"): |
|
|
"""Train all model variants and save artifacts.""" |
|
|
|
|
|
|
|
|
output_path = Path(output_dir) |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
target_col = 'spoilagetype' |
|
|
variants_metadata = {} |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("Training All Model Variants") |
|
|
print("=" * 70) |
|
|
|
|
|
for variant_id, variant_config in MODEL_VARIANTS.items(): |
|
|
print(f"\n{variant_id}") |
|
|
print(f" Name: {variant_config['name']}") |
|
|
print(f" Features: {', '.join(variant_config['features'])}") |
|
|
|
|
|
|
|
|
features = variant_config['features'] |
|
|
train_set = train_df[features + [target_col]].dropna() |
|
|
test_set = test_df[features + [target_col]].dropna() |
|
|
|
|
|
X_train = train_set[features] |
|
|
y_train = train_set[target_col] |
|
|
X_test = test_set[features] |
|
|
y_test = test_set[target_col] |
|
|
|
|
|
print(f" Training samples: {len(X_train)}, Test samples: {len(X_test)}") |
|
|
|
|
|
|
|
|
model = train_model(X_train, y_train) |
|
|
|
|
|
|
|
|
test_accuracy, test_report = evaluate_model(model, X_test, y_test) |
|
|
train_accuracy = accuracy_score(y_train, model.predict(X_train)) |
|
|
|
|
|
print(f" Train accuracy: {train_accuracy:.4f}") |
|
|
print(f" Test accuracy: {test_accuracy:.4f}") |
|
|
|
|
|
|
|
|
model_path = output_path / f"{variant_id}.joblib" |
|
|
joblib.dump(model, model_path) |
|
|
print(f" ✓ Saved to {model_path}") |
|
|
|
|
|
|
|
|
variants_metadata[variant_id] = { |
|
|
'name': variant_config['name'], |
|
|
'description': variant_config['description'], |
|
|
'features': features, |
|
|
'train_accuracy': float(train_accuracy), |
|
|
'test_accuracy': float(test_accuracy), |
|
|
'n_train_samples': len(X_train), |
|
|
'n_test_samples': len(X_test), |
|
|
'classes': list(model.classes_), |
|
|
'class_metrics': { |
|
|
cls: { |
|
|
'precision': float(test_report[cls]['precision']), |
|
|
'recall': float(test_report[cls]['recall']), |
|
|
'f1-score': float(test_report[cls]['f1-score']), |
|
|
'support': int(test_report[cls]['support']) |
|
|
} |
|
|
for cls in model.classes_ |
|
|
} |
|
|
} |
|
|
|
|
|
return variants_metadata |
|
|
|
|
|
|
|
|
def create_variants_config(variants_metadata, output_dir="model/variants"): |
|
|
"""Create comprehensive config file for all variants.""" |
|
|
|
|
|
config = { |
|
|
'model_type': 'RandomForestClassifier', |
|
|
'framework': 'sklearn', |
|
|
'task': 'classification', |
|
|
'hyperparameters': { |
|
|
'n_estimators': 100, |
|
|
'max_depth': None, |
|
|
'min_samples_split': 5, |
|
|
'min_samples_leaf': 1, |
|
|
'random_state': 42 |
|
|
}, |
|
|
'feature_descriptions': FEATURE_DESCRIPTIONS, |
|
|
'class_descriptions': CLASS_DESCRIPTIONS, |
|
|
'variants': variants_metadata |
|
|
} |
|
|
|
|
|
config_path = Path(output_dir) / "variants_config.json" |
|
|
with open(config_path, 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
print(f"\n✓ Variants config saved to {config_path}") |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
def print_summary(variants_metadata): |
|
|
"""Print summary of all trained variants.""" |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("Training Summary") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
sorted_variants = sorted( |
|
|
variants_metadata.items(), |
|
|
key=lambda x: x[1]['test_accuracy'], |
|
|
reverse=True |
|
|
) |
|
|
|
|
|
print(f"\n{'Rank':<6} {'Variant':<30} {'Test Acc':<12} {'Features'}") |
|
|
print("-" * 70) |
|
|
|
|
|
for rank, (variant_id, metadata) in enumerate(sorted_variants, 1): |
|
|
medal = ['🥇', '🥈', '🥉'][rank - 1] if rank <= 3 else ' ' |
|
|
features_str = ', '.join(metadata['features'][:2]) + ( |
|
|
'...' if len(metadata['features']) > 2 else '' |
|
|
) |
|
|
print(f"{medal} {rank:<4} {variant_id:<30} {metadata['test_accuracy']:.4f} {features_str}") |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to train all model variants.""" |
|
|
|
|
|
print("=" * 70) |
|
|
print("Milk Spoilage Classification - Multi-Variant Training") |
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
train_df, test_df = load_data() |
|
|
|
|
|
|
|
|
variants_metadata = train_all_variants(train_df, test_df) |
|
|
|
|
|
|
|
|
create_variants_config(variants_metadata) |
|
|
|
|
|
|
|
|
print_summary(variants_metadata) |
|
|
|
|
|
print("\n✓ All model variants trained successfully!") |
|
|
print(f"\nGenerated files:") |
|
|
print(f" - model/variants/*.joblib (10 model files)") |
|
|
print(f" - model/variants/variants_config.json") |
|
|
print(f"\nNext step: Update API to load all variants") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|