chenhaoq87 commited on
Commit
b6193b3
·
verified ·
1 Parent(s): 79ac555

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. .DS_Store +0 -0
  2. README.md +92 -0
  3. app.py +210 -0
  4. config.json +38 -0
  5. handler.py +160 -0
  6. model.joblib +3 -0
  7. requirements.txt +4 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
README.md CHANGED
@@ -1,3 +1,95 @@
1
  ---
2
  license: mit
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ library_name: sklearn
4
+ tags:
5
+ - sklearn
6
+ - classification
7
+ - random-forest
8
+ - food-science
9
+ - milk-quality
10
+ pipeline_tag: tabular-classification
11
  ---
12
+
13
+ # Milk Spoilage Classification Model
14
+
15
+ A Random Forest classifier for predicting milk spoilage type based on microbial count data.
16
+
17
+ ## Model Description
18
+
19
+ This model classifies milk samples into three spoilage categories based on Standard Plate Count (SPC) and Total Gram-Negative (TGN) bacterial counts measured at days 7, 14, and 21 of shelf life.
20
+
21
+ ### Classes
22
+
23
+ - **PPC**: Post-Pasteurization Contamination
24
+ - **no spoilage**: No spoilage detected
25
+ - **spore spoilage**: Spore-forming bacteria spoilage
26
+
27
+ ### Input Features
28
+
29
+ | Feature | Description |
30
+ |---------|-------------|
31
+ | SPC_D7 | Standard Plate Count at Day 7 (log CFU/mL) |
32
+ | SPC_D14 | Standard Plate Count at Day 14 (log CFU/mL) |
33
+ | SPC_D21 | Standard Plate Count at Day 21 (log CFU/mL) |
34
+ | TGN_D7 | Total Gram-Negative count at Day 7 (log CFU/mL) |
35
+ | TGN_D14 | Total Gram-Negative count at Day 14 (log CFU/mL) |
36
+ | TGN_D21 | Total Gram-Negative count at Day 21 (log CFU/mL) |
37
+
38
+ ## Performance
39
+
40
+ - **Test Accuracy**: 95.76%
41
+
42
+ ## Usage
43
+
44
+ ### Using the Inference API
45
+
46
+ ```python
47
+ import requests
48
+
49
+ API_URL = "https://api-inference.huggingface.co/models/chenhaoq87/MilkSpoilageClassifier"
50
+ headers = {"Authorization": "Bearer YOUR_HF_TOKEN"}
51
+
52
+ # Input: [SPC_D7, SPC_D14, SPC_D21, TGN_D7, TGN_D14, TGN_D21]
53
+ payload = {"inputs": [[4.5, 5.2, 6.1, 3.2, 4.0, 4.8]]}
54
+
55
+ response = requests.post(API_URL, headers=headers, json=payload)
56
+ print(response.json())
57
+ ```
58
+
59
+ ### Local Usage
60
+
61
+ ```python
62
+ import joblib
63
+ import numpy as np
64
+
65
+ # Load the model
66
+ model = joblib.load("model.joblib")
67
+
68
+ # Prepare input features
69
+ # [SPC_D7, SPC_D14, SPC_D21, TGN_D7, TGN_D14, TGN_D21]
70
+ features = np.array([[4.5, 5.2, 6.1, 3.2, 4.0, 4.8]])
71
+
72
+ # Make prediction
73
+ prediction = model.predict(features)
74
+ probabilities = model.predict_proba(features)
75
+
76
+ print(f"Predicted class: {prediction[0]}")
77
+ print(f"Class probabilities: {dict(zip(model.classes_, probabilities[0]))}")
78
+ ```
79
+
80
+ ## Model Details
81
+
82
+ - **Model Type**: Random Forest Classifier
83
+ - **Framework**: scikit-learn
84
+ - **Number of Estimators**: 100
85
+ - **Max Depth**: None (unlimited)
86
+ - **Min Samples Split**: 5
87
+ - **Min Samples Leaf**: 1
88
+
89
+ ## Citation
90
+
91
+ If you use this model, please cite the original research on milk spoilage classification.
92
+
93
+ ## License
94
+
95
+ MIT License
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Web Interface for Milk Spoilage Classification
3
+
4
+ This app provides an interactive web interface for predicting
5
+ milk spoilage type based on microbial count data.
6
+ """
7
+
8
+ import gradio as gr
9
+ import joblib
10
+ import numpy as np
11
+
12
+
13
+ # Load the trained model
14
+ model = joblib.load("model.joblib")
15
+
16
+ # Feature information for the UI
17
+ FEATURE_INFO = {
18
+ "SPC_D7": ("Standard Plate Count - Day 7", "log CFU/mL", 0.0, 10.0, 4.0),
19
+ "SPC_D14": ("Standard Plate Count - Day 14", "log CFU/mL", 0.0, 10.0, 5.0),
20
+ "SPC_D21": ("Standard Plate Count - Day 21", "log CFU/mL", 0.0, 10.0, 6.0),
21
+ "TGN_D7": ("Total Gram-Negative - Day 7", "log CFU/mL", 0.0, 10.0, 3.0),
22
+ "TGN_D14": ("Total Gram-Negative - Day 14", "log CFU/mL", 0.0, 10.0, 4.0),
23
+ "TGN_D21": ("Total Gram-Negative - Day 21", "log CFU/mL", 0.0, 10.0, 5.0),
24
+ }
25
+
26
+ # Class descriptions
27
+ CLASS_DESCRIPTIONS = {
28
+ "PPC": "Post-Pasteurization Contamination - Bacteria introduced after pasteurization",
29
+ "no spoilage": "No significant spoilage detected in the sample",
30
+ "spore spoilage": "Spoilage caused by spore-forming bacteria"
31
+ }
32
+
33
+
34
+ def predict_spoilage(spc_d7, spc_d14, spc_d21, tgn_d7, tgn_d14, tgn_d21):
35
+ """
36
+ Predict milk spoilage type based on microbial counts.
37
+
38
+ Args:
39
+ spc_d7: Standard Plate Count at Day 7
40
+ spc_d14: Standard Plate Count at Day 14
41
+ spc_d21: Standard Plate Count at Day 21
42
+ tgn_d7: Total Gram-Negative count at Day 7
43
+ tgn_d14: Total Gram-Negative count at Day 14
44
+ tgn_d21: Total Gram-Negative count at Day 21
45
+
46
+ Returns:
47
+ Dictionary of class probabilities for Gradio Label component
48
+ """
49
+ # Prepare input features
50
+ features = np.array([[spc_d7, spc_d14, spc_d21, tgn_d7, tgn_d14, tgn_d21]])
51
+
52
+ # Get prediction and probabilities
53
+ prediction = model.predict(features)[0]
54
+ probabilities = model.predict_proba(features)[0]
55
+
56
+ # Create probability dictionary for Gradio Label
57
+ prob_dict = {
58
+ cls: float(prob)
59
+ for cls, prob in zip(model.classes_, probabilities)
60
+ }
61
+
62
+ return prob_dict
63
+
64
+
65
+ def create_interface():
66
+ """Create and configure the Gradio interface."""
67
+
68
+ # Custom CSS for styling
69
+ custom_css = """
70
+ .gradio-container {
71
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
72
+ }
73
+ .feature-group {
74
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
75
+ border-radius: 10px;
76
+ padding: 15px;
77
+ margin: 10px 0;
78
+ }
79
+ """
80
+
81
+ with gr.Blocks(
82
+ title="Milk Spoilage Classifier",
83
+ theme=gr.themes.Soft(
84
+ primary_hue="indigo",
85
+ secondary_hue="purple",
86
+ ),
87
+ css=custom_css
88
+ ) as demo:
89
+
90
+ # Header
91
+ gr.Markdown(
92
+ """
93
+ # 🥛 Milk Spoilage Classification Model
94
+
95
+ Predict milk spoilage type based on microbial count data measured at different time points.
96
+ Enter the Standard Plate Count (SPC) and Total Gram-Negative (TGN) values below.
97
+ """
98
+ )
99
+
100
+ with gr.Row():
101
+ # Input Section
102
+ with gr.Column(scale=1):
103
+ gr.Markdown("### 📊 Standard Plate Count (SPC)")
104
+ gr.Markdown("*Total bacterial count in log CFU/mL*")
105
+
106
+ spc_d7 = gr.Number(
107
+ label="Day 7",
108
+ value=4.0,
109
+ minimum=0.0,
110
+ maximum=10.0,
111
+ info="SPC measurement at day 7"
112
+ )
113
+ spc_d14 = gr.Number(
114
+ label="Day 14",
115
+ value=5.0,
116
+ minimum=0.0,
117
+ maximum=10.0,
118
+ info="SPC measurement at day 14"
119
+ )
120
+ spc_d21 = gr.Number(
121
+ label="Day 21",
122
+ value=6.0,
123
+ minimum=0.0,
124
+ maximum=10.0,
125
+ info="SPC measurement at day 21"
126
+ )
127
+
128
+ with gr.Column(scale=1):
129
+ gr.Markdown("### 🦠 Total Gram-Negative (TGN)")
130
+ gr.Markdown("*Gram-negative bacterial count in log CFU/mL*")
131
+
132
+ tgn_d7 = gr.Number(
133
+ label="Day 7",
134
+ value=3.0,
135
+ minimum=0.0,
136
+ maximum=10.0,
137
+ info="TGN measurement at day 7"
138
+ )
139
+ tgn_d14 = gr.Number(
140
+ label="Day 14",
141
+ value=4.0,
142
+ minimum=0.0,
143
+ maximum=10.0,
144
+ info="TGN measurement at day 14"
145
+ )
146
+ tgn_d21 = gr.Number(
147
+ label="Day 21",
148
+ value=5.0,
149
+ minimum=0.0,
150
+ maximum=10.0,
151
+ info="TGN measurement at day 21"
152
+ )
153
+
154
+ # Predict button
155
+ predict_btn = gr.Button("🔬 Classify Spoilage Type", variant="primary", size="lg")
156
+
157
+ # Output Section
158
+ gr.Markdown("### 📋 Prediction Results")
159
+
160
+ output_label = gr.Label(
161
+ label="Spoilage Classification",
162
+ num_top_classes=3
163
+ )
164
+
165
+ # Connect the prediction function
166
+ predict_btn.click(
167
+ fn=predict_spoilage,
168
+ inputs=[spc_d7, spc_d14, spc_d21, tgn_d7, tgn_d14, tgn_d21],
169
+ outputs=output_label
170
+ )
171
+
172
+ # Also trigger on any input change
173
+ for input_component in [spc_d7, spc_d14, spc_d21, tgn_d7, tgn_d14, tgn_d21]:
174
+ input_component.change(
175
+ fn=predict_spoilage,
176
+ inputs=[spc_d7, spc_d14, spc_d21, tgn_d7, tgn_d14, tgn_d21],
177
+ outputs=output_label
178
+ )
179
+
180
+ # Information Section
181
+ gr.Markdown(
182
+ """
183
+ ---
184
+ ### ℹ️ About the Classes
185
+
186
+ | Class | Description |
187
+ |-------|-------------|
188
+ | **PPC** | Post-Pasteurization Contamination - Bacteria introduced after pasteurization process |
189
+ | **no spoilage** | No significant spoilage detected in the sample |
190
+ | **spore spoilage** | Spoilage caused by spore-forming bacteria that survive pasteurization |
191
+
192
+ ---
193
+ ### 📖 How to Use
194
+
195
+ 1. Enter the microbial count values (in log CFU/mL) for each time point
196
+ 2. Click "Classify Spoilage Type" or wait for automatic prediction
197
+ 3. View the predicted spoilage category and confidence scores
198
+
199
+ ---
200
+ *Model: Random Forest Classifier trained on milk quality data*
201
+ """
202
+ )
203
+
204
+ return demo
205
+
206
+
207
+ # Create and launch the interface
208
+ if __name__ == "__main__":
209
+ demo = create_interface()
210
+ demo.launch()
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "RandomForestClassifier",
3
+ "framework": "sklearn",
4
+ "task": "classification",
5
+ "features": [
6
+ "SPC_D7",
7
+ "SPC_D14",
8
+ "SPC_D21",
9
+ "TGN_D7",
10
+ "TGN_D14",
11
+ "TGN_D21"
12
+ ],
13
+ "feature_descriptions": {
14
+ "SPC_D7": "Standard Plate Count at Day 7 (log CFU/mL)",
15
+ "SPC_D14": "Standard Plate Count at Day 14 (log CFU/mL)",
16
+ "SPC_D21": "Standard Plate Count at Day 21 (log CFU/mL)",
17
+ "TGN_D7": "Total Gram-Negative count at Day 7 (log CFU/mL)",
18
+ "TGN_D14": "Total Gram-Negative count at Day 14 (log CFU/mL)",
19
+ "TGN_D21": "Total Gram-Negative count at Day 21 (log CFU/mL)"
20
+ },
21
+ "classes": [
22
+ "PPC",
23
+ "no spoilage",
24
+ "spore spoilage"
25
+ ],
26
+ "class_descriptions": {
27
+ "PPC": "Post-Pasteurization Contamination",
28
+ "no spoilage": "No spoilage detected",
29
+ "spore spoilage": "Spore-forming bacteria spoilage"
30
+ },
31
+ "hyperparameters": {
32
+ "n_estimators": 100,
33
+ "max_depth": null,
34
+ "min_samples_split": 5,
35
+ "min_samples_leaf": 1,
36
+ "random_state": 42
37
+ }
38
+ }
handler.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom Inference Handler for Hugging Face Inference Endpoints
3
+
4
+ This handler loads the trained RandomForest model and provides
5
+ prediction functionality for the Hugging Face Inference API.
6
+ """
7
+
8
+ import joblib
9
+ import numpy as np
10
+ from typing import Dict, List, Any, Union
11
+ import os
12
+
13
+
14
+ class EndpointHandler:
15
+ """
16
+ Custom handler for Hugging Face Inference Endpoints.
17
+
18
+ This class is automatically instantiated by the Inference API
19
+ and handles incoming prediction requests.
20
+ """
21
+
22
+ def __init__(self, path: str = ""):
23
+ """
24
+ Initialize the handler by loading the model.
25
+
26
+ Args:
27
+ path: Path to the model directory (provided by HF Inference API)
28
+ """
29
+ model_path = os.path.join(path, "model.joblib") if path else "model.joblib"
30
+ self.model = joblib.load(model_path)
31
+
32
+ # Feature names in expected order
33
+ self.feature_names = [
34
+ "SPC_D7", "SPC_D14", "SPC_D21",
35
+ "TGN_D7", "TGN_D14", "TGN_D21"
36
+ ]
37
+
38
+ # Class names from the model
39
+ self.class_names = list(self.model.classes_)
40
+
41
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
42
+ """
43
+ Handle prediction requests.
44
+
45
+ Args:
46
+ data: Input data dictionary. Supports multiple formats:
47
+ - {"inputs": [[f1, f2, f3, f4, f5, f6], ...]} # List of feature arrays
48
+ - {"inputs": {"SPC_D7": 4.5, ...}} # Dict with feature names
49
+ - {"inputs": [{"SPC_D7": 4.5, ...}, ...]} # List of dicts
50
+
51
+ Returns:
52
+ List of prediction results with labels and probabilities
53
+ """
54
+ # Extract inputs from the data
55
+ inputs = data.get("inputs", data)
56
+
57
+ # Convert inputs to numpy array
58
+ X = self._process_inputs(inputs)
59
+
60
+ # Make predictions
61
+ predictions = self.model.predict(X)
62
+ probabilities = self.model.predict_proba(X)
63
+
64
+ # Format results
65
+ results = []
66
+ for pred, probs in zip(predictions, probabilities):
67
+ result = {
68
+ "label": str(pred),
69
+ "score": float(max(probs)),
70
+ "probabilities": {
71
+ cls: float(prob)
72
+ for cls, prob in zip(self.class_names, probs)
73
+ }
74
+ }
75
+ results.append(result)
76
+
77
+ return results
78
+
79
+ def _process_inputs(self, inputs: Union[List, Dict]) -> np.ndarray:
80
+ """
81
+ Process various input formats into a numpy array.
82
+
83
+ Args:
84
+ inputs: Input data in various formats
85
+
86
+ Returns:
87
+ Numpy array of shape (n_samples, n_features)
88
+ """
89
+ # Case 1: List of lists/arrays (direct feature values)
90
+ if isinstance(inputs, list) and len(inputs) > 0:
91
+ if isinstance(inputs[0], (list, tuple, np.ndarray)):
92
+ return np.array(inputs).reshape(-1, len(self.feature_names))
93
+
94
+ # Case 2: List of dictionaries with feature names
95
+ elif isinstance(inputs[0], dict):
96
+ return np.array([
97
+ [sample.get(feat, 0) for feat in self.feature_names]
98
+ for sample in inputs
99
+ ])
100
+
101
+ # Case 3: Single sample as flat list
102
+ else:
103
+ return np.array(inputs).reshape(1, -1)
104
+
105
+ # Case 4: Single dictionary with feature names
106
+ elif isinstance(inputs, dict):
107
+ return np.array([[
108
+ inputs.get(feat, 0) for feat in self.feature_names
109
+ ]])
110
+
111
+ # Fallback: try to convert directly
112
+ return np.array(inputs).reshape(-1, len(self.feature_names))
113
+
114
+
115
+ # For local testing
116
+ if __name__ == "__main__":
117
+ # Test the handler locally
118
+ print("Testing EndpointHandler locally...")
119
+
120
+ try:
121
+ handler = EndpointHandler()
122
+
123
+ # Test with list format
124
+ test_data_list = {
125
+ "inputs": [[4.5, 5.2, 6.1, 3.2, 4.0, 4.8]]
126
+ }
127
+ result = handler(test_data_list)
128
+ print(f"\nTest 1 (list format):")
129
+ print(f" Input: {test_data_list}")
130
+ print(f" Output: {result}")
131
+
132
+ # Test with dict format
133
+ test_data_dict = {
134
+ "inputs": {
135
+ "SPC_D7": 4.5, "SPC_D14": 5.2, "SPC_D21": 6.1,
136
+ "TGN_D7": 3.2, "TGN_D14": 4.0, "TGN_D21": 4.8
137
+ }
138
+ }
139
+ result = handler(test_data_dict)
140
+ print(f"\nTest 2 (dict format):")
141
+ print(f" Input: {test_data_dict}")
142
+ print(f" Output: {result}")
143
+
144
+ # Test batch prediction
145
+ test_data_batch = {
146
+ "inputs": [
147
+ [4.5, 5.2, 6.1, 3.2, 4.0, 4.8],
148
+ [2.0, 2.5, 3.0, 1.5, 2.0, 2.5],
149
+ [6.0, 7.0, 8.0, 5.0, 6.0, 7.0]
150
+ ]
151
+ }
152
+ result = handler(test_data_batch)
153
+ print(f"\nTest 3 (batch format):")
154
+ print(f" Input: {test_data_batch}")
155
+ print(f" Output: {result}")
156
+
157
+ print("\nAll tests passed!")
158
+
159
+ except FileNotFoundError:
160
+ print("Note: model.joblib not found. Run 'python prepare_model.py' first.")
model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f56cb2839629f726b040cf8fa19fbc7a61e5b47a6fdbd414b96cccbc8a83b876
3
+ size 302097
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ scikit-learn>=1.0
2
+ joblib>=1.0
3
+ numpy>=1.20
4
+ pandas>=1.3