hardin009 commited on
Commit
1402cea
·
verified ·
1 Parent(s): 3a563a6

Upload potato_price_model.py

Browse files
Files changed (1) hide show
  1. potato_price_model.py +73 -0
potato_price_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import joblib
3
+ import pandas as pd
4
+ import numpy as np
5
+ from datetime import datetime, timedelta
6
+
7
+ class PotatoPricePredictor:
8
+ def __init__(self):
9
+ self.model = joblib.load('best_potato_price_model_Ridge.joblib')
10
+ self.sentiment_analyzer = pipeline("sentiment-analysis")
11
+
12
+ def preprocess(self, data):
13
+ df = pd.DataFrame([data])
14
+ df['Date'] = pd.to_datetime(df['Date'])
15
+
16
+ df['DayOfWeek'] = df['Date'].dt.dayofweek
17
+ df['Month'] = df['Date'].dt.month
18
+ df['Quarter'] = df['Date'].dt.quarter
19
+ df['Year'] = df['Date'].dt.year
20
+
21
+ df['Events_Sentiment'] = df['Events'].apply(lambda x: self.sentiment_analyzer(x)[0]['score'] if x else 0)
22
+ df['Impacts_Sentiment'] = df['Impacts'].apply(lambda x: self.sentiment_analyzer(x)[0]['score'] if x else 0)
23
+
24
+ return df
25
+
26
+ def predict(self, data):
27
+ processed_data = self.preprocess(data)
28
+ features = ['ArrivalQuantity', 'Temperature', 'Humidity', 'Wind direction',
29
+ 'Events_Sentiment', 'Impacts_Sentiment', 'DayOfWeek', 'Month', 'Quarter', 'Year',
30
+ 'PriceLag1', 'PriceLag7', 'PriceRollingMean7', 'PriceRollingStd7', 'PrevWeekAvgPrice']
31
+
32
+ X = processed_data[features]
33
+ prediction = self.model.predict(X)
34
+
35
+ return {'predicted_price': float(prediction[0])}
36
+
37
+ def predict_future(self, days=30):
38
+ last_date = datetime.now().date()
39
+ future_dates = [last_date + timedelta(days=i) for i in range(1, days + 1)]
40
+
41
+ future_prices = []
42
+ last_price = 50 # You may want to adjust this initial value
43
+
44
+ for date in future_dates:
45
+ data = {
46
+ 'Date': date.strftime('%Y-%m-%d'),
47
+ 'ArrivalQuantity': 1000, # You may want to randomize or adjust these values
48
+ 'Temperature': 25,
49
+ 'Humidity': 60,
50
+ 'Wind direction': 180,
51
+ 'Events': 'Normal day',
52
+ 'Impacts': 'No significant impacts',
53
+ 'PriceLag1': last_price,
54
+ 'PriceLag7': last_price,
55
+ 'PriceRollingMean7': last_price,
56
+ 'PriceRollingStd7': 2,
57
+ 'PrevWeekAvgPrice': last_price
58
+ }
59
+
60
+ prediction = self.predict(data)
61
+ future_prices.append(prediction['predicted_price'])
62
+ last_price = prediction['predicted_price']
63
+
64
+ return {'future_prices': [{'date': date.strftime('%Y-%m-%d'), 'price': price} for date, price in zip(future_dates, future_prices)]}
65
+
66
+ predictor = PotatoPricePredictor()
67
+
68
+ def query(payload):
69
+ if payload.get('predict_future'):
70
+ days = payload.get('days', 30)
71
+ return predictor.predict_future(days)
72
+ else:
73
+ return predictor.predict(payload)