MilkSpoilageClassifier / scripts /prepare_model.py
chenhaoq87's picture
Upload folder using huggingface_hub
63603f7 verified
"""
Prepare Milk Spoilage Classification Model for Hugging Face Deployment
This script:
1. Loads training data from CSV files
2. Trains a RandomForest model with tuned hyperparameters
3. Exports model artifacts (model.joblib, config.json, requirements.txt, README.md)
"""
import json
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import joblib
def load_and_prepare_data():
"""Load and prepare training data from CSV files."""
print("Loading training data...")
train_df = pd.read_csv("data/train_df.csv")
test_df = pd.read_csv("data/test_df.csv")
# Select relevant columns and drop NaN values
feature_cols = ['SPC_D7', 'SPC_D14', 'SPC_D21', 'TGN_D7', 'TGN_D14', 'TGN_D21']
target_col = 'spoilagetype'
train_set = train_df[feature_cols + [target_col]].dropna()
test_set = test_df[feature_cols + [target_col]].dropna()
X_train = train_set[feature_cols]
y_train = train_set[target_col]
X_test = test_set[feature_cols]
y_test = test_set[target_col]
print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")
return X_train, y_train, X_test, y_test, feature_cols
def train_model(X_train, y_train):
"""Train RandomForest model with best hyperparameters from notebook."""
print("\nTraining RandomForest model...")
# Best hyperparameters from GridSearchCV in notebook
model = RandomForestClassifier(
n_estimators=100,
max_depth=None,
min_samples_split=5,
min_samples_leaf=1,
random_state=42
)
model.fit(X_train, y_train)
print("Model training complete!")
return model
def evaluate_model(model, X_test, y_test):
"""Evaluate model performance on test set."""
print("\nEvaluating model on test set...")
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Test Accuracy: {accuracy:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, digits=4))
return accuracy
def save_model(model, filepath="model/model.joblib"):
"""Save trained model to disk."""
print(f"\nSaving model to {filepath}...")
joblib.dump(model, filepath)
print("Model saved!")
def create_config(model, feature_cols, filepath="model/config.json"):
"""Create config.json with model metadata."""
print(f"\nCreating {filepath}...")
config = {
"model_type": "RandomForestClassifier",
"framework": "sklearn",
"task": "classification",
"features": feature_cols,
"feature_descriptions": {
"SPC_D7": "Standard Plate Count at Day 7 (log CFU/mL)",
"SPC_D14": "Standard Plate Count at Day 14 (log CFU/mL)",
"SPC_D21": "Standard Plate Count at Day 21 (log CFU/mL)",
"TGN_D7": "Total Gram-Negative count at Day 7 (log CFU/mL)",
"TGN_D14": "Total Gram-Negative count at Day 14 (log CFU/mL)",
"TGN_D21": "Total Gram-Negative count at Day 21 (log CFU/mL)"
},
"classes": list(model.classes_),
"class_descriptions": {
"PPC": "Post-Pasteurization Contamination",
"no spoilage": "No spoilage detected",
"spore spoilage": "Spore-forming bacteria spoilage"
},
"hyperparameters": {
"n_estimators": 100,
"max_depth": None,
"min_samples_split": 5,
"min_samples_leaf": 1,
"random_state": 42
}
}
with open(filepath, 'w') as f:
json.dump(config, f, indent=2)
print("Config saved!")
def create_requirements(filepath="requirements.txt"):
"""Create requirements.txt for inference."""
print(f"\nCreating {filepath}...")
requirements = """scikit-learn>=1.0
joblib>=1.0
numpy>=1.20
pandas>=1.3
"""
with open(filepath, 'w') as f:
f.write(requirements)
print("Requirements saved!")
def create_readme(model, accuracy, feature_cols, filepath="README.md"):
"""Create README.md model card."""
print(f"\nCreating {filepath}...")
readme_content = f"""---
license: mit
library_name: sklearn
tags:
- sklearn
- classification
- random-forest
- food-science
- milk-quality
pipeline_tag: tabular-classification
---
# Milk Spoilage Classification Model
A Random Forest classifier for predicting milk spoilage type based on microbial count data.
## Model Description
This model classifies milk samples into three spoilage categories based on Standard Plate Count (SPC) and Total Gram-Negative (TGN) bacterial counts measured at days 7, 14, and 21 of shelf life.
### Classes
- **PPC**: Post-Pasteurization Contamination
- **no spoilage**: No spoilage detected
- **spore spoilage**: Spore-forming bacteria spoilage
### Input Features
| Feature | Description |
|---------|-------------|
| SPC_D7 | Standard Plate Count at Day 7 (log CFU/mL) |
| SPC_D14 | Standard Plate Count at Day 14 (log CFU/mL) |
| SPC_D21 | Standard Plate Count at Day 21 (log CFU/mL) |
| TGN_D7 | Total Gram-Negative count at Day 7 (log CFU/mL) |
| TGN_D14 | Total Gram-Negative count at Day 14 (log CFU/mL) |
| TGN_D21 | Total Gram-Negative count at Day 21 (log CFU/mL) |
## Performance
- **Test Accuracy**: {accuracy:.2%}
## Usage
### Using the Inference API
```python
import requests
API_URL = "https://api-inference.huggingface.co/models/chenhaoq87/MilkSpoilageClassifier"
headers = {{"Authorization": "Bearer YOUR_HF_TOKEN"}}
# Input: [SPC_D7, SPC_D14, SPC_D21, TGN_D7, TGN_D14, TGN_D21]
payload = {{"inputs": [[4.5, 5.2, 6.1, 3.2, 4.0, 4.8]]}}
response = requests.post(API_URL, headers=headers, json=payload)
print(response.json())
```
### Local Usage
```python
import joblib
import numpy as np
# Load the model
model = joblib.load("model.joblib")
# Prepare input features
# [SPC_D7, SPC_D14, SPC_D21, TGN_D7, TGN_D14, TGN_D21]
features = np.array([[4.5, 5.2, 6.1, 3.2, 4.0, 4.8]])
# Make prediction
prediction = model.predict(features)
probabilities = model.predict_proba(features)
print(f"Predicted class: {{prediction[0]}}")
print(f"Class probabilities: {{dict(zip(model.classes_, probabilities[0]))}}")
```
## Model Details
- **Model Type**: Random Forest Classifier
- **Framework**: scikit-learn
- **Number of Estimators**: 100
- **Max Depth**: None (unlimited)
- **Min Samples Split**: 5
- **Min Samples Leaf**: 1
## Citation
If you use this model, please cite the original research on milk spoilage classification.
## License
MIT License
"""
with open(filepath, 'w') as f:
f.write(readme_content)
print("README saved!")
def main():
"""Main function to prepare all model artifacts."""
print("=" * 60)
print("Milk Spoilage Classification Model - Artifact Preparation")
print("=" * 60)
# Load data
X_train, y_train, X_test, y_test, feature_cols = load_and_prepare_data()
# Train model
model = train_model(X_train, y_train)
# Evaluate model
accuracy = evaluate_model(model, X_test, y_test)
# Save artifacts
save_model(model)
create_config(model, feature_cols)
create_requirements()
create_readme(model, accuracy, feature_cols)
print("\n" + "=" * 60)
print("All artifacts created successfully!")
print("=" * 60)
print("\nGenerated files:")
print(" - model/model.joblib")
print(" - model/config.json")
print(" - requirements.txt")
print(" - README.md")
print("\nNext step: Run 'python scripts/upload_to_hf.py' to upload to Hugging Face")
if __name__ == "__main__":
main()