|
|
"""
|
|
|
Prepare Milk Spoilage Classification Model for Hugging Face Deployment
|
|
|
|
|
|
This script:
|
|
|
1. Loads training data from CSV files
|
|
|
2. Trains a RandomForest model with tuned hyperparameters
|
|
|
3. Exports model artifacts (model.joblib, config.json, requirements.txt, README.md)
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
import pandas as pd
|
|
|
import numpy as np
|
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
from sklearn.metrics import accuracy_score, classification_report
|
|
|
import joblib
|
|
|
|
|
|
|
|
|
def load_and_prepare_data():
|
|
|
"""Load and prepare training data from CSV files."""
|
|
|
print("Loading training data...")
|
|
|
train_df = pd.read_csv("data/train_df.csv")
|
|
|
test_df = pd.read_csv("data/test_df.csv")
|
|
|
|
|
|
|
|
|
feature_cols = ['SPC_D7', 'SPC_D14', 'SPC_D21', 'TGN_D7', 'TGN_D14', 'TGN_D21']
|
|
|
target_col = 'spoilagetype'
|
|
|
|
|
|
train_set = train_df[feature_cols + [target_col]].dropna()
|
|
|
test_set = test_df[feature_cols + [target_col]].dropna()
|
|
|
|
|
|
X_train = train_set[feature_cols]
|
|
|
y_train = train_set[target_col]
|
|
|
X_test = test_set[feature_cols]
|
|
|
y_test = test_set[target_col]
|
|
|
|
|
|
print(f"Training samples: {len(X_train)}")
|
|
|
print(f"Test samples: {len(X_test)}")
|
|
|
|
|
|
return X_train, y_train, X_test, y_test, feature_cols
|
|
|
|
|
|
|
|
|
def train_model(X_train, y_train):
|
|
|
"""Train RandomForest model with best hyperparameters from notebook."""
|
|
|
print("\nTraining RandomForest model...")
|
|
|
|
|
|
|
|
|
model = RandomForestClassifier(
|
|
|
n_estimators=100,
|
|
|
max_depth=None,
|
|
|
min_samples_split=5,
|
|
|
min_samples_leaf=1,
|
|
|
random_state=42
|
|
|
)
|
|
|
|
|
|
model.fit(X_train, y_train)
|
|
|
print("Model training complete!")
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
def evaluate_model(model, X_test, y_test):
|
|
|
"""Evaluate model performance on test set."""
|
|
|
print("\nEvaluating model on test set...")
|
|
|
y_pred = model.predict(X_test)
|
|
|
|
|
|
accuracy = accuracy_score(y_test, y_pred)
|
|
|
print(f"Test Accuracy: {accuracy:.4f}")
|
|
|
print("\nClassification Report:")
|
|
|
print(classification_report(y_test, y_pred, digits=4))
|
|
|
|
|
|
return accuracy
|
|
|
|
|
|
|
|
|
def save_model(model, filepath="model/model.joblib"):
|
|
|
"""Save trained model to disk."""
|
|
|
print(f"\nSaving model to {filepath}...")
|
|
|
joblib.dump(model, filepath)
|
|
|
print("Model saved!")
|
|
|
|
|
|
|
|
|
def create_config(model, feature_cols, filepath="model/config.json"):
|
|
|
"""Create config.json with model metadata."""
|
|
|
print(f"\nCreating {filepath}...")
|
|
|
|
|
|
config = {
|
|
|
"model_type": "RandomForestClassifier",
|
|
|
"framework": "sklearn",
|
|
|
"task": "classification",
|
|
|
"features": feature_cols,
|
|
|
"feature_descriptions": {
|
|
|
"SPC_D7": "Standard Plate Count at Day 7 (log CFU/mL)",
|
|
|
"SPC_D14": "Standard Plate Count at Day 14 (log CFU/mL)",
|
|
|
"SPC_D21": "Standard Plate Count at Day 21 (log CFU/mL)",
|
|
|
"TGN_D7": "Total Gram-Negative count at Day 7 (log CFU/mL)",
|
|
|
"TGN_D14": "Total Gram-Negative count at Day 14 (log CFU/mL)",
|
|
|
"TGN_D21": "Total Gram-Negative count at Day 21 (log CFU/mL)"
|
|
|
},
|
|
|
"classes": list(model.classes_),
|
|
|
"class_descriptions": {
|
|
|
"PPC": "Post-Pasteurization Contamination",
|
|
|
"no spoilage": "No spoilage detected",
|
|
|
"spore spoilage": "Spore-forming bacteria spoilage"
|
|
|
},
|
|
|
"hyperparameters": {
|
|
|
"n_estimators": 100,
|
|
|
"max_depth": None,
|
|
|
"min_samples_split": 5,
|
|
|
"min_samples_leaf": 1,
|
|
|
"random_state": 42
|
|
|
}
|
|
|
}
|
|
|
|
|
|
with open(filepath, 'w') as f:
|
|
|
json.dump(config, f, indent=2)
|
|
|
|
|
|
print("Config saved!")
|
|
|
|
|
|
|
|
|
def create_requirements(filepath="requirements.txt"):
|
|
|
"""Create requirements.txt for inference."""
|
|
|
print(f"\nCreating {filepath}...")
|
|
|
|
|
|
requirements = """scikit-learn>=1.0
|
|
|
joblib>=1.0
|
|
|
numpy>=1.20
|
|
|
pandas>=1.3
|
|
|
"""
|
|
|
|
|
|
with open(filepath, 'w') as f:
|
|
|
f.write(requirements)
|
|
|
|
|
|
print("Requirements saved!")
|
|
|
|
|
|
|
|
|
def create_readme(model, accuracy, feature_cols, filepath="README.md"):
|
|
|
"""Create README.md model card."""
|
|
|
print(f"\nCreating {filepath}...")
|
|
|
|
|
|
readme_content = f"""---
|
|
|
license: mit
|
|
|
library_name: sklearn
|
|
|
tags:
|
|
|
- sklearn
|
|
|
- classification
|
|
|
- random-forest
|
|
|
- food-science
|
|
|
- milk-quality
|
|
|
pipeline_tag: tabular-classification
|
|
|
---
|
|
|
|
|
|
# Milk Spoilage Classification Model
|
|
|
|
|
|
A Random Forest classifier for predicting milk spoilage type based on microbial count data.
|
|
|
|
|
|
## Model Description
|
|
|
|
|
|
This model classifies milk samples into three spoilage categories based on Standard Plate Count (SPC) and Total Gram-Negative (TGN) bacterial counts measured at days 7, 14, and 21 of shelf life.
|
|
|
|
|
|
### Classes
|
|
|
|
|
|
- **PPC**: Post-Pasteurization Contamination
|
|
|
- **no spoilage**: No spoilage detected
|
|
|
- **spore spoilage**: Spore-forming bacteria spoilage
|
|
|
|
|
|
### Input Features
|
|
|
|
|
|
| Feature | Description |
|
|
|
|---------|-------------|
|
|
|
| SPC_D7 | Standard Plate Count at Day 7 (log CFU/mL) |
|
|
|
| SPC_D14 | Standard Plate Count at Day 14 (log CFU/mL) |
|
|
|
| SPC_D21 | Standard Plate Count at Day 21 (log CFU/mL) |
|
|
|
| TGN_D7 | Total Gram-Negative count at Day 7 (log CFU/mL) |
|
|
|
| TGN_D14 | Total Gram-Negative count at Day 14 (log CFU/mL) |
|
|
|
| TGN_D21 | Total Gram-Negative count at Day 21 (log CFU/mL) |
|
|
|
|
|
|
## Performance
|
|
|
|
|
|
- **Test Accuracy**: {accuracy:.2%}
|
|
|
|
|
|
## Usage
|
|
|
|
|
|
### Using the Inference API
|
|
|
|
|
|
```python
|
|
|
import requests
|
|
|
|
|
|
API_URL = "https://api-inference.huggingface.co/models/chenhaoq87/MilkSpoilageClassifier"
|
|
|
headers = {{"Authorization": "Bearer YOUR_HF_TOKEN"}}
|
|
|
|
|
|
# Input: [SPC_D7, SPC_D14, SPC_D21, TGN_D7, TGN_D14, TGN_D21]
|
|
|
payload = {{"inputs": [[4.5, 5.2, 6.1, 3.2, 4.0, 4.8]]}}
|
|
|
|
|
|
response = requests.post(API_URL, headers=headers, json=payload)
|
|
|
print(response.json())
|
|
|
```
|
|
|
|
|
|
### Local Usage
|
|
|
|
|
|
```python
|
|
|
import joblib
|
|
|
import numpy as np
|
|
|
|
|
|
# Load the model
|
|
|
model = joblib.load("model.joblib")
|
|
|
|
|
|
# Prepare input features
|
|
|
# [SPC_D7, SPC_D14, SPC_D21, TGN_D7, TGN_D14, TGN_D21]
|
|
|
features = np.array([[4.5, 5.2, 6.1, 3.2, 4.0, 4.8]])
|
|
|
|
|
|
# Make prediction
|
|
|
prediction = model.predict(features)
|
|
|
probabilities = model.predict_proba(features)
|
|
|
|
|
|
print(f"Predicted class: {{prediction[0]}}")
|
|
|
print(f"Class probabilities: {{dict(zip(model.classes_, probabilities[0]))}}")
|
|
|
```
|
|
|
|
|
|
## Model Details
|
|
|
|
|
|
- **Model Type**: Random Forest Classifier
|
|
|
- **Framework**: scikit-learn
|
|
|
- **Number of Estimators**: 100
|
|
|
- **Max Depth**: None (unlimited)
|
|
|
- **Min Samples Split**: 5
|
|
|
- **Min Samples Leaf**: 1
|
|
|
|
|
|
## Citation
|
|
|
|
|
|
If you use this model, please cite the original research on milk spoilage classification.
|
|
|
|
|
|
## License
|
|
|
|
|
|
MIT License
|
|
|
"""
|
|
|
|
|
|
with open(filepath, 'w') as f:
|
|
|
f.write(readme_content)
|
|
|
|
|
|
print("README saved!")
|
|
|
|
|
|
|
|
|
def main():
|
|
|
"""Main function to prepare all model artifacts."""
|
|
|
print("=" * 60)
|
|
|
print("Milk Spoilage Classification Model - Artifact Preparation")
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
X_train, y_train, X_test, y_test, feature_cols = load_and_prepare_data()
|
|
|
|
|
|
|
|
|
model = train_model(X_train, y_train)
|
|
|
|
|
|
|
|
|
accuracy = evaluate_model(model, X_test, y_test)
|
|
|
|
|
|
|
|
|
save_model(model)
|
|
|
create_config(model, feature_cols)
|
|
|
create_requirements()
|
|
|
create_readme(model, accuracy, feature_cols)
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
print("All artifacts created successfully!")
|
|
|
print("=" * 60)
|
|
|
print("\nGenerated files:")
|
|
|
print(" - model/model.joblib")
|
|
|
print(" - model/config.json")
|
|
|
print(" - requirements.txt")
|
|
|
print(" - README.md")
|
|
|
print("\nNext step: Run 'python scripts/upload_to_hf.py' to upload to Hugging Face")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|