Spaces:
Sleeping
Sleeping
Commit ·
be55486
1
Parent(s): a893300
Update predict.py to auto-load/train model
Browse files- src/predict.py +33 -23
src/predict.py
CHANGED
|
@@ -2,33 +2,43 @@
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import pickle
|
| 5 |
-
import
|
|
|
|
| 6 |
|
| 7 |
-
#
|
| 8 |
-
|
| 9 |
|
| 10 |
-
|
| 11 |
-
saved = pickle.load(f)
|
| 12 |
-
|
| 13 |
-
model = saved['model']
|
| 14 |
-
label_encoder = saved['label_encoder']
|
| 15 |
-
|
| 16 |
-
def predict_crop(features):
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
Args:
|
| 21 |
-
|
| 22 |
-
|
| 23 |
Returns:
|
| 24 |
-
|
| 25 |
"""
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
return crop_name
|
| 30 |
|
| 31 |
-
# Quick test
|
| 32 |
-
if __name__ == "__main__":
|
| 33 |
-
test_features = [90, 42, 43, 20, 80, 6.5, 200]
|
| 34 |
-
print("Predicted Crop:", predict_crop(test_features))
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import pickle
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from model_training import train_model # Your training function
|
| 7 |
|
| 8 |
+
# Path to the model file
|
| 9 |
+
MODEL_PATH = os.path.join(os.path.dirname(__file__), '../models/crop_model.pkl')
|
| 10 |
|
| 11 |
+
def load_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
+
Loads the crop prediction model from disk.
|
| 14 |
+
If the model file does not exist, trains a new model and saves it.
|
| 15 |
+
"""
|
| 16 |
+
if os.path.exists(MODEL_PATH):
|
| 17 |
+
# Load existing model
|
| 18 |
+
with open(MODEL_PATH, 'rb') as f:
|
| 19 |
+
model = pickle.load(f)
|
| 20 |
+
else:
|
| 21 |
+
# Train new model and save it
|
| 22 |
+
print("Model not found. Training a new model...")
|
| 23 |
+
model = train_model() # Make sure your model_training.py has train_model() returning a fitted model
|
| 24 |
+
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
|
| 25 |
+
with open(MODEL_PATH, 'wb') as f:
|
| 26 |
+
pickle.dump(model, f)
|
| 27 |
+
print("Model trained and saved.")
|
| 28 |
+
return model
|
| 29 |
+
|
| 30 |
+
# Load model once when the app starts
|
| 31 |
+
model = load_model()
|
| 32 |
+
|
| 33 |
+
def predict_crop(input_features: pd.DataFrame):
|
| 34 |
+
"""
|
| 35 |
+
Predicts the crop based on input features.
|
| 36 |
Args:
|
| 37 |
+
input_features: pd.DataFrame with columns ['N', 'P', 'K', 'temperature', 'humidity', 'ph', 'rainfall']
|
|
|
|
| 38 |
Returns:
|
| 39 |
+
List of predicted crop names
|
| 40 |
"""
|
| 41 |
+
predictions = model.predict(input_features)
|
| 42 |
+
return predictions.tolist()
|
| 43 |
+
|
|
|
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|