anujkum0x commited on
Commit
4d6fc82
·
verified ·
1 Parent(s): 42be1b5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +287 -0
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import plotly.graph_objects as go
4
+ from plotly.subplots import make_subplots
5
+ import io
6
+ import os
7
+ import numpy as np
8
+ import yaml
9
+ import logging
10
+ import json
11
+ import csv
12
+ from datetime import datetime
13
+ from plotly.colors import n_colors
14
+ from nixtla import NixtlaClient
15
+ import tempfile
16
+ from typing import Tuple
17
+ from datetime import date
18
+ from datetime import time
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Initialize NixtlaClient with your API key
25
+ nixtla_client = NixtlaClient(api_key='nixak-IzAtInwxiZNzvbdatMlOlak0IK6aLlUTJAvbQvnUzYSc45xuQHjqtMyOFYhg2IRIMphbFV3qGBYZbbvr')
26
+
27
+ # --- Utility Functions ---
28
+ def load_data(file_obj):
29
+ """
30
+ Loads data from different file formats using Pandas.
31
+ """
32
+ try:
33
+ filename = file_obj.name
34
+ if filename.endswith('.csv'):
35
+ df = pd.read_csv(file_obj.name)
36
+ elif filename.endswith('.xlsx') or filename.endswith('.xls'):
37
+ df = pd.read_excel(file_obj.name)
38
+ elif filename.endswith('.json'):
39
+ df = pd.read_json(file_obj.name)
40
+ elif filename.endswith('.yaml') or filename.endswith('.yml'):
41
+ with open(file_obj.name, 'r') as f:
42
+ data = yaml.safe_load(f)
43
+ df = pd.DataFrame(data)
44
+ else:
45
+ raise ValueError("Unsupported file format")
46
+ print("DataFrame loaded successfully:")
47
+ print(df)
48
+ return df
49
+
50
+ except Exception as e:
51
+ logger.error(f"Error loading data: {e}", exc_info=True)
52
+ raise ValueError(f"Error loading data: {e}")
53
+
54
+ def forecast_nixtla(df, forecast_horizon, finetune_steps, freq, time_col, target_col):
55
+ """
56
+ Function to call the Nixtla API directly.
57
+ """
58
+ try:
59
+ # Make forecast using NixtlaClient
60
+ forecast = nixtla_client.forecast(
61
+ df=df,
62
+ h=forecast_horizon,
63
+ finetune_steps=finetune_steps,
64
+ time_col=time_col,
65
+ target_col=target_col,
66
+ freq=freq
67
+ )
68
+ logger.info("Nixtla API call successful")
69
+ return forecast
70
+
71
+ except Exception as e:
72
+ logger.error(f"Error communicating with the forecasting API: {e}", exc_info=True)
73
+ raise ValueError(f"Error communicating with the forecasting API: {e}")
74
+
75
+ def process_forecast_data(forecast_data, time_col) -> pd.DataFrame:
76
+ """
77
+ Process the forecast data to be more human-readable.
78
+ """
79
+ try:
80
+ forecast_df = pd.DataFrame(forecast_data)
81
+ forecast_df[time_col] = pd.to_datetime(forecast_df[time_col])
82
+ forecast_df[time_col] = forecast_df[time_col].dt.strftime('%Y-%m-%d %H:%M:%S')
83
+ return forecast_df
84
+
85
+ except Exception as e:
86
+ logger.error(f"Error processing forecast data: {e}", exc_info=True)
87
+ raise ValueError(f"Error processing forecast data: {e}")
88
+
89
+ def apply_zero_patterns(df: pd.DataFrame, forecast_df: pd.DataFrame, time_col: str, target_col: str) -> pd.DataFrame:
90
+ """
91
+ Identifies patterns in the input data where the values are zero and applies those patterns to the forecast.
92
+ """
93
+ try:
94
+ # Convert time column to datetime
95
+ df[time_col] = pd.to_datetime(df[time_col])
96
+ forecast_df[time_col] = pd.to_datetime(forecast_df[time_col])
97
+
98
+ # Extract hour and day of week from the start_time
99
+ df['hour'] = df[time_col].dt.hour
100
+ df['dayofweek'] = df[time_col].dt.dayofweek # Monday=0, Sunday=6
101
+
102
+ # Calculate the average value for each hour and day of week
103
+ hourly_avg = df.groupby('hour')[target_col].mean()
104
+ daily_avg = df.groupby('dayofweek')[target_col].mean()
105
+
106
+ # Get the forecast value column name
107
+ forecast_value_col = [col for col in forecast_df.columns if col != time_col][0]
108
+
109
+ # Apply the learned patterns to the forecast
110
+ forecast_df['hour'] = forecast_df[time_col].apply(lambda x: x.hour if isinstance(x, datetime) else None)
111
+ forecast_df['dayofweek'] = forecast_df[time_col].apply(lambda x: x.dayofweek if isinstance(x, datetime) else None)
112
+
113
+ forecast_df = forecast_df.dropna(subset=['hour', 'dayofweek'])
114
+
115
+ # Nullify forecast values based on historical patterns
116
+ forecast_df[forecast_value_col] = forecast_df.apply(
117
+ lambda row: 0 if hourly_avg[row['hour']] < 1 or daily_avg[row['dayofweek']] < 1 else max(0, row[forecast_value_col]),
118
+ axis=1
119
+ )
120
+ forecast_df.drop(columns=['hour', 'dayofweek'], inplace=True)
121
+ return forecast_df
122
+ except Exception as e:
123
+ forecast_df[[forecast_value_col]] = 0
124
+ logger.error(f"Error applying zero patterns: {e}", exc_info=True)
125
+ raise ValueError(f"Error applying zero patterns: {e}")
126
+
127
+ def create_plot(data, forecast_data, time_col, target_col):
128
+ """
129
+ Creates a Plotly plot of the time series data and forecast.
130
+ """
131
+ fig = go.Figure()
132
+
133
+ # Historical Data
134
+ fig.add_trace(go.Scatter(
135
+ x=data[time_col],
136
+ y=data[target_col],
137
+ mode='lines',
138
+ name='Historical Data'
139
+ ))
140
+
141
+ # Forecast Data
142
+ if forecast_data is not None:
143
+ forecast_value_col = [col for col in forecast_data.columns if col != time_col][0]
144
+ fig.add_trace(go.Scatter(
145
+ x=forecast_data[time_col],
146
+ y=forecast_data[forecast_value_col],
147
+ mode='lines',
148
+ name='Forecast'
149
+ ))
150
+
151
+ fig.update_layout(
152
+ title='Time Series Data and Forecast',
153
+ xaxis_title='Time',
154
+ yaxis_title='Value',
155
+ template='plotly_white',
156
+ hovermode="x unified"
157
+ )
158
+ return fig
159
+
160
+ def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data) -> Tuple[str, object, str, str]:
161
+ """
162
+ Full pipeline: loads the data, calls the forecast function, and then processes the data.
163
+ """
164
+ try:
165
+ data = load_data(file_obj)
166
+ if not isinstance(data, pd.DataFrame):
167
+ return "Error loading data. Please check the file format and content.", None, None, None
168
+
169
+ if time_col not in data.columns or target_col not in data.columns:
170
+ return "Error: Timestamp column or Value column not found in the data.", None, None, None
171
+
172
+ # Convert time column to datetime
173
+ data[time_col] = pd.to_datetime(data[time_col])
174
+
175
+ # Sort the DataFrame by the time column
176
+ data = data.sort_values(by=time_col)
177
+
178
+ # Get min and max dates from the data
179
+ min_date = data[time_col].min().strftime('%Y-%m-%d')
180
+ max_date = data[time_col].max().strftime('%Y-%m-%d')
181
+
182
+ # Fill missing values with 0
183
+ data = data.fillna(0)
184
+
185
+ # Apply date range selection
186
+ if start_date and end_date:
187
+ start_datetime = pd.to_datetime(start_date)
188
+ end_datetime = pd.to_datetime(end_date)
189
+ data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
190
+ logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
191
+
192
+ data = data.set_index(time_col)
193
+
194
+ # Resample the data
195
+ data = data.resample(resample_freq).mean()
196
+ data.reset_index(inplace=True)
197
+
198
+ forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
199
+ processed_data = process_forecast_data(forecast_result, time_col)
200
+ processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
201
+
202
+ if merge_data:
203
+ merged_data = pd.merge(data.reset_index(), processed_data, on=time_col, how='inner')
204
+ else:
205
+ merged_data = processed_data
206
+
207
+ plot = create_plot(data, processed_data, time_col, target_col)
208
+ csv_data = processed_data.to_csv(index=False)
209
+
210
+ # Create a temporary file and write the CSV data to it
211
+ with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile:
212
+ tmpfile.write(csv_data)
213
+ csv_path = tmpfile.name
214
+
215
+ return processed_data.to_html(index=False), plot, csv_path, None
216
+
217
+ except ValueError as e:
218
+ return f"Error: {e}", None, None, None
219
+ except Exception as e:
220
+ logger.exception("An unexpected error occurred:")
221
+ return f"Error: An unexpected error occurred: {e}", None, None, None
222
+
223
+ def get_column_names(file_obj):
224
+ """
225
+ Extracts column names from the uploaded file.
226
+ """
227
+ try:
228
+ df = load_data(file_obj)
229
+ columns = df.columns.tolist()
230
+ print(f"Column names: {columns}")
231
+ return columns
232
+ except Exception as e:
233
+ logger.error(f"Error in get_column_names: {e}", exc_info=True)
234
+ print(f"Error in get_column_names: {e}")
235
+ return []
236
+
237
+ def create_interface():
238
+ with gr.Blocks() as iface:
239
+ gr.Markdown("""
240
+ # CP360 App
241
+ Upload your time series data, select the appropriate columns, and generate a forecast!
242
+ """)
243
+
244
+ file_input = gr.File(label="Upload Time Series Data (CSV, Excel, JSON, YAML)")
245
+ with gr.Row():
246
+ time_col_dropdown = gr.Dropdown(choices=[], label="Select Timestamp Column")
247
+ target_col_dropdown = gr.Dropdown(choices=[], label="Select Value Column")
248
+
249
+ def update_dropdown_choices(file_obj):
250
+ columns = get_column_names(file_obj)
251
+ return gr.update(choices=columns), gr.update(choices=columns)
252
+
253
+ file_input.upload(
254
+ update_dropdown_choices,
255
+ [file_input],
256
+ [time_col_dropdown, target_col_dropdown]
257
+ )
258
+
259
+ with gr.Row():
260
+ forecast_horizon_input = gr.Number(label="Forecast Horizon", value=10)
261
+ finetune_steps_input = gr.Number(label="Finetune Steps", value=100)
262
+ freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Frequency", value='D')
263
+
264
+ with gr.Row():
265
+ start_date_input = gr.Textbox(label="Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-01-01")
266
+ start_time_input = gr.Textbox(label="Start Time (HH:MM)", placeholder="HH:MM", value="00:00")
267
+ end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-12-31")
268
+ end_time_input = gr.Textbox(label="End Time (HH:MM)", placeholder="HH:MM", value="23:59")
269
+
270
+ resample_freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Resample Frequency", value='D')
271
+
272
+ output_html = gr.HTML(label="Forecast Data")
273
+ output_plot = gr.Plot(label="Time Series Plot")
274
+ download_button = gr.File(label="Download Forecast Data as CSV")
275
+ error_output = gr.Markdown(label="Error Messages")
276
+
277
+ # Button to trigger the full pipeline
278
+ btn = gr.Button("Generate Forecast")
279
+ btn.click(
280
+ fn=full_forecast_pipeline,
281
+ inputs=[file_input, time_col_dropdown, target_col_dropdown, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown],
282
+ outputs=[output_html, output_plot, download_button, error_output]
283
+ )
284
+ return iface
285
+
286
+ iface = create_interface()
287
+ iface.launch()