Snxt1 commited on
Commit
2742aad
·
verified ·
1 Parent(s): 3a84b34

Update app.py

Browse files

Shipping some updates.

Files changed (1) hide show
  1. app.py +197 -92
app.py CHANGED
@@ -14,89 +14,126 @@ model = load_model("NX-AI/TiRex")
14
 
15
  def load_columns(file):
16
  if file is None:
17
- return gr.Dropdown(choices=[], label="Select Column to Forecast:", interactive=True)
18
-
 
 
 
19
  try:
20
  # Handle file as path string (Gradio convention)
21
  with open(file, 'rb') as f:
22
  content = f.read()
23
  df_preview = pd.read_csv(io.BytesIO(content))
24
-
25
- # Assume first column is date-like, rename if 'Day'
26
- date_cols = [col for col in df_preview.columns if 'day' in col.lower() or 'date' in col.lower()]
27
- if date_cols:
28
- df_preview = df_preview.rename(columns={date_cols[0]: 'date'})
29
-
30
- # Available numeric columns for forecast (exclude date)
31
  numeric_cols = df_preview.select_dtypes(include=['number']).columns.tolist()
32
- if 'date' in numeric_cols:
33
- numeric_cols.remove('date')
34
-
35
  if numeric_cols:
36
- return gr.Dropdown(
37
- choices=[(col, col) for col in numeric_cols],
38
- value=numeric_cols[0],
39
- label="Select Column to Forecast:",
40
- interactive=True
41
- )
42
  else:
43
- return gr.Dropdown(
44
- choices=[],
45
- value=None,
46
- label="No numeric columns found",
47
- interactive=False
48
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  except Exception as e:
50
- return gr.Dropdown(
51
  choices=[],
52
  value=None,
53
  label=f"Error loading CSV: {str(e)}",
54
  interactive=False
55
- )
 
 
 
 
 
 
 
 
 
 
56
 
57
- def run_forecast(file, selected_col, prediction_length, confidence):
58
- if file is None or selected_col is None:
59
- return None, "### Error\nPlease upload a CSV and select a column!"
60
-
61
  try:
62
  # Handle file as path string (Gradio convention)
63
  with open(file, 'rb') as f:
64
  content = f.read()
65
  df = pd.read_csv(io.BytesIO(content))
66
 
67
- # Rename date column if needed
68
- date_cols = [col for col in df.columns if 'day' in col.lower() or 'date' in col.lower()]
69
- if date_cols:
70
- df = df.rename(columns={date_cols[0]: 'date'})
71
- else:
72
- return None, "### Error\nNo date column found (looking for 'Day' or 'date'). Edit CSV."
73
 
74
- # Use selected column as 'sales'
75
- df = df.rename(columns={selected_col: 'sales'})
76
 
77
  # Validate
78
  required_cols = ['date', 'sales']
79
  if not all(col in df.columns for col in required_cols):
80
- return None, f"### Error\nMissing 'date' or selected column '{selected_col}'."
81
 
82
  # Prep data
83
  df['date'] = pd.to_datetime(df['date'])
84
  df = df.set_index('date').sort_index()
85
- if len(df) < 10:
86
- return None, "### Error\nNeed at least 10 data points."
87
 
88
- series = df['sales'].dropna().values
89
- print(f"Loaded: {len(series)} points from {df.index.min().date()} to {df.index.max().date()} (Column: {selected_col})") # For logs
 
 
 
 
 
 
 
 
 
90
 
91
  # Infer freq
92
- freq = pd.infer_freq(df.index)
93
  if freq is None:
94
  freq = 'D'
95
  print(f"Frequency: '{freq}'.")
96
 
97
  # Prep context
98
- context_len = min(len(series), 2048)
99
- context = torch.tensor(series[-context_len:]).unsqueeze(0).float()
100
 
101
  pred_len = prediction_length
102
  conf_level = confidence / 100.0
@@ -122,6 +159,8 @@ def run_forecast(file, selected_col, prediction_length, confidence):
122
  lower_slider = np.zeros(pred_len)
123
  upper_slider = np.zeros(pred_len)
124
 
 
 
125
  for t in range(pred_len):
126
  q_t = q[t]
127
  lower50[t] = np.interp(lower_alpha_50, alphas, q_t)
@@ -129,11 +168,22 @@ def run_forecast(file, selected_col, prediction_length, confidence):
129
  lower_slider[t] = np.interp(lower_alpha_slider, alphas, q_t)
130
  upper_slider[t] = np.interp(upper_alpha_slider, alphas, q_t)
131
 
 
 
 
 
 
 
 
 
 
 
 
132
  # Mean forecast
133
  mean_forecast = mean[0].detach().numpy()
134
 
135
  # Future dates
136
- last_date = df.index[-1]
137
  if freq == 'D':
138
  future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_len, freq='D')
