Update inference.py
Browse files- inference.py +5 -1
inference.py
CHANGED
|
@@ -21,7 +21,12 @@ def lstm_forecast(ts_data, look_back=60, steps=5, epochs=20):
|
|
| 21 |
"""
|
| 22 |
ts_data: list of historical stock prices
|
| 23 |
steps: number of future steps to forecast
|
|
|
|
| 24 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# Normalize
|
| 26 |
scaler = MinMaxScaler(feature_range=(0, 1))
|
| 27 |
scaled_data = scaler.fit_transform(np.array(ts_data).reshape(-1,1))
|
|
@@ -61,7 +66,6 @@ def lstm_forecast(ts_data, look_back=60, steps=5, epochs=20):
|
|
| 61 |
predictions = scaler.inverse_transform(np.array(predictions).reshape(-1,1))
|
| 62 |
return predictions.flatten().tolist()
|
| 63 |
|
| 64 |
-
|
| 65 |
def infer(model_type: str, input_data: list, steps: int = 5):
|
| 66 |
"""
|
| 67 |
model_type: 'arima' or 'lstm'
|
|
|
|
| 21 |
"""
|
| 22 |
ts_data: list of historical stock prices
|
| 23 |
steps: number of future steps to forecast
|
| 24 |
+
Automatically adjusts look_back if input is shorter than look_back.
|
| 25 |
"""
|
| 26 |
+
# Adjust look_back if input is too short
|
| 27 |
+
if len(ts_data) < look_back + 1:
|
| 28 |
+
look_back = max(1, len(ts_data) - 1)
|
| 29 |
+
|
| 30 |
# Normalize
|
| 31 |
scaler = MinMaxScaler(feature_range=(0, 1))
|
| 32 |
scaled_data = scaler.fit_transform(np.array(ts_data).reshape(-1,1))
|
|
|
|
| 66 |
predictions = scaler.inverse_transform(np.array(predictions).reshape(-1,1))
|
| 67 |
return predictions.flatten().tolist()
|
| 68 |
|
|
|
|
| 69 |
def infer(model_type: str, input_data: list, steps: int = 5):
|
| 70 |
"""
|
| 71 |
model_type: 'arima' or 'lstm'
|