forcaster / app.py
anujkum0x's picture
Update app.py
7783537 verified
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import io
import os
import numpy as np
import yaml
import logging
import json
import csv
from datetime import datetime
from plotly.colors import n_colors
from nixtla import NixtlaClient
import tempfile
from typing import Tuple
from datetime import date
from datetime import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize NixtlaClient with your API key
nixtla_client = NixtlaClient(api_key='nixak-IzAtInwxiZNzvbdatMlOlak0IK6aLlUTJAvbQvnUzYSc45xuQHjqtMyOFYhg2IRIMphbFV3qGBYZbbvr')
# --- Utility Functions ---
def load_data(file_obj):
"""
Loads data from different file formats using Pandas.
"""
try:
filename = file_obj.name
if filename.endswith('.csv'):
df = pd.read_csv(file_obj.name)
elif filename.endswith('.xlsx') or filename.endswith('.xls'):
df = pd.read_excel(file_obj.name)
elif filename.endswith('.json'):
df = pd.read_json(file_obj.name)
elif filename.endswith('.yaml') or filename.endswith('.yml'):
with open(file_obj.name, 'r') as f:
data = yaml.safe_load(f)
df = pd.DataFrame(data)
else:
raise ValueError("Unsupported file format")
print("DataFrame loaded successfully:")
print(df)
return df
except Exception as e:
logger.error(f"Error loading data: {e}", exc_info=True)
raise ValueError(f"Error loading data: {e}")
def forecast_nixtla(df, forecast_horizon, finetune_steps, freq, time_col, target_col):
"""
Function to call the Nixtla API directly.
"""
try:
# Make forecast using NixtlaClient
forecast = nixtla_client.forecast(
df=df,
h=forecast_horizon,
finetune_steps=finetune_steps,
time_col=time_col,
target_col=target_col,
freq=freq
)
logger.info("Nixtla API call successful")
return forecast
except Exception as e:
logger.error(f"Error communicating with the forecasting API: {e}", exc_info=True)
raise ValueError(f"Error communicating with the forecasting API: {e}")
def process_forecast_data(forecast_data, time_col) -> pd.DataFrame:
"""
Process the forecast data to be more human-readable.
"""
try:
forecast_df = pd.DataFrame(forecast_data)
forecast_df[time_col] = pd.to_datetime(forecast_df[time_col])
forecast_df[time_col] = forecast_df[time_col].dt.strftime('%Y-%m-%d %H:%M:%S')
return forecast_df
except Exception as e:
logger.error(f"Error processing forecast data: {e}", exc_info=True)
raise ValueError(f"Error processing forecast data: {e}")
def apply_zero_patterns(df: pd.DataFrame, forecast_df: pd.DataFrame, time_col: str, target_col: str) -> pd.DataFrame:
"""
Identifies patterns in the input data where the values are zero and applies those patterns to the forecast.
"""
try:
# Convert time column to datetime
df[time_col] = pd.to_datetime(df[time_col])
forecast_df[time_col] = pd.to_datetime(forecast_df[time_col])
# Extract hour and day of week from the start_time
df['hour'] = df[time_col].dt.hour
df['dayofweek'] = df[time_col].dt.dayofweek # Monday=0, Sunday=6
# Calculate the average value for each hour and day of week
hourly_avg = df.groupby('hour')[target_col].mean()
daily_avg = df.groupby('dayofweek')[target_col].mean()
# Get the forecast value column name
forecast_value_col = [col for col in forecast_df.columns if col != time_col][0]
# Apply the learned patterns to the forecast
forecast_df['hour'] = forecast_df[time_col].apply(lambda x: x.hour if isinstance(x, datetime) else None)
forecast_df['dayofweek'] = forecast_df[time_col].apply(lambda x: x.dayofweek if isinstance(x, datetime) else None)
forecast_df = forecast_df.dropna(subset=['hour', 'dayofweek'])
# Nullify forecast values based on historical patterns
forecast_df[forecast_value_col] = forecast_df.apply(
lambda row: 0 if hourly_avg[row['hour']] < 1 or daily_avg[row['dayofweek']] < 1 else max(0, row[forecast_value_col]),
axis=1
)
forecast_df.drop(columns=['hour', 'dayofweek'], inplace=True)
return forecast_df
except Exception as e:
forecast_df[[forecast_value_col]] = 0
logger.error(f"Error applying zero patterns: {e}", exc_info=True)
raise ValueError(f"Error applying zero patterns: {e}")
def create_plot(data, forecast_data, time_col, target_col):
"""
Creates a Plotly plot of the time series data and forecast.
"""
fig = go.Figure()
# Historical Data
fig.add_trace(go.Scatter(
x=data[time_col],
y=data[target_col],
mode='lines',
name='Historical Data'
))
# Forecast Data
if forecast_data is not None:
forecast_value_col = [col for col in forecast_data.columns if col != time_col][0]
fig.add_trace(go.Scatter(
x=forecast_data[time_col],
y=forecast_data[forecast_value_col],
mode='lines',
name='Forecast'
))
fig.update_layout(
title='Time Series Data and Forecast',
xaxis_title='Time',
yaxis_title='Value',
template='plotly_white',
hovermode="x unified"
)
return fig
def full_forecast_pipeline(file_obj, time_col, target_col, forecast_horizon, finetune_steps, freq, start_date, end_date, start_time, end_time, resample_freq, merge_data) -> Tuple[str, object, str, str]:
"""
Full pipeline: loads the data, calls the forecast function, and then processes the data.
"""
try:
data = load_data(file_obj)
if not isinstance(data, pd.DataFrame):
return "Error loading data. Please check the file format and content.", None, None, None
if time_col not in data.columns or target_col not in data.columns:
return "Error: Timestamp column or Value column not found in the data.", None, None, None
# Convert time column to datetime
data[time_col] = pd.to_datetime(data[time_col])
# Sort the DataFrame by the time column
data = data.sort_values(by=time_col)
# Get min and max dates from the data
min_date = data[time_col].min().strftime('%Y-%m-%d')
max_date = data[time_col].max().strftime('%Y-%m-%d')
# Fill missing values with 0
data = data.fillna(0)
# Apply date range selection
if start_date and end_date:
start_datetime = pd.to_datetime(start_date)
end_datetime = pd.to_datetime(end_date)
data = data[(data[time_col] >= start_datetime) & (data[time_col] <= end_datetime)]
logger.info(f"Data filtered from {start_datetime} to {end_datetime}. Shape: {data.shape}")
data = data.set_index(time_col)
# Resample the data
data = data.resample(resample_freq).mean()
data.reset_index(inplace=True)
forecast_result = forecast_nixtla(data, forecast_horizon, finetune_steps, freq, time_col, target_col)
processed_data = process_forecast_data(forecast_result, time_col)
processed_data = apply_zero_patterns(data.copy(), processed_data, time_col, target_col)
if merge_data:
merged_data = pd.merge(data.reset_index(), processed_data, on=time_col, how='inner')
else:
merged_data = processed_data
plot = create_plot(data, processed_data, time_col, target_col)
csv_data = processed_data.to_csv(index=False)
# Create a temporary file and write the CSV data to it
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix=".csv") as tmpfile:
tmpfile.write(csv_data)
csv_path = tmpfile.name
return processed_data.to_html(index=False), plot, csv_path, None
except ValueError as e:
return f"Error: {e}", None, None, None
except Exception as e:
logger.exception("An unexpected error occurred:")
return f"Error: An unexpected error occurred: {e}", None, None, None
def get_column_names(file_obj):
"""
Extracts column names from the uploaded file.
"""
try:
df = load_data(file_obj)
columns = df.columns.tolist()
print(f"Column names: {columns}")
return columns
except Exception as e:
logger.error(f"Error in get_column_names: {e}", exc_info=True)
print(f"Error in get_column_names: {e}")
return []
def create_interface():
with gr.Blocks() as iface:
gr.Markdown("""
# CP360 App
Upload your time series data, select the appropriate columns, and generate a forecast!
""")
file_input = gr.File(label="Upload Time Series Data (CSV, Excel, JSON, YAML)")
with gr.Row():
time_col_dropdown = gr.Dropdown(choices=[], label="Select Timestamp Column")
target_col_dropdown = gr.Dropdown(choices=[], label="Select Value Column")
def update_dropdown_choices(file_obj):
columns = get_column_names(file_obj)
return gr.update(choices=columns), gr.update(choices=columns)
file_input.upload(
update_dropdown_choices,
[file_input],
[time_col_dropdown, target_col_dropdown]
)
with gr.Row():
forecast_horizon_input = gr.Number(label="Forecast Horizon", value=10)
finetune_steps_input = gr.Number(label="Finetune Steps", value=100)
freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Frequency", value='D')
with gr.Row():
start_date_input = gr.Textbox(label="Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-01-01")
start_time_input = gr.Textbox(label="Start Time (HH:MM)", placeholder="HH:MM", value="00:00")
end_date_input = gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD", value="2023-12-31")
end_time_input = gr.Textbox(label="End Time (HH:MM)", placeholder="HH:MM", value="23:59")
resample_freq_dropdown = gr.Dropdown(choices=['15min', '30min', 'H', '2H', '3H', '4H', '5H', '6H', '12H', 'D', 'W', 'M', 'Y'], label="Resample Frequency", value='D')
output_html = gr.HTML(label="Forecast Data")
output_plot = gr.Plot(label="Time Series Plot")
download_button = gr.File(label="Download Forecast Data as CSV")
error_output = gr.Markdown(label="Error Messages")
# Button to trigger the full pipeline
btn = gr.Button("Generate Forecast")
btn.click(
fn=full_forecast_pipeline,
inputs=[file_input, time_col_dropdown, target_col_dropdown, forecast_horizon_input, finetune_steps_input, freq_dropdown, start_date_input, end_date_input, start_time_input, end_time_input, resample_freq_dropdown],
outputs=[output_html, output_plot, download_button, error_output]
)
return iface
iface = create_interface()
iface.launch()