139
  else:
@@ -144,35 +194,49 @@ def run_forecast(file, selected_col, prediction_length, confidence):
144
  'predicted_sales_median': median,
145
  'predicted_sales_lower': lower_slider,
146
  'predicted_sales_upper': upper_slider,
147
- 'predicted_sales_mean': mean_forecast
 
148
  }).set_index('date')
149
 
 
 
 
 
 
150
  # Prepare markdown output (broken into smaller strings to avoid multiline f-string parsing issues)
151
- markdown_text = "### ✅ TiRex Forecast Results (Median + {}% Interval)\n\n".format(confidence)
152
- markdown_text += "| Date | Median | Lower Bound | Upper Bound | Mean |\n"
153
- markdown_text += "|------|--------|-------------|-------------|------|\n"
154
- for idx, row in pred_df.iterrows():
155
- markdown_text += "| {} | {:.2f} | {:.2f} | {:.2f} | {:.2f} |\n".format(
156
- idx.strftime('%Y-%m-%d'),
157
- row['predicted_sales_median'],
158
- row['predicted_sales_lower'],
159
- row['predicted_sales_upper'],
160
- row['predicted_sales_mean']
161
- )
162
-
163
- markdown_text += "\n### 📊 Summary\n"
164
  markdown_text += "- **Prediction Length:** {} periods\n".format(pred_len)
165
  markdown_text += "- **Confidence Level:** {}% (alphas: {:.3f} - {:.3f})\n".format(confidence, lower_alpha_slider, upper_alpha_slider)
166
  markdown_text += "- **Sum of Median Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_median'].sum())
167
- markdown_text += "- **Sum of Mean Predicted Values:** {:.2f}\n\n".format(pred_df['predicted_sales_mean'].sum())
168
-
169
- markdown_text += "### Sample Historical Data\n"
170
- markdown_text += "```\n" + df.head().to_string() + "\n```"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  # Create plot
173
  fig, ax = plt.subplots(figsize=(14, 7))
174
- ax.plot(df.index, df['sales'], label=f'Historical {selected_col} (Full CSV Data)', color='#1f77b4', linewidth=1.5, alpha=0.8)
175
- ax.plot(pred_df.index, pred_df['predicted_sales_median'], label='TiRex Forecast (Median)', color='#d62728', linestyle='-', linewidth=2)
 
176
  ax.plot(pred_df.index, pred_df['predicted_sales_mean'], label='TiRex Forecast (Mean)', color='#ff7f0e', linestyle='--', linewidth=2)
177
 
178
  # Fan chart: non-overlapping bands
@@ -185,81 +249,122 @@ def run_forecast(file, selected_col, prediction_length, confidence):
185
  ax.fill_between(pred_df.index, upper50, upper_slider,
186
  color='#d62728', alpha=0.3, label=f'{confidence}% Uncertainty Wings')
187
 
188
- ax.set_title(f'{selected_col} Forecast with TiRex (Full History + Horizon: {pred_len})', fontsize=16, fontweight='bold')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  ax.set_xlabel('Date', fontsize=12)
190
  ax.set_ylabel(selected_col, fontsize=12)
191
- ax.legend(fontsize=10)
192
  ax.tick_params(axis='x', rotation=45)
193
  plt.tight_layout()
194
-
195
  return fig, markdown_text
196
 
197
  except Exception as e:
198
- return None, f"### Error\n{str(e)}\n\nTips: Check NaNs/zeros; ensure data is valid."
199
 
200
  # Create the Gradio interface
201
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="🚀 TiRex Forecaster") as demo:
202
  gr.Markdown("""
203
- # 🚀 TiRex Forecaster Dashboard
204
- Upload a CSV file with a date column (e.g., 'Day' or 'date') and numeric columns. Select one to forecast future values using the TiRex model.
205
-
206
- The dashboard will display in this new window/tab for a cool, interactive experience!
207
  """)
