Ti-sha commited on
Commit
788b276
·
verified ·
1 Parent(s): ca51b56

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +5 -1
inference.py CHANGED
@@ -21,7 +21,12 @@ def lstm_forecast(ts_data, look_back=60, steps=5, epochs=20):
21
  """
22
  ts_data: list of historical stock prices
23
  steps: number of future steps to forecast
 
24
  """
 
 
 
 
25
  # Normalize
26
  scaler = MinMaxScaler(feature_range=(0, 1))
27
  scaled_data = scaler.fit_transform(np.array(ts_data).reshape(-1,1))
@@ -61,7 +66,6 @@ def lstm_forecast(ts_data, look_back=60, steps=5, epochs=20):
61
  predictions = scaler.inverse_transform(np.array(predictions).reshape(-1,1))
62
  return predictions.flatten().tolist()
63
 
64
-
65
  def infer(model_type: str, input_data: list, steps: int = 5):
66
  """
67
  model_type: 'arima' or 'lstm'
 
21
  """
22
  ts_data: list of historical stock prices
23
  steps: number of future steps to forecast
24
+ Automatically adjusts look_back if input is shorter than look_back.
25
  """
26
+ # Adjust look_back if input is too short
27
+ if len(ts_data) < look_back + 1:
28
+ look_back = max(1, len(ts_data) - 1)
29
+
30
  # Normalize
31
  scaler = MinMaxScaler(feature_range=(0, 1))
32
  scaled_data = scaler.fit_transform(np.array(ts_data).reshape(-1,1))
 
66
  predictions = scaler.inverse_transform(np.array(predictions).reshape(-1,1))
67
  return predictions.flatten().tolist()
68
 
 
69
  def infer(model_type: str, input_data: list, steps: int = 5):
70
  """
71
  model_type: 'arima' or 'lstm'