Spaces:
Sleeping
Sleeping
Update app.py
Browse filesShipping some updates.
app.py
CHANGED
|
@@ -14,89 +14,126 @@ model = load_model("NX-AI/TiRex")
|
|
| 14 |
|
| 15 |
def load_columns(file):
|
| 16 |
if file is None:
|
| 17 |
-
return gr.Dropdown(choices=[], label="Select Column
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
try:
|
| 20 |
# Handle file as path string (Gradio convention)
|
| 21 |
with open(file, 'rb') as f:
|
| 22 |
content = f.read()
|
| 23 |
df_preview = pd.read_csv(io.BytesIO(content))
|
| 24 |
-
|
| 25 |
-
#
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# Available numeric columns for forecast
|
| 31 |
numeric_cols = df_preview.select_dtypes(include=['number']).columns.tolist()
|
| 32 |
-
|
| 33 |
-
numeric_cols.remove('date')
|
| 34 |
-
|
| 35 |
if numeric_cols:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
value=numeric_cols[0],
|
| 39 |
-
label="Select Column to Forecast:",
|
| 40 |
-
interactive=True
|
| 41 |
-
)
|
| 42 |
else:
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
except Exception as e:
|
| 50 |
-
return gr.Dropdown(
|
| 51 |
choices=[],
|
| 52 |
value=None,
|
| 53 |
label=f"Error loading CSV: {str(e)}",
|
| 54 |
interactive=False
|
| 55 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
def run_forecast(file, selected_col, prediction_length, confidence):
|
| 58 |
-
if file is None or selected_col is None:
|
| 59 |
-
return None, "### Error\nPlease upload a CSV and select a column!"
|
| 60 |
-
|
| 61 |
try:
|
| 62 |
# Handle file as path string (Gradio convention)
|
| 63 |
with open(file, 'rb') as f:
|
| 64 |
content = f.read()
|
| 65 |
df = pd.read_csv(io.BytesIO(content))
|
| 66 |
|
| 67 |
-
#
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
df = df.rename(columns={date_cols[0]: 'date'})
|
| 71 |
-
else:
|
| 72 |
-
return None, "### Error\nNo date column found (looking for 'Day' or 'date'). Edit CSV."
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
df = df.rename(columns={selected_col: 'sales'})
|
| 76 |
|
| 77 |
# Validate
|
| 78 |
required_cols = ['date', 'sales']
|
| 79 |
if not all(col in df.columns for col in required_cols):
|
| 80 |
-
return None, f"### Error\nMissing
|
| 81 |
|
| 82 |
# Prep data
|
| 83 |
df['date'] = pd.to_datetime(df['date'])
|
| 84 |
df = df.set_index('date').sort_index()
|
| 85 |
-
if len(df) < 10:
|
| 86 |
-
return None, "### Error\nNeed at least 10 data points."
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
# Infer freq
|
| 92 |
-
freq = pd.infer_freq(
|
| 93 |
if freq is None:
|
| 94 |
freq = 'D'
|
| 95 |
print(f"Frequency: '{freq}'.")
|
| 96 |
|
| 97 |
# Prep context
|
| 98 |
-
context_len = min(len(
|
| 99 |
-
context = torch.tensor(
|
| 100 |
|
| 101 |
pred_len = prediction_length
|
| 102 |
conf_level = confidence / 100.0
|
|
@@ -122,6 +159,8 @@ def run_forecast(file, selected_col, prediction_length, confidence):
|
|
| 122 |
lower_slider = np.zeros(pred_len)
|
| 123 |
upper_slider = np.zeros(pred_len)
|
| 124 |
|
|
|
|
|
|
|
| 125 |
for t in range(pred_len):
|
| 126 |
q_t = q[t]
|
| 127 |
lower50[t] = np.interp(lower_alpha_50, alphas, q_t)
|
|
@@ -129,11 +168,22 @@ def run_forecast(file, selected_col, prediction_length, confidence):
|
|
| 129 |
lower_slider[t] = np.interp(lower_alpha_slider, alphas, q_t)
|
| 130 |
upper_slider[t] = np.interp(upper_alpha_slider, alphas, q_t)
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
# Mean forecast
|
| 133 |
mean_forecast = mean[0].detach().numpy()
|
| 134 |
|
| 135 |
# Future dates
|
| 136 |
-
last_date =
|
| 137 |
if freq == 'D':
|
| 138 |
future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_len, freq='D')
|
| 139 |
else:
|
|
@@ -144,35 +194,49 @@ def run_forecast(file, selected_col, prediction_length, confidence):
|
|
| 144 |
'predicted_sales_median': median,
|
| 145 |
'predicted_sales_lower': lower_slider,
|
| 146 |
'predicted_sales_upper': upper_slider,
|
| 147 |
-
'predicted_sales_mean': mean_forecast
|
|
|
|
| 148 |
}).set_index('date')
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
# Prepare markdown output (broken into smaller strings to avoid multiline f-string parsing issues)
|
| 151 |
-
markdown_text = "###
|
| 152 |
-
markdown_text += "
|
| 153 |
-
markdown_text += "
|
| 154 |
-
for idx, row in pred_df.iterrows():
|
| 155 |
-
markdown_text += "| {} | {:.2f} | {:.2f} | {:.2f} | {:.2f} |\n".format(
|
| 156 |
-
idx.strftime('%Y-%m-%d'),
|
| 157 |
-
row['predicted_sales_median'],
|
| 158 |
-
row['predicted_sales_lower'],
|
| 159 |
-
row['predicted_sales_upper'],
|
| 160 |
-
row['predicted_sales_mean']
|
| 161 |
-
)
|
| 162 |
-
|
| 163 |
-
markdown_text += "\n### 📊 Summary\n"
|
| 164 |
markdown_text += "- **Prediction Length:** {} periods\n".format(pred_len)
|
| 165 |
markdown_text += "- **Confidence Level:** {}% (alphas: {:.3f} - {:.3f})\n".format(confidence, lower_alpha_slider, upper_alpha_slider)
|
| 166 |
markdown_text += "- **Sum of Median Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_median'].sum())
|
| 167 |
-
markdown_text += "- **Sum of Mean Predicted Values:** {:.2f}\n
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
# Create plot
|
| 173 |
fig, ax = plt.subplots(figsize=(14, 7))
|
| 174 |
-
ax.plot(
|
| 175 |
-
|
|
|
|
| 176 |
ax.plot(pred_df.index, pred_df['predicted_sales_mean'], label='TiRex Forecast (Mean)', color='#ff7f0e', linestyle='--', linewidth=2)
|
| 177 |
|
| 178 |
# Fan chart: non-overlapping bands
|
|
@@ -185,81 +249,122 @@ def run_forecast(file, selected_col, prediction_length, confidence):
|
|
| 185 |
ax.fill_between(pred_df.index, upper50, upper_slider,
|
| 186 |
color='#d62728', alpha=0.3, label=f'{confidence}% Uncertainty Wings')
|
| 187 |
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
ax.set_xlabel('Date', fontsize=12)
|
| 190 |
ax.set_ylabel(selected_col, fontsize=12)
|
| 191 |
-
ax.legend(fontsize=10)
|
| 192 |
ax.tick_params(axis='x', rotation=45)
|
| 193 |
plt.tight_layout()
|
| 194 |
-
|
| 195 |
return fig, markdown_text
|
| 196 |
|
| 197 |
except Exception as e:
|
| 198 |
-
return None, f"###
|
| 199 |
|
| 200 |
# Create the Gradio interface
|
| 201 |
-
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="
|
| 202 |
gr.Markdown("""
|
| 203 |
-
#
|
| 204 |
-
Upload a CSV file with a
|
| 205 |
-
|
| 206 |
-
The dashboard will display in this new window/tab for a cool, interactive experience!
|
| 207 |
""")
|
| 208 |
-
|
| 209 |
with gr.Row(variant="panel"):
|
| 210 |
with gr.Column(scale=1):
|
| 211 |
csv_file = gr.File(
|
| 212 |
file_types=[".csv"],
|
| 213 |
-
label="
|
| 214 |
elem_id="file_upload"
|
| 215 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
column_dropdown = gr.Dropdown(
|
| 217 |
choices=[],
|
| 218 |
-
label="
|
| 219 |
interactive=True,
|
| 220 |
elem_id="column_select"
|
| 221 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
prediction_length = gr.Slider(
|
| 223 |
-
minimum=1, maximum=
|
| 224 |
-
label="
|
| 225 |
elem_id="pred_length"
|
| 226 |
)
|
| 227 |
confidence = gr.Slider(
|
| 228 |
minimum=50, maximum=95, value=80, step=5,
|
| 229 |
-
label="
|
| 230 |
elem_id="confidence"
|
| 231 |
)
|
| 232 |
run_button = gr.Button(
|
| 233 |
-
"
|
| 234 |
variant="primary",
|
| 235 |
size="lg",
|
| 236 |
elem_id="run_btn"
|
| 237 |
)
|
| 238 |
-
|
| 239 |
with gr.Column(scale=2):
|
| 240 |
forecast_plot = gr.Plot(
|
| 241 |
-
label="
|
| 242 |
elem_id="plot"
|
| 243 |
)
|
| 244 |
output_text = gr.Markdown(
|
| 245 |
"### Welcome!\nUpload your CSV to get started.",
|
| 246 |
elem_id="output"
|
| 247 |
)
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
| 250 |
csv_file.change(
|
| 251 |
load_columns,
|
| 252 |
inputs=csv_file,
|
| 253 |
-
outputs=column_dropdown
|
| 254 |
)
|
| 255 |
-
|
| 256 |
# Event for running forecast
|
| 257 |
run_button.click(
|
| 258 |
run_forecast,
|
| 259 |
-
inputs=[csv_file, column_dropdown, prediction_length, confidence],
|
| 260 |
outputs=[forecast_plot, output_text]
|
| 261 |
)
|
| 262 |
|
| 263 |
# Launch the app
|
| 264 |
-
|
| 265 |
-
demo.launch()
|
|
|
|
| 14 |
|
| 15 |
def load_columns(file):
|
| 16 |
if file is None:
|
| 17 |
+
return (gr.Dropdown(choices=[], label="Select Time Column:", interactive=True),
|
| 18 |
+
gr.Dropdown(choices=[], label="Select Column to Forecast:", interactive=True),
|
| 19 |
+
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index (1-based)"),
|
| 20 |
+
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index (1-based)"))
|
| 21 |
+
|
| 22 |
try:
|
| 23 |
# Handle file as path string (Gradio convention)
|
| 24 |
with open(file, 'rb') as f:
|
| 25 |
content = f.read()
|
| 26 |
df_preview = pd.read_csv(io.BytesIO(content))
|
| 27 |
+
|
| 28 |
+
# All columns for time selection
|
| 29 |
+
all_cols = df_preview.columns.tolist()
|
| 30 |
+
time_choices = [(col, col) for col in all_cols]
|
| 31 |
+
time_value = all_cols[0] if all_cols else None
|
| 32 |
+
|
| 33 |
+
# Available numeric columns for forecast
|
| 34 |
numeric_cols = df_preview.select_dtypes(include=['number']).columns.tolist()
|
| 35 |
+
|
|
|
|
|
|
|
| 36 |
if numeric_cols:
|
| 37 |
+
value_choices = [(col, col) for col in numeric_cols]
|
| 38 |
+
value_value = numeric_cols[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
else:
|
| 40 |
+
value_choices = []
|
| 41 |
+
value_value = None
|
| 42 |
+
|
| 43 |
+
n_rows = len(df_preview)
|
| 44 |
+
|
| 45 |
+
time_dropdown = gr.Dropdown(
|
| 46 |
+
choices=time_choices,
|
| 47 |
+
value=time_value,
|
| 48 |
+
label="Select Time Column:",
|
| 49 |
+
interactive=True
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
value_dropdown = gr.Dropdown(
|
| 53 |
+
choices=value_choices,
|
| 54 |
+
value=value_value,
|
| 55 |
+
label="Select Column to Forecast:",
|
| 56 |
+
interactive=True
|
| 57 |
+
) if value_choices else gr.Dropdown(
|
| 58 |
+
choices=[],
|
| 59 |
+
value=None,
|
| 60 |
+
label="No numeric columns found",
|
| 61 |
+
interactive=False
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
start_slider = gr.Slider(
|
| 65 |
+
minimum=1, maximum=n_rows, value=1, step=1,
|
| 66 |
+
label="Historical Start Index (1-based)"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
end_slider = gr.Slider(
|
| 70 |
+
minimum=1, maximum=n_rows, value=n_rows, step=1,
|
| 71 |
+
label="Historical End Index (1-based)"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return time_dropdown, value_dropdown, start_slider, end_slider
|
| 75 |
+
|
| 76 |
except Exception as e:
|
| 77 |
+
return (gr.Dropdown(
|
| 78 |
choices=[],
|
| 79 |
value=None,
|
| 80 |
label=f"Error loading CSV: {str(e)}",
|
| 81 |
interactive=False
|
| 82 |
+
), gr.Dropdown(
|
| 83 |
+
choices=[],
|
| 84 |
+
value=None,
|
| 85 |
+
label=f"Error loading CSV: {str(e)}",
|
| 86 |
+
interactive=False
|
| 87 |
+
), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index (1-based)"),
|
| 88 |
+
gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index (1-based)"))
|
| 89 |
+
|
| 90 |
+
def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_length, confidence):
|
| 91 |
+
if file is None or time_col is None or selected_col is None:
|
| 92 |
+
return None, "### Error\nPlease upload a CSV and select time and value columns!"
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
try:
|
| 95 |
# Handle file as path string (Gradio convention)
|
| 96 |
with open(file, 'rb') as f:
|
| 97 |
content = f.read()
|
| 98 |
df = pd.read_csv(io.BytesIO(content))
|
| 99 |
|
| 100 |
+
# Validate columns exist
|
| 101 |
+
if time_col not in df.columns or selected_col not in df.columns:
|
| 102 |
+
return None, f"### Error\nSelected columns '{time_col}' or '{selected_col}' not found in CSV."
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
# Rename selected columns
|
| 105 |
+
df = df.rename(columns={time_col: 'date', selected_col: 'sales'})
|
| 106 |
|
| 107 |
# Validate
|
| 108 |
required_cols = ['date', 'sales']
|
| 109 |
if not all(col in df.columns for col in required_cols):
|
| 110 |
+
return None, f"### Error\nMissing renamed columns."
|
| 111 |
|
| 112 |
# Prep data
|
| 113 |
df['date'] = pd.to_datetime(df['date'])
|
| 114 |
df = df.set_index('date').sort_index()
|
|
|
|
|
|
|
| 115 |
|
| 116 |
+
full_len = len(df)
|
| 117 |
+
context_start = max(0, int(start_idx) - 1)
|
| 118 |
+
context_end = min(full_len, int(end_idx))
|
| 119 |
+
context_df = df.iloc[context_start:context_end]
|
| 120 |
+
held_out_df = df.iloc[context_end:] if context_end < full_len else pd.DataFrame(index=pd.DatetimeIndex([]), columns=df.columns)
|
| 121 |
+
|
| 122 |
+
if len(context_df) < 10:
|
| 123 |
+
return None, "### Error\nNeed at least 10 data points in the selected historical range."
|
| 124 |
+
|
| 125 |
+
context_series = context_df['sales'].dropna().values
|
| 126 |
+
print(f"Loaded context: {len(context_series)} points from {context_df.index.min().date()} to {context_df.index.max().date()} (Column: {selected_col})") # For logs
|
| 127 |
|
| 128 |
# Infer freq
|
| 129 |
+
freq = pd.infer_freq(context_df.index)
|
| 130 |
if freq is None:
|
| 131 |
freq = 'D'
|
| 132 |
print(f"Frequency: '{freq}'.")
|
| 133 |
|
| 134 |
# Prep context
|
| 135 |
+
context_len = min(len(context_series), 2048)
|
| 136 |
+
context = torch.tensor(context_series[-context_len:]).unsqueeze(0).float()
|
| 137 |
|
| 138 |
pred_len = prediction_length
|
| 139 |
conf_level = confidence / 100.0
|
|
|
|
| 159 |
lower_slider = np.zeros(pred_len)
|
| 160 |
upper_slider = np.zeros(pred_len)
|
| 161 |
|
| 162 |
+
skew_directions = []
|
| 163 |
+
|
| 164 |
for t in range(pred_len):
|
| 165 |
q_t = q[t]
|
| 166 |
lower50[t] = np.interp(lower_alpha_50, alphas, q_t)
|
|
|
|
| 168 |
lower_slider[t] = np.interp(lower_alpha_slider, alphas, q_t)
|
| 169 |
upper_slider[t] = np.interp(upper_alpha_slider, alphas, q_t)
|
| 170 |
|
| 171 |
+
# Compute skew direction based on asymmetry around median
|
| 172 |
+
med = median[t]
|
| 173 |
+
upside_dist = upper_slider[t] - med
|
| 174 |
+
downside_dist = med - lower_slider[t]
|
| 175 |
+
if upside_dist > downside_dist:
|
| 176 |
+
skew_directions.append("Upside")
|
| 177 |
+
elif downside_dist > upside_dist:
|
| 178 |
+
skew_directions.append("Downside")
|
| 179 |
+
else:
|
| 180 |
+
skew_directions.append("Neutral")
|
| 181 |
+
|
| 182 |
# Mean forecast
|
| 183 |
mean_forecast = mean[0].detach().numpy()
|
| 184 |
|
| 185 |
# Future dates
|
| 186 |
+
last_date = context_df.index[-1]
|
| 187 |
if freq == 'D':
|
| 188 |
future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_len, freq='D')
|
| 189 |
else:
|
|
|
|
| 194 |
'predicted_sales_median': median,
|
| 195 |
'predicted_sales_lower': lower_slider,
|
| 196 |
'predicted_sales_upper': upper_slider,
|
| 197 |
+
'predicted_sales_mean': mean_forecast,
|
| 198 |
+
'skew_direction': skew_directions
|
| 199 |
}).set_index('date')
|
| 200 |
|
| 201 |
+
# Count skews for summary
|
| 202 |
+
upside_count = skew_directions.count("Upside")
|
| 203 |
+
downside_count = skew_directions.count("Downside")
|
| 204 |
+
neutral_count = skew_directions.count("Neutral")
|
| 205 |
+
|
| 206 |
# Prepare markdown output (broken into smaller strings to avoid multiline f-string parsing issues)
|
| 207 |
+
markdown_text = "### Summary\n"
|
| 208 |
+
markdown_text += "- **Number of Historical Periods Used:** {} points\n".format(len(context_series))
|
| 209 |
+
markdown_text += "- **Held Out Periods:** {} points {}\n".format(len(held_out_df), "(Full Context Used)" if len(held_out_df) == 0 else "(For Validation)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
markdown_text += "- **Prediction Length:** {} periods\n".format(pred_len)
|
| 211 |
markdown_text += "- **Confidence Level:** {}% (alphas: {:.3f} - {:.3f})\n".format(confidence, lower_alpha_slider, upper_alpha_slider)
|
| 212 |
markdown_text += "- **Sum of Median Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_median'].sum())
|
| 213 |
+
markdown_text += "- **Sum of Mean Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_mean'].sum())
|
| 214 |
+
markdown_text += "- **Skew Distribution:** {} Upside, {} Downside, {} Neutral\n\n".format(upside_count, downside_count, neutral_count)
|
| 215 |
+
|
| 216 |
+
forecast_table = "### TiRex Forecast Results (Median + {}% Interval)\n\n".format(confidence)
|
| 217 |
+
forecast_table += "| Date | Median | Lower Bound | Upper Bound | Mean | Skew |\n"
|
| 218 |
+
forecast_table += "|------|--------|-------------|-------------|------|------|\n"
|
| 219 |
+
for idx, row in pred_df.iterrows():
|
| 220 |
+
forecast_table += "| {} | {:.2f} | {:.2f} | {:.2f} | {:.2f} | {} |\n".format(
|
| 221 |
+
idx.strftime('%Y-%m-%d'),
|
| 222 |
+
row['predicted_sales_median'],
|
| 223 |
+
row['predicted_sales_lower'],
|
| 224 |
+
row['predicted_sales_upper'],
|
| 225 |
+
row['predicted_sales_mean'],
|
| 226 |
+
row['skew_direction']
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
sample_data = "### Sample Historical Data (Context)\n"
|
| 230 |
+
sample_data += "```\n" + context_df.head().to_string() + "\n```"
|
| 231 |
+
|
| 232 |
+
markdown_text += f'\n<details><summary>Click to expand Forecast Table</summary>\n\n{forecast_table}\n</details>\n\n'
|
| 233 |
+
markdown_text += f'<details><summary>Click to expand Sample Historical Data</summary>\n\n{sample_data}\n</details>'
|
| 234 |
|
| 235 |
# Create plot
|
| 236 |
fig, ax = plt.subplots(figsize=(14, 7))
|
| 237 |
+
ax.plot(context_df.index, context_df['sales'], label=f'Used Historical {selected_col}', color='#1f77b4', linewidth=1.5, alpha=0.8)
|
| 238 |
+
if not held_out_df.empty:
|
| 239 |
+
ax.plot(held_out_df.index, held_out_df['sales'], label='Held Out Actual (Validation)', color='#2ca02c', linestyle=':', linewidth=2)
|
| 240 |
ax.plot(pred_df.index, pred_df['predicted_sales_mean'], label='TiRex Forecast (Mean)', color='#ff7f0e', linestyle='--', linewidth=2)
|
| 241 |
|
| 242 |
# Fan chart: non-overlapping bands
|
|
|
|
| 249 |
ax.fill_between(pred_df.index, upper50, upper_slider,
|
| 250 |
color='#d62728', alpha=0.3, label=f'{confidence}% Uncertainty Wings')
|
| 251 |
|
| 252 |
+
# Subtle skew visualization: colored segments on the median forecast line
|
| 253 |
+
from matplotlib.lines import Line2D
|
| 254 |
+
legend_elements = []
|
| 255 |
+
|
| 256 |
+
skew_colors = {'Upside': 'green', 'Downside': 'red', 'Neutral': 'gray'}
|
| 257 |
+
for i in range(len(pred_df) - 1):
|
| 258 |
+
start_date = pred_df.index[i]
|
| 259 |
+
end_date = pred_df.index[i + 1]
|
| 260 |
+
start_val = median[i]
|
| 261 |
+
end_val = median[i + 1]
|
| 262 |
+
skew = skew_directions[i]
|
| 263 |
+
color = skew_colors[skew]
|
| 264 |
+
ax.plot([start_date, end_date], [start_val, end_val], color=color, linewidth=2.5, alpha=0.7)
|
| 265 |
+
|
| 266 |
+
# Connect the last point if needed, but since segments cover, add a small marker at end if desired
|
| 267 |
+
ax.plot(pred_df.index[-1], median[-1], marker='o', color=skew_colors[skew_directions[-1]], markersize=4, alpha=0.7)
|
| 268 |
+
|
| 269 |
+
# Add to legend only if present
|
| 270 |
+
if upside_count > 0:
|
| 271 |
+
legend_elements.append(Line2D([0], [0], color='green', lw=2, label='Upside Skew'))
|
| 272 |
+
if downside_count > 0:
|
| 273 |
+
legend_elements.append(Line2D([0], [0], color='red', lw=2, label='Downside Skew'))
|
| 274 |
+
if neutral_count > 0:
|
| 275 |
+
legend_elements.append(Line2D([0], [0], color='gray', lw=2, label='Neutral Skew'))
|
| 276 |
+
|
| 277 |
+
ax.set_title(f'{selected_col} Forecast with TiRex (Context: {context_start+1}-{context_end}, Horizon: {pred_len})', fontsize=16, fontweight='bold')
|
| 278 |
ax.set_xlabel('Date', fontsize=12)
|
| 279 |
ax.set_ylabel(selected_col, fontsize=12)
|
| 280 |
+
ax.legend(handles=ax.get_legend_handles_labels()[0] + legend_elements, fontsize=10)
|
| 281 |
ax.tick_params(axis='x', rotation=45)
|
| 282 |
plt.tight_layout()
|
| 283 |
+
|
| 284 |
return fig, markdown_text
|
| 285 |
|
| 286 |
except Exception as e:
|
| 287 |
+
return None, f"### Error\n{str(e)}\n\nTips: Ensure the time column can be parsed as dates; check NaNs/zeros; ensure data is valid."
|
| 288 |
|
| 289 |
# Create the Gradio interface
|
| 290 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="TiRex Forecaster") as demo:
|
| 291 |
gr.Markdown("""
|
| 292 |
+
# TiRex Forecaster Dashboard
|
| 293 |
+
Upload a CSV file with a time column and numeric columns. Select the time column and one numeric column to forecast future values using the TiRex model.
|
|
|
|
|
|
|
| 294 |
""")
|
| 295 |
+
|
| 296 |
with gr.Row(variant="panel"):
|
| 297 |
with gr.Column(scale=1):
|
| 298 |
csv_file = gr.File(
|
| 299 |
file_types=[".csv"],
|
| 300 |
+
label="Upload CSV File",
|
| 301 |
elem_id="file_upload"
|
| 302 |
)
|
| 303 |
+
gr.Markdown("The minimum effective input is around 128 time steps per series. Use a full context of 2048 steps for optimal performance.")
|
| 304 |
+
time_dropdown = gr.Dropdown(
|
| 305 |
+
choices=[],
|
| 306 |
+
label="Select Time Column",
|
| 307 |
+
interactive=True,
|
| 308 |
+
elem_id="time_select"
|
| 309 |
+
)
|
| 310 |
column_dropdown = gr.Dropdown(
|
| 311 |
choices=[],
|
| 312 |
+
label="Select Column to Forecast",
|
| 313 |
interactive=True,
|
| 314 |
elem_id="column_select"
|
| 315 |
)
|
| 316 |
+
start_slider = gr.Slider(
|
| 317 |
+
minimum=1, maximum=1, value=1, step=1,
|
| 318 |
+
label="Historical Start Index (1-based)",
|
| 319 |
+
elem_id="start_idx"
|
| 320 |
+
)
|
| 321 |
+
end_slider = gr.Slider(
|
| 322 |
+
minimum=1, maximum=1, value=1, step=1,
|
| 323 |
+
label="Historical End Index (1-based)",
|
| 324 |
+
elem_id="end_idx"
|
| 325 |
+
)
|
| 326 |
prediction_length = gr.Slider(
|
| 327 |
+
minimum=1, maximum=720, value=12, step=1,
|
| 328 |
+
label="Prediction Length (Periods)",
|
| 329 |
elem_id="pred_length"
|
| 330 |
)
|
| 331 |
confidence = gr.Slider(
|
| 332 |
minimum=50, maximum=95, value=80, step=5,
|
| 333 |
+
label="Confidence Level (%)",
|
| 334 |
elem_id="confidence"
|
| 335 |
)
|
| 336 |
run_button = gr.Button(
|
| 337 |
+
"Run TiRex Forecast",
|
| 338 |
variant="primary",
|
| 339 |
size="lg",
|
| 340 |
elem_id="run_btn"
|
| 341 |
)
|
| 342 |
+
|
| 343 |
with gr.Column(scale=2):
|
| 344 |
forecast_plot = gr.Plot(
|
| 345 |
+
label="Forecast Visualization",
|
| 346 |
elem_id="plot"
|
| 347 |
)
|
| 348 |
output_text = gr.Markdown(
|
| 349 |
"### Welcome!\nUpload your CSV to get started.",
|
| 350 |
elem_id="output"
|
| 351 |
)
|
| 352 |
+
|
| 353 |
+
gr.Markdown("**Built by** [next one gmbh](https://www.nextone.at)")
|
| 354 |
+
|
| 355 |
+
# Event for updating dropdowns on file upload
|
| 356 |
csv_file.change(
|
| 357 |
load_columns,
|
| 358 |
inputs=csv_file,
|
| 359 |
+
outputs=[time_dropdown, column_dropdown, start_slider, end_slider]
|
| 360 |
)
|
| 361 |
+
|
| 362 |
# Event for running forecast
|
| 363 |
run_button.click(
|
| 364 |
run_forecast,
|
| 365 |
+
inputs=[csv_file, time_dropdown, column_dropdown, start_slider, end_slider, prediction_length, confidence],
|
| 366 |
outputs=[forecast_plot, output_text]
|
| 367 |
)
|
| 368 |
|
| 369 |
# Launch the app
|
| 370 |
+
demo.launch()
|
|
|