Create configuration_speech_encoder.py
Browse files
configuration_speech_encoder.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class SpeechEncoderConfig(PretrainedConfig):
|
| 4 |
+
model_type = "gslm-speech-encoder"
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
# Backend for dense features: "fairseq" (your uploaded .pt) or "transformers"
|
| 9 |
+
hubert_backend: str = "fairseq",
|
| 10 |
+
# If backend=fairseq: filename of your uploaded HuBERT checkpoint (e.g. "hubert_base_ls960.pt")
|
| 11 |
+
hubert_ckpt: str = "hubert_base_ls960.pt",
|
| 12 |
+
# If backend=transformers: HF model id (e.g. "facebook/hubert-base-ls960")
|
| 13 |
+
hubert_hf_name: str = "facebook/hubert-base-ls960",
|
| 14 |
+
hubert_layer: int = 9, # which layer features to use
|
| 15 |
+
expected_sample_rate: int = 16000, # HuBERT-base default
|
| 16 |
+
code_hop_size: int = 320, # 20 ms at 16 kHz
|
| 17 |
+
# Quantizer file (K-Means centers)
|
| 18 |
+
quantizer_file: str = "kmeans_100.pt",
|
| 19 |
+
# Optional key to find centers inside a .pt/.pkl
|
| 20 |
+
quantizer_key: str = "",
|
| 21 |
+
# Interface flags (mirror textless SpeechEncoder)
|
| 22 |
+
deduplicate: bool = True,
|
| 23 |
+
add_bos_eos: bool = False,
|
| 24 |
+
need_f0: bool = False, # keep False (F0 pipeline not bundled here)
|
| 25 |
+
# BOS/EOS ids (if left None, we will place them at vocab_size and vocab_size+1)
|
| 26 |
+
bos_id: int | None = None,
|
| 27 |
+
eos_id: int | None = None,
|
| 28 |
+
# Feature normalization before KMeans (None | "unit" | "layernorm")
|
| 29 |
+
feature_norm: str | None = None,
|
| 30 |
+
**kwargs,
|
| 31 |
+
):
|
| 32 |
+
super().__init__(**kwargs)
|
| 33 |
+
self.hubert_backend = hubert_backend
|
| 34 |
+
self.hubert_ckpt = hubert_ckpt
|
| 35 |
+
self.hubert_hf_name = hubert_hf_name
|
| 36 |
+
self.hubert_layer = int(hubert_layer)
|
| 37 |
+
self.expected_sample_rate = int(expected_sample_rate)
|
| 38 |
+
self.code_hop_size = int(code_hop_size)
|
| 39 |
+
|
| 40 |
+
self.quantizer_file = quantizer_file
|
| 41 |
+
self.quantizer_key = quantizer_key
|
| 42 |
+
|
| 43 |
+
self.deduplicate = bool(deduplicate)
|
| 44 |
+
self.add_bos_eos = bool(add_bos_eos)
|
| 45 |
+
self.need_f0 = bool(need_f0)
|
| 46 |
+
self.bos_id = bos_id
|
| 47 |
+
self.eos_id = eos_id
|
| 48 |
+
|
| 49 |
+
self.feature_norm = feature_norm
|