208
-
209
  with gr.Row(variant="panel"):
210
  with gr.Column(scale=1):
211
  csv_file = gr.File(
212
  file_types=[".csv"],
213
- label="📁 Upload CSV File",
214
  elem_id="file_upload"
215
  )
 
 
 
 
 
 
 
216
  column_dropdown = gr.Dropdown(
217
  choices=[],
218
- label="📈 Select Column to Forecast",
219
  interactive=True,
220
  elem_id="column_select"
221
  )
 
 
 
 
 
 
 
 
 
 
222
  prediction_length = gr.Slider(
223
- minimum=1, maximum=100, value=12, step=1,
224
- label="🔮 Prediction Length (Periods)",
225
  elem_id="pred_length"
226
  )
227
  confidence = gr.Slider(
228
  minimum=50, maximum=95, value=80, step=5,
229
- label="🎯 Confidence Level (%)",
230
  elem_id="confidence"
231
  )
232
  run_button = gr.Button(
233
- "Run TiRex Forecast",
234
  variant="primary",
235
  size="lg",
236
  elem_id="run_btn"
237
  )
238
-
239
  with gr.Column(scale=2):
240
  forecast_plot = gr.Plot(
241
- label="📊 Forecast Visualization",
242
  elem_id="plot"
243
  )
244
  output_text = gr.Markdown(
245
  "### Welcome!\nUpload your CSV to get started.",
246
  elem_id="output"
247
  )
248
-
249
- # Event for updating dropdown on file upload
 
 
250
  csv_file.change(
251
  load_columns,
252
  inputs=csv_file,
253
- outputs=column_dropdown
254
  )
255
-
256
  # Event for running forecast
257
  run_button.click(
258
  run_forecast,
259
- inputs=[csv_file, column_dropdown, prediction_length, confidence],
260
  outputs=[forecast_plot, output_text]
261
  )
262
 
263
  # Launch the app
264
- if __name__ == "__main__":
265
- demo.launch()
 
14
 
15
  def load_columns(file):
16
  if file is None:
