import numpy as np class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self def override(self, attrs): if isinstance(attrs, dict): self.__dict__.update(**attrs) elif isinstance(attrs, (list, tuple, set)): for attr in attrs: self.override(attr) elif attrs is not None: raise NotImplementedError return self all_params = { 'Plugin_freevc': AttrDict( # Diff params diff=AttrDict( num_train_steps=1000, beta_start=1e-4, beta_end=0.02, num_infer_steps=50, v_prediction=True, ), text_encoder=AttrDict( model='google/flan-t5-base' ), opt=AttrDict( learning_rate=1e-4, beta1=0.9, beta2=0.999, weight_decay=1e-4, adam_epsilon=1e-08, ),), } def get_params(name): return all_params[name]