File size: 14,135 Bytes
4d6fc82 0d60395 4d6fc82 4c65297 4d6fc82 4c65297 4d6fc82 2043d47 4d6fc82 4c65297 4d6fc82 4c65297 4d6fc82 555bf89 4d6fc82 4c65297 4d6fc82 555bf89 4d6fc82 0d60395 4d6fc82 4c65297 555bf89 2043d47 555bf89 4622a9e 555bf89 4622a9e 555bf89 4622a9e 555bf89 4622a9e 2043d47 4c65297 4d6fc82 555bf89 4d6fc82 4c65297 4d6fc82 555bf89 4d6fc82 c9a20e0 4d6fc82 0d60395 4c65297 0d60395 4c65297 0d60395 4c65297 0d60395 4c65297 0d60395 4c65297 4d6fc82 0d60395 4d6fc82 0cbe015 555bf89 0d60395 f1f924f 4c65297 555bf89 4d6fc82 f1f924f 0cbe015 c9a20e0 4d6fc82 0d60395 2043d47 c9a20e0 4d6fc82 555bf89 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 | import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import io
import os
import numpy as np
import yaml
import logging
import json
import csv
from datetime import datetime
from plotly.colors import n_colors
from nixtla import NixtlaClient
import tempfile
from typing import Tuple
from datetime import date
from datetime import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize NixtlaClient with your API key
nixtla_client = NixtlaClient(api_key='nixak-IzAtInwxiZNzvbdatMlOlak0IK6aLlUTJAvbQvnUzYSc45xuQHjqtMyOFYhg2IRIMphbFV3qGBYZbbvr')
# --- Utility Functions ---
def load_data(file_obj):
"""
Loads data from different file formats using Pandas.
"""
try:
filename = file_obj.name
if filename.endswith('.csv'):
df = pd.read_csv(file_obj.name)
elif filename.endswith('.xlsx') or filename.endswith('.xls'):
df = pd.read_excel(file_obj.name)
elif filename.endswith('.json'):
df = pd.read_json(file_obj.name)
elif filename.endswith('.yaml') or filename.endswith('.yml'):
with open(file_obj.name, 'r') as f:
data = yaml.safe_load(f)
df = pd.DataFrame(data)
else:
raise ValueError("Unsupported file format")
print("DataFrame loaded successfully:")
print(df)
return df
except Exception as e:
logger.error(f"Error loading data: {e}", exc_info=True)
raise ValueError(f"Error loading data: {e}")
def forecast_nixtla(df, forecast_horizon, finetune_steps, freq, time_col, target_col):
"""
Function to call the Nixtla API directly.
"""
try:
# Make forecast using NixtlaClient
forecast = nixtla_client.forecast(
df=df,
h=forecast_horizon,
finetune_steps=finetune_steps,
time_col=time_col,
target_col=target_col,
freq=freq
)
logger.info("Nixtla API call successful")
return forecast
except Exception as e:
logger.error(f"Error communicating with the forecasting API: {e}", exc_info=True)
raise ValueError(f"Error communicating with the forecasting API: {e}")
def process_forecast_data(forecast_data, time_col) -> pd.DataFrame:
"""
Process the forecast data to be more human-readable.
"""
try:
forecast_df = pd.DataFrame(forecast_data)
forecast_df[time_col] = pd.to_datetime(forecast_df[time_col])
forecast_df[time_col] = forecast_df[time_col].dt.strftime('%Y-%m-%d %H:%M:%S')
return forecast_df
except Exception as e:
logger.error(f"Error processing forecast data: {e}", exc_info=True)
raise ValueError(f"Error processing forecast data: {e}")
def apply_zero_patterns(df: pd.DataFrame, forecast_df: pd.DataFrame, time_col: str, target_col: str) -> pd.DataFrame:
"""
Identifies patterns in the input data where the values are zero and applies those patterns to the forecast.
"""
try:
# Convert time column to datetime
df[time_col] = pd.to_datetime(df[time_col])
forecast_df[time_col] = pd.to_datetime(forecast_df[time_col])
# Extract hour and day of week from the start_time
df['hour'] = df[time_col].dt.hour
df['dayofweek'] = df[time_col].dt.dayofweek # Monday=0, Sunday=6
# Calculate the average value for each hour and day of week
hourly_avg = df.groupby('hour')[target_col].mean()
daily_avg = df.groupby('dayofweek')[target_col].mean()
# Get the forecast value column name
forecast_value_col = [col for col in forecast_df.columns if col != time_col][0]
# Apply the learned patterns to the forecast
forecast_df['hour'] = forecast_df[time_col].apply(lambda x: x.hour if isinstance(x, datetime) else None)
forecast_df['dayofweek'] = forecast_df[time_col].apply(lambda x: x.dayofweek if isinstance(x, datetime) else None)
forecast_df = forecast_df.dropna(subset=['hour', 'dayofweek'])
# Nullify forecast values based on historical patterns
forecast_df[forecast_value_col] = forecast_df.apply(
lambda row: 0 if hourly_avg[row['hour']] < 1 or daily_avg[row['dayofweek']] < 1 else max(0, row[forecast_value_col]),
axis=1
)
forecast_df.drop(columns=['hour', 'dayofweek'], inplace=True)
return forecast_df
except Exception as e:
forecast_df[[forecast_value_col]] = 0
logger.error(f"Error applying zero patterns: {e}", exc_info=True)
raise ValueError(f"Error applying zero patterns: {e}")
def create_plot(data, forecast_data, time_col, target_col):
"""
Creates a Plotly plot of the time series data and forecast.
"""
fig = go.Figure()
# Historical Data
fig.add_trace(go.Scatter(
x=data[time_col],
y=data[target_col],
mode='lines',
name='Historical Data'
))
# Forecast Data
if forecast_data is not None:
forecast_value_col = [col for col in forecast_data.columns if col != time_col][0]
fig.add_trace(go.Scatter(
x=forecast_data[time_col],
y=forecast_data[forecast_value_col],
mode='lines',
name='Forecast'
))
fig.update_layout(
title='Time Series Data and Forecast',
xaxis_title='Time',
yaxis_title='Value',
template='plotly_white',
hovermode="x unified"
)
return fig
def full_forecast_pipeline(file_obj, time_col, target_col, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data, forecast_start_date, forecast_end_date) -> Tuple[str, object, str, str]:
"""
Full pipeline: loads the data, calls the forecast function, and then processes the data.
"""
try:
data = load_data(file_obj)
if not isinstance(data, pd.DataFrame):
return "Error loading data. Please check the file format and content.", None, None, None
# Convert time column to datetime
data[time_col] = pd.to_datetime(data[time_col])
# Sort the DataFrame by the time column
data = data.sort_values(by=time_col)
# Get min and max dates from the data
min_date = data[time_col].min().strftime('%Y-%m-%d')
max_date = data[time_col].max().strftime('%Y-%m-%d')
# Fill missing values with 0
data = data.fillna(0)
# Apply date range selection for historical data
if start_date and end_date:
start_datetime = pd.to_datetime(start_date)
end_datetime = pd.to_datetime(end_date)
data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
data = data.set_index(time_col)
# Resample the data
data = data.resample(resample_freq).mean()
data.reset_index(inplace=True)
# Calculate forecast horizon if forecast_end_date is provided
forecast_horizon = 10 # Default forecast horizon if forecast_end_date is not provided or calculation fails
if forecast_end_date:
historical_end_date = pd.to_datetime(end_date) if end_date else data[time_col].max()
forecast_end_datetime = pd.to_datetime(forecast_end_date)
day_difference = (forecast_end_datetime - historical_end_date).days
if day_difference <= 0:
raise ValueError("Forecast end date must be after the historical data end date.")
# Adjust forecast_horizon based on frequency
if freq == 'H':
forecast_horizon = day_difference * 24
elif freq == '30min':
forecast_horizon = day_difference * 48
elif freq == '15min':
forecast_horizon = day_difference * 96
elif freq == 'D':
forecast_horizon = day_difference
elif freq == 'W': # Approximation: 7 days in a week
forecast_horizon = day_difference / 7
elif freq == 'M': # Approximation: 30 days in a month
forecast_horizon = day_difference / 30
elif freq == 'Y': # Approximation: 365 days in a year
forecast_horizon = day_difference / 365
else:
forecast_horizon = day_difference # Default to days if frequency is not recognized
forecast_horizon = max(1, int(round(forecast_horizon))) # Ensure forecast_horizon is at least 1 and integer
forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
processed_data = process_forecast_data(forecast_result, time_col)
processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
# Apply forecast date range selection
if forecast_start_date and forecast_end_date:
forecast_start_datetime = pd.to_datetime(forecast_start_date)
forecast_end_datetime = pd.to_datetime(forecast_end_date)
processed_data = processed_data[(processed_data[time_col] >= forecast_start_datetime) & (processed_data[time_col] <= forecast_end_datetime)]
logger.info(f"Forecast data filtered from {forecast_start_datetime} to {forecast_end_datetime}. Shape: {processed_data.shape}")
if merge_data:
merged_data = pd.merge(data.reset_index(), processed_data, on=time_col, how='inner')
else:
merged_data = processed_data
plot = create_plot(data, processed_data, time_col, target_col)
csv_data = processed_data.to_csv(index=False)
# Create a temporary file and write the CSV data to it
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile:
tmpfile.write(csv_data)
csv_path = tmpfile.name
return csv_data, plot, csv_path, None
except ValueError as e:
return f"Error: {e}", None, None, None
except Exception as e:
logger.exception("An unexpected error occurred:")
return f"Error: An unexpected error occurred: {e}", None, None, None
def get_column_names(file_obj):
"""
Extracts column names from the uploaded file.
"""
try:
df = load_data(file_obj)
columns = df.columns.tolist()
print(f"Column names: {columns}")
return columns
except Exception as e:
logger.error(f"Error in get_column_names: {e}", exc_info=True)
print(f"Error in get_column_names: {e}")
return []
def update_dropdown_choices(file_obj):
"""
Updates the dropdown choices based on the uploaded file.
"""
try:
columns = get_column_names(file_obj)
return gr.Dropdown.update(choices=columns), gr.Dropdown.update(choices=columns)
except Exception as e:
logger.error(f"Error updating dropdown choices: {e}", exc_info=True)
return gr.Dropdown.update(choices=[]), gr.Dropdown.update(choices=[])
def create_interface():
with gr.Blocks() as iface:
gr.Markdown("""
# CP360 App
Upload your time series data, select the appropriate columns, and generate a forecast!
""")
file_input = gr.File(label="Upload Time Series Data (CSV, Excel, JSON, YAML)")
with gr.Row():
time_col_input = gr.Textbox(label="Time Column", placeholder="Enter time column name")
target_col_input = gr.Textbox(label="Target Column", placeholder="Enter target column name")
with gr.Row():
forecast_horizon_input = gr.Number(label="Forecast Horizon", value=10, visible=False) # Hide forecast horizon input
finetune_steps_input = gr.Number(label="Finetune Steps", value=100)
freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Frequency", value='D')
with gr.Column(): # Group date inputs in a column
with gr.Row():
start_date_input = gr.Textbox(label="Historical Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-01-01")
start_time_input = gr.Textbox(label="Start Time (HH:MM)", placeholder="HH:MM", value="00:00", visible=False) # Hide start time input
with gr.Row():
end_date_input = gr.Textbox(label="Historical End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-12-31")
end_time_input = gr.Textbox(label="End Time (HH:MM)", placeholder="HH:MM", value="23:59", visible=False) # Hide end time input
with gr.Row():
forecast_start_date_input = gr.Textbox(label="Forecast Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD")
forecast_end_date_input = gr.Textbox(label="Forecast End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD")
resample_freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Resample Frequency", value='D')
output_csv = gr.Textbox(label="Forecast Data (CSV)")
output_plot = gr.Plot(label="Time Series Plot")
download_button = gr.File(label="Download Forecast Data as CSV")
error_output = gr.Markdown(label="Error Messages")
# Button to trigger the full pipeline
btn = gr.Button("Generate Forecast")
btn.click(
fn=full_forecast_pipeline,
inputs=[file_input, time_col_input, target_col_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown, gr.Checkbox(label="Merge Data", value=False), forecast_start_date_input, forecast_end_date_input],
outputs=[output_csv, output_plot, download_button, error_output]
)
return iface
iface = create_interface()
iface.launch() |