Spaces:
Sleeping
Sleeping
| # src/predict.py | |
| import os | |
| import pickle | |
| import pandas as pd | |
| from model_training import train_model # Your training function | |
| # Path to the model file | |
| MODEL_PATH = os.path.join(os.path.dirname(__file__), '../models/crop_model.pkl') | |
| def load_model(): | |
| """ | |
| Loads the crop prediction model from disk. | |
| If the model file does not exist, trains a new model and saves it. | |
| """ | |
| if os.path.exists(MODEL_PATH): | |
| # Load existing model | |
| with open(MODEL_PATH, 'rb') as f: | |
| model = pickle.load(f) | |
| else: | |
| # Train new model and save it | |
| print("Model not found. Training a new model...") | |
| model = train_model() # Make sure your model_training.py has train_model() returning a fitted model | |
| os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True) | |
| with open(MODEL_PATH, 'wb') as f: | |
| pickle.dump(model, f) | |
| print("Model trained and saved.") | |
| return model | |
| # Load model once when the app starts | |
| model = load_model() | |
| def predict_crop(input_features: pd.DataFrame): | |
| """ | |
| Predicts the crop based on input features. | |
| Args: | |
| input_features: pd.DataFrame with columns ['N', 'P', 'K', 'temperature', 'humidity', 'ph', 'rainfall'] | |
| Returns: | |
| List of predicted crop names | |
| """ | |
| predictions = model.predict(input_features) | |
| return predictions.tolist() | |