anujkum0x commited on
Commit
4c65297
·
verified ·
1 Parent(s): c9a20e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -16
app.py CHANGED
@@ -51,7 +51,7 @@ def load_data(file_obj):
51
  logger.error(f"Error loading data: {e}", exc_info=True)
52
  raise ValueError(f"Error loading data: {e}")
53
 
54
- def forecast_nixtla(df, forecast_horizon, finetune_steps, freq):
55
  """
56
  Function to call the Nixtla API directly.
57
  """
@@ -61,8 +61,8 @@ def forecast_nixtla(df, forecast_horizon, finetune_steps, freq):
61
  df=df,
62
  h=forecast_horizon,
63
  finetune_steps=finetune_steps,
64
- time_col="start_time",
65
- target_col="num_calls_queued",
66
  freq=freq
67
  )
68
  logger.info("Nixtla API call successful")
@@ -157,7 +157,7 @@ def create_plot(data, forecast_data, time_col, target_col):
157
  )
158
  return fig
159
 
160
- def full_forecast_pipeline(file_obj, forecast_horizon, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data) -> Tuple[str, object, str, str]:
161
  """
162
  Full pipeline: loads the data, calls the forecast function, and then processes the data.
163
  """
@@ -167,14 +167,14 @@ def full_forecast_pipeline(file_obj, forecast_horizon, finetune_steps, freq, sta
167
  return "Error loading data. Please check the file format and content.", None, None, None
168
 
169
  # Convert time column to datetime
170
- data["start_time"] = pd.to_datetime(data["start_time"])
171
 
172
  # Sort the DataFrame by the time column
173
- data = data.sort_values(by="start_time")
174
 
175
  # Get min and max dates from the data
176
- min_date = data["start_time"].min().strftime('%Y-%m-%d')
177
- max_date = data["start_time"].max().strftime('%Y-%m-%d')
178
 
179
  # Fill missing values with 0
180
  data = data.fillna(0)
@@ -183,25 +183,64 @@ def full_forecast_pipeline(file_obj, forecast_horizon, finetune_steps, freq, sta
183
  if start_date and end_date:
184
  start_datetime = pd.to_datetime(start_date)
185
  end_datetime = pd.to_datetime(end_date)
186
- data = data[(data["start_time"] >= start_datetime) & (data["start_time"] <= end_datetime)]
187
  logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
188
 
189
- data = data.set_index("start_time")
190
 
191
  # Resample the data
192
  data = data.resample(resample_freq).mean()
193
  data.reset_index(inplace=True)
194
 
195
- forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq)
196
- processed_data = process_forecast_data(forecast_result, "start_time")
197
- processed_data = apply_zero_patterns(data.copy(), processed_data, "start_time", "num_calls_queued")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  if merge_data:
200
- merged_data = pd.merge(data.reset_index(), processed_data, on="start_time", how='inner')
201
  else:
202
  merged_data = processed_data
203
 
204
- plot = create_plot(data, processed_data, "start_time", "num_calls_queued")
205
  csv_data = processed_data.to_csv(index=False)
206
 
207
  # Create a temporary file and write the CSV data to it
@@ -217,6 +256,31 @@ def full_forecast_pipeline(file_obj, forecast_horizon, finetune_steps, freq, sta
217
  logger.exception("An unexpected error occurred:")
218
  return f"Error: An unexpected error occurred: {e}", None, None, None
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  def create_interface():
221
  with gr.Blocks() as iface:
222
  gr.Markdown("""
@@ -226,6 +290,10 @@ def create_interface():
226
 
227
  file_input = gr.File(label="Upload Time Series Data (CSV, Excel, JSON, YAML)")
228
 
 
 
 
 
229
  with gr.Row():
230
  forecast_horizon_input = gr.Number(label="Forecast Horizon", value=10)
231
  finetune_steps_input = gr.Number(label="Finetune Steps", value=100)
@@ -237,6 +305,10 @@ def create_interface():
237
  end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-12-31")
238
  end_time_input = gr.Textbox(label="End Time (HH:MM)", placeholder="HH:MM", value="23:59")
239
 
 
 
 
 
240
  resample_freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Resample Frequency", value='D')
241
 
242
  output_csv = gr.Textbox(label="Forecast Data (CSV)")
@@ -248,7 +320,7 @@ def create_interface():
248
  btn = gr.Button("Generate Forecast")
249
  btn.click(
250
  fn=full_forecast_pipeline,
251
- inputs=[file_input, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown],
252
  outputs=[output_csv, output_plot, download_button, error_output]
253
  )
254
  return iface
 
51
  logger.error(f"Error loading data: {e}", exc_info=True)
52
  raise ValueError(f"Error loading data: {e}")
53
 
54
+ def forecast_nixtla(df, forecast_horizon, finetune_steps, freq, time_col, target_col):
55
  """
56
  Function to call the Nixtla API directly.
57
  """
 
61
  df=df,
62
  h=forecast_horizon,
63
  finetune_steps=finetune_steps,
64
+ time_col=time_col,
65
+ target_col=target_col,
66
  freq=freq
67
  )
68
  logger.info("Nixtla API call successful")
 
157
  )
158
  return fig
159
 
160
+ def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data, forecast_start_date, forecast_end_date) -> Tuple[str, object, str, str]:
161
  """
162
  Full pipeline: loads the data, calls the forecast function, and then processes the data.
163
  """
 
167
  return "Error loading data. Please check the file format and content.", None, None, None
168
 
169
  # Convert time column to datetime
170
+ data[time_col] = pd.to_datetime(data[time_col])
171
 
172
  # Sort the DataFrame by the time column
173
+ data = data.sort_values(by=time_col)
174
 
175
  # Get min and max dates from the data
176
+ min_date = data[time_col].min().strftime('%Y-%m-%d')
177
+ max_date = data[time_col].max().strftime('%Y-%m-%d')
178
 
179
  # Fill missing values with 0
180
  data = data.fillna(0)
 
183
  if start_date and end_date:
184
  start_datetime = pd.to_datetime(start_date)
185
  end_datetime = pd.to_datetime(end_date)
186
+ data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
187
  logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
188
 
189
+ data = data.set_index(time_col)
190
 
191
  # Resample the data
192
  data = data.resample(resample_freq).mean()
193
  data.reset_index(inplace=True)
194
 
195
+ if forecast_start_date and forecast_end_date:
196
+ forecast_start_datetime = pd.to_datetime(forecast_start_date)
197
+ forecast_end_datetime = pd.to_datetime(forecast_end_date)
198
+
199
+ # Calculate the time difference in days
200
+ time_difference = (forecast_end_datetime - forecast_start_datetime).days
201
+
202
+ # Adjust forecast horizon based on frequency
203
+ if freq == 'D':
204
+ forecast_horizon = time_difference
205
+ elif freq == 'W':
206
+ forecast_horizon = time_difference / 7
207
+ elif freq == 'M':
208
+ forecast_horizon = time_difference / 30 # Approximation
209
+ elif freq == 'Y':
210
+ forecast_horizon = time_difference / 365 # Approximation
211
+ elif 'min' in freq:
212
+ minutes = int(freq.replace('min', ''))
213
+ forecast_horizon = time_difference * 24 * 60 / minutes
214
+ elif 'H' in freq:
215
+ hours = int(freq.replace('H', ''))
216
+ forecast_horizon = time_difference * 24 / hours
217
+ else:
218
+ raise ValueError("Unsupported frequency. Please select a valid frequency.")
219
+
220
+ forecast_horizon = int(forecast_horizon) # Convert to integer
221
+
222
+ # Generate complete date range
223
+ start_datetime = data[time_col].min()
224
+ end_datetime = data[time_col].max()
225
+ complete_date_range = pd.date_range(start=start_datetime, end=end_datetime, freq=resample_freq)
226
+
227
+ # Reindex the data
228
+ data = data.set_index(time_col)
229
+ data = data.reindex(complete_date_range)
230
+ data = data.fillna(0)
231
+ data = data.reset_index()
232
+ data = data.rename(columns={'index': time_col})
233
+
234
+ forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
235
+ processed_data = process_forecast_data(forecast_result, time_col)
236
+ processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
237
 
238
  if merge_data:
239
+ merged_data = pd.merge(data.reset_index(), processed_data, on=time_col, how='inner')
240
  else:
241
  merged_data = processed_data
242
 
243
+ plot = create_plot(data, processed_data, time_col, target_col)
244
  csv_data = processed_data.to_csv(index=False)
245
 
246
  # Create a temporary file and write the CSV data to it
 
256
  logger.exception("An unexpected error occurred:")
257
  return f"Error: An unexpected error occurred: {e}", None, None, None
258
 
259
+ def get_column_names(file_obj):
260
+ """
261
+ Extracts column names from the uploaded file.
262
+ """
263
+ try:
264
+ df = load_data(file_obj)
265
+ columns = df.columns.tolist()
266
+ print(f"Column names: {columns}")
267
+ return columns
268
+ except Exception as e:
269
+ logger.error(f"Error in get_column_names: {e}", exc_info=True)
270
+ print(f"Error in get_column_names: {e}")
271
+ return []
272
+
273
+ def update_dropdown_choices(file_obj):
274
+ """
275
+ Updates the dropdown choices based on the uploaded file.
276
+ """
277
+ try:
278
+ columns = get_column_names(file_obj)
279
+ return gr.Dropdown.update(choices=columns), gr.Dropdown.update(choices=columns)
280
+ except Exception as e:
281
+ logger.error(f"Error updating dropdown choices: {e}", exc_info=True)
282
+ return gr.Dropdown.update(choices=[]), gr.Dropdown.update(choices=[])
283
+
284
  def create_interface():
285
  with gr.Blocks() as iface:
286
  gr.Markdown("""
 
290
 
291
  file_input = gr.File(label="Upload Time Series Data (CSV, Excel, JSON, YAML)")
292
 
293
+ with gr.Row():
294
+ time_col_input = gr.Textbox(label="Time Column", placeholder="Enter time column name")
295
+ target_col_input = gr.Textbox(label="Target Column", placeholder="Enter target column name")
296
+
297
  with gr.Row():
298
  forecast_horizon_input = gr.Number(label="Forecast Horizon", value=10)
299
  finetune_steps_input = gr.Number(label="Finetune Steps", value=100)
 
305
  end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-12-31")
306
  end_time_input = gr.Textbox(label="End Time (HH:MM)", placeholder="HH:MM", value="23:59")
307
 
308
+ with gr.Row():
309
+ forecast_start_date_input = gr.Textbox(label="Forecast Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD")
310
+ forecast_end_date_input = gr.Textbox(label="Forecast End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD")
311
+
312
  resample_freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Resample Frequency", value='D')
313
 
314
  output_csv = gr.Textbox(label="Forecast Data (CSV)")
 
320
  btn = gr.Button("Generate Forecast")
321
  btn.click(
322
  fn=full_forecast_pipeline,
323
+ inputs=[file_input, time_col_input, target_col_input, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown, gr.Checkbox(label="Merge Data", value=False), forecast_start_date_input, forecast_end_date_input],
324
  outputs=[output_csv, output_plot, download_button, error_output]
325
  )
326
  return iface