17
+ return (gr.Dropdown(choices=[], label="Select Time Column:", interactive=True),
18
+ gr.Dropdown(choices=[], label="Select Column to Forecast:", interactive=True),
19
+ gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index (1-based)"),
20
+ gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index (1-based)"))
21
+
22
  try:
23
  # Handle file as path string (Gradio convention)
24
  with open(file, 'rb') as f:
25
  content = f.read()
26
  df_preview = pd.read_csv(io.BytesIO(content))
27
+
28
+ # All columns for time selection
29
+ all_cols = df_preview.columns.tolist()
30
+ time_choices = [(col, col) for col in all_cols]
31
+ time_value = all_cols[0] if all_cols else None
32
+
33
+ # Available numeric columns for forecast
34
  numeric_cols = df_preview.select_dtypes(include=['number']).columns.tolist()
35
+
 
 
36
  if numeric_cols:
37
+ value_choices = [(col, col) for col in numeric_cols]
38
+ value_value = numeric_cols[0]
 
 
 
 
39
  else:
40
+ value_choices = []
41
+ value_value = None
42
+
43
+ n_rows = len(df_preview)
44
+
45
+ time_dropdown = gr.Dropdown(
46
+ choices=time_choices,
47
+ value=time_value,
48
+ label="Select Time Column:",
49
+ interactive=True
50
+ )
51
+
52
+ value_dropdown = gr.Dropdown(
53
+ choices=value_choices,
54
+ value=value_value,
55
+ label="Select Column to Forecast:",
56
+ interactive=True
57
+ ) if value_choices else gr.Dropdown(
58
+ choices=[],
59
+ value=None,
60
+ label="No numeric columns found",
61
+ interactive=False
62
+ )
63
+
64
+ start_slider = gr.Slider(
65
+ minimum=1, maximum=n_rows, value=1, step=1,
66
+ label="Historical Start Index (1-based)"
67
+ )
68
+
69
+ end_slider = gr.Slider(
70
+ minimum=1, maximum=n_rows, value=n_rows, step=1,
71
+ label="Historical End Index (1-based)"
72
+ )
73
+
74
+ return time_dropdown, value_dropdown, start_slider, end_slider
75
+
76
  except Exception as e:
77
+ return (gr.Dropdown(
78
  choices=[],
79
  value=None,
80
  label=f"Error loading CSV: {str(e)}",
81
  interactive=False
82
+ ), gr.Dropdown(
83
+ choices=[],
84
+ value=None,
85
+ label=f"Error loading CSV: {str(e)}",
86
+ interactive=False
87
+ ), gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical Start Index (1-based)"),
88
+ gr.Slider(minimum=1, maximum=1, value=1, step=1, label="Historical End Index (1-based)"))
89
+
90
+ def run_forecast(file, time_col, selected_col, start_idx, end_idx, prediction_length, confidence):
91
+ if file is None or time_col is None or selected_col is None:
92
+ return None, "### Error\nPlease upload a CSV and select time and value columns!"
93
 
 
 
 
 
94
  try:
95
  # Handle file as path string (Gradio convention)
96
  with open(file, 'rb') as f:
97
  content = f.read()
98
  df = pd.read_csv(io.BytesIO(content))
99
 
100
+ # Validate columns exist
101
+ if time_col not in df.columns or selected_col not in df.columns:
102
+ return None, f"### Error\nSelected columns '{time_col}' or '{selected_col}' not found in CSV."
 
 
 
103
 
104
+ # Rename selected columns
105
+ df = df.rename(columns={time_col: 'date', selected_col: 'sales'})
106
 
107
  # Validate
108
  required_cols = ['date', 'sales']
109
  if not all(col in df.columns for col in required_cols):
110
+ return None, f"### Error\nMissing renamed columns."
111
 
112
  # Prep data
113
  df['date'] = pd.to_datetime(df['date'])
114
  df = df.set_index('date').sort_index()
 
 
115
 
116
+ full_len = len(df)
117
+ context_start = max(0, int(start_idx) - 1)
118
+ context_end = min(full_len, int(end_idx))
119
+ context_df = df.iloc[context_start:context_end]
120
+ held_out_df = df.iloc[context_end:] if context_end < full_len else pd.DataFrame(index=pd.DatetimeIndex([]), columns=df.columns)
121
+
122
+ if len(context_df) < 10:
123
+ return None, "### Error\nNeed at least 10 data points in the selected historical range."
124
+
125
+ context_series = context_df['sales'].dropna().values
126
+ print(f"Loaded context: {len(context_series)} points from {context_df.index.min().date()} to {context_df.index.max().date()} (Column: {selected_col})") # For logs
127
 
128
  # Infer freq
129
+ freq = pd.infer_freq(context_df.index)
130
  if freq is None:
131
  freq = 'D'
132
  print(f"Frequency: '{freq}'.")
133
 
134
  # Prep context
135
+ context_len = min(len(context_series), 2048)
136
+ context = torch.tensor(context_series[-context_len:]).unsqueeze(0).float()
137
 
138
  pred_len = prediction_length
139
  conf_level = confidence / 100.0
 
159
  lower_slider = np.zeros(pred_len)
160
  upper_slider = np.zeros(pred_len)
161
 
162
+ skew_directions = []
163
+
164
  for t in range(pred_len):
165
  q_t = q[t]
166
  lower50[t] = np.interp(lower_alpha_50, alphas, q_t)
 
168
  lower_slider[t] = np.interp(lower_alpha_slider, alphas, q_t)
169
  upper_slider[t] = np.interp(upper_alpha_slider, alphas, q_t)
170
 
171
+ # Compute skew direction based on asymmetry around median
172
+ med = median[t]
173
+ upside_dist = upper_slider[t] - med
174
+ downside_dist = med - lower_slider[t]
175
+ if upside_dist > downside_dist:
176
+ skew_directions.append("Upside")
177
+ elif downside_dist > upside_dist:
178
+ skew_directions.append("Downside")
179
+ else:
180
+ skew_directions.append("Neutral")
181
+
182
  # Mean forecast
183
  mean_forecast = mean[0].detach().numpy()
184
 
185
  # Future dates
186
+ last_date = context_df.index[-1]
187
  if freq == 'D':
188
  future_dates = pd.date_range(start=last_date + timedelta(days=1), periods=pred_len, freq='D')
189
  else:
 
194
  'predicted_sales_median': median,
195
  'predicted_sales_lower': lower_slider,
196
  'predicted_sales_upper': upper_slider,
197
+ 'predicted_sales_mean': mean_forecast,
198
+ 'skew_direction': skew_directions
199
  }).set_index('date')
200
 
