Ale37 commited on
Commit
40e07d5
·
1 Parent(s): 794fde1
Files changed (1) hide show
  1. src/model.py +5 -6
src/model.py CHANGED
@@ -1,7 +1,7 @@
1
  from darts import TimeSeries
2
  from darts.datasets import ILINetDataset
3
  from darts.metrics import mape
4
- from darts.model import ExponentialSmoothing
5
  from darts.utils.missing_values import fill_missing_values
6
  from darts.dataprocessing.transformers import Scaler
7
  import matplotlib.pyplot as plt
@@ -70,7 +70,7 @@ def scale_train(train_ili):
70
  return train_ili_scaled, scaler
71
 
72
 
73
- def train_(train_ili_scaled, val_ili, save_model_path='./models'):
74
  model = ExponentialSmoothing()
75
  model.fit(train_ili_scaled)
76
 
@@ -79,12 +79,11 @@ def train_(train_ili_scaled, val_ili, save_model_path='./models'):
79
 
80
  def save_model(model, path):
81
  model_name = str(model).split('(')[0]
82
- path = f'{path}/{model_name}.pkl'
83
  if not os.path.exists(path=path):
84
  os.makedirs(path)
85
- model.save(path)
86
- else:
87
- model = ExponentialSmoothing.load(path)
88
 
89
 
90
  def load_model(path):
 
1
  from darts import TimeSeries
2
  from darts.datasets import ILINetDataset
3
  from darts.metrics import mape
4
+ from darts.models import ExponentialSmoothing
5
  from darts.utils.missing_values import fill_missing_values
6
  from darts.dataprocessing.transformers import Scaler
7
  import matplotlib.pyplot as plt
 
70
  return train_ili_scaled, scaler
71
 
72
 
73
+ def train(train_ili_scaled):
74
  model = ExponentialSmoothing()
75
  model.fit(train_ili_scaled)
76
 
 
79
 
80
  def save_model(model, path):
81
  model_name = str(model).split('(')[0]
 
82
  if not os.path.exists(path=path):
83
  os.makedirs(path)
84
+ model.save(os.path.join(path, model_name))
85
+ #else:
86
+ # model = ExponentialSmoothing.load(path)
87
 
88
 
89
  def load_model(path):