jeshwanth93 commited on
Commit
be55486
·
1 Parent(s): a893300

Update predict.py to auto-load/train model

Browse files
Files changed (1) hide show
  1. src/predict.py +33 -23
src/predict.py CHANGED
@@ -2,33 +2,43 @@
2
 
3
  import os
4
  import pickle
5
- import numpy as np
 
6
 
7
- # Load trained model and label encoder
8
- model_path = os.path.join(os.path.dirname(__file__), '..', 'models', 'crop_model.pkl')
9
 
10
- with open(model_path, 'rb') as f:
11
- saved = pickle.load(f)
12
-
13
- model = saved['model']
14
- label_encoder = saved['label_encoder']
15
-
16
- def predict_crop(features):
17
  """
18
- Predict the best crop based on input features.
19
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  Args:
21
- features (list or array): [N, P, K, temperature, humidity, ph, rainfall]
22
-
23
  Returns:
24
- str: Predicted crop name
25
  """
26
- features = np.array(features).reshape(1, -1) # reshape for single sample
27
- pred_encoded = model.predict(features)[0] # get encoded label
28
- crop_name = label_encoder.inverse_transform([pred_encoded])[0] # decode label
29
- return crop_name
30
 
31
- # Quick test
32
- if __name__ == "__main__":
33
- test_features = [90, 42, 43, 20, 80, 6.5, 200]
34
- print("Predicted Crop:", predict_crop(test_features))
 
2
 
3
  import os
4
  import pickle
5
+ import pandas as pd
6
+ from model_training import train_model # Your training function
7
 
8
+ # Path to the model file
9
+ MODEL_PATH = os.path.join(os.path.dirname(__file__), '../models/crop_model.pkl')
10
 
11
+ def load_model():
 
 
 
 
 
 
12
  """
13
+ Loads the crop prediction model from disk.
14
+ If the model file does not exist, trains a new model and saves it.
15
+ """
16
+ if os.path.exists(MODEL_PATH):
17
+ # Load existing model
18
+ with open(MODEL_PATH, 'rb') as f:
19
+ model = pickle.load(f)
20
+ else:
21
+ # Train new model and save it
22
+ print("Model not found. Training a new model...")
23
+ model = train_model() # Make sure your model_training.py has train_model() returning a fitted model
24
+ os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
25
+ with open(MODEL_PATH, 'wb') as f:
26
+ pickle.dump(model, f)
27
+ print("Model trained and saved.")
28
+ return model
29
+
30
+ # Load model once when the app starts
31
+ model = load_model()
32
+
33
+ def predict_crop(input_features: pd.DataFrame):
34
+ """
35
+ Predicts the crop based on input features.
36
  Args:
37
+ input_features: pd.DataFrame with columns ['N', 'P', 'K', 'temperature', 'humidity', 'ph', 'rainfall']
 
38
  Returns:
39
+ List of predicted crop names
40
  """
41
+ predictions = model.predict(input_features)
42
+ return predictions.tolist()
43
+
 
44