201
+ # Count skews for summary
202
+ upside_count = skew_directions.count("Upside")
203
+ downside_count = skew_directions.count("Downside")
204
+ neutral_count = skew_directions.count("Neutral")
205
+
206
  # Prepare markdown output (broken into smaller strings to avoid multiline f-string parsing issues)
207
+ markdown_text = "### Summary\n"
208
+ markdown_text += "- **Number of Historical Periods Used:** {} points\n".format(len(context_series))
209
+ markdown_text += "- **Held Out Periods:** {} points {}\n".format(len(held_out_df), "(Full Context Used)" if len(held_out_df) == 0 else "(For Validation)")
 
 
 
 
 
 
 
 
 
 
210
  markdown_text += "- **Prediction Length:** {} periods\n".format(pred_len)
211
  markdown_text += "- **Confidence Level:** {}% (alphas: {:.3f} - {:.3f})\n".format(confidence, lower_alpha_slider, upper_alpha_slider)
212
  markdown_text += "- **Sum of Median Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_median'].sum())
213
+ markdown_text += "- **Sum of Mean Predicted Values:** {:.2f}\n".format(pred_df['predicted_sales_mean'].sum())
214
+ markdown_text += "- **Skew Distribution:** {} Upside, {} Downside, {} Neutral\n\n".format(upside_count, downside_count, neutral_count)
215
+
216
+ forecast_table = "### TiRex Forecast Results (Median + {}% Interval)\n\n".format(confidence)
217
+ forecast_table += "| Date | Median | Lower Bound | Upper Bound | Mean | Skew |\n"
218
+ forecast_table += "|------|--------|-------------|-------------|------|------|\n"
219
+ for idx, row in pred_df.iterrows():
220
+ forecast_table += "| {} | {:.2f} | {:.2f} | {:.2f} | {:.2f} | {} |\n".format(
221
+ idx.strftime('%Y-%m-%d'),
222
+ row['predicted_sales_median'],
223
+ row['predicted_sales_lower'],
224
+ row['predicted_sales_upper'],
225
+ row['predicted_sales_mean'],
226
+ row['skew_direction']
227
+ )
228
+
229
+ sample_data = "### Sample Historical Data (Context)\n"
230
+ sample_data += "```\n" + context_df.head().to_string() + "\n```"
231
+
232
+ markdown_text += f'\n<details><summary>Click to expand Forecast Table</summary>\n\n{forecast_table}\n</details>\n\n'
233
+ markdown_text += f'<details><summary>Click to expand Sample Historical Data</summary>\n\n{sample_data}\n</details>'
234
 
235
  # Create plot
236
  fig, ax = plt.subplots(figsize=(14, 7))
237
+ ax.plot(context_df.index, context_df['sales'], label=f'Used Historical {selected_col}', color='#1f77b4', linewidth=1.5, alpha=0.8)
238
+ if not held_out_df.empty:
239
+ ax.plot(held_out_df.index, held_out_df['sales'], label='Held Out Actual (Validation)', color='#2ca02c', linestyle=':', linewidth=2)
240
  ax.plot(pred_df.index, pred_df['predicted_sales_mean'], label='TiRex Forecast (Mean)', color='#ff7f0e', linestyle='--', linewidth=2)
241
 
242
  # Fan chart: non-overlapping bands
 
249
  ax.fill_between(pred_df.index, upper50, upper_slider,
250
  color='#d62728', alpha=0.3, label=f'{confidence}% Uncertainty Wings')
251
 
252
+ # Subtle skew visualization: colored segments on the median forecast line
253
+ from matplotlib.lines import Line2D
254
+ legend_elements = []
255
+
256
+ skew_colors = {'Upside': 'green', 'Downside': 'red', 'Neutral': 'gray'}
257
+ for i in range(len(pred_df) - 1):
258
+ start_date = pred_df.index[i]
259
+ end_date = pred_df.index[i + 1]
260
+ start_val = median[i]
261
+ end_val = median[i + 1]
262
+ skew = skew_directions[i]
263
+ color = skew_colors[skew]
264
+ ax.plot([start_date, end_date], [start_val, end_val], color=color, linewidth=2.5, alpha=0.7)
265
+
266
+ # Connect the last point if needed, but since segments cover, add a small marker at end if desired
267
+ ax.plot(pred_df.index[-1], median[-1], marker='o', color=skew_colors[skew_directions[-1]], markersize=4, alpha=0.7)
268
+
269
+ # Add to legend only if present
270
+ if upside_count > 0:
271
+ legend_elements.append(Line2D([0], [0], color='green', lw=2, label='Upside Skew'))
272
+ if downside_count > 0:
273
+ legend_elements.append(Line2D([0], [0], color='red', lw=2, label='Downside Skew'))
274
+ if neutral_count > 0:
275
+ legend_elements.append(Line2D([0], [0], color='gray', lw=2, label='Neutral Skew'))
276
+
277
+ ax.set_title(f'{selected_col} Forecast with TiRex (Context: {context_start+1}-{context_end}, Horizon: {pred_len})', fontsize=16, fontweight='bold')
278
  ax.set_xlabel('Date', fontsize=12)
