Spaces:
Runtime error
Runtime error
| from typing import TypedDict | |
| import logging | |
| from .xgbregressor import XGBRegressor | |
| from .multiple_linear_regressor import MultipleLinearRegressor | |
| from .xgboost import XGBoost | |
| from .prophet import ProphetForecaster | |
| from typing import TypedDict, List | |
| class Model(TypedDict): | |
| name: str | |
| model: any | |
| class AllModels(): | |
| def __init__(self) -> None: | |
| # Any available model must register here | |
| self.all_models = { | |
| 'xgbreg': XGBRegressor, | |
| 'mlr': MultipleLinearRegressor | |
| } | |
| self.all_model_names = self.all_models.keys() | |
| def init_models( | |
| self, | |
| models | |
| ) -> List[Model]: | |
| logging.debug('Init models') | |
| if models == 'all': | |
| self.model_names = self.all_model_names | |
| elif isinstance(models, str): | |
| self.model_names = [models] | |
| else: | |
| self.model_names = models | |
| logging.debug('Check model names') | |
| unknown_models = set(self.model_names) - set(self.all_model_names) | |
| if len(unknown_models) > 0: | |
| raise ValueError( | |
| f'Unknown model : {unknown_models}, please use active models: {self.all_model_names}') | |
| else: | |
| self.models = [ | |
| { | |
| 'name': name, | |
| 'model': self.all_models[name]() | |
| } | |
| for name in self.model_names] | |
| return self.models | |