anujkum0x commited on
Commit
555bf89
·
verified ·
1 Parent(s): 4622a9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -43
app.py CHANGED
@@ -157,7 +157,7 @@ def create_plot(data, forecast_data, time_col, target_col):
157
  )
158
  return fig
159
 
160
- def full_forecast_pipeline(file_obj, time_col, target_col, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data, forecast_start_date, forecast_end_date) -> Tuple[str, object, str, str]:
161
  """
162
  Full pipeline: loads the data, calls the forecast function, and then processes the data.
163
  """
@@ -172,63 +172,73 @@ def full_forecast_pipeline(file_obj, time_col, target_col, finetune_steps, freq,
172
  # Sort the DataFrame by the time column
173
  data = data.sort_values(by=time_col)
174
 
175
- # Apply date range selection
 
 
 
 
 
 
 
176
  if start_date and end_date:
177
  start_datetime = pd.to_datetime(start_date)
178
  end_datetime = pd.to_datetime(end_date)
179
  data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
180
  logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
181
 
 
 
182
  # Resample the data
183
  data = data.resample(resample_freq).mean()
184
  data.reset_index(inplace=True)
185
 
186
- # Convert forecast start and end dates to datetime
187
- if forecast_start_date and forecast_end_date:
188
- forecast_start_datetime = pd.to_datetime(forecast_start_date)
189
  forecast_end_datetime = pd.to_datetime(forecast_end_date)
190
-
191
- # Calculate the time difference
192
- time_difference = forecast_end_datetime - forecast_start_datetime
193
-
194
- # Calculate forecast horizon based on frequency
195
- if freq == 'D':
196
- forecast_horizon = time_difference.days
197
- elif freq == 'W':
198
- forecast_horizon = time_difference.days / 7
199
- elif freq == 'M':
200
- forecast_horizon = time_difference.days / 30.44 # Average days in a month
201
- elif freq == 'Y':
202
- forecast_horizon = time_difference.days / 365.25 # Average days in a year
203
- elif 'min' in freq:
204
- minutes = int(freq.replace('min', ''))
205
- forecast_horizon = time_difference.total_seconds() / (minutes * 60)
206
- elif 'H' in freq:
207
- hours = int(freq.replace('H', ''))
208
- forecast_horizon = time_difference.total_seconds() / (hours * 3600)
209
  else:
210
- raise ValueError("Unsupported frequency. Please select a valid frequency.")
211
 
212
- forecast_horizon = int(forecast_horizon)
213
- else:
214
- raise ValueError("Forecast start and end dates must be provided.")
215
 
216
  forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
217
  processed_data = process_forecast_data(forecast_result, time_col)
218
  processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
219
 
 
 
 
 
 
 
 
 
220
  if merge_data:
221
  merged_data = pd.merge(data.reset_index(), processed_data, on=time_col, how='inner')
222
  else:
223
  merged_data = processed_data
224
 
225
- # Filter forecast data based on forecast start and end dates
226
- merged_data[time_col] = pd.to_datetime(merged_data[time_col]) # Ensure time_col is datetime
227
- merged_data = merged_data[(merged_data[time_col] >= forecast_start_datetime) & (merged_data[time_col] <= forecast_end_datetime)]
228
- logger.info(f"Forecast data filtered from {forecast_start_datetime} to {forecast_end_datetime}. Shape: {merged_data.shape}")
229
-
230
- plot = create_plot(data, merged_data, time_col, target_col)
231
- csv_data = merged_data.to_csv(index=False)
232
 
233
  # Create a temporary file and write the CSV data to it
234
  with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile:
@@ -282,15 +292,21 @@ def create_interface():
282
  target_col_input = gr.Textbox(label="Target Column", placeholder="Enter target column name")
283
 
284
  with gr.Row():
285
- forecast_horizon_input = gr.Number(label="Forecast Horizon", value=10)
286
  finetune_steps_input = gr.Number(label="Finetune Steps", value=100)
287
  freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Frequency", value='D')
288
 
289
- with gr.Row():
290
- start_date_input = gr.Textbox(label="Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-01-01")
291
- start_time_input = gr.Textbox(label="Start Time (HH:MM)", placeholder="HH:MM", value="00:00")
292
- end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-12-31")
293
- end_time_input = gr.Textbox(label="End Time (HH:MM)", placeholder="HH:MM", value="23:59")
 
 
 
 
 
 
294
 
295
  resample_freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Resample Frequency", value='D')
296
 
@@ -303,10 +319,10 @@ def create_interface():
303
  btn = gr.Button("Generate Forecast")
304
  btn.click(
305
  fn=full_forecast_pipeline,
306
- inputs=[file_input, time_col_input, target_col_input, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown, gr.Checkbox(label="Merge Data", value=False), gr.Textbox(label="Forecast Start Date", placeholder="YYYY-MM-DD", value="2023-01-01"), gr.Textbox(label="Forecast End Date", placeholder="YYYY-MM-DD", value="2023-12-31")],
307
  outputs=[output_csv, output_plot, download_button, error_output]
308
  )
309
  return iface
310
 
311
  iface = create_interface()
312
- iface.launch()
 
157
  )
158
  return fig
159
 
