| from transformers import PretrainedConfig | |
| class SSLConfig(PretrainedConfig): | |
| model_type = "ssl-aasist" | |
| def __init__( | |
| self, | |
| filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]], | |
| gat_dims = [64, 32], | |
| pool_ratios = [0.5, 0.5, 0.5, 0.5], | |
| temperatures = [2.0, 2.0, 100.0, 100.0], | |
| **kwargs, | |
| ): | |
| self.filts = filts | |
| self.gat_dims = gat_dims | |
| self.pool_ratios = pool_ratios | |
| self.temperatures = temperatures | |
| super().__init__(**kwargs) |