anujkum0x commited on
Commit
4622a9e
·
verified ·
1 Parent(s): 0d60395

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -13
app.py CHANGED
@@ -157,7 +157,7 @@ def create_plot(data, forecast_data, time_col, target_col):
157
  )
158
  return fig
159
 
160
- def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data) -> Tuple[str, object, str, str]:
161
  """
162
  Full pipeline: loads the data, calls the forecast function, and then processes the data.
163
  """
@@ -172,13 +172,6 @@ def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, fin
172
  # Sort the DataFrame by the time column
173
  data = data.sort_values(by=time_col)
174
 
175
- # Get min and max dates from the data
176
- min_date = data[time_col].min().strftime('%Y-%m-%d')
177
- max_date = data[time_col].max().strftime('%Y-%m-%d')
178
-
179
- # Fill missing values with 0
180
- data = data.fillna(0)
181
-
182
  # Apply date range selection
183
  if start_date and end_date:
184
  start_datetime = pd.to_datetime(start_date)
@@ -186,12 +179,40 @@ def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, fin
186
  data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
187
  logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
188
 
189
- data = data.set_index(time_col)
190
-
191
  # Resample the data
192
  data = data.resample(resample_freq).mean()
193
  data.reset_index(inplace=True)
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
196
  processed_data = process_forecast_data(forecast_result, time_col)
197
  processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
@@ -201,8 +222,13 @@ def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, fin
201
  else:
202
  merged_data = processed_data
203
 
204
- plot = create_plot(data, processed_data, time_col, target_col)
205
- csv_data = processed_data.to_csv(index=False)
 
 
 
 
 
206
 
207
  # Create a temporary file and write the CSV data to it
208
  with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile:
@@ -277,7 +303,7 @@ def create_interface():
277
  btn = gr.Button("Generate Forecast")
278
  btn.click(
279
  fn=full_forecast_pipeline,
280
- inputs=[file_input, time_col_input, target_col_input, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown, gr.Checkbox(label="Merge Data", value=False)],
281
  outputs=[output_csv, output_plot, download_button, error_output]
282
  )
283
  return iface
 
157
  )
158
  return fig
159
 
160
+ def full_forecast_pipeline(file_obj, time_col, target_col, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data, forecast_start_date, forecast_end_date) -> Tuple[str, object, str, str]:
161
  """
162
  Full pipeline: loads the data, calls the forecast function, and then processes the data.
163
  """
 
172
  # Sort the DataFrame by the time column
173
  data = data.sort_values(by=time_col)
174
 
 
 
 
 
 
 
 
175
  # Apply date range selection
176
  if start_date and end_date:
177
  start_datetime = pd.to_datetime(start_date)
 
179
  data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
180
  logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
181
 
 
 
182
  # Resample the data
183
  data = data.resample(resample_freq).mean()
184
  data.reset_index(inplace=True)
185
 
186
+ # Convert forecast start and end dates to datetime
187
+ if forecast_start_date and forecast_end_date:
188
+ forecast_start_datetime = pd.to_datetime(forecast_start_date)
189
+ forecast_end_datetime = pd.to_datetime(forecast_end_date)
190
+
191
+ # Calculate the time difference
192
+ time_difference = forecast_end_datetime - forecast_start_datetime
193
+
194
+ # Calculate forecast horizon based on frequency
195
+ if freq == 'D':
196
+ forecast_horizon = time_difference.days
197
+ elif freq == 'W':
198
+ forecast_horizon = time_difference.days / 7
199
+ elif freq == 'M':
200
+ forecast_horizon = time_difference.days / 30.44 # Average days in a month
201
+ elif freq == 'Y':
202
+ forecast_horizon = time_difference.days / 365.25 # Average days in a year
203
+ elif 'min' in freq:
204
+ minutes = int(freq.replace('min', ''))
205
+ forecast_horizon = time_difference.total_seconds() / (minutes * 60)
206
+ elif 'H' in freq:
207
+ hours = int(freq.replace('H', ''))
208
+ forecast_horizon = time_difference.total_seconds() / (hours * 3600)
209
+ else:
210
+ raise ValueError("Unsupported frequency. Please select a valid frequency.")
211
+
212
+ forecast_horizon = int(forecast_horizon)
213
+ else:
214
+ raise ValueError("Forecast start and end dates must be provided.")
215
+
216
  forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
217
  processed_data = process_forecast_data(forecast_result, time_col)
218
  processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
 
222
  else:
223
  merged_data = processed_data
224
 
225
+ # Filter forecast data based on forecast start and end dates
226
+ merged_data[time_col] = pd.to_datetime(merged_data[time_col]) # Ensure time_col is datetime
227
+ merged_data = merged_data[(merged_data[time_col] >= forecast_start_datetime) & (merged_data[time_col] <= forecast_end_datetime)]
228
+ logger.info(f"Forecast data filtered from {forecast_start_datetime} to {forecast_end_datetime}. Shape: {merged_data.shape}")
229
+
230
+ plot = create_plot(data, merged_data, time_col, target_col)
231
+ csv_data = merged_data.to_csv(index=False)
232
 
233
  # Create a temporary file and write the CSV data to it
234
  with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile:
 
303
  btn = gr.Button("Generate Forecast")
304
  btn.click(
305
  fn=full_forecast_pipeline,
306
+ inputs=[file_input, time_col_input, target_col_input, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown, gr.Checkbox(label="Merge Data", value=False), gr.Textbox(label="Forecast Start Date", placeholder="YYYY-MM-DD", value="2023-01-01"), gr.Textbox(label="Forecast End Date", placeholder="YYYY-MM-DD", value="2023-12-31")],
307
  outputs=[output_csv, output_plot, download_button, error_output]
308
  )
309
  return iface