File size: 1,368 Bytes
4fec82c
 
 
 
be55486
 
4fec82c
be55486
 
4fec82c
be55486
4fec82c
be55486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fec82c
be55486
4fec82c
be55486
4fec82c
be55486
 
 
4fec82c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# src/predict.py

import os
import pickle
import pandas as pd
from model_training import train_model  # Your training function

# Path to the model file
MODEL_PATH = os.path.join(os.path.dirname(__file__), '../models/crop_model.pkl')

def load_model():
    """
    Loads the crop prediction model from disk.
    If the model file does not exist, trains a new model and saves it.
    """
    if os.path.exists(MODEL_PATH):
        # Load existing model
        with open(MODEL_PATH, 'rb') as f:
            model = pickle.load(f)
    else:
        # Train new model and save it
        print("Model not found. Training a new model...")
        model = train_model()  # Make sure your model_training.py has train_model() returning a fitted model
        os.makedirs(os.path.dirname(MODEL_PATH), exist_ok=True)
        with open(MODEL_PATH, 'wb') as f:
            pickle.dump(model, f)
        print("Model trained and saved.")
    return model

# Load model once when the app starts
model = load_model()

def predict_crop(input_features: pd.DataFrame):
    """
    Predicts the crop based on input features.
    Args:
        input_features: pd.DataFrame with columns ['N', 'P', 'K', 'temperature', 'humidity', 'ph', 'rainfall']
    Returns:
        List of predicted crop names
    """
    predictions = model.predict(input_features)
    return predictions.tolist()