kawaiipeace commited on
Commit
43fc3c5
·
1 Parent(s): 270a901
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -4,9 +4,9 @@ import torch
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  from dotenv import load_dotenv
7
-
8
  from utils.preprocessing import load_and_process_data
9
  from models.custom_models import run_forecast
 
10
 
11
  load_dotenv()
12
 
@@ -16,30 +16,41 @@ models = ["LSTM", "BiLSTM", "GRU", "ARIMA", "ExponentialSmoothing", "Prophet"]
16
 
17
  classic_models = ["ARIMA", "ExponentialSmoothing", "Prophet"]
18
 
19
-
20
- # ---------- UI Logic ----------- #
21
- def plot_raw_data(df):
22
  print("Data Shape:", df.shape)
23
  print("Data Index (Datetime):", df.index[:5])
 
 
 
 
 
24
  if df.shape[1] == 1:
25
- fig = plt.figure(figsize=(12, 4))
26
- plt.plot(df.index, df.values, label="Target")
27
- plt.title("Univariate Time Series")
28
- plt.xlabel("Datetime")
29
- plt.ylabel("Value")
30
- plt.legend()
 
 
 
 
 
31
  return fig
32
  else:
33
  num_features = df.shape[1]
34
- fig, axes = plt.subplots(
35
- num_features, 1, figsize=(12, 2.5 * num_features), sharex=True
36
- )
37
  if num_features == 1:
38
  axes = [axes]
39
  for i, col in enumerate(df.columns):
40
  axes[i].plot(df.index, df[col].values)
41
  axes[i].set_title(f"{col}")
42
  axes[i].set_ylabel("Value")
 
 
 
 
 
43
  axes[-1].set_xlabel("Datetime")
44
  fig.tight_layout()
45
  return fig
@@ -66,7 +77,7 @@ def forecast_interface(
66
  is_multivariate == "Multivariate",
67
  keep_datetime_column_for_darts=True,
68
  )
69
- raw_plot = plot_raw_data(df)
70
 
71
  if model_type == "ARIMA":
72
  arima_order = (ar, i, ma)
 
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  from dotenv import load_dotenv
 
7
  from utils.preprocessing import load_and_process_data
8
  from models.custom_models import run_forecast
9
+ from matplotlib.dates import DateFormatter
10
 
11
  load_dotenv()
12
 
 
16
 
17
  classic_models = ["ARIMA", "ExponentialSmoothing", "Prophet"]
18
 
19
+ def plot_raw_data(df, horizon=None):
 
 
20
  print("Data Shape:", df.shape)
21
  print("Data Index (Datetime):", df.index[:5])
22
+
23
+ date_fmt = DateFormatter("%d/%m/%Y")
24
+ split_index = len(df) - horizon if horizon else None
25
+ split_time = df.index[split_index] if split_index else None
26
+
27
  if df.shape[1] == 1:
28
+ fig, ax = plt.subplots(figsize=(12, 4))
29
+ ax.plot(df.index, df.values, label="Target")
30
+ if split_time:
31
+ ax.axvline(split_time, color='red', linestyle='--', label='Train/Test Split')
32
+ ax.text(split_time, ax.get_ylim()[1], " ← Train | Test →", color='red', va='top', fontsize=10)
33
+ ax.set_title("Univariate Time Series")
34
+ ax.set_xlabel("Datetime")
35
+ ax.set_ylabel("Value")
36
+ ax.xaxis.set_major_formatter(date_fmt)
37
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right")
38
+ ax.legend()
39
  return fig
40
  else:
41
  num_features = df.shape[1]
42
+ fig, axes = plt.subplots(num_features, 1, figsize=(12, 2.5 * num_features), sharex=True)
 
 
43
  if num_features == 1:
44
  axes = [axes]
45
  for i, col in enumerate(df.columns):
46
  axes[i].plot(df.index, df[col].values)
47
  axes[i].set_title(f"{col}")
48
  axes[i].set_ylabel("Value")
49
+ if split_time:
50
+ axes[i].axvline(split_time, color='red', linestyle='--')
51
+ axes[i].text(split_time, axes[i].get_ylim()[1], " ← Train | Test →", color='red', va='top', fontsize=10)
52
+ axes[i].xaxis.set_major_formatter(date_fmt)
53
+ plt.setp(axes[i].get_xticklabels(), rotation=45, ha="right")
54
  axes[-1].set_xlabel("Datetime")
55
  fig.tight_layout()
56
  return fig
 
77
  is_multivariate == "Multivariate",
78
  keep_datetime_column_for_darts=True,
79
  )
80
+ raw_plot = plot_raw_data(df, horizon)
81
 
82
  if model_type == "ARIMA":
83
  arima_order = (ar, i, ma)