| | import gradio as gr |
| | import pandas as pd |
| | import plotly.graph_objects as go |
| | from plotly.subplots import make_subplots |
| | import io |
| | import os |
| | import numpy as np |
| | import yaml |
| | import logging |
| | import json |
| | import csv |
| | from datetime import datetime |
| | from plotly.colors import n_colors |
| | from nixtla import NixtlaClient |
| | import tempfile |
| | from typing import Tuple |
| | from datetime import date |
| | from datetime import time |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | nixtla_client = NixtlaClient(api_key='nixak-IzAtInwxiZNzvbdatMlOlak0IK6aLlUTJAvbQvnUzYSc45xuQHjqtMyOFYhg2IRIMphbFV3qGBYZbbvr') |
| |
|
| | |
| | def load_data(file_obj): |
| | """ |
| | Loads data from different file formats using Pandas. |
| | """ |
| | try: |
| | filename = file_obj.name |
| | if filename.endswith('.csv'): |
| | df = pd.read_csv(file_obj.name) |
| | elif filename.endswith('.xlsx') or filename.endswith('.xls'): |
| | df = pd.read_excel(file_obj.name) |
| | elif filename.endswith('.json'): |
| | df = pd.read_json(file_obj.name) |
| | elif filename.endswith('.yaml') or filename.endswith('.yml'): |
| | with open(file_obj.name, 'r') as f: |
| | data = yaml.safe_load(f) |
| | df = pd.DataFrame(data) |
| | else: |
| | raise ValueError("Unsupported file format") |
| | print("DataFrame loaded successfully:") |
| | print(df) |
| | return df |
| |
|
| | except Exception as e: |
| | logger.error(f"Error loading data: {e}", exc_info=True) |
| | raise ValueError(f"Error loading data: {e}") |
| |
|
| | def forecast_nixtla(df, forecast_horizon, finetune_steps, freq, time_col, target_col): |
| | """ |
| | Function to call the Nixtla API directly. |
| | """ |
| | try: |
| | |
| | forecast = nixtla_client.forecast( |
| | df=df, |
| | h=forecast_horizon, |
| | finetune_steps=finetune_steps, |
| | time_col=time_col, |
| | target_col=target_col, |
| | freq=freq |
| | ) |
| | logger.info("Nixtla API call successful") |
| | return forecast |
| |
|
| | except Exception as e: |
| | logger.error(f"Error communicating with the forecasting API: {e}", exc_info=True) |
| | raise ValueError(f"Error communicating with the forecasting API: {e}") |
| |
|
| | def process_forecast_data(forecast_data, time_col) -> pd.DataFrame: |
| | """ |
| | Process the forecast data to be more human-readable. |
| | """ |
| | try: |
| | forecast_df = pd.DataFrame(forecast_data) |
| | forecast_df[time_col] = pd.to_datetime(forecast_df[time_col]) |
| | forecast_df[time_col] = forecast_df[time_col].dt.strftime('%Y-%m-%d %H:%M:%S') |
| | return forecast_df |
| |
|
| | except Exception as e: |
| | logger.error(f"Error processing forecast data: {e}", exc_info=True) |
| | raise ValueError(f"Error processing forecast data: {e}") |
| |
|
| | def apply_zero_patterns(df: pd.DataFrame, forecast_df: pd.DataFrame, time_col: str, target_col: str) -> pd.DataFrame: |
| | """ |
| | Identifies patterns in the input data where the values are zero and applies those patterns to the forecast. |
| | """ |
| | try: |
| | |
| | df[time_col] = pd.to_datetime(df[time_col]) |
| | forecast_df[time_col] = pd.to_datetime(forecast_df[time_col]) |
| |
|
| | |
| | df['hour'] = df[time_col].dt.hour |
| | df['dayofweek'] = df[time_col].dt.dayofweek |
| |
|
| | |
| | hourly_avg = df.groupby('hour')[target_col].mean() |
| | daily_avg = df.groupby('dayofweek')[target_col].mean() |
| |
|
| | |
| | forecast_value_col = [col for col in forecast_df.columns if col != time_col][0] |
| |
|
| | |
| | forecast_df['hour'] = forecast_df[time_col].apply(lambda x: x.hour if isinstance(x, datetime) else None) |
| | forecast_df['dayofweek'] = forecast_df[time_col].apply(lambda x: x.dayofweek if isinstance(x, datetime) else None) |
| |
|
| | forecast_df = forecast_df.dropna(subset=['hour', 'dayofweek']) |
| |
|
| | |
| | forecast_df[forecast_value_col] = forecast_df.apply( |
| | lambda row: 0 if hourly_avg[row['hour']] < 1 or daily_avg[row['dayofweek']] < 1 else max(0, row[forecast_value_col]), |
| | axis=1 |
| | ) |
| | forecast_df.drop(columns=['hour', 'dayofweek'], inplace=True) |
| | return forecast_df |
| | except Exception as e: |
| | forecast_df[[forecast_value_col]] = 0 |
| | logger.error(f"Error applying zero patterns: {e}", exc_info=True) |
| | raise ValueError(f"Error applying zero patterns: {e}") |
| |
|
| | def create_plot(data, forecast_data, time_col, target_col): |
| | """ |
| | Creates a Plotly plot of the time series data and forecast. |
| | """ |
| | fig = go.Figure() |
| |
|
| | |
| | fig.add_trace(go.Scatter( |
| | x=data[time_col], |
| | y=data[target_col], |
| | mode='lines', |
| | name='Historical Data' |
| | )) |
| |
|
| | |
| | if forecast_data is not None: |
| | forecast_value_col = [col for col in forecast_data.columns if col != time_col][0] |
| | fig.add_trace(go.Scatter( |
| | x=forecast_data[time_col], |
| | y=forecast_data[forecast_value_col], |
| | mode='lines', |
| | name='Forecast' |
| | )) |
| |
|
| | fig.update_layout( |
| | title='Time Series Data and Forecast', |
| | xaxis_title='Time', |
| | yaxis_title='Value', |
| | template='plotly_white', |
| | hovermode="x unified" |
| | ) |
| | return fig |
| |
|
| | def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data) -> Tuple[str, object, str, str]: |
| | """ |
| | Full pipeline: loads the data, calls the forecast function, and then processes the data. |
| | """ |
| | try: |
| | data = load_data(file_obj) |
| | if not isinstance(data, pd.DataFrame): |
| | return "Error loading data. Please check the file format and content.", None, None, None |
| |
|
| | if time_col not in data.columns or target_col not in data.columns: |
| | return "Error: Timestamp column or Value column not found in the data.", None, None, None |
| |
|
| | |
| | data[time_col] = pd.to_datetime(data[time_col]) |
| |
|
| | |
| | data = data.sort_values(by=time_col) |
| |
|
| | |
| | min_date = data[time_col].min().strftime('%Y-%m-%d') |
| | max_date = data[time_col].max().strftime('%Y-%m-%d') |
| |
|
| | |
| | data = data.fillna(0) |
| |
|
| | |
| | if start_date and end_date: |
| | start_datetime = pd.to_datetime(start_date) |
| | end_datetime = pd.to_datetime(end_date) |
| | data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)] |
| | logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}") |
| |
|
| | data = data.set_index(time_col) |
| |
|
| | |
| | data = data.resample(resample_freq).mean() |
| | data.reset_index(inplace=True) |
| |
|
| | forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col) |
| | processed_data = process_forecast_data(forecast_result, time_col) |
| | processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col) |
| |
|
| | if merge_data: |
| | merged_data = pd.merge(data.reset_index(), processed_data, on=time_col, how='inner') |
| | else: |
| | merged_data = processed_data |
| |
|
| | plot = create_plot(data, processed_data, time_col, target_col) |
| | csv_data = processed_data.to_csv(index=False) |
| |
|
| | |
| | with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile: |
| | tmpfile.write(csv_data) |
| | csv_path = tmpfile.name |
| |
|
| | return processed_data.to_html(index=False), plot, csv_path, None |
| |
|
| | except ValueError as e: |
| | return f"Error: {e}", None, None, None |
| | except Exception as e: |
| | logger.exception("An unexpected error occurred:") |
| | return f"Error: An unexpected error occurred: {e}", None, None, None |
| |
|
| | def get_column_names(file_obj): |
| | """ |
| | Extracts column names from the uploaded file. |
| | """ |
| | try: |
| | df = load_data(file_obj) |
| | columns = df.columns.tolist() |
| | print(f"Column names: {columns}") |
| | return columns |
| | except Exception as e: |
| | logger.error(f"Error in get_column_names: {e}", exc_info=True) |
| | print(f"Error in get_column_names: {e}") |
| | return [] |
| |
|
| | def create_interface(): |
| | with gr.Blocks() as iface: |
| | gr.Markdown(""" |
| | # CP360 App |
| | Upload your time series data, select the appropriate columns, and generate a forecast! |
| | """) |
| |
|
| | file_input = gr.File(label="Upload Time Series Data (CSV, Excel, JSON, YAML)") |
| | with gr.Row(): |
| | time_col_dropdown = gr.Dropdown(choices=[], label="Select Timestamp Column") |
| | target_col_dropdown = gr.Dropdown(choices=[], label="Select Value Column") |
| |
|
| | def update_dropdown_choices(file_obj): |
| | columns = get_column_names(file_obj) |
| | return gr.update(choices=columns), gr.update(choices=columns) |
| |
|
| | file_input.upload( |
| | update_dropdown_choices, |
| | [file_input], |
| | [time_col_dropdown, target_col_dropdown] |
| | ) |
| |
|
| | with gr.Row(): |
| | forecast_horizon_input = gr.Number(label="Forecast Horizon", value=10) |
| | finetune_steps_input = gr.Number(label="Finetune Steps", value=100) |
| | freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Frequency", value='D') |
| |
|
| | with gr.Row(): |
| | start_date_input = gr.Textbox(label="Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-01-01") |
| | start_time_input = gr.Textbox(label="Start Time (HH:MM)", placeholder="HH:MM", value="00:00") |
| | end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-12-31") |
| | end_time_input = gr.Textbox(label="End Time (HH:MM)", placeholder="HH:MM", value="23:59") |
| |
|
| | resample_freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Resample Frequency", value='D') |
| |
|
| | output_html = gr.HTML(label="Forecast Data") |
| | output_plot = gr.Plot(label="Time Series Plot") |
| | download_button = gr.File(label="Download Forecast Data as CSV") |
| | error_output = gr.Markdown(label="Error Messages") |
| |
|
| | |
| | btn = gr.Button("Generate Forecast") |
| | btn.click( |
| | fn=full_forecast_pipeline, |
| | inputs=[file_input, time_col_dropdown, target_col_dropdown, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown], |
| | outputs=[output_html, output_plot, download_button, error_output] |
| | ) |
| | return iface |
| |
|
| | iface = create_interface() |
| | iface.launch() |