fixes
Browse files- src/model.py +5 -6
src/model.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from darts import TimeSeries
|
| 2 |
from darts.datasets import ILINetDataset
|
| 3 |
from darts.metrics import mape
|
| 4 |
-
from darts.
|
| 5 |
from darts.utils.missing_values import fill_missing_values
|
| 6 |
from darts.dataprocessing.transformers import Scaler
|
| 7 |
import matplotlib.pyplot as plt
|
|
@@ -70,7 +70,7 @@ def scale_train(train_ili):
|
|
| 70 |
return train_ili_scaled, scaler
|
| 71 |
|
| 72 |
|
| 73 |
-
def
|
| 74 |
model = ExponentialSmoothing()
|
| 75 |
model.fit(train_ili_scaled)
|
| 76 |
|
|
@@ -79,12 +79,11 @@ def train_(train_ili_scaled, val_ili, save_model_path='./models'):
|
|
| 79 |
|
| 80 |
def save_model(model, path):
|
| 81 |
model_name = str(model).split('(')[0]
|
| 82 |
-
path = f'{path}/{model_name}.pkl'
|
| 83 |
if not os.path.exists(path=path):
|
| 84 |
os.makedirs(path)
|
| 85 |
-
model.save(path)
|
| 86 |
-
else:
|
| 87 |
-
|
| 88 |
|
| 89 |
|
| 90 |
def load_model(path):
|
|
|
|
| 1 |
from darts import TimeSeries
|
| 2 |
from darts.datasets import ILINetDataset
|
| 3 |
from darts.metrics import mape
|
| 4 |
+
from darts.models import ExponentialSmoothing
|
| 5 |
from darts.utils.missing_values import fill_missing_values
|
| 6 |
from darts.dataprocessing.transformers import Scaler
|
| 7 |
import matplotlib.pyplot as plt
|
|
|
|
| 70 |
return train_ili_scaled, scaler
|
| 71 |
|
| 72 |
|
| 73 |
+
def train(train_ili_scaled):
|
| 74 |
model = ExponentialSmoothing()
|
| 75 |
model.fit(train_ili_scaled)
|
| 76 |
|
|
|
|
| 79 |
|
| 80 |
def save_model(model, path):
|
| 81 |
model_name = str(model).split('(')[0]
|
|
|
|
| 82 |
if not os.path.exists(path=path):
|
| 83 |
os.makedirs(path)
|
| 84 |
+
model.save(os.path.join(path, model_name))
|
| 85 |
+
#else:
|
| 86 |
+
# model = ExponentialSmoothing.load(path)
|
| 87 |
|
| 88 |
|
| 89 |
def load_model(path):
|