160
+ def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data, forecast_start_date, forecast_end_date) -> Tuple[str, object, str, str]:
161
  """
162
  Full pipeline: loads the data, calls the forecast function, and then processes the data.
163
  """
 
172
  # Sort the DataFrame by the time column
173
  data = data.sort_values(by=time_col)
174
 
175
+ # Get min and max dates from the data
176
+ min_date = data[time_col].min().strftime('%Y-%m-%d')
177
+ max_date = data[time_col].max().strftime('%Y-%m-%d')
178
+
179
+ # Fill missing values with 0
180
+ data = data.fillna(0)
181
+
182
+ # Apply date range selection for historical data
183
  if start_date and end_date:
184
  start_datetime = pd.to_datetime(start_date)
185
  end_datetime = pd.to_datetime(end_date)
186
  data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
187
  logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
188
 
189
+ data = data.set_index(time_col)
190
+
191
  # Resample the data
192
  data = data.resample(resample_freq).mean()
193
  data.reset_index(inplace=True)
194
 
195
+ # Calculate forecast horizon if forecast_end_date is provided
196
+ if forecast_end_date:
197
+ historical_end_date = pd.to_datetime(end_date) if end_date else data[time_col].max()
198
  forecast_end_datetime = pd.to_datetime(forecast_end_date)
199
+ day_difference = (forecast_end_datetime - historical_end_date).days
200
+ if day_difference <= 0:
201
+ raise ValueError("Forecast end date must be after the historical data end date.")
202
+
203
+ # Adjust forecast_horizon based on frequency
204
+ if freq == 'H':
205
+ forecast_horizon = day_difference * 24
206
+ elif freq == '30min':
207
+ forecast_horizon = day_difference * 48
208
+ elif freq == '15min':
209
+ forecast_horizon = day_difference * 96
210
+ elif freq == 'D':
211
+ forecast_horizon = day_difference
212
+ elif freq == 'W': # Approximation: 7 days in a week
213
+ forecast_horizon = day_difference / 7
214
+ elif freq == 'M': # Approximation: 30 days in a month
215
+ forecast_horizon = day_difference / 30
216
+ elif freq == 'Y': # Approximation: 365 days in a year
217
+ forecast_horizon = day_difference / 365
218
  else:
219
+ forecast_horizon = day_difference # Default to days if frequency is not recognized
220
 
221
+ forecast_horizon = max(1, int(round(forecast_horizon))) # Ensure forecast_horizon is at least 1 and integer
 
 
222
 
223
  forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
224
  processed_data = process_forecast_data(forecast_result, time_col)
225
  processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
226
 
227
+ # Apply forecast date range selection
228
+ if forecast_start_date and forecast_end_date:
229
+ forecast_start_datetime = pd.to_datetime(forecast_start_date)
230
+ forecast_end_datetime = pd.to_datetime(forecast_end_date)
231
+ processed_data = processed_data[(processed_data[time_col] >= forecast_start_datetime) & (processed_data[time_col] <= forecast_end_datetime)]
232
+ logger.info(f"Forecast data filtered from {forecast_start_datetime} to {forecast_end_datetime}. Shape: {processed_data.shape}")
233
+
234
+
235
  if merge_data:
236
  merged_data = pd.merge(data.reset_index(), processed_data, on=time_col, how='inner')
237
  else:
238
  merged_data = processed_data
239
 
240
+ plot = create_plot(data, processed_data, time_col, target_col)
241
+ csv_data = processed_data.to_csv(index=False)
 
 
 
 
 
242
 
243
  # Create a temporary file and write the CSV data to it
244
  with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile:
 
292
  target_col_input = gr.Textbox(label="Target Column", placeholder="Enter target column name")
293
 
294
  with gr.Row():
295
+ forecast_horizon_input = gr.Number(label="Forecast Horizon", value=10, visible=False) # Hide forecast horizon input
296
  finetune_steps_input = gr.Number(label="Finetune Steps", value=100)
297
  freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Frequency", value='D')
298
 
299
+ with gr.Column(): # Group date inputs in a column
300
+ with gr.Row():
301
+ start_date_input = gr.Textbox(label="Historical Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-01-01")
302
+ start_time_input = gr.Textbox(label="Start Time (HH:MM)", placeholder="HH:MM", value="00:00", visible=False) # Hide start time input
303
+ with gr.Row():
304
+ end_date_input = gr.Textbox(label="Historical End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-12-31")
305
+ end_time_input = gr.Textbox(label="End Time (HH:MM)", placeholder="HH:MM", value="23:59", visible=False) # Hide end time input
306
+ with gr.Row():
307
+ forecast_start_date_input = gr.Textbox(label="Forecast Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD")
308
+ forecast_end_date_input = gr.Textbox(label="Forecast End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD")
309
+
310
 
311
  resample_freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Resample Frequency", value='D')
312
 
 
319
  btn = gr.Button("Generate Forecast")
320
  btn.click(
321
  fn=full_forecast_pipeline,
322
+ inputs=[file_input, time_col_input, target_col_input, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown, gr.Checkbox(label="Merge Data", value=False), forecast_start_date_input, forecast_end_date_input],
323
  outputs=[output_csv, output_plot, download_button, error_output]
324
  )
325
  return iface
326
 
327
  iface = create_interface()
328
+ iface.launch()