Spaces:
Sleeping
Sleeping
| from ..custom_types import * | |
| from .. import constants | |
| from tqdm import tqdm | |
| from . import files_utils | |
| import os | |
| from .. import options | |
| from ..models import models_utils, occ_gmm | |
| LI = Union[T, float, int] | |
| Models = {'spaghetti': occ_gmm.Spaghetti} | |
| def is_model_clean(model: nn.Module) -> bool: | |
| for wh in model.parameters(): | |
| if torch.isnan(wh).sum() > 0: | |
| return False | |
| return True | |
| def model_factory(opt: options.Options, override_model: Optional[str], device: D) -> models_utils.Model: | |
| if override_model is None: | |
| return Models[opt.model_name](opt).to(device) | |
| return Models[override_model](opt).to(device) | |
| def load_model(opt, device, suffix: str = '', override_model: Optional[str] = None) -> models_utils.Model: | |
| model_path = f'{opt.cp_folder}/model{"_" + suffix if suffix else ""}' | |
| model = model_factory(opt, override_model, device) | |
| name = opt.model_name if override_model is None else override_model | |
| if os.path.isfile(model_path): | |
| print(f'loading {name} model from {model_path}') | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| else: | |
| print(f'init {name} model') | |
| return model | |
| def save_model(model, path): | |
| if constants.DEBUG: | |
| return False | |
| print(f'saving model in {path}') | |
| torch.save(model.state_dict(), path) | |
| return True | |
| def model_lc(opt: options.Options, override_model: Optional[str] = None) -> Tuple[occ_gmm.Spaghetti, options.Options]: | |
| def save_model(model_: models_utils.Model, suffix: str = ''): | |
| nonlocal already_init | |
| if override_model is not None and suffix == '': | |
| suffix = override_model | |
| model_path = f'{opt.cp_folder}/model{"_" + suffix if suffix else ""}' | |
| if constants.DEBUG or 'debug' in opt.tag: | |
| return False | |
| if not already_init: | |
| files_utils.init_folders(model_path) | |
| files_utils.save_pickle(opt, params_path) | |
| already_init = True | |
| if is_model_clean(model_): | |
| print(f'saving {opt.model_name} model at {model_path}') | |
| torch.save(model_.state_dict(), model_path) | |
| elif os.path.isfile(model_path): | |
| print(f'model is corrupted') | |
| print(f'loading {opt.model_name} model from {model_path}') | |
| model.load_state_dict(torch.load(model_path, map_location=opt.device)) | |
| return True | |
| already_init = False | |
| params_path = f'{opt.cp_folder}/options.pkl' | |
| opt_ = files_utils.load_pickle(params_path) | |
| if opt_ is not None: | |
| opt_.device = opt.device | |
| opt = opt_ | |
| already_init = True | |
| model = load_model(opt, opt.device, override_model=override_model) | |
| model.save_model = save_model | |
| return model, opt | |
| class Logger: | |
| def __init__(self, level: int = 0): | |
| self.level_dictionary = dict() | |
| self.iter_dictionary = dict() | |
| self.level = level | |
| self.progress: Union[N, tqdm] = None | |
| self.iters = 0 | |
| self.tag = '' | |
| def aggregate(dictionary: dict, parent_dictionary: Union[dict, N] = None) -> dict: | |
| aggregate_dictionary = dict() | |
| for key in dictionary: | |
| if 'counter' not in key: | |
| aggregate_dictionary[key] = dictionary[key] / float(dictionary[f"{key}_counter"]) | |
| if parent_dictionary is not None: | |
| Logger.stash(parent_dictionary, (key, aggregate_dictionary[key])) | |
| return aggregate_dictionary | |
| def flatten(items: Tuple[Union[Dict[str, LI], str, LI], ...]) -> List[Union[str, LI]]: | |
| flat_items = [] | |
| for item in items: | |
| if type(item) is dict: | |
| for key, value in item.items(): | |
| flat_items.append(key) | |
| flat_items.append(value) | |
| else: | |
| flat_items.append(item) | |
| return flat_items | |
| def stash(dictionary: Dict[str, LI], items: Tuple[Union[Dict[str, LI], str, LI], ...]) -> Dict[str, LI]: | |
| flat_items = Logger.flatten(items) | |
| for i in range(0, len(flat_items), 2): | |
| key, item = flat_items[i], flat_items[i + 1] | |
| if type(item) is T: | |
| item = item.item() | |
| if key not in dictionary: | |
| dictionary[key] = 0 | |
| dictionary[f"{key}_counter"] = 0 | |
| dictionary[key] += item | |
| dictionary[f"{key}_counter"] += 1 | |
| return dictionary | |
| def stash_iter(self, *items: Union[Dict[str, LI], str, LI]): | |
| self.iter_dictionary = self.stash(self.iter_dictionary, items) | |
| return self | |
| def stash_level(self, *items: Union[Dict[str, LI], str, LI]): | |
| self.level_dictionary = self.stash(self.level_dictionary, items) | |
| def reset_iter(self, *items: Union[Dict[str, LI], str, LI]): | |
| if len(items) > 0: | |
| self.stash_iter(*items) | |
| aggregate_dictionary = self.aggregate(self.iter_dictionary, self.level_dictionary) | |
| self.progress.set_postfix(aggregate_dictionary) | |
| self.progress.update() | |
| self.iter_dictionary = dict() | |
| return self | |
| def start(self, iters: int, tag: str = ''): | |
| if self.progress is not None: | |
| self.stop() | |
| if iters < 0: | |
| iters = self.iters | |
| if tag == '': | |
| tag = self.tag | |
| self.iters, self.tag = iters, tag | |
| self.progress = tqdm(total=self.iters, desc=f'{self.tag} {self.level}') | |
| return self | |
| def stop(self, aggregate: bool = True): | |
| if aggregate: | |
| aggregate_dictionary = self.aggregate(self.level_dictionary) | |
| self.progress.set_postfix(aggregate_dictionary) | |
| self.level_dictionary = dict() | |
| self.progress.close() | |
| self.progress = None | |
| self.level += 1 | |
| return aggregate_dictionary | |
| def reset_level(self, aggregate: bool = True): | |
| self.stop(aggregate) | |
| self.start() | |
| class LinearWarmupScheduler: | |
| def get_lr(self): | |
| if self.cur_iter >= self.num_iters: | |
| return [self.target_lr] * len(self.base_lrs) | |
| alpha = self.cur_iter / self.num_iters | |
| return [base_lr + delta_lr * alpha for base_lr, delta_lr in zip(self.base_lrs, self.delta_lrs)] | |
| def step(self): | |
| if not self.finished: | |
| for group, lr in zip(self.optimizer.param_groups, self.get_lr()): | |
| group['lr'] = lr | |
| self.cur_iter += 1. | |
| self.finished = self.cur_iter > self.num_iters | |
| def __init__(self, optimizer, target_lr, num_iters): | |
| self.cur_iter = 0. | |
| self.target_lr = target_lr | |
| self.num_iters = num_iters | |
| self.finished = False | |
| self.optimizer = optimizer | |
| self.base_lrs = [group['lr'] for group in optimizer.param_groups] | |
| self.delta_lrs = [target_lr - base_lr for base_lr in self.base_lrs] | |