sborhade commited on
Commit
b8d885c
·
verified ·
1 Parent(s): 1e8bcd6

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +15 -15
inference.py CHANGED
@@ -1,18 +1,18 @@
1
- import sklearn # Explicit import
2
- import pandas as pd
3
- import pickle
4
- from datetime import datetime
5
 
6
- class ExpenseForecaster:
7
- def __init__(self, model_path="model/expense_forecaster_model.pkl"):
8
- with open(model_path, "rb") as model_file:
9
- self.model = pickle.load(model_file)
10
- self.min_date = pd.to_datetime("2024-01-01") # Reference date
11
 
12
- def __call__(self, input_date_str):
13
- input_date = pd.to_datetime(input_date_str)
14
- numerical_date = (input_date - self.min_date) / pd.Timedelta(days=30)
15
- prediction = self.model.predict(pd.DataFrame({"ds": [numerical_date]}))
16
- return prediction[0].tolist() # Return as a list
17
 
18
- model = ExpenseForecaster() # Instantiate the model
 
 
 
 
1
+ import pickle
2
+ import pandas as pd
 
 
3
 
4
+ def load_model():
5
+ with open("model/expense_forecaster_model.pkl", "rb") as f:
6
+ model = pickle.load(f)
7
+ return model
 
8
 
9
+ def predict(data):
10
+ model = load_model()
11
+ df = pd.DataFrame([data])
12
+ prediction = model.predict(df)
13
+ return prediction.tolist()
14
 
15
+ if __name__ == "__main__":
16
+ example_input = {"income": 5000, "previous_expenses": 3000, "month": 12} #example data, change this to match your feature names.
17
+ prediction = predict(example_input)
18
+ print(f"Prediction: {prediction}")