| from transformers import PretrainedConfig | |
| from typing import List | |
| class CompressionConfig(PretrainedConfig): | |
| model_type = "compression_head" | |
| def __init__(self, | |
| input_size: int = 768, | |
| compression_sizes: List[int] = [512, 256, 128, 64, 32], | |
| dropout: float = 0.1, | |
| loss_k_vals: List[int] = [], | |
| **kwargs | |
| ): | |
| self.input_size = input_size | |
| self.compression_sizes = compression_sizes | |
| self.dropout = dropout | |
| self.loss_k_vals = loss_k_vals | |
| super().__init__(**kwargs) | |