Dixtral / configuration_dixtral.py
Lakoc's picture
Upload DixtralForConditionalGeneration
cc34926 verified
Raw
History Blame Contribute Delete
2.77 kB
from transformers.models.voxtral.configuration_voxtral import VoxtralConfig, VoxtralEncoderConfig
class DixtralEncoderConfig(VoxtralEncoderConfig):
def __init__(
self,
# DiCoW-specific parameters
use_dicow_encoder: bool = False,
ctc_weight: float = 0.0,
additional_layer: bool = False,
additional_self_attention_layer: bool = False,
pre_ctc_sub_sample: bool = False,
final_dropout: float = 0.0,
use_fddt: bool = False,
apply_fddt_to_n_layers: int = -1,
use_pre_pos_fddt: bool = False,
fddt_init: str = "zeros",
fddt_is_diagonal: bool = False,
fddt_bias_only: bool = False,
fddt_use_silence: bool = True,
fddt_use_target: bool = True,
fddt_use_overlap: bool = True,
fddt_use_non_target: bool = True,
non_target_fddt_value: float = 1.0,
use_enrollments: bool = False,
scb_layers: int = None,
remove_timestamps_from_ctc: bool = False,
ctc_loss_reduction: str = "mean",
**kwargs
):
super().__init__(**kwargs)
self.use_dicow_encoder = use_dicow_encoder
self.ctc_weight = ctc_weight
self.additional_layer = additional_layer
self.additional_self_attention_layer = additional_self_attention_layer
self.pre_ctc_sub_sample = pre_ctc_sub_sample
self.final_dropout = final_dropout
self.use_fddt = use_fddt
self.apply_fddt_to_n_layers = apply_fddt_to_n_layers
self.use_pre_pos_fddt = use_pre_pos_fddt
self.fddt_init = fddt_init
self.fddt_is_diagonal = fddt_is_diagonal
self.fddt_bias_only = fddt_bias_only
self.fddt_use_silence = fddt_use_silence
self.fddt_use_target = fddt_use_target
self.fddt_use_overlap = fddt_use_overlap
self.fddt_use_non_target = fddt_use_non_target
self.non_target_fddt_value = non_target_fddt_value
self.use_enrollments = use_enrollments
self.scb_layers = scb_layers
self.remove_timestamps_from_ctc = remove_timestamps_from_ctc
self.ctc_loss_reduction = ctc_loss_reduction
class DixtralConfig(VoxtralConfig):
def __init__(
self,
audio_config: dict = None,
num_soft_prompts: int = 0,
**kwargs
):
# Convert audio_config to DiCoW version if provided
if audio_config is not None and not isinstance(audio_config, DixtralEncoderConfig):
audio_config = DixtralEncoderConfig(**audio_config)
super().__init__(audio_config=audio_config, **kwargs)
self.num_soft_prompts = num_soft_prompts