| import gradio as gr |
| import pandas as pd |
| import plotly.graph_objects as go |
| from plotly.subplots import make_subplots |
| import io |
| import os |
| import numpy as np |
| import yaml |
| import logging |
| import json |
| import csv |
| from datetime import datetime |
| from plotly.colors import n_colors |
| from nixtla import NixtlaClient |
| import tempfile |
| from typing import Tuple |
| from datetime import date |
| from datetime import time |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| nixtla_client = NixtlaClient(api_key='nixak-IzAtInwxiZNzvbdatMlOlak0IK6aLlUTJAvbQvnUzYSc45xuQHjqtMyOFYhg2IRIMphbFV3qGBYZbbvr') |
|
|
| |
| def load_data(file_obj): |
| """ |
| Loads data from different file formats using Pandas. |
| """ |
| try: |
| filename = file_obj.name |
| if filename.endswith('.csv'): |
| df = pd.read_csv(file_obj.name) |
| elif filename.endswith('.xlsx') or filename.endswith('.xls'): |
| df = pd.read_excel(file_obj.name) |
| elif filename.endswith('.json'): |
| df = pd.read_json(file_obj.name) |
| elif filename.endswith('.yaml') or filename.endswith('.yml'): |
| with open(file_obj.name, 'r') as f: |
| data = yaml.safe_load(f) |
| df = pd.DataFrame(data) |
| else: |
| raise ValueError("Unsupported file format") |
| print("DataFrame loaded successfully:") |
| print(df) |
| return df |
|
|
| except Exception as e: |
| logger.error(f"Error loading data: {e}", exc_info=True) |
| raise ValueError(f"Error loading data: {e}") |
|
|
| def forecast_nixtla(df, forecast_horizon, finetune_steps, freq, time_col, target_col): |
| """ |
| Function to call the Nixtla API directly. |
| """ |
| try: |
| |
| forecast = nixtla_client.forecast( |
| df=df, |
| h=forecast_horizon, |
| finetune_steps=finetune_steps, |
| time_col=time_col, |
| target_col=target_col, |
| freq=freq |
| ) |
| logger.info("Nixtla API call successful") |
| return forecast |
|
|
| except Exception as e: |
| logger.error(f"Error communicating with the forecasting API: {e}", exc_info=True) |
| raise ValueError(f"Error communicating with the forecasting API: {e}") |
|
|
| def process_forecast_data(forecast_data, time_col) -> pd.DataFrame: |
| """ |
| Process the forecast data to be more human-readable. |
| """ |
| try: |
| forecast_df = pd.DataFrame(forecast_data) |
| forecast_df[time_col] = pd.to_datetime(forecast_df[time_col]) |
| forecast_df[time_col] = forecast_df[time_col].dt.strftime('%Y-%m-%d %H:%M:%S') |
| return forecast_df |
|
|
| except Exception as e: |
| logger.error(f"Error processing forecast data: {e}", exc_info=True) |
| raise ValueError(f"Error processing forecast data: {e}") |
|
|
| def apply_zero_patterns(df: pd.DataFrame, forecast_df: pd.DataFrame, time_col: str, target_col: str) -> pd.DataFrame: |
| """ |
| Identifies patterns in the input data where the values are zero and applies those patterns to the forecast. |
| """ |
| try: |
| |
| df[time_col] = pd.to_datetime(df[time_col]) |
| forecast_df[time_col] = pd.to_datetime(forecast_df[time_col]) |
|
|
| |
| df['hour'] = df[time_col].dt.hour |
| df['dayofweek'] = df[time_col].dt.dayofweek |
|
|
| |
| hourly_avg = df.groupby('hour')[target_col].mean() |
| daily_avg = df.groupby('dayofweek')[target_col].mean() |
|
|
| |
| forecast_value_col = [col for col in forecast_df.columns if col != time_col][0] |
|
|
| |
| forecast_df['hour'] = forecast_df[time_col].apply(lambda x: x.hour if isinstance(x, datetime) else None) |
| forecast_df['dayofweek'] = forecast_df[time_col].apply(lambda x: x.dayofweek if isinstance(x, datetime) else None) |
|
|
| forecast_df = forecast_df.dropna(subset=['hour', 'dayofweek']) |
|
|
| |
| forecast_df[forecast_value_col] = forecast_df.apply( |
| lambda row: 0 if hourly_avg[row['hour']] < 1 or daily_avg[row['dayofweek']] < 1 else max(0, row[forecast_value_col]), |
| axis=1 |
| ) |
| forecast_df.drop(columns=['hour', 'dayofweek'], inplace=True) |
| return forecast_df |
| except Exception as e: |
| forecast_df[[forecast_value_col]] = 0 |
| logger.error(f"Error applying zero patterns: {e}", exc_info=True) |
| raise ValueError(f"Error applying zero patterns: {e}") |
|
|
| def create_plot(data, forecast_data, time_col, target_col): |
| """ |
| Creates a Plotly plot of the time series data and forecast. |
| """ |
| fig = go.Figure() |
|
|
| |
| fig.add_trace(go.Scatter( |
| x=data[time_col], |
| y=data[target_col], |
| mode='lines', |
| name='Historical Data' |
| )) |
|
|
| |
| if forecast_data is not None: |
| forecast_value_col = [col for col in forecast_data.columns if col != time_col][0] |
| fig.add_trace(go.Scatter( |
| x=forecast_data[time_col], |
| y=forecast_data[forecast_value_col], |
| mode='lines', |
| name='Forecast' |
| )) |
|
|
| fig.update_layout( |
| title='Time Series Data and Forecast', |
| xaxis_title='Time', |
| yaxis_title='Value', |
| template='plotly_white', |
| hovermode="x unified" |
| ) |
| return fig |
|
|
| def full_forecast_pipeline(file_obj, time_col, target_col, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data, forecast_start_date=None, forecast_end_date=None) -> Tuple[str, object, str, str]: |
| """ |
| Full pipeline: loads the data, calls the forecast function, and then processes the data. |
| """ |
| try: |
| data = load_data(file_obj) |
| if not isinstance(data, pd.DataFrame): |
| return "Error loading data. Please check the file format and content.", None, None, None |
|
|
| |
| data[time_col] = pd.to_datetime(data[time_col]) |
|
|
| |
| data = data.sort_values(by=time_col) |
|
|
| |
| min_date = data[time_col].min() |
| max_date = data[time_col].max() |
|
|
| |
| if not start_date: |
| start_date = min_date.strftime('%Y-%m-%d') |
| logger.info(f"start_date not provided, defaulting to min_date from data: {start_date}") |
| if not end_date: |
| end_date = max_date.strftime('%Y-%m-%d') |
| logger.info(f"end_date not provided, defaulting to max_date from data: {end_date}") |
|
|
| |
| if not start_time: |
| start_time = "00:00" |
| if not end_time: |
| end_time = "23:59" |
|
|
| |
| if forecast_start_date is None: |
| forecast_start_date = (pd.to_datetime(end_date) + pd.Timedelta(days=1)).strftime('%Y-%m-%d') |
| logger.info(f"forecast_start_date not provided, defaulting to day after end_date: {forecast_start_date}") |
|
|
| |
| if forecast_end_date is None: |
| default_forecast_horizon_days = 10 |
| forecast_end_date = (pd.to_datetime(forecast_start_date) + pd.Timedelta(days=default_forecast_horizon_days)).strftime('%Y-%m-%d') |
| logger.info(f"forecast_end_date not provided, defaulting to {default_forecast_horizon_days} days after forecast_start_date: {forecast_end_date}") |
|
|
|
|
| |
| data = data.fillna(0) |
|
|
| |
| start_datetime = pd.to_datetime(start_date) |
| end_datetime = pd.to_datetime(end_date) |
| data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)] |
| logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}") |
|
|
|
|
| data = data.set_index(time_col) |
|
|
| |
| data = data.resample(resample_freq).mean() |
| data.reset_index(inplace=True) |
|
|
| |
| forecast_horizon = 10 |
| if forecast_end_date: |
| historical_end_date = pd.to_datetime(end_date) if end_date else data[time_col].max() |
| forecast_end_datetime = pd.to_datetime(forecast_end_date) |
| day_difference = (forecast_end_datetime - historical_end_date).days |
| if day_difference <= 0: |
| raise ValueError("Forecast end date must be after the historical data end date.") |
|
|
| |
| if freq == 'h': |
| forecast_horizon = day_difference * 24 |
| elif freq == '2h': |
| forecast_horizon = day_difference * 12 |
| elif freq == '3h': |
| forecast_horizon = day_difference * 8 |
| elif freq == '4h': |
| forecast_horizon = day_difference * 6 |
| elif freq == '5h': |
| forecast_horizon = day_difference * 4.8 |
| elif freq == '6h': |
| forecast_horizon = day_difference * 4 |
| elif freq == '12h': |
| forecast_horizon = day_difference * 2 |
| elif freq == '30min': |
| forecast_horizon = day_difference * 48 |
| elif freq == '15min': |
| forecast_horizon = day_difference * 96 |
| elif freq == 'D': |
| forecast_horizon = day_difference |
| elif freq == 'W': |
| forecast_horizon = day_difference / 7 |
| elif freq == 'M': |
| forecast_horizon = day_difference / 30 |
| elif freq == 'Y': |
| forecast_horizon = day_difference / 365 |
| else: |
| forecast_horizon = day_difference |
|
|
| forecast_horizon = max(1, int(round(forecast_horizon))) |
|
|
|
|
| forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col) |
| processed_data = process_forecast_data(forecast_result, time_col) |
| processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col) |
|
|
| |
| forecast_start_datetime = pd.to_datetime(forecast_start_date) |
| forecast_end_datetime = pd.to_datetime(forecast_end_date) |
| processed_data = processed_data[(processed_data[time_col] >= forecast_start_datetime) & (processed_data[time_col] <= forecast_end_datetime)] |
| logger.info(f"Forecast data filtered from {forecast_start_datetime} to {forecast_end_date}. Shape: {processed_data.shape}") |
|
|
|
|
| if merge_data: |
| merged_data = pd.merge(data.reset_index(), processed_data, on=time_col, how='inner') |
| else: |
| merged_data = processed_data |
|
|
| plot = create_plot(data, processed_data, time_col, target_col) |
| csv_data = processed_data.to_csv(index=False) |
|
|
| |
| with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile: |
| tmpfile.write(csv_data) |
| csv_path = tmpfile.name |
|
|
| return csv_data, plot, csv_path, None |
|
|
| except ValueError as e: |
| return f"Error: {e}", None, None, None |
| except Exception as e: |
| logger.exception("An unexpected error occurred:") |
| return f"Error: An unexpected error occurred: {e}", None, None, None |
|
|
| def get_column_names(file_obj): |
| """ |
| Extracts column names from the uploaded file. |
| """ |
| try: |
| df = load_data(file_obj) |
| columns = df.columns.tolist() |
| print(f"Column names: {columns}") |
| return columns |
| except Exception as e: |
| logger.error(f"Error in get_column_names: {e}", exc_info=True) |
| print(f"Error in get_column_names: {e}") |
| return [] |
|
|
| def update_dropdown_choices(file_obj): |
| """ |
| Updates the dropdown choices based on the uploaded file. |
| """ |
| try: |
| columns = get_column_names(file_obj) |
| return gr.Dropdown.update(choices=columns), gr.Dropdown.update(choices=columns) |
| except Exception as e: |
| logger.error(f"Error updating dropdown choices: {e}", exc_info=True) |
| return gr.Dropdown.update(choices=[]), gr.Dropdown.update(choices=[]) |
|
|
| def create_interface(): |
| with gr.Blocks() as iface: |
| gr.Markdown(""" |
| # CP360 App |
| Upload your time series data, select the appropriate columns, and generate a forecast! |
| """) |
|
|
| file_input = gr.File(label="Upload Time Series Data (CSV, Excel, JSON, YAML)") |
|
|
| with gr.Row(): |
| time_col_input = gr.Textbox(label="Time Column", placeholder="Enter time column name") |
| target_col_input = gr.Textbox(label="Target Column", placeholder="Enter target column name") |
|
|
| with gr.Row(): |
| forecast_horizon_input = gr.Number(label="Forecast Horizon", value=10, visible=False) |
| finetune_steps_input = gr.Number(label="Finetune Steps", value=100) |
| freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'h', '2h', '3h', '4h', '5h', '6h', '12h', 'D', 'W', 'M', 'Y'], label="Frequency", value='D') |
|
|
| with gr.Column(): |
| with gr.Row(): |
| start_date_input = gr.Textbox(label="Historical Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-01-01") |
| start_time_input = gr.Textbox(label="Start Time (HH:MM)", placeholder="HH:MM", value="00:00", visible=False) |
| with gr.Row(): |
| end_date_input = gr.Textbox(label="Historical End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-12-31") |
| end_time_input = gr.Textbox(label="End Time (HH:MM)", placeholder="HH:MM", value="23:59", visible=False) |
| with gr.Row(): |
| forecast_start_date_input = gr.Textbox(label="Forecast Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD") |
| forecast_end_date_input = gr.Textbox(label="Forecast End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD") |
|
|
|
|
| resample_freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'h', '2h', '3h', '4h', '5h', '6h', '12h', 'D', 'W', 'M', 'Y'], label="Resample Frequency", value='D') |
|
|
| output_csv = gr.Textbox(label="Forecast Data (CSV)") |
| output_plot = gr.Plot(label="Time Series Plot") |
| download_button = gr.File(label="Download Forecast Data as CSV") |
| error_output = gr.Markdown(label="Error Messages") |
|
|
| |
| btn = gr.Button("Generate Forecast") |
| btn.click( |
| fn=full_forecast_pipeline, |
| inputs=[file_input, time_col_input, target_col_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown, gr.Checkbox(label="Merge Data", value=False), forecast_start_date_input, forecast_end_date_input], |
| outputs=[output_csv, output_plot, download_button, error_output] |
| ) |
| return iface |
|
|
| iface = create_interface() |
| iface.launch() |