Update app.py
Browse files
app.py
CHANGED
|
@@ -157,7 +157,7 @@ def create_plot(data, forecast_data, time_col, target_col):
|
|
| 157 |
)
|
| 158 |
return fig
|
| 159 |
|
| 160 |
-
def full_forecast_pipeline(file_obj, time_col, target_col,
|
| 161 |
"""
|
| 162 |
Full pipeline: loads the data, calls the forecast function, and then processes the data.
|
| 163 |
"""
|
|
@@ -172,13 +172,6 @@ def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, fin
|
|
| 172 |
# Sort the DataFrame by the time column
|
| 173 |
data = data.sort_values(by=time_col)
|
| 174 |
|
| 175 |
-
# Get min and max dates from the data
|
| 176 |
-
min_date = data[time_col].min().strftime('%Y-%m-%d')
|
| 177 |
-
max_date = data[time_col].max().strftime('%Y-%m-%d')
|
| 178 |
-
|
| 179 |
-
# Fill missing values with 0
|
| 180 |
-
data = data.fillna(0)
|
| 181 |
-
|
| 182 |
# Apply date range selection
|
| 183 |
if start_date and end_date:
|
| 184 |
start_datetime = pd.to_datetime(start_date)
|
|
@@ -186,12 +179,40 @@ def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, fin
|
|
| 186 |
data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
|
| 187 |
logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
|
| 188 |
|
| 189 |
-
data = data.set_index(time_col)
|
| 190 |
-
|
| 191 |
# Resample the data
|
| 192 |
data = data.resample(resample_freq).mean()
|
| 193 |
data.reset_index(inplace=True)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
|
| 196 |
processed_data = process_forecast_data(forecast_result, time_col)
|
| 197 |
processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
|
|
@@ -201,8 +222,13 @@ def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, fin
|
|
| 201 |
else:
|
| 202 |
merged_data = processed_data
|
| 203 |
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
# Create a temporary file and write the CSV data to it
|
| 208 |
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile:
|
|
@@ -277,7 +303,7 @@ def create_interface():
|
|
| 277 |
btn = gr.Button("Generate Forecast")
|
| 278 |
btn.click(
|
| 279 |
fn=full_forecast_pipeline,
|
| 280 |
-
inputs=[file_input, time_col_input, target_col_input, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown, gr.Checkbox(label="Merge Data", value=False)],
|
| 281 |
outputs=[output_csv, output_plot, download_button, error_output]
|
| 282 |
)
|
| 283 |
return iface
|
|
|
|
| 157 |
)
|
| 158 |
return fig
|
| 159 |
|
| 160 |
+
def full_forecast_pipeline(file_obj, time_col, target_col, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data, forecast_start_date, forecast_end_date) -> Tuple[str, object, str, str]:
|
| 161 |
"""
|
| 162 |
Full pipeline: loads the data, calls the forecast function, and then processes the data.
|
| 163 |
"""
|
|
|
|
| 172 |
# Sort the DataFrame by the time column
|
| 173 |
data = data.sort_values(by=time_col)
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
# Apply date range selection
|
| 176 |
if start_date and end_date:
|
| 177 |
start_datetime = pd.to_datetime(start_date)
|
|
|
|
| 179 |
data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
|
| 180 |
logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
|
| 181 |
|
|
|
|
|
|
|
| 182 |
# Resample the data
|
| 183 |
data = data.resample(resample_freq).mean()
|
| 184 |
data.reset_index(inplace=True)
|
| 185 |
|
| 186 |
+
# Convert forecast start and end dates to datetime
|
| 187 |
+
if forecast_start_date and forecast_end_date:
|
| 188 |
+
forecast_start_datetime = pd.to_datetime(forecast_start_date)
|
| 189 |
+
forecast_end_datetime = pd.to_datetime(forecast_end_date)
|
| 190 |
+
|
| 191 |
+
# Calculate the time difference
|
| 192 |
+
time_difference = forecast_end_datetime - forecast_start_datetime
|
| 193 |
+
|
| 194 |
+
# Calculate forecast horizon based on frequency
|
| 195 |
+
if freq == 'D':
|
| 196 |
+
forecast_horizon = time_difference.days
|
| 197 |
+
elif freq == 'W':
|
| 198 |
+
forecast_horizon = time_difference.days / 7
|
| 199 |
+
elif freq == 'M':
|
| 200 |
+
forecast_horizon = time_difference.days / 30.44 # Average days in a month
|
| 201 |
+
elif freq == 'Y':
|
| 202 |
+
forecast_horizon = time_difference.days / 365.25 # Average days in a year
|
| 203 |
+
elif 'min' in freq:
|
| 204 |
+
minutes = int(freq.replace('min', ''))
|
| 205 |
+
forecast_horizon = time_difference.total_seconds() / (minutes * 60)
|
| 206 |
+
elif 'H' in freq:
|
| 207 |
+
hours = int(freq.replace('H', ''))
|
| 208 |
+
forecast_horizon = time_difference.total_seconds() / (hours * 3600)
|
| 209 |
+
else:
|
| 210 |
+
raise ValueError("Unsupported frequency. Please select a valid frequency.")
|
| 211 |
+
|
| 212 |
+
forecast_horizon = int(forecast_horizon)
|
| 213 |
+
else:
|
| 214 |
+
raise ValueError("Forecast start and end dates must be provided.")
|
| 215 |
+
|
| 216 |
forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
|
| 217 |
processed_data = process_forecast_data(forecast_result, time_col)
|
| 218 |
processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
|
|
|
|
| 222 |
else:
|
| 223 |
merged_data = processed_data
|
| 224 |
|
| 225 |
+
# Filter forecast data based on forecast start and end dates
|
| 226 |
+
merged_data[time_col] = pd.to_datetime(merged_data[time_col]) # Ensure time_col is datetime
|
| 227 |
+
merged_data = merged_data[(merged_data[time_col] >= forecast_start_datetime) & (merged_data[time_col] <= forecast_end_datetime)]
|
| 228 |
+
logger.info(f"Forecast data filtered from {forecast_start_datetime} to {forecast_end_datetime}. Shape: {merged_data.shape}")
|
| 229 |
+
|
| 230 |
+
plot = create_plot(data, merged_data, time_col, target_col)
|
| 231 |
+
csv_data = merged_data.to_csv(index=False)
|
| 232 |
|
| 233 |
# Create a temporary file and write the CSV data to it
|
| 234 |
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile:
|
|
|
|
| 303 |
btn = gr.Button("Generate Forecast")
|
| 304 |
btn.click(
|
| 305 |
fn=full_forecast_pipeline,
|
| 306 |
+
inputs=[file_input, time_col_input, target_col_input, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown, gr.Checkbox(label="Merge Data", value=False), gr.Textbox(label="Forecast Start Date", placeholder="YYYY-MM-DD", value="2023-01-01"), gr.Textbox(label="Forecast End Date", placeholder="YYYY-MM-DD", value="2023-12-31")],
|
| 307 |
outputs=[output_csv, output_plot, download_button, error_output]
|
| 308 |
)
|
| 309 |
return iface
|