suggestion / src /predict.py
jeshwanth93's picture
Update predict.py to auto-load/train model
be55486
# src/predict.py
import os
import pickle
import pandas as pd
from model_training import train_model # Your training function
# Path to the model file
MODEL_PATH = os.path.join(os.path.dirname(__file__), '../models/crop_model.pkl')
def load_model():
"""
Loads the crop prediction model from disk.
If the model file does not exist, trains a new model and saves it.
"""
if os.path.exists(MODEL_PATH):
# Load existing model
with open(MODEL_PATH, 'rb') as f:
model = pickle.load(f)
else:
# Train new model and save it
print("Model not found. Training a new model...")
model = train_model() # Make sure your model_training.py has train_model() returning a fitted model
os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
with open(MODEL_PATH, 'wb') as f:
pickle.dump(model, f)
print("Model trained and saved.")
return model
# Load model once when the app starts
model = load_model()
def predict_crop(input_features: pd.DataFrame):
"""
Predicts the crop based on input features.
Args:
input_features: pd.DataFrame with columns ['N', 'P', 'K', 'temperature', 'humidity', 'ph', 'rainfall']
Returns:
List of predicted crop names
"""
predictions = model.predict(input_features)
return predictions.tolist()