279
  ax.set_ylabel(selected_col, fontsize=12)
280
+ ax.legend(handles=ax.get_legend_handles_labels()[0] + legend_elements, fontsize=10)
281
  ax.tick_params(axis='x', rotation=45)
282
  plt.tight_layout()
283
+
284
  return fig, markdown_text
285
 
286
  except Exception as e:
287
+ return None, f"### Error\n{str(e)}\n\nTips: Ensure the time column can be parsed as dates; check NaNs/zeros; ensure data is valid."
288
 
289
  # Create the Gradio interface
290
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="red"), title="TiRex Forecaster") as demo:
291
  gr.Markdown("""
292
+ # TiRex Forecaster Dashboard
293
+ Upload a CSV file with a time column and numeric columns. Select the time column and one numeric column to forecast future values using the TiRex model.
 
 
294
  """)
295
+
296
  with gr.Row(variant="panel"):
297
  with gr.Column(scale=1):
298
  csv_file = gr.File(
299
  file_types=[".csv"],
300
+ label="Upload CSV File",
301
  elem_id="file_upload"
302
  )
303
+ gr.Markdown("The minimum effective input is around 128 time steps per series. Use a full context of 2048 steps for optimal performance.")
304
+ time_dropdown = gr.Dropdown(
305
+ choices=[],
306
+ label="Select Time Column",
307
+ interactive=True,
308
+ elem_id="time_select"
309
+ )
310
  column_dropdown = gr.Dropdown(
311
  choices=[],
312
+ label="Select Column to Forecast",
313
  interactive=True,
314
  elem_id="column_select"
315
  )
316
+ start_slider = gr.Slider(
317
+ minimum=1, maximum=1, value=1, step=1,
318
+ label="Historical Start Index (1-based)",
319
+ elem_id="start_idx"
320
+ )
321
+ end_slider = gr.Slider(
322
+ minimum=1, maximum=1, value=1, step=1,
323
+ label="Historical End Index (1-based)",
324
+ elem_id="end_idx"
325
+ )
326
  prediction_length = gr.Slider(
327
+ minimum=1, maximum=720, value=12, step=1,
328
+ label="Prediction Length (Periods)",
329
  elem_id="pred_length"
330
  )
331
  confidence = gr.Slider(
332
  minimum=50, maximum=95, value=80, step=5,
333
+ label="Confidence Level (%)",
334
  elem_id="confidence"
335
  )
336
  run_button = gr.Button(
337
+ "Run TiRex Forecast",
338
  variant="primary",
339
  size="lg",
340
  elem_id="run_btn"
341
  )
342
+
343
  with gr.Column(scale=2):
344
  forecast_plot = gr.Plot(
345
+ label="Forecast Visualization",
346
  elem_id="plot"
347
  )
348
  output_text = gr.Markdown(
349
  "### Welcome!\nUpload your CSV to get started.",
350
  elem_id="output"
351
  )
352
+
353
+ gr.Markdown("**Built by** [next one gmbh](https://www.nextone.at)")
354
+
355
+ # Event for updating dropdowns on file upload
356
  csv_file.change(
357
  load_columns,
358
  inputs=csv_file,
359
+ outputs=[time_dropdown, column_dropdown, start_slider, end_slider]
360
  )
361
+
362
  # Event for running forecast
363
  run_button.click(
364
  run_forecast,
365
+ inputs=[csv_file, time_dropdown, column_dropdown, start_slider, end_slider, prediction_length, confidence],
366
  outputs=[forecast_plot, output_text]
367
  )
368
 
369
  # Launch the app
370
+ demo.launch()