Training in progress - step 1000
Browse files- asr_config.py +65 -16
- asr_modeling.py +379 -532
- asr_pipeline.py +29 -256
- asr_processing.py +68 -54
- chat_template.jinja +94 -6
- mlp_projector.py +42 -0
- moe_projector.py +162 -0
- preprocessor_config.json +0 -3
- residual_projector.py +153 -0
- shared_moe_projector.py +182 -0
- special_tokens_map.json +8 -23
- swiglu_projector.py +68 -0
- tokenizer.json +2 -2
- tokenizer_config.json +0 -0
asr_config.py
CHANGED
|
@@ -11,9 +11,8 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 11 |
self,
|
| 12 |
audio_model_id: str = "openai/whisper-large-v3-turbo",
|
| 13 |
text_model_id: str = "HuggingFaceTB/SmolLM3-3B",
|
| 14 |
-
attn_implementation: str = "
|
| 15 |
model_dtype: str = "bfloat16",
|
| 16 |
-
audio_downsample_rate: int = 5, # Deprecated: use projector_pool_stride instead
|
| 17 |
num_beams: Optional[int] = None,
|
| 18 |
system_prompt: str = "/no_think /system_override",
|
| 19 |
user_prompt: str = "Transcribe: <audio>",
|
|
@@ -22,8 +21,18 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 22 |
audio_sample_rate: int = 16000,
|
| 23 |
projector_init_std: float = 0.02,
|
| 24 |
projector_pool_stride: int = 2,
|
|
|
|
| 25 |
projector_hidden_dim: Optional[int] = None,
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
inference_diversity_penalty: float = 0.0,
|
| 28 |
inference_warmup_tokens: int = 10,
|
| 29 |
max_new_tokens: Optional[int] = None,
|
|
@@ -42,10 +51,12 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 42 |
# Set default generation parameters
|
| 43 |
generation_defaults = {
|
| 44 |
"num_beams": 1,
|
| 45 |
-
"max_new_tokens":
|
| 46 |
-
"min_new_tokens":
|
| 47 |
"do_sample": False,
|
| 48 |
-
"
|
|
|
|
|
|
|
| 49 |
"no_repeat_ngram_size": 0,
|
| 50 |
"use_cache": True,
|
| 51 |
}
|
|
@@ -57,7 +68,6 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 57 |
self.text_model_id = text_model_id
|
| 58 |
self.attn_implementation = attn_implementation
|
| 59 |
self.model_dtype = model_dtype
|
| 60 |
-
self.audio_downsample_rate = audio_downsample_rate
|
| 61 |
self.system_prompt = system_prompt
|
| 62 |
self.user_prompt = user_prompt
|
| 63 |
self.encoder_dim = encoder_dim
|
|
@@ -65,12 +75,55 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 65 |
self.audio_sample_rate = audio_sample_rate
|
| 66 |
self.projector_init_std = projector_init_std
|
| 67 |
self.projector_pool_stride = projector_pool_stride
|
|
|
|
| 68 |
self.projector_hidden_dim = projector_hidden_dim
|
|
|
|
|
|
|
| 69 |
self.projector_dropout = projector_dropout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
self.inference_diversity_penalty = inference_diversity_penalty
|
| 71 |
self.inference_warmup_tokens = inference_warmup_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
if "audio_config" not in kwargs:
|
| 73 |
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
self.audio_config = kwargs.pop("audio_config")
|
| 76 |
|
|
@@ -78,20 +131,16 @@ class ASRConfig(transformers.PretrainedConfig):
|
|
| 78 |
self.text_config = transformers.AutoConfig.from_pretrained(
|
| 79 |
text_model_id, trust_remote_code=True
|
| 80 |
)
|
|
|
|
|
|
|
| 81 |
else:
|
| 82 |
self.text_config = kwargs.pop("text_config")
|
| 83 |
|
| 84 |
if isinstance(self.text_config, dict):
|
| 85 |
# Reconstruct config from dict using the model_type stored in the dict
|
| 86 |
-
model_type = self.text_config
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
self.text_config = config_class(**self.text_config)
|
| 90 |
-
else:
|
| 91 |
-
# Fallback: try to load from model_id
|
| 92 |
-
self.text_config = transformers.AutoConfig.from_pretrained(
|
| 93 |
-
text_model_id, trust_remote_code=True
|
| 94 |
-
)
|
| 95 |
|
| 96 |
if isinstance(self.audio_config, dict):
|
| 97 |
model_type = self.audio_config.get("model_type")
|
|
|
|
| 11 |
self,
|
| 12 |
audio_model_id: str = "openai/whisper-large-v3-turbo",
|
| 13 |
text_model_id: str = "HuggingFaceTB/SmolLM3-3B",
|
| 14 |
+
attn_implementation: str = "flash_attention_2",
|
| 15 |
model_dtype: str = "bfloat16",
|
|
|
|
| 16 |
num_beams: Optional[int] = None,
|
| 17 |
system_prompt: str = "/no_think /system_override",
|
| 18 |
user_prompt: str = "Transcribe: <audio>",
|
|
|
|
| 21 |
audio_sample_rate: int = 16000,
|
| 22 |
projector_init_std: float = 0.02,
|
| 23 |
projector_pool_stride: int = 2,
|
| 24 |
+
downsample_rate: int = 16,
|
| 25 |
projector_hidden_dim: Optional[int] = None,
|
| 26 |
+
projector_type: str = "moe", # "moe", "swiglu", "residual", "shared_moe", "mlp"
|
| 27 |
+
projector_num_layers: int = 2, # Number of layers (for residual projector)
|
| 28 |
+
projector_dropout: float = 0.05, # Dropout rate for projector layers
|
| 29 |
+
projector_input_noise: float = 0.02, # Input noise for projector
|
| 30 |
+
# MoE-specific configuration
|
| 31 |
+
num_experts: int = 4, # Number of experts in MoE projectors
|
| 32 |
+
num_experts_per_tok: int = 2, # Top-k experts per token
|
| 33 |
+
router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
|
| 34 |
+
use_specaugment: bool = True, # Apply SpecAugment during training
|
| 35 |
+
label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
|
| 36 |
inference_diversity_penalty: float = 0.0,
|
| 37 |
inference_warmup_tokens: int = 10,
|
| 38 |
max_new_tokens: Optional[int] = None,
|
|
|
|
| 51 |
# Set default generation parameters
|
| 52 |
generation_defaults = {
|
| 53 |
"num_beams": 1,
|
| 54 |
+
"max_new_tokens": 96,
|
| 55 |
+
"min_new_tokens": 0,
|
| 56 |
"do_sample": False,
|
| 57 |
+
"temperature": 0.1,
|
| 58 |
+
"repetition_penalty": 1.0,
|
| 59 |
+
"length_penalty": 1.0,
|
| 60 |
"no_repeat_ngram_size": 0,
|
| 61 |
"use_cache": True,
|
| 62 |
}
|
|
|
|
| 68 |
self.text_model_id = text_model_id
|
| 69 |
self.attn_implementation = attn_implementation
|
| 70 |
self.model_dtype = model_dtype
|
|
|
|
| 71 |
self.system_prompt = system_prompt
|
| 72 |
self.user_prompt = user_prompt
|
| 73 |
self.encoder_dim = encoder_dim
|
|
|
|
| 75 |
self.audio_sample_rate = audio_sample_rate
|
| 76 |
self.projector_init_std = projector_init_std
|
| 77 |
self.projector_pool_stride = projector_pool_stride
|
| 78 |
+
self.downsample_rate = downsample_rate
|
| 79 |
self.projector_hidden_dim = projector_hidden_dim
|
| 80 |
+
self.projector_type = projector_type
|
| 81 |
+
self.projector_num_layers = projector_num_layers
|
| 82 |
self.projector_dropout = projector_dropout
|
| 83 |
+
self.projector_input_noise = projector_input_noise
|
| 84 |
+
# MoE-specific configuration
|
| 85 |
+
self.num_experts = num_experts
|
| 86 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 87 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 88 |
+
self.use_specaugment = use_specaugment
|
| 89 |
+
self.label_smoothing = label_smoothing
|
| 90 |
self.inference_diversity_penalty = inference_diversity_penalty
|
| 91 |
self.inference_warmup_tokens = inference_warmup_tokens
|
| 92 |
+
|
| 93 |
+
# Generation parameters (use explicit value if provided, else use default)
|
| 94 |
+
self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
|
| 95 |
+
self.max_new_tokens = (
|
| 96 |
+
max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
|
| 97 |
+
)
|
| 98 |
+
self.min_new_tokens = (
|
| 99 |
+
min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
|
| 100 |
+
)
|
| 101 |
+
self.do_sample = do_sample if do_sample is not None else generation_defaults["do_sample"]
|
| 102 |
+
self.repetition_penalty = (
|
| 103 |
+
repetition_penalty
|
| 104 |
+
if repetition_penalty is not None
|
| 105 |
+
else generation_defaults["repetition_penalty"]
|
| 106 |
+
)
|
| 107 |
+
self.length_penalty = (
|
| 108 |
+
length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
|
| 109 |
+
)
|
| 110 |
+
self.no_repeat_ngram_size = (
|
| 111 |
+
no_repeat_ngram_size
|
| 112 |
+
if no_repeat_ngram_size is not None
|
| 113 |
+
else generation_defaults["no_repeat_ngram_size"]
|
| 114 |
+
)
|
| 115 |
+
self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
|
| 116 |
+
self.temperature = (
|
| 117 |
+
temperature if temperature is not None else generation_defaults["temperature"]
|
| 118 |
+
)
|
| 119 |
+
self.top_k = top_k
|
| 120 |
+
self.top_p = top_p
|
| 121 |
+
self.early_stopping = early_stopping
|
| 122 |
+
|
| 123 |
if "audio_config" not in kwargs:
|
| 124 |
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
|
| 125 |
+
# Override dtype to match model_dtype
|
| 126 |
+
self.audio_config.dtype = model_dtype
|
| 127 |
else:
|
| 128 |
self.audio_config = kwargs.pop("audio_config")
|
| 129 |
|
|
|
|
| 131 |
self.text_config = transformers.AutoConfig.from_pretrained(
|
| 132 |
text_model_id, trust_remote_code=True
|
| 133 |
)
|
| 134 |
+
# Override dtype to match model_dtype
|
| 135 |
+
self.text_config.dtype = model_dtype
|
| 136 |
else:
|
| 137 |
self.text_config = kwargs.pop("text_config")
|
| 138 |
|
| 139 |
if isinstance(self.text_config, dict):
|
| 140 |
# Reconstruct config from dict using the model_type stored in the dict
|
| 141 |
+
model_type = self.text_config["model_type"]
|
| 142 |
+
config_class = transformers.AutoConfig.for_model(model_type).__class__
|
| 143 |
+
self.text_config = config_class(**self.text_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
if isinstance(self.audio_config, dict):
|
| 146 |
model_type = self.audio_config.get("model_type")
|
asr_modeling.py
CHANGED
|
@@ -1,148 +1,78 @@
|
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
from typing import Optional, Union
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
-
import torch.nn.functional as F # noqa: N812
|
| 7 |
from transformers import (
|
| 8 |
AutoConfig,
|
| 9 |
AutoModel,
|
| 10 |
AutoModelForCausalLM,
|
| 11 |
AutoTokenizer,
|
| 12 |
PreTrainedModel,
|
| 13 |
-
Wav2Vec2FeatureExtractor,
|
| 14 |
)
|
| 15 |
-
from transformers.generation
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
GenerateEncoderDecoderOutput,
|
| 20 |
)
|
| 21 |
|
| 22 |
try:
|
| 23 |
from .asr_config import ASRConfig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
except ImportError:
|
| 25 |
from asr_config import ASRConfig # type: ignore[no-redef]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
class SwiGLU(nn.Module):
|
| 29 |
-
def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
|
| 30 |
-
super().__init__()
|
| 31 |
-
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 32 |
-
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 33 |
-
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 34 |
-
self.act = nn.SiLU()
|
| 35 |
-
self.dropout = nn.Dropout(dropout)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
x_val = self.w2(x)
|
| 40 |
-
x = x_gate * x_val
|
| 41 |
-
x = self.dropout(x)
|
| 42 |
-
return self.w3(x)
|
| 43 |
|
| 44 |
-
|
| 45 |
-
class AudioProjector(nn.Module):
|
| 46 |
-
def __init__(self, config):
|
| 47 |
-
super().__init__()
|
| 48 |
-
self.k = getattr(config, "projector_pool_stride", 2) # Downsampling rate
|
| 49 |
-
in_dim = config.encoder_dim * self.k
|
| 50 |
-
out_dim = config.llm_dim
|
| 51 |
-
hidden_dim = config.projector_hidden_dim
|
| 52 |
-
if hidden_dim is None:
|
| 53 |
-
hidden_dim = config.encoder_dim * 4
|
| 54 |
-
|
| 55 |
-
dropout_rate = getattr(config, "projector_dropout", 0.0)
|
| 56 |
-
|
| 57 |
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 58 |
-
|
| 59 |
-
self.ln_pre = LlamaRMSNorm(in_dim, eps=1e-6)
|
| 60 |
-
self.proj = SwiGLU(in_dim, hidden_dim, out_dim, dropout=dropout_rate)
|
| 61 |
-
self.ln_post = LlamaRMSNorm(out_dim, eps=1e-6)
|
| 62 |
-
self.output_dropout = nn.Dropout(dropout_rate)
|
| 63 |
-
|
| 64 |
-
with torch.no_grad():
|
| 65 |
-
std = getattr(config, "projector_init_std", 0.02)
|
| 66 |
-
self.ln_pre.weight.data.fill_(1.0)
|
| 67 |
-
self.ln_post.weight.data.fill_(1.0)
|
| 68 |
-
nn.init.normal_(self.proj.w1.weight, mean=0.0, std=std)
|
| 69 |
-
nn.init.normal_(self.proj.w2.weight, mean=0.0, std=std)
|
| 70 |
-
nn.init.normal_(self.proj.w3.weight, mean=0.0, std=std)
|
| 71 |
-
|
| 72 |
-
def forward(self, x):
|
| 73 |
-
batch_size, seq_len, dim = x.size()
|
| 74 |
-
|
| 75 |
-
target_dtype = self.proj.w1.weight.dtype
|
| 76 |
-
if x.dtype != target_dtype:
|
| 77 |
-
x = x.to(target_dtype)
|
| 78 |
-
|
| 79 |
-
remainder = seq_len % self.k
|
| 80 |
-
if remainder:
|
| 81 |
-
pad_len = self.k - remainder
|
| 82 |
-
x = F.pad(x, (0, 0, 0, pad_len))
|
| 83 |
-
|
| 84 |
-
x = x.contiguous().view(batch_size, -1, dim * self.k)
|
| 85 |
-
x = self.ln_pre(x)
|
| 86 |
-
x = self.proj(x)
|
| 87 |
-
x = self.ln_post(x)
|
| 88 |
-
|
| 89 |
-
return self.output_dropout(x)
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
class ASRModel(PreTrainedModel):
|
| 93 |
config_class = ASRConfig
|
| 94 |
base_model_prefix = "model"
|
| 95 |
-
main_input_name = "
|
| 96 |
_supports_flash_attn_2 = True
|
| 97 |
supports_gradient_checkpointing = True
|
| 98 |
_is_loading_from_pretrained: bool = False
|
| 99 |
_pretrained_model_path: Optional[str] = None
|
| 100 |
|
| 101 |
-
|
| 102 |
-
TASK_PROMPTS = {
|
| 103 |
-
"transcribe": "Transcribe: <audio>",
|
| 104 |
-
"continue": "Continue: <audio>",
|
| 105 |
-
"describe": "Describe: <audio>",
|
| 106 |
-
"emotion": "Emotion: <audio>",
|
| 107 |
-
}
|
| 108 |
-
|
| 109 |
-
@staticmethod
|
| 110 |
-
def _create_feature_extractor(audio_model_id: str):
|
| 111 |
-
"""Factory method to create the appropriate feature extractor."""
|
| 112 |
-
is_whisper = "whisper" in audio_model_id.lower()
|
| 113 |
-
if is_whisper:
|
| 114 |
-
from transformers import WhisperConfig, WhisperFeatureExtractor
|
| 115 |
-
|
| 116 |
-
encoder_config = WhisperConfig.from_pretrained(audio_model_id)
|
| 117 |
-
num_mel_bins = encoder_config.num_mel_bins
|
| 118 |
-
return WhisperFeatureExtractor.from_pretrained(
|
| 119 |
-
audio_model_id,
|
| 120 |
-
feature_size=num_mel_bins,
|
| 121 |
-
)
|
| 122 |
-
return Wav2Vec2FeatureExtractor.from_pretrained(audio_model_id)
|
| 123 |
|
| 124 |
@classmethod
|
| 125 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 126 |
-
from
|
|
|
|
|
|
|
| 127 |
|
| 128 |
config = kwargs.pop("config", None)
|
| 129 |
if config is None:
|
| 130 |
config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 131 |
|
| 132 |
-
#
|
| 133 |
-
kwargs["feature_extractor"] = AutoFeatureExtractor.from_pretrained(
|
| 134 |
-
pretrained_model_name_or_path, **kwargs
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
cls._is_loading_from_pretrained = True
|
| 138 |
cls._pretrained_model_path = pretrained_model_name_or_path
|
| 139 |
|
| 140 |
try:
|
| 141 |
-
from safetensors.torch import load_file
|
| 142 |
-
from transformers.utils.hub import cached_file
|
| 143 |
-
|
| 144 |
model = cls(config, **kwargs)
|
| 145 |
|
|
|
|
| 146 |
subfolder = kwargs.get("subfolder")
|
| 147 |
revision = kwargs.get("revision")
|
| 148 |
cache_kwargs = {}
|
|
@@ -158,102 +88,76 @@ class ASRModel(PreTrainedModel):
|
|
| 158 |
**cache_kwargs,
|
| 159 |
)
|
| 160 |
|
| 161 |
-
if not
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
"The repository may not have been trained yet."
|
| 165 |
-
)
|
| 166 |
-
|
| 167 |
-
state_dict = load_file(model_file)
|
| 168 |
-
model.load_state_dict(state_dict, strict=False, assign=True)
|
| 169 |
-
|
| 170 |
-
target_dtype = getattr(torch, config.model_dtype)
|
| 171 |
-
model.projector = model.projector.to(dtype=target_dtype)
|
| 172 |
-
|
| 173 |
-
device = kwargs.get("device")
|
| 174 |
-
if device is not None:
|
| 175 |
-
model = model.to(device)
|
| 176 |
|
| 177 |
return model
|
| 178 |
finally:
|
| 179 |
cls._is_loading_from_pretrained = False
|
| 180 |
-
|
| 181 |
|
| 182 |
def __init__(self, config: ASRConfig, **kwargs):
|
| 183 |
super().__init__(config)
|
| 184 |
|
| 185 |
-
feature_extractor = kwargs.pop("feature_extractor", None)
|
| 186 |
-
|
| 187 |
self.system_prompt = config.system_prompt
|
|
|
|
| 188 |
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
else:
|
| 204 |
-
self.
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
self.generation_config = self.
|
| 208 |
-
|
| 209 |
-
self._init_tokenizer()
|
| 210 |
-
|
| 211 |
-
from types import SimpleNamespace
|
| 212 |
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
if hasattr(self.encoder.config, "hidden_size"):
|
| 216 |
-
encoder_dim = self.encoder.config.hidden_size
|
| 217 |
-
elif hasattr(self.encoder.config, "d_model"):
|
| 218 |
-
encoder_dim = self.encoder.config.d_model
|
| 219 |
-
else:
|
| 220 |
-
raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
if hasattr(self.decoder.config, "hidden_size"):
|
| 225 |
-
llm_dim = self.decoder.config.hidden_size
|
| 226 |
-
elif hasattr(self.decoder.config, "d_model"):
|
| 227 |
-
llm_dim = self.decoder.config.d_model
|
| 228 |
-
else:
|
| 229 |
-
raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
|
| 230 |
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
llm_dim=llm_dim,
|
| 234 |
-
projector_pool_stride=getattr(config, "projector_pool_stride", 2),
|
| 235 |
-
projector_hidden_dim=getattr(config, "projector_hidden_dim", None),
|
| 236 |
-
projector_init_std=getattr(config, "projector_init_std", 0.02),
|
| 237 |
-
projector_dropout=getattr(config, "projector_dropout", 0.0),
|
| 238 |
-
)
|
| 239 |
-
self.projector = AudioProjector(projector_config)
|
| 240 |
|
| 241 |
-
|
| 242 |
-
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
|
| 246 |
@classmethod
|
| 247 |
-
def
|
| 248 |
-
|
| 249 |
-
|
| 250 |
encoder_kwargs = {
|
| 251 |
"attn_implementation": config.attn_implementation,
|
| 252 |
-
"dtype": target_dtype,
|
| 253 |
"low_cpu_mem_usage": True,
|
|
|
|
| 254 |
}
|
| 255 |
-
if not cls._is_loading_from_pretrained:
|
| 256 |
-
encoder_kwargs["device_map"] = "auto"
|
| 257 |
|
| 258 |
if "whisper" in config.audio_model_id.lower():
|
| 259 |
from transformers import WhisperModel
|
|
@@ -264,471 +168,414 @@ class ASRModel(PreTrainedModel):
|
|
| 264 |
else:
|
| 265 |
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 266 |
|
| 267 |
-
is_whisper = "whisper" in config.audio_model_id.lower() or (
|
| 268 |
-
hasattr(encoder.config, "model_type") and "whisper" in encoder.config.model_type.lower()
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
original_forward = encoder.forward
|
| 272 |
-
input_key = "input_features" if is_whisper else "input_values"
|
| 273 |
-
|
| 274 |
-
def safe_encoder_forward(self_encoder, input_values=None, **kwargs):
|
| 275 |
-
kwargs.pop("input_ids", None)
|
| 276 |
-
return original_forward(**{input_key: input_values}, **kwargs)
|
| 277 |
-
|
| 278 |
-
import types
|
| 279 |
-
|
| 280 |
-
encoder.forward = types.MethodType(safe_encoder_forward, encoder)
|
| 281 |
encoder.requires_grad_(False)
|
| 282 |
-
|
| 283 |
return encoder
|
| 284 |
|
| 285 |
@classmethod
|
| 286 |
-
def
|
| 287 |
-
|
| 288 |
-
|
| 289 |
decoder_kwargs = {
|
| 290 |
"attn_implementation": config.attn_implementation,
|
| 291 |
-
"dtype": target_dtype,
|
| 292 |
"trust_remote_code": True,
|
|
|
|
|
|
|
|
|
|
| 293 |
}
|
| 294 |
|
| 295 |
decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
|
| 296 |
-
decoder.config.use_cache = config
|
| 297 |
decoder.requires_grad_(False)
|
| 298 |
-
|
| 299 |
return decoder
|
| 300 |
|
| 301 |
-
def
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
if self._is_loading_from_pretrained
|
| 317 |
-
else self.config.text_model_id
|
| 318 |
-
)
|
| 319 |
|
| 320 |
-
|
|
|
|
|
|
|
| 321 |
|
|
|
|
| 322 |
if (
|
| 323 |
self.tokenizer.pad_token is None
|
| 324 |
or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
|
| 325 |
) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
|
| 326 |
self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
| 327 |
|
|
|
|
| 328 |
existing_special = self.tokenizer.additional_special_tokens or []
|
| 329 |
-
|
| 330 |
if "<audio>" not in existing_special:
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
current_embed_size = self.decoder.get_input_embeddings().weight.shape[0]
|
| 337 |
-
expected_size = len(self.tokenizer)
|
| 338 |
-
if current_embed_size != expected_size:
|
| 339 |
-
self.decoder.resize_token_embeddings(expected_size, mean_resizing=False)
|
| 340 |
|
| 341 |
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
| 342 |
-
|
| 343 |
self.tokenizer.padding_side = "right"
|
| 344 |
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
cfg["eos_token_id"] = self.tokenizer.eos_token_id
|
| 349 |
-
cfg["bos_token_id"] = self.tokenizer.bos_token_id
|
| 350 |
-
else:
|
| 351 |
cfg.pad_token_id = self.tokenizer.pad_token_id
|
| 352 |
cfg.eos_token_id = self.tokenizer.eos_token_id
|
| 353 |
cfg.bos_token_id = self.tokenizer.bos_token_id
|
| 354 |
|
| 355 |
-
def
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
except ImportError:
|
| 359 |
-
from asr_processing import ASRProcessor # type: ignore[no-redef]
|
| 360 |
-
|
| 361 |
-
return ASRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
|
| 362 |
-
|
| 363 |
-
def state_dict(self, *args, **kwargs):
|
| 364 |
-
return self._get_trainable_state_dict()
|
| 365 |
-
|
| 366 |
-
def _get_trainable_state_dict(self):
|
| 367 |
-
state = {}
|
| 368 |
-
|
| 369 |
-
projector_state = self.projector.state_dict()
|
| 370 |
-
for name, tensor in projector_state.items():
|
| 371 |
-
state[f"projector.{name}"] = tensor
|
| 372 |
|
| 373 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
def get_input_embeddings(self):
|
| 376 |
-
return self.
|
| 377 |
|
| 378 |
def set_input_embeddings(self, value):
|
| 379 |
-
self.
|
| 380 |
|
| 381 |
def get_output_embeddings(self):
|
| 382 |
-
return self.
|
| 383 |
|
| 384 |
def set_output_embeddings(self, value):
|
| 385 |
-
self.
|
| 386 |
|
| 387 |
-
def
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
encoder_dtype = next(self.encoder.parameters()).dtype
|
| 394 |
-
input_values = input_values.clone().to(device=encoder_device, dtype=encoder_dtype)
|
| 395 |
-
|
| 396 |
-
with torch.no_grad():
|
| 397 |
-
audio_features = self.encoder(
|
| 398 |
-
input_values=input_values,
|
| 399 |
-
attention_mask=audio_attention_mask,
|
| 400 |
-
).last_hidden_state
|
| 401 |
-
|
| 402 |
-
audio_embeds = self.projector(audio_features)
|
| 403 |
-
|
| 404 |
-
decoder_dtype = next(self.decoder.parameters()).dtype
|
| 405 |
-
if audio_embeds.dtype != decoder_dtype:
|
| 406 |
-
audio_embeds = audio_embeds.to(dtype=decoder_dtype)
|
| 407 |
-
|
| 408 |
-
return audio_embeds
|
| 409 |
-
|
| 410 |
-
def _get_audio_expansion_details(self, input_ids: torch.Tensor, num_audio_tokens: int) -> dict:
|
| 411 |
-
batch_size, seq_len = input_ids.shape
|
| 412 |
-
device = input_ids.device
|
| 413 |
-
audio_mask = input_ids == self.audio_token_id
|
| 414 |
-
|
| 415 |
-
audio_counts = audio_mask.sum(dim=1)
|
| 416 |
-
if not (audio_counts == 1).all():
|
| 417 |
-
missing = (audio_counts == 0).any()
|
| 418 |
-
multiple = (audio_counts > 1).any()
|
| 419 |
-
if missing:
|
| 420 |
-
raise ValueError("Some samples are missing audio token")
|
| 421 |
-
if multiple:
|
| 422 |
-
raise ValueError("Some samples have multiple audio tokens")
|
| 423 |
-
|
| 424 |
-
token_counts = torch.where(audio_mask, num_audio_tokens, 1)
|
| 425 |
-
cumsum_counts = torch.cumsum(token_counts, dim=1)
|
| 426 |
-
new_start_positions = torch.cat(
|
| 427 |
-
[
|
| 428 |
-
torch.zeros(batch_size, 1, dtype=torch.long, device=device),
|
| 429 |
-
cumsum_counts[:, :-1],
|
| 430 |
-
],
|
| 431 |
-
dim=1,
|
| 432 |
-
)
|
| 433 |
|
| 434 |
-
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
"audio_mask": audio_mask,
|
| 440 |
-
}
|
| 441 |
|
| 442 |
-
def
|
| 443 |
self,
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
num_audio_tokens: int,
|
| 447 |
-
fill_value: Optional[Union[int, float]] = None,
|
| 448 |
-
audio_fill_value: Optional[Union[int, float]] = None,
|
| 449 |
) -> torch.Tensor:
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
fill_value,
|
| 473 |
-
dtype=tensor_to_expand.dtype,
|
| 474 |
-
device=device,
|
| 475 |
-
)
|
| 476 |
-
|
| 477 |
-
batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, seq_len)
|
| 478 |
-
non_audio_mask = ~audio_mask
|
| 479 |
-
expanded[batch_indices[non_audio_mask], new_start_positions[non_audio_mask]] = (
|
| 480 |
-
tensor_to_expand[non_audio_mask]
|
| 481 |
-
)
|
| 482 |
-
|
| 483 |
-
if audio_fill_value != fill_value:
|
| 484 |
-
audio_positions = audio_mask.int().argmax(dim=1)
|
| 485 |
-
audio_new_start = new_start_positions[
|
| 486 |
-
torch.arange(batch_size, device=device), audio_positions
|
| 487 |
-
]
|
| 488 |
-
audio_token_indices = torch.arange(num_audio_tokens, device=device).unsqueeze(0)
|
| 489 |
-
audio_positions_expanded = audio_new_start.unsqueeze(1) + audio_token_indices
|
| 490 |
-
batch_idx_expanded = (
|
| 491 |
-
torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, num_audio_tokens)
|
| 492 |
)
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
-
|
| 498 |
-
return self._expand_tensor_for_audio(input_ids, None, num_audio_tokens)
|
| 499 |
|
| 500 |
-
def
|
| 501 |
self,
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
num_audio_tokens: int,
|
| 505 |
-
fill_value: Union[int, float],
|
| 506 |
) -> torch.Tensor:
|
| 507 |
-
|
| 508 |
-
input_ids, tensor_to_expand, num_audio_tokens, fill_value
|
| 509 |
-
)
|
| 510 |
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
def forward(
|
| 521 |
self,
|
| 522 |
input_ids: Optional[torch.Tensor] = None,
|
| 523 |
-
|
| 524 |
-
input_features: Optional[torch.Tensor] = None, # For Whisper
|
| 525 |
-
labels: Optional[torch.Tensor] = None,
|
| 526 |
attention_mask: Optional[torch.Tensor] = None,
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
**kwargs,
|
| 531 |
-
):
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
audio_embeds = self._encode_audio(
|
| 547 |
-
input_values=audio_inputs, # Will be mapped to input_features for Whisper by safe_encoder_forward
|
| 548 |
-
audio_attention_mask=audio_attention_mask,
|
| 549 |
)
|
| 550 |
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
)
|
| 559 |
|
| 560 |
-
|
| 561 |
-
raise ValueError("Audio token <audio> must be present in input")
|
| 562 |
|
| 563 |
-
|
| 564 |
-
|
|
|
|
|
|
|
| 565 |
|
| 566 |
-
|
| 567 |
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
)
|
| 572 |
-
else:
|
| 573 |
-
full_attention_mask = None
|
| 574 |
|
| 575 |
-
|
| 576 |
-
labels = self._expand_for_audio_tokens(
|
| 577 |
-
input_ids, labels, num_audio_tokens, fill_value=-100
|
| 578 |
-
)
|
| 579 |
-
else:
|
| 580 |
-
inputs_embeds = self.decoder.get_input_embeddings()(input_ids)
|
| 581 |
-
full_attention_mask = attention_mask
|
| 582 |
-
use_cache = kwargs.pop("use_cache", None)
|
| 583 |
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
|
|
|
|
|
|
| 591 |
|
| 592 |
@torch.no_grad()
|
| 593 |
def generate(
|
| 594 |
self,
|
| 595 |
-
|
| 596 |
-
input_features: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
| 597 |
system_prompt: Optional[str] = None,
|
| 598 |
-
user_prompt: Optional[str] = None,
|
| 599 |
-
task: Optional[str] = None,
|
| 600 |
**generate_kwargs,
|
| 601 |
-
) ->
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
|
| 626 |
-
|
| 627 |
-
messages.append({"role": "
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
if not (prompt_ids == self.audio_token_id).any():
|
| 650 |
-
raise ValueError("Audio token <audio> not found in prompt")
|
| 651 |
-
|
| 652 |
-
num_audio_tokens = audio_embeds.shape[1]
|
| 653 |
-
expanded_prompt_ids = self._expand_audio_tokens(prompt_ids, num_audio_tokens)
|
| 654 |
-
inputs_embeds = self._prepare_audio_inputs_embeds(expanded_prompt_ids, audio_embeds)
|
| 655 |
-
total_seq_len = inputs_embeds.shape[1]
|
| 656 |
-
attention_mask = torch.ones(batch_size, total_seq_len, dtype=torch.long, device=device)
|
| 657 |
-
config_params = [
|
| 658 |
-
"max_new_tokens",
|
| 659 |
-
"min_new_tokens",
|
| 660 |
-
"num_beams",
|
| 661 |
-
"do_sample",
|
| 662 |
-
"temperature",
|
| 663 |
-
"top_k",
|
| 664 |
-
"top_p",
|
| 665 |
-
"repetition_penalty",
|
| 666 |
-
"length_penalty",
|
| 667 |
-
"no_repeat_ngram_size",
|
| 668 |
-
"early_stopping",
|
| 669 |
-
]
|
| 670 |
-
for param in config_params:
|
| 671 |
-
if hasattr(self.config, param) and getattr(self.config, param) is not None:
|
| 672 |
-
generate_kwargs.setdefault(param, getattr(self.config, param))
|
| 673 |
-
|
| 674 |
-
generate_kwargs.setdefault("use_cache", True)
|
| 675 |
-
generate_kwargs.setdefault(
|
| 676 |
-
"eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 677 |
)
|
| 678 |
-
generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
|
| 679 |
-
prompt_length = expanded_prompt_ids.shape[1]
|
| 680 |
|
| 681 |
-
|
| 682 |
-
|
| 683 |
inputs_embeds=inputs_embeds,
|
| 684 |
attention_mask=attention_mask,
|
|
|
|
| 685 |
**generate_kwargs,
|
| 686 |
)
|
| 687 |
|
| 688 |
-
|
|
|
|
|
|
|
|
|
|
| 689 |
|
| 690 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
|
|
|
| 691 |
import shutil
|
| 692 |
from pathlib import Path as PathlibPath
|
| 693 |
|
| 694 |
save_dir = PathlibPath(save_directory)
|
| 695 |
save_dir.mkdir(parents=True, exist_ok=True)
|
| 696 |
|
| 697 |
-
|
| 698 |
-
self.config.vocab_size =
|
| 699 |
-
self.config.text_config.vocab_size =
|
| 700 |
|
| 701 |
-
if hasattr(self.
|
| 702 |
-
self.config.audio_config.num_mel_bins = self.
|
| 703 |
|
| 704 |
-
|
| 705 |
tokenizer = self.tokenizer
|
| 706 |
-
del self.feature_extractor
|
| 707 |
del self.tokenizer
|
| 708 |
|
| 709 |
try:
|
| 710 |
super().save_pretrained(save_dir, **kwargs)
|
| 711 |
finally:
|
| 712 |
-
self.feature_extractor = feature_extractor
|
| 713 |
self.tokenizer = tokenizer
|
| 714 |
|
|
|
|
| 715 |
self.tokenizer.save_pretrained(save_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
self.feature_extractor.n_mels = num_mel_bins
|
| 724 |
-
self.feature_extractor.nb_max_frames = 3000 # Whisper's max frames
|
| 725 |
|
| 726 |
-
|
|
|
|
| 727 |
|
|
|
|
| 728 |
src_dir = PathlibPath(__file__).parent
|
| 729 |
for asr_file in src_dir.glob("asr_*.py"):
|
| 730 |
shutil.copy(asr_file, save_dir / asr_file.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
|
| 732 |
|
|
|
|
| 733 |
AutoConfig.register("asr_model", ASRConfig)
|
| 734 |
AutoModel.register(ASRConfig, ASRModel)
|
|
|
|
| 1 |
+
import json
|
| 2 |
from pathlib import Path
|
| 3 |
from typing import Optional, Union
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
|
|
|
| 7 |
from transformers import (
|
| 8 |
AutoConfig,
|
| 9 |
AutoModel,
|
| 10 |
AutoModelForCausalLM,
|
| 11 |
AutoTokenizer,
|
| 12 |
PreTrainedModel,
|
|
|
|
| 13 |
)
|
| 14 |
+
from transformers.generation import GenerationMixin
|
| 15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 16 |
+
from transformers.models.whisper.modeling_whisper import (
|
| 17 |
+
_compute_mask_indices,
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
try:
|
| 21 |
from .asr_config import ASRConfig
|
| 22 |
+
from .mlp_projector import MLPAudioProjector
|
| 23 |
+
from .moe_projector import MoEAudioProjector
|
| 24 |
+
from .residual_projector import ResidualAudioProjector
|
| 25 |
+
from .shared_moe_projector import SharedMoEAudioProjector
|
| 26 |
+
from .swiglu_projector import AudioProjector
|
| 27 |
except ImportError:
|
| 28 |
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 29 |
+
from mlp_projector import MLPAudioProjector # type: ignore[no-redef]
|
| 30 |
+
from moe_projector import MoEAudioProjector # type: ignore[no-redef]
|
| 31 |
+
from residual_projector import ResidualAudioProjector # type: ignore[no-redef]
|
| 32 |
+
from shared_moe_projector import SharedMoEAudioProjector # type: ignore[no-redef]
|
| 33 |
+
from swiglu_projector import AudioProjector # type: ignore[no-redef]
|
| 34 |
|
| 35 |
+
# Map projector type names to classes
|
| 36 |
+
PROJECTOR_CLASSES = {
|
| 37 |
+
"swiglu": AudioProjector,
|
| 38 |
+
"residual": ResidualAudioProjector,
|
| 39 |
+
"moe": MoEAudioProjector,
|
| 40 |
+
"shared_moe": SharedMoEAudioProjector,
|
| 41 |
+
"mlp": MLPAudioProjector,
|
| 42 |
+
}
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
class ASRModel(PreTrainedModel, GenerationMixin):
|
| 46 |
+
"""Audio-to-text model combining an audio encoder, projector, and language model."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
config_class = ASRConfig
|
| 49 |
base_model_prefix = "model"
|
| 50 |
+
main_input_name = "input_features"
|
| 51 |
_supports_flash_attn_2 = True
|
| 52 |
supports_gradient_checkpointing = True
|
| 53 |
_is_loading_from_pretrained: bool = False
|
| 54 |
_pretrained_model_path: Optional[str] = None
|
| 55 |
|
| 56 |
+
TRANSCRIBE_PROMPT = "Transcribe: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
@classmethod
|
| 59 |
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 60 |
+
"""Load model from pretrained, handling device placement correctly."""
|
| 61 |
+
from safetensors.torch import load_file
|
| 62 |
+
from transformers.utils.hub import cached_file
|
| 63 |
|
| 64 |
config = kwargs.pop("config", None)
|
| 65 |
if config is None:
|
| 66 |
config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 67 |
|
| 68 |
+
# Set flag to avoid device_map="auto" in sub-model loaders
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
cls._is_loading_from_pretrained = True
|
| 70 |
cls._pretrained_model_path = pretrained_model_name_or_path
|
| 71 |
|
| 72 |
try:
|
|
|
|
|
|
|
|
|
|
| 73 |
model = cls(config, **kwargs)
|
| 74 |
|
| 75 |
+
# Load projector weights from safetensors
|
| 76 |
subfolder = kwargs.get("subfolder")
|
| 77 |
revision = kwargs.get("revision")
|
| 78 |
cache_kwargs = {}
|
|
|
|
| 88 |
**cache_kwargs,
|
| 89 |
)
|
| 90 |
|
| 91 |
+
if model_file is not None:
|
| 92 |
+
state_dict = load_file(model_file)
|
| 93 |
+
model.load_state_dict(state_dict, strict=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
return model
|
| 96 |
finally:
|
| 97 |
cls._is_loading_from_pretrained = False
|
| 98 |
+
cls._pretrained_model_path = None
|
| 99 |
|
| 100 |
def __init__(self, config: ASRConfig, **kwargs):
|
| 101 |
super().__init__(config)
|
| 102 |
|
|
|
|
|
|
|
| 103 |
self.system_prompt = config.system_prompt
|
| 104 |
+
target_dtype = getattr(torch, config.model_dtype)
|
| 105 |
|
| 106 |
+
# Audio encoder (frozen)
|
| 107 |
+
self.audio_tower = self._load_audio_encoder(config, target_dtype)
|
| 108 |
+
|
| 109 |
+
# Language model (frozen)
|
| 110 |
+
self.language_model = self._load_language_model(config, target_dtype)
|
| 111 |
+
|
| 112 |
+
# Initialize tokenizer and special tokens
|
| 113 |
+
self._init_tokenizer(config)
|
| 114 |
+
|
| 115 |
+
# Set up generation config with our defaults
|
| 116 |
+
self.generation_config = self.language_model.generation_config
|
| 117 |
+
self.generation_config.max_new_tokens = config.max_new_tokens
|
| 118 |
+
self.generation_config.num_beams = config.num_beams
|
| 119 |
+
self.generation_config.do_sample = config.do_sample
|
| 120 |
+
self.generation_config.use_cache = config.use_cache
|
| 121 |
+
self.generation_config.length_penalty = config.length_penalty
|
| 122 |
+
self.generation_config.repetition_penalty = config.repetition_penalty
|
| 123 |
+
self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
|
| 124 |
+
# Only set sampling params when do_sample=True, otherwise clear them
|
| 125 |
+
if config.do_sample:
|
| 126 |
+
self.generation_config.temperature = config.temperature
|
| 127 |
+
if config.top_k is not None:
|
| 128 |
+
self.generation_config.top_k = config.top_k
|
| 129 |
+
if config.top_p is not None:
|
| 130 |
+
self.generation_config.top_p = config.top_p
|
| 131 |
else:
|
| 132 |
+
self.generation_config.temperature = None
|
| 133 |
+
self.generation_config.top_k = None
|
| 134 |
+
self.generation_config.top_p = None
|
| 135 |
+
self.generation_config.eos_token_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 136 |
+
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
# Feature extractor for audio preprocessing
|
| 139 |
+
self.feature_extractor = self._create_feature_extractor(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
# Audio projector (trainable)
|
| 142 |
+
self.projector = self._create_projector(config, target_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
# For model parallelism
|
| 145 |
+
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
+
def _create_feature_extractor(self, config: ASRConfig):
|
| 148 |
+
"""Create the appropriate feature extractor for the audio encoder."""
|
| 149 |
+
from transformers import AutoFeatureExtractor
|
| 150 |
|
| 151 |
+
return AutoFeatureExtractor.from_pretrained(config.audio_model_id)
|
| 152 |
|
| 153 |
@classmethod
|
| 154 |
+
def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
| 155 |
+
"""Load and freeze the audio encoder."""
|
|
|
|
| 156 |
encoder_kwargs = {
|
| 157 |
"attn_implementation": config.attn_implementation,
|
|
|
|
| 158 |
"low_cpu_mem_usage": True,
|
| 159 |
+
"dtype": dtype,
|
| 160 |
}
|
|
|
|
|
|
|
| 161 |
|
| 162 |
if "whisper" in config.audio_model_id.lower():
|
| 163 |
from transformers import WhisperModel
|
|
|
|
| 168 |
else:
|
| 169 |
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
encoder.requires_grad_(False)
|
| 172 |
+
encoder.eval()
|
| 173 |
return encoder
|
| 174 |
|
| 175 |
@classmethod
|
| 176 |
+
def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
|
| 177 |
+
"""Load and freeze the language model."""
|
|
|
|
| 178 |
decoder_kwargs = {
|
| 179 |
"attn_implementation": config.attn_implementation,
|
|
|
|
| 180 |
"trust_remote_code": True,
|
| 181 |
+
"tie_word_embeddings": True,
|
| 182 |
+
"low_cpu_mem_usage": True,
|
| 183 |
+
"dtype": dtype,
|
| 184 |
}
|
| 185 |
|
| 186 |
decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
|
| 187 |
+
decoder.config.use_cache = getattr(config, "use_cache", True)
|
| 188 |
decoder.requires_grad_(False)
|
| 189 |
+
decoder.eval()
|
| 190 |
return decoder
|
| 191 |
|
| 192 |
+
def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
| 193 |
+
"""Create the trainable audio projector."""
|
| 194 |
+
# Auto-detect dimensions if not specified
|
| 195 |
+
if config.encoder_dim is None:
|
| 196 |
+
enc_cfg = self.audio_tower.config
|
| 197 |
+
config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
|
| 198 |
+
enc_cfg, "d_model", None
|
| 199 |
+
)
|
| 200 |
+
if config.encoder_dim is None:
|
| 201 |
+
raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
|
| 202 |
|
| 203 |
+
if config.llm_dim is None:
|
| 204 |
+
dec_cfg = self.language_model.config
|
| 205 |
+
config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
|
| 206 |
+
dec_cfg, "d_model", None
|
| 207 |
+
)
|
| 208 |
+
if config.llm_dim is None:
|
| 209 |
+
raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
|
| 210 |
|
| 211 |
+
# Select projector type based on config
|
| 212 |
+
projector_type = getattr(config, "projector_type", "moe")
|
| 213 |
+
projector_class = PROJECTOR_CLASSES.get(projector_type)
|
| 214 |
+
if projector_class is None:
|
| 215 |
+
raise ValueError(
|
| 216 |
+
f"Unknown projector_type: {projector_type}. "
|
| 217 |
+
f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
|
| 218 |
+
)
|
| 219 |
+
projector = projector_class(config)
|
| 220 |
|
| 221 |
+
# Move projector to same device as language model (important when using quantization)
|
| 222 |
+
device = next(self.language_model.parameters()).device
|
| 223 |
+
return projector.to(device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
def _init_tokenizer(self, config: ASRConfig):
|
| 226 |
+
"""Initialize tokenizer with audio token."""
|
| 227 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
|
| 228 |
|
| 229 |
+
# Set pad token
|
| 230 |
if (
|
| 231 |
self.tokenizer.pad_token is None
|
| 232 |
or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
|
| 233 |
) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
|
| 234 |
self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
| 235 |
|
| 236 |
+
# Add audio token
|
| 237 |
existing_special = self.tokenizer.additional_special_tokens or []
|
|
|
|
| 238 |
if "<audio>" not in existing_special:
|
| 239 |
+
self.tokenizer.add_special_tokens(
|
| 240 |
+
{"additional_special_tokens": existing_special + ["<audio>"]}
|
| 241 |
+
)
|
| 242 |
+
self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
|
|
|
| 245 |
self.tokenizer.padding_side = "right"
|
| 246 |
|
| 247 |
+
# Sync token IDs to configs
|
| 248 |
+
for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
|
| 249 |
+
if cfg is not None:
|
|
|
|
|
|
|
|
|
|
| 250 |
cfg.pad_token_id = self.tokenizer.pad_token_id
|
| 251 |
cfg.eos_token_id = self.tokenizer.eos_token_id
|
| 252 |
cfg.bos_token_id = self.tokenizer.bos_token_id
|
| 253 |
|
| 254 |
+
def _init_weights(self, module):
|
| 255 |
+
"""Weight initialization (projector weights are initialized in MoEAudioProjector)."""
|
| 256 |
+
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
|
| 258 |
+
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
|
| 259 |
+
"""Enable/disable gradient checkpointing for the language model."""
|
| 260 |
+
# The LLM still stores activations during forward for backprop to projector
|
| 261 |
+
# Gradient checkpointing trades compute for memory by recomputing activations
|
| 262 |
+
if hasattr(self.language_model, "_set_gradient_checkpointing"):
|
| 263 |
+
self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
|
| 264 |
+
elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
|
| 265 |
+
self.language_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
| 266 |
+
elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
|
| 267 |
+
self.language_model.gradient_checkpointing_disable()
|
| 268 |
|
| 269 |
def get_input_embeddings(self):
|
| 270 |
+
return self.language_model.get_input_embeddings()
|
| 271 |
|
| 272 |
def set_input_embeddings(self, value):
|
| 273 |
+
self.language_model.set_input_embeddings(value)
|
| 274 |
|
| 275 |
def get_output_embeddings(self):
|
| 276 |
+
return self.language_model.get_output_embeddings()
|
| 277 |
|
| 278 |
def set_output_embeddings(self, value):
|
| 279 |
+
self.language_model.set_output_embeddings(value)
|
| 280 |
|
| 281 |
+
def get_processor(self):
|
| 282 |
+
"""Get the processor for this model."""
|
| 283 |
+
try:
|
| 284 |
+
from .asr_processing import ASRProcessor
|
| 285 |
+
except ImportError:
|
| 286 |
+
from asr_processing import ASRProcessor # type: ignore[no-redef]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
+
return ASRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
|
| 289 |
|
| 290 |
+
def state_dict(self, *args, **kwargs):
|
| 291 |
+
"""Only save trainable projector weights."""
|
| 292 |
+
return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
|
|
|
|
|
|
| 293 |
|
| 294 |
+
def _apply_specaugment(
|
| 295 |
self,
|
| 296 |
+
input_features: torch.Tensor,
|
| 297 |
+
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
|
|
|
| 298 |
) -> torch.Tensor:
|
| 299 |
+
if not getattr(self.config, "use_specaugment", False):
|
| 300 |
+
return input_features
|
| 301 |
+
|
| 302 |
+
if not self.training:
|
| 303 |
+
return input_features
|
| 304 |
+
|
| 305 |
+
# Input shape: (batch_size, num_mel_bins, sequence_length) for Whisper
|
| 306 |
+
batch_size, hidden_size, sequence_length = input_features.size()
|
| 307 |
+
|
| 308 |
+
mask_time_prob = getattr(self.config, "mask_time_prob", 0.05)
|
| 309 |
+
mask_time_length = getattr(self.config, "mask_time_length", 10)
|
| 310 |
+
mask_feature_prob = getattr(self.config, "mask_feature_prob", 0.0)
|
| 311 |
+
mask_feature_length = getattr(self.config, "mask_feature_length", 10)
|
| 312 |
+
|
| 313 |
+
# Time masking
|
| 314 |
+
if mask_time_prob > 0:
|
| 315 |
+
mask_time_np = _compute_mask_indices(
|
| 316 |
+
(batch_size, sequence_length),
|
| 317 |
+
mask_prob=mask_time_prob,
|
| 318 |
+
mask_length=mask_time_length,
|
| 319 |
+
attention_mask=attention_mask,
|
| 320 |
+
min_masks=2,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
)
|
| 322 |
+
mask_time_indices = torch.tensor(
|
| 323 |
+
mask_time_np, device=input_features.device, dtype=torch.bool
|
| 324 |
+
)
|
| 325 |
+
# Expand to cover all features: (batch, seq) -> (batch, features, seq)
|
| 326 |
+
mask_time_expanded = mask_time_indices[:, None].expand(-1, hidden_size, -1)
|
| 327 |
+
input_features = input_features.masked_fill(mask_time_expanded, 0.0)
|
| 328 |
+
|
| 329 |
+
# Feature masking
|
| 330 |
+
if mask_feature_prob > 0:
|
| 331 |
+
mask_feature_np = _compute_mask_indices(
|
| 332 |
+
(batch_size, hidden_size),
|
| 333 |
+
mask_prob=mask_feature_prob,
|
| 334 |
+
mask_length=mask_feature_length,
|
| 335 |
+
min_masks=2,
|
| 336 |
+
)
|
| 337 |
+
mask_feature_indices = torch.tensor(
|
| 338 |
+
mask_feature_np, device=input_features.device, dtype=torch.bool
|
| 339 |
+
)
|
| 340 |
+
# Expand: (batch, features) -> (batch, features, seq)
|
| 341 |
+
mask_feature_expanded = mask_feature_indices[:, :, None].expand(-1, -1, sequence_length)
|
| 342 |
+
input_features = input_features.masked_fill(mask_feature_expanded, 0.0)
|
| 343 |
|
| 344 |
+
return input_features
|
|
|
|
| 345 |
|
| 346 |
+
def _encode_audio(
|
| 347 |
self,
|
| 348 |
+
audio_features: torch.Tensor,
|
| 349 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
| 350 |
) -> torch.Tensor:
|
| 351 |
+
"""Encode audio and project to LLM embedding space.
|
|
|
|
|
|
|
| 352 |
|
| 353 |
+
Returns flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
|
| 354 |
+
"""
|
| 355 |
+
# Apply SpecAugment during training (before encoding)
|
| 356 |
+
audio_features = self._apply_specaugment(audio_features, audio_attention_mask)
|
| 357 |
+
|
| 358 |
+
with torch.no_grad():
|
| 359 |
+
encoder_out = self.audio_tower(
|
| 360 |
+
input_features=audio_features, attention_mask=audio_attention_mask
|
| 361 |
+
)
|
| 362 |
+
hidden_states = encoder_out.last_hidden_state
|
| 363 |
+
|
| 364 |
+
audio_embeds = self.projector(hidden_states)
|
| 365 |
+
|
| 366 |
+
# Flatten: (batch, seq, hidden) -> (batch * seq, hidden)
|
| 367 |
+
# This allows masked_scatter to do 1:1 replacement
|
| 368 |
+
return audio_embeds.reshape(-1, audio_embeds.shape[-1])
|
| 369 |
|
| 370 |
def forward(
|
| 371 |
self,
|
| 372 |
input_ids: Optional[torch.Tensor] = None,
|
| 373 |
+
input_features: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
| 374 |
attention_mask: Optional[torch.Tensor] = None,
|
| 375 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 376 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 377 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 378 |
+
labels: Optional[torch.Tensor] = None,
|
| 379 |
+
use_cache: Optional[bool] = None,
|
| 380 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 381 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 382 |
**kwargs,
|
| 383 |
+
) -> CausalLMOutputWithPast:
|
| 384 |
+
"""Forward pass for training and inference."""
|
| 385 |
+
# Get text embeddings if not provided
|
| 386 |
+
if inputs_embeds is None:
|
| 387 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 388 |
+
|
| 389 |
+
if input_features is not None and input_ids is not None:
|
| 390 |
+
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 391 |
+
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 392 |
+
|
| 393 |
+
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 394 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 395 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 396 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 397 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
|
|
|
|
|
|
|
|
|
| 398 |
)
|
| 399 |
|
| 400 |
+
# Run through language model (let it compute loss if labels provided)
|
| 401 |
+
outputs = self.language_model(
|
| 402 |
+
attention_mask=attention_mask,
|
| 403 |
+
position_ids=position_ids,
|
| 404 |
+
past_key_values=past_key_values,
|
| 405 |
+
inputs_embeds=inputs_embeds,
|
| 406 |
+
labels=labels,
|
| 407 |
+
use_cache=use_cache,
|
| 408 |
+
cache_position=cache_position,
|
| 409 |
+
**kwargs,
|
| 410 |
+
)
|
| 411 |
|
| 412 |
+
# Add auxiliary loss from MoE projectors if available
|
| 413 |
+
if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
|
| 414 |
+
aux_loss = self.projector.get_aux_loss()
|
| 415 |
+
if aux_loss is not None and aux_loss.numel() > 0:
|
| 416 |
+
outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
|
| 417 |
|
| 418 |
+
return outputs
|
|
|
|
| 419 |
|
| 420 |
+
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 421 |
+
"""Prepare inputs for generation, handling audio features for cached decoding."""
|
| 422 |
+
input_features = kwargs.pop("input_features", None)
|
| 423 |
+
cache_position = kwargs.get("cache_position")
|
| 424 |
|
| 425 |
+
model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
|
| 426 |
|
| 427 |
+
# Only pass audio features on the first generation step (cache_position[0] == 0)
|
| 428 |
+
if cache_position is not None and cache_position[0] == 0 and input_features is not None:
|
| 429 |
+
model_inputs["input_features"] = input_features
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
+
return model_inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
+
def _get_num_audio_tokens(self, input_features: torch.Tensor) -> int:
|
| 434 |
+
"""Calculate number of audio tokens based on input shape.
|
| 435 |
+
|
| 436 |
+
Whisper: input_features shape is (batch, n_mels, mel_len)
|
| 437 |
+
Encoder output is mel_len // 2 due to stride-2 conv
|
| 438 |
+
MLP projector adds another stride-2 for 4x total downsampling
|
| 439 |
+
"""
|
| 440 |
+
mel_len = input_features.shape[-1]
|
| 441 |
+
return mel_len // 4
|
| 442 |
|
| 443 |
@torch.no_grad()
|
| 444 |
def generate(
|
| 445 |
self,
|
| 446 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 447 |
+
input_features: Optional[torch.Tensor] = None,
|
| 448 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 449 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 450 |
system_prompt: Optional[str] = None,
|
|
|
|
|
|
|
| 451 |
**generate_kwargs,
|
| 452 |
+
) -> torch.Tensor:
|
| 453 |
+
"""Generate transcription from audio input.
|
| 454 |
+
|
| 455 |
+
Can be called in two ways:
|
| 456 |
+
1. With input_ids containing <audio> tokens (from processor)
|
| 457 |
+
2. With just audio, and we build the prompt internally
|
| 458 |
+
"""
|
| 459 |
+
if input_features is None:
|
| 460 |
+
raise ValueError("input_features required for generation")
|
| 461 |
+
|
| 462 |
+
device = input_features.device
|
| 463 |
+
batch_size = input_features.shape[0]
|
| 464 |
+
|
| 465 |
+
# Encode audio -> flattened embeddings
|
| 466 |
+
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 467 |
+
|
| 468 |
+
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 469 |
+
if input_ids is None:
|
| 470 |
+
num_audio_tokens = self._get_num_audio_tokens(input_features)
|
| 471 |
+
audio_placeholder = "<audio>" * num_audio_tokens
|
| 472 |
+
|
| 473 |
+
system_prompt = system_prompt or self.system_prompt
|
| 474 |
+
|
| 475 |
+
messages: list[dict[str, str]] = []
|
| 476 |
+
if system_prompt:
|
| 477 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 478 |
+
messages.append({"role": "user", "content": self.TRANSCRIBE_PROMPT + audio_placeholder})
|
| 479 |
+
|
| 480 |
+
input_ids = self.tokenizer.apply_chat_template(
|
| 481 |
+
messages,
|
| 482 |
+
tokenize=True,
|
| 483 |
+
add_generation_prompt=True,
|
| 484 |
+
return_tensors="pt",
|
| 485 |
+
).to(device)
|
| 486 |
+
|
| 487 |
+
if input_ids.dim() == 1:
|
| 488 |
+
input_ids = input_ids.unsqueeze(0)
|
| 489 |
+
if input_ids.shape[0] == 1 and batch_size > 1:
|
| 490 |
+
input_ids = input_ids.expand(batch_size, -1)
|
| 491 |
+
|
| 492 |
+
attention_mask = torch.ones_like(input_ids)
|
| 493 |
+
|
| 494 |
+
# Get text embeddings and replace audio tokens with audio embeddings
|
| 495 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 496 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 497 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 498 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 499 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 500 |
)
|
|
|
|
|
|
|
| 501 |
|
| 502 |
+
# Generate using language model
|
| 503 |
+
output = self.language_model.generate(
|
| 504 |
inputs_embeds=inputs_embeds,
|
| 505 |
attention_mask=attention_mask,
|
| 506 |
+
generation_config=self.generation_config,
|
| 507 |
**generate_kwargs,
|
| 508 |
)
|
| 509 |
|
| 510 |
+
# When using inputs_embeds without input_ids, generate returns only new tokens
|
| 511 |
+
if isinstance(output, torch.Tensor):
|
| 512 |
+
return output
|
| 513 |
+
return output.sequences
|
| 514 |
|
| 515 |
def save_pretrained(self, save_directory: Union[str, Path], **kwargs):
|
| 516 |
+
"""Save model, tokenizer, and processor."""
|
| 517 |
import shutil
|
| 518 |
from pathlib import Path as PathlibPath
|
| 519 |
|
| 520 |
save_dir = PathlibPath(save_directory)
|
| 521 |
save_dir.mkdir(parents=True, exist_ok=True)
|
| 522 |
|
| 523 |
+
# Update config with actual vocab size
|
| 524 |
+
self.config.vocab_size = self.language_model.config.vocab_size
|
| 525 |
+
self.config.text_config.vocab_size = self.language_model.config.vocab_size
|
| 526 |
|
| 527 |
+
if hasattr(self.audio_tower.config, "num_mel_bins"):
|
| 528 |
+
self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
|
| 529 |
|
| 530 |
+
# Save model (temporarily remove non-serializable attributes)
|
| 531 |
tokenizer = self.tokenizer
|
|
|
|
| 532 |
del self.tokenizer
|
| 533 |
|
| 534 |
try:
|
| 535 |
super().save_pretrained(save_dir, **kwargs)
|
| 536 |
finally:
|
|
|
|
| 537 |
self.tokenizer = tokenizer
|
| 538 |
|
| 539 |
+
# Save tokenizer and feature extractor
|
| 540 |
self.tokenizer.save_pretrained(save_dir)
|
| 541 |
+
self.feature_extractor.save_pretrained(save_dir)
|
| 542 |
+
|
| 543 |
+
# Add processor auto_map to preprocessor_config.json
|
| 544 |
+
config_path = save_dir / "preprocessor_config.json"
|
| 545 |
+
if config_path.exists():
|
| 546 |
+
with config_path.open() as f:
|
| 547 |
+
processor_config = json.load(f)
|
| 548 |
+
else:
|
| 549 |
+
processor_config = {}
|
| 550 |
|
| 551 |
+
processor_config.update(
|
| 552 |
+
{
|
| 553 |
+
"processor_class": "ASRProcessor",
|
| 554 |
+
"auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
|
| 555 |
+
}
|
| 556 |
+
)
|
|
|
|
|
|
|
| 557 |
|
| 558 |
+
with config_path.open("w") as f:
|
| 559 |
+
json.dump(processor_config, f, indent=2)
|
| 560 |
|
| 561 |
+
# Copy source files for auto-loading
|
| 562 |
src_dir = PathlibPath(__file__).parent
|
| 563 |
for asr_file in src_dir.glob("asr_*.py"):
|
| 564 |
shutil.copy(asr_file, save_dir / asr_file.name)
|
| 565 |
+
# Copy projector files
|
| 566 |
+
projector_files = [
|
| 567 |
+
"mlp_projector.py",
|
| 568 |
+
"moe_projector.py",
|
| 569 |
+
"residual_projector.py",
|
| 570 |
+
"swiglu_projector.py",
|
| 571 |
+
"shared_moe_projector.py",
|
| 572 |
+
]
|
| 573 |
+
for projector_file in projector_files:
|
| 574 |
+
src_path = src_dir / projector_file
|
| 575 |
+
if src_path.exists():
|
| 576 |
+
shutil.copy(src_path, save_dir / projector_file)
|
| 577 |
|
| 578 |
|
| 579 |
+
# Register with transformers Auto classes
|
| 580 |
AutoConfig.register("asr_model", ASRConfig)
|
| 581 |
AutoModel.register(ASRConfig, ASRModel)
|
asr_pipeline.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
-
from typing import Any
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import transformers
|
| 5 |
-
from truecase import get_true_case
|
| 6 |
|
| 7 |
try:
|
| 8 |
from .asr_modeling import ASRModel
|
|
@@ -11,284 +10,58 @@ except ImportError:
|
|
| 11 |
|
| 12 |
|
| 13 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
|
|
|
|
|
|
| 14 |
model: ASRModel
|
| 15 |
|
| 16 |
def __init__(self, model: ASRModel, **kwargs):
|
| 17 |
-
feature_extractor = kwargs.pop("feature_extractor",
|
| 18 |
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
super().__init__(
|
| 21 |
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
|
| 22 |
)
|
| 23 |
|
| 24 |
-
# Initialize text normalizer (same as train.py)
|
| 25 |
-
if hasattr(tokenizer, "normalize"):
|
| 26 |
-
self.text_normalizer = tokenizer
|
| 27 |
-
else:
|
| 28 |
-
# Fallback to whisper-tiny tokenizer for its normalize() method only
|
| 29 |
-
from transformers import WhisperTokenizer
|
| 30 |
-
|
| 31 |
-
self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
|
| 32 |
-
|
| 33 |
-
def __call__(self, inputs, **kwargs):
|
| 34 |
-
generate_kwargs = {}
|
| 35 |
-
for key in [
|
| 36 |
-
"max_new_tokens",
|
| 37 |
-
"num_beams",
|
| 38 |
-
"do_sample",
|
| 39 |
-
"length_penalty",
|
| 40 |
-
"repetition_penalty",
|
| 41 |
-
"no_repeat_ngram_size",
|
| 42 |
-
"early_stopping",
|
| 43 |
-
"num_beam_groups",
|
| 44 |
-
"diversity_penalty",
|
| 45 |
-
"top_k",
|
| 46 |
-
"temperature",
|
| 47 |
-
"top_p",
|
| 48 |
-
"user_prompt",
|
| 49 |
-
"task",
|
| 50 |
-
"text_input",
|
| 51 |
-
]:
|
| 52 |
-
if key in kwargs:
|
| 53 |
-
generate_kwargs[key] = kwargs.pop(key)
|
| 54 |
-
|
| 55 |
-
# Handle text-only mode
|
| 56 |
-
task = generate_kwargs.get("task")
|
| 57 |
-
if task == "text" or generate_kwargs.get("text_input"):
|
| 58 |
-
return self._process_text_only(generate_kwargs)
|
| 59 |
-
|
| 60 |
-
if isinstance(inputs, list):
|
| 61 |
-
results = []
|
| 62 |
-
for single_input in inputs:
|
| 63 |
-
result = self.__call__(single_input, **kwargs, **generate_kwargs)
|
| 64 |
-
results.append(result)
|
| 65 |
-
return results
|
| 66 |
-
|
| 67 |
-
model_inputs = self.preprocess(inputs, **kwargs)
|
| 68 |
-
|
| 69 |
-
from collections.abc import Iterator
|
| 70 |
-
|
| 71 |
-
if isinstance(model_inputs, Iterator):
|
| 72 |
-
# Convert iterator to list to process chunks
|
| 73 |
-
chunks = list(model_inputs)
|
| 74 |
-
|
| 75 |
-
all_outputs = []
|
| 76 |
-
for _chunk_num, chunk in enumerate(chunks, start=1):
|
| 77 |
-
chunk_output = self._forward(chunk, **generate_kwargs)
|
| 78 |
-
# Move tensors to CPU before adding to outputs
|
| 79 |
-
for key, value in chunk_output.items():
|
| 80 |
-
if torch.is_tensor(value):
|
| 81 |
-
chunk_output[key] = value.cpu()
|
| 82 |
-
all_outputs.append(chunk_output)
|
| 83 |
-
|
| 84 |
-
# Merge chunks and decode ourselves to ensure skip_special_tokens=True
|
| 85 |
-
all_tokens: list[int] = []
|
| 86 |
-
for output in all_outputs:
|
| 87 |
-
tokens = output.get("tokens")
|
| 88 |
-
if tokens is None:
|
| 89 |
-
tokens = output.get("generated_ids")
|
| 90 |
-
if tokens is not None:
|
| 91 |
-
if torch.is_tensor(tokens):
|
| 92 |
-
tokens = tokens.cpu()
|
| 93 |
-
if len(tokens.shape) > 1:
|
| 94 |
-
tokens = tokens[0]
|
| 95 |
-
all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens)
|
| 96 |
-
|
| 97 |
-
# Decode the merged tokens with skip_special_tokens
|
| 98 |
-
text = self.tokenizer.decode(all_tokens, skip_special_tokens=True)
|
| 99 |
-
text = text.strip()
|
| 100 |
-
|
| 101 |
-
# Apply Whisper normalization (matches training)
|
| 102 |
-
text = self.text_normalizer.normalize(text)
|
| 103 |
-
|
| 104 |
-
# Apply truecasing for proper capitalization
|
| 105 |
-
text = get_true_case(text)
|
| 106 |
-
|
| 107 |
-
return {"text": text}
|
| 108 |
-
|
| 109 |
-
model_outputs = self._forward(model_inputs, **generate_kwargs)
|
| 110 |
-
return self.postprocess(model_outputs)
|
| 111 |
-
|
| 112 |
def preprocess(self, inputs, **preprocess_params):
|
| 113 |
-
|
| 114 |
-
raise ValueError("Lists should not reach preprocess - bug in __call__")
|
| 115 |
|
| 116 |
-
#
|
| 117 |
-
|
| 118 |
-
preprocess_params.setdefault("stride_length_s", (5, 5))
|
| 119 |
-
|
| 120 |
-
# Handle different formats from datasets
|
| 121 |
-
if isinstance(inputs, dict):
|
| 122 |
-
if "bytes" in inputs:
|
| 123 |
-
# Decode bytes to audio array using torchcodec
|
| 124 |
-
import tempfile
|
| 125 |
-
|
| 126 |
-
from torchcodec.decoders import AudioDecoder
|
| 127 |
-
|
| 128 |
-
wav_bytes = inputs["bytes"]
|
| 129 |
-
# Write to temp file for torchcodec to read
|
| 130 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 131 |
-
f.write(wav_bytes)
|
| 132 |
-
temp_path = f.name
|
| 133 |
-
try:
|
| 134 |
-
decoder = AudioDecoder(temp_path)
|
| 135 |
-
# Get all audio samples
|
| 136 |
-
audio_result = decoder.get_all_samples()
|
| 137 |
-
audio_tensor = audio_result.data
|
| 138 |
-
sample_rate = audio_result.sample_rate
|
| 139 |
-
inputs = {"raw": audio_tensor.squeeze().numpy(), "sampling_rate": sample_rate}
|
| 140 |
-
finally:
|
| 141 |
-
from pathlib import Path
|
| 142 |
-
|
| 143 |
-
Path(temp_path).unlink()
|
| 144 |
-
elif "array" in inputs:
|
| 145 |
-
# Convert "array" key to "raw" key
|
| 146 |
-
inputs = {"raw": inputs["array"], "sampling_rate": inputs["sampling_rate"]}
|
| 147 |
-
# If it already has "raw" and "sampling_rate", it's good to go
|
| 148 |
-
elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"):
|
| 149 |
-
# Audio object with attributes (not dict)
|
| 150 |
-
inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate}
|
| 151 |
-
elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)):
|
| 152 |
-
inputs = {"raw": inputs, "sampling_rate": self.model.config.audio_sample_rate}
|
| 153 |
-
elif torch.is_tensor(inputs):
|
| 154 |
inputs = {
|
| 155 |
-
"raw": inputs
|
| 156 |
-
"sampling_rate": self.
|
| 157 |
}
|
| 158 |
|
| 159 |
return super().preprocess(inputs, **preprocess_params)
|
| 160 |
|
| 161 |
-
def _forward(self, model_inputs, **generate_kwargs):
|
| 162 |
-
# Extract
|
| 163 |
-
task = generate_kwargs.pop("task", None)
|
| 164 |
-
|
| 165 |
-
# Task-specific sampling parameters
|
| 166 |
-
task_params: Dict[str, Dict[str, Any]] = {
|
| 167 |
-
"transcribe": {"do_sample": False},
|
| 168 |
-
"emotion": {"do_sample": True, "temperature": 0.7},
|
| 169 |
-
"describe": {"do_sample": True, "temperature": 0.7},
|
| 170 |
-
"continue": {"do_sample": True, "temperature": 1.0},
|
| 171 |
-
}
|
| 172 |
-
|
| 173 |
-
if task in task_params:
|
| 174 |
-
for key, value in task_params[task].items():
|
| 175 |
-
generate_kwargs.setdefault(key, value)
|
| 176 |
-
|
| 177 |
-
# Extract audio inputs from various formats
|
| 178 |
-
is_last = True
|
| 179 |
-
audio_inputs = None
|
| 180 |
-
is_whisper = False # Track if this is Whisper input
|
| 181 |
-
|
| 182 |
-
# Normalize model_inputs to dict format
|
| 183 |
-
if isinstance(model_inputs, torch.Tensor):
|
| 184 |
-
audio_inputs = model_inputs
|
| 185 |
-
elif isinstance(model_inputs, (list, tuple)) and model_inputs:
|
| 186 |
-
model_inputs = (
|
| 187 |
-
model_inputs[0]
|
| 188 |
-
if isinstance(model_inputs[0], dict)
|
| 189 |
-
else {"input_values": model_inputs[0]}
|
| 190 |
-
)
|
| 191 |
-
|
| 192 |
if isinstance(model_inputs, dict):
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
# Get audio input (Whisper uses input_features, others use input_values)
|
| 197 |
-
if "input_features" in model_inputs:
|
| 198 |
-
audio_inputs = model_inputs["input_features"]
|
| 199 |
-
is_whisper = True
|
| 200 |
-
else:
|
| 201 |
-
audio_inputs = model_inputs.get("input_values")
|
| 202 |
-
|
| 203 |
-
if audio_inputs is None:
|
| 204 |
-
raise ValueError(
|
| 205 |
-
f"Could not extract input_values or input_features from {type(model_inputs)}"
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
-
if isinstance(audio_inputs, torch.Tensor):
|
| 209 |
-
audio_inputs = audio_inputs.to(self.model.device)
|
| 210 |
else:
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
im_end_id = self.model.tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 214 |
-
generate_kwargs.setdefault("eos_token_id", im_end_id)
|
| 215 |
-
generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
|
| 216 |
-
|
| 217 |
-
# Pass the appropriate input type to generate
|
| 218 |
-
if is_whisper:
|
| 219 |
-
# Whisper model - use input_features
|
| 220 |
-
generated_ids = self.model.generate(
|
| 221 |
-
input_features=audio_inputs,
|
| 222 |
-
system_prompt=self.model.config.system_prompt,
|
| 223 |
-
task=task,
|
| 224 |
-
**generate_kwargs,
|
| 225 |
-
)
|
| 226 |
-
else:
|
| 227 |
-
# Wav2Vec2/HuBERT model - use input_values
|
| 228 |
-
generated_ids = self.model.generate(
|
| 229 |
-
input_values=audio_inputs,
|
| 230 |
-
system_prompt=self.model.config.system_prompt,
|
| 231 |
-
task=task,
|
| 232 |
-
**generate_kwargs,
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
return {"tokens": generated_ids, "is_last": is_last}
|
| 236 |
-
|
| 237 |
-
def _process_text_only(self, generate_kwargs):
|
| 238 |
-
"""Process text-only input without audio encoding."""
|
| 239 |
-
text_input = generate_kwargs.pop("text_input", None)
|
| 240 |
-
if text_input is None:
|
| 241 |
-
raise ValueError("text_input is required for text task")
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
generated_ids = self.model.generate(task="text", text_input=text_input, **generate_kwargs)
|
| 248 |
-
|
| 249 |
-
# Decode the generated text
|
| 250 |
-
generated_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
| 251 |
-
|
| 252 |
-
return {"text": generated_text}
|
| 253 |
-
|
| 254 |
-
def postprocess(
|
| 255 |
-
self, model_outputs: Dict[str, Any], return_timestamps=None, return_language=None
|
| 256 |
-
):
|
| 257 |
-
# Handle chunked outputs from iterator
|
| 258 |
-
if isinstance(model_outputs, list):
|
| 259 |
-
# Move all tensors to CPU before calling parent postprocess
|
| 260 |
-
for output_dict in model_outputs:
|
| 261 |
-
for key, value in output_dict.items():
|
| 262 |
-
if torch.is_tensor(value):
|
| 263 |
-
output_dict[key] = value.cpu()
|
| 264 |
-
return super().postprocess(model_outputs)
|
| 265 |
|
| 266 |
-
|
| 267 |
-
model_outputs.pop("is_last")
|
| 268 |
|
|
|
|
| 269 |
tokens = model_outputs.get("tokens")
|
| 270 |
if tokens is None:
|
| 271 |
-
|
| 272 |
|
| 273 |
-
if tokens
|
| 274 |
-
raise ValueError(
|
| 275 |
-
f"Expected 'tokens' or 'generated_ids' in model_outputs, got: {model_outputs.keys()}"
|
| 276 |
-
)
|
| 277 |
-
|
| 278 |
-
# Move to CPU if on MPS or other device
|
| 279 |
-
if torch.is_tensor(tokens) and tokens.device.type != "cpu":
|
| 280 |
tokens = tokens.cpu()
|
|
|
|
|
|
|
| 281 |
|
| 282 |
-
|
| 283 |
-
tokens = tokens[0]
|
| 284 |
-
|
| 285 |
-
text = self.tokenizer.decode(tokens, skip_special_tokens=True)
|
| 286 |
-
text = text.strip()
|
| 287 |
-
|
| 288 |
-
# Apply Whisper normalization (matches training)
|
| 289 |
-
text = self.text_normalizer.normalize(text)
|
| 290 |
-
|
| 291 |
-
# Apply truecasing for proper capitalization
|
| 292 |
-
text = get_true_case(text)
|
| 293 |
-
|
| 294 |
return {"text": text}
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import transformers
|
|
|
|
| 5 |
|
| 6 |
try:
|
| 7 |
from .asr_modeling import ASRModel
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 13 |
+
"""ASR Pipeline for audio-to-text transcription."""
|
| 14 |
+
|
| 15 |
model: ASRModel
|
| 16 |
|
| 17 |
def __init__(self, model: ASRModel, **kwargs):
|
| 18 |
+
feature_extractor = kwargs.pop("feature_extractor", None)
|
| 19 |
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
|
| 20 |
|
| 21 |
+
if feature_extractor is None:
|
| 22 |
+
feature_extractor = model.get_processor().feature_extractor
|
| 23 |
+
|
| 24 |
super().__init__(
|
| 25 |
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
|
| 26 |
)
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def preprocess(self, inputs, **preprocess_params):
|
| 29 |
+
preprocess_params.setdefault("chunk_length_s", 0)
|
|
|
|
| 30 |
|
| 31 |
+
# Handle dict with "array" key (from datasets)
|
| 32 |
+
if isinstance(inputs, dict) and "array" in inputs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
inputs = {
|
| 34 |
+
"raw": inputs["array"],
|
| 35 |
+
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
|
| 36 |
}
|
| 37 |
|
| 38 |
return super().preprocess(inputs, **preprocess_params)
|
| 39 |
|
| 40 |
+
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
|
| 41 |
+
# Extract audio features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
if isinstance(model_inputs, dict):
|
| 43 |
+
input_features = model_inputs.get("input_features")
|
| 44 |
+
if input_features is not None:
|
| 45 |
+
input_features = input_features.to(self.model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
else:
|
| 47 |
+
input_features = model_inputs.to(self.model.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
generated_ids = self.model.generate(
|
| 50 |
+
input_features=input_features,
|
| 51 |
+
**generate_kwargs,
|
| 52 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
return {"tokens": generated_ids}
|
|
|
|
| 55 |
|
| 56 |
+
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
|
| 57 |
tokens = model_outputs.get("tokens")
|
| 58 |
if tokens is None:
|
| 59 |
+
return super().postprocess(model_outputs, **kwargs)
|
| 60 |
|
| 61 |
+
if torch.is_tensor(tokens):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
tokens = tokens.cpu()
|
| 63 |
+
if tokens.dim() > 1:
|
| 64 |
+
tokens = tokens[0]
|
| 65 |
|
| 66 |
+
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
return {"text": text}
|
asr_processing.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import transformers
|
| 2 |
-
from transformers import
|
| 3 |
|
| 4 |
-
# Handle both package and standalone imports
|
| 5 |
try:
|
| 6 |
from .asr_config import ASRConfig
|
| 7 |
except ImportError:
|
|
@@ -9,69 +11,81 @@ except ImportError:
|
|
| 9 |
|
| 10 |
|
| 11 |
class ASRProcessor(ProcessorMixin):
|
| 12 |
-
"""
|
| 13 |
|
|
|
|
| 14 |
feature_extractor_class = "AutoFeatureExtractor"
|
| 15 |
tokenizer_class = "AutoTokenizer"
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def __init__(self, feature_extractor, tokenizer):
|
| 18 |
self.feature_extractor = feature_extractor
|
| 19 |
self.tokenizer = tokenizer
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
""
|
| 38 |
-
import json
|
| 39 |
-
from pathlib import Path
|
| 40 |
-
|
| 41 |
-
save_path = Path(save_directory)
|
| 42 |
-
save_path.mkdir(parents=True, exist_ok=True)
|
| 43 |
-
|
| 44 |
-
# Save the feature extractor (this creates preprocessor_config.json with all feature extractor settings)
|
| 45 |
-
if self.feature_extractor is not None:
|
| 46 |
-
self.feature_extractor.save_pretrained(save_directory)
|
| 47 |
-
|
| 48 |
-
# Save the tokenizer
|
| 49 |
-
if self.tokenizer is not None:
|
| 50 |
-
self.tokenizer.save_pretrained(save_directory)
|
| 51 |
-
|
| 52 |
-
# Load the existing preprocessor_config.json and add processor-specific metadata
|
| 53 |
-
config_path = save_path / "preprocessor_config.json"
|
| 54 |
-
if config_path.exists():
|
| 55 |
-
with config_path.open() as f:
|
| 56 |
-
processor_config = json.load(f)
|
| 57 |
-
else:
|
| 58 |
-
processor_config = {}
|
| 59 |
-
|
| 60 |
-
# Add/update processor metadata while preserving feature extractor settings
|
| 61 |
-
feature_extractor_type = self.feature_extractor.__class__.__name__
|
| 62 |
-
processor_config.update(
|
| 63 |
-
{
|
| 64 |
-
"processor_class": self.__class__.__name__,
|
| 65 |
-
"feature_extractor_class": self.feature_extractor_class,
|
| 66 |
-
"tokenizer_class": self.tokenizer_class,
|
| 67 |
-
"feature_extractor_type": feature_extractor_type, # Dynamic based on actual type
|
| 68 |
-
"auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
|
| 69 |
-
}
|
| 70 |
-
)
|
| 71 |
|
| 72 |
-
|
| 73 |
-
with config_path.open("w") as f:
|
| 74 |
-
json.dump(processor_config, f, indent=2)
|
| 75 |
|
| 76 |
|
| 77 |
ASRProcessor.register_for_auto_class()
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
import transformers
|
| 5 |
+
from transformers import ProcessorMixin
|
| 6 |
|
|
|
|
| 7 |
try:
|
| 8 |
from .asr_config import ASRConfig
|
| 9 |
except ImportError:
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class ASRProcessor(ProcessorMixin):
|
| 14 |
+
"""Processor for Whisper-based ASR models."""
|
| 15 |
|
| 16 |
+
attributes = ["feature_extractor", "tokenizer"]
|
| 17 |
feature_extractor_class = "AutoFeatureExtractor"
|
| 18 |
tokenizer_class = "AutoTokenizer"
|
| 19 |
+
AUDIO_TOKEN = "<audio>"
|
| 20 |
+
TRANSCRIBE_PROMPT = "Transcribe: "
|
| 21 |
|
| 22 |
def __init__(self, feature_extractor, tokenizer):
|
| 23 |
self.feature_extractor = feature_extractor
|
| 24 |
self.tokenizer = tokenizer
|
| 25 |
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
|
| 26 |
+
|
| 27 |
+
def __call__(
|
| 28 |
+
self,
|
| 29 |
+
audio: Optional[Union[list, "torch.Tensor"]] = None,
|
| 30 |
+
text: Optional[str] = None,
|
| 31 |
+
system_prompt: Optional[str] = None,
|
| 32 |
+
return_tensors: str = "pt",
|
| 33 |
+
**kwargs,
|
| 34 |
+
) -> dict:
|
| 35 |
+
"""Process audio and text inputs for inference.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
audio: Raw audio waveform(s)
|
| 39 |
+
text: Target transcription (optional, for training - but use DataCollator instead)
|
| 40 |
+
system_prompt: Optional system prompt
|
| 41 |
+
return_tensors: Return format ("pt" for PyTorch)
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Dict with input_features, input_ids, attention_mask
|
| 45 |
+
"""
|
| 46 |
+
result = {}
|
| 47 |
+
|
| 48 |
+
# Process audio
|
| 49 |
+
if audio is not None:
|
| 50 |
+
audio_inputs = self.feature_extractor(
|
| 51 |
+
audio,
|
| 52 |
+
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
|
| 53 |
+
return_tensors=return_tensors,
|
| 54 |
+
**kwargs,
|
| 55 |
+
)
|
| 56 |
+
result["input_features"] = audio_inputs["input_features"]
|
| 57 |
+
# Whisper encoder output length = mel_len // 2 (stride-2 conv)
|
| 58 |
+
num_audio_tokens = audio_inputs["input_features"].shape[-1] // 2
|
| 59 |
+
else:
|
| 60 |
+
num_audio_tokens = 0
|
| 61 |
+
|
| 62 |
+
# Build prompt with audio token placeholders
|
| 63 |
+
user_content = self.TRANSCRIBE_PROMPT
|
| 64 |
+
if num_audio_tokens > 0:
|
| 65 |
+
user_content += self.AUDIO_TOKEN * num_audio_tokens
|
| 66 |
+
|
| 67 |
+
messages = []
|
| 68 |
+
if system_prompt:
|
| 69 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 70 |
+
messages.append({"role": "user", "content": user_content})
|
| 71 |
+
if text is not None:
|
| 72 |
+
messages.append({"role": "assistant", "content": text})
|
| 73 |
+
|
| 74 |
+
# Tokenize
|
| 75 |
+
input_ids = self.tokenizer.apply_chat_template(
|
| 76 |
+
messages,
|
| 77 |
+
tokenize=True,
|
| 78 |
+
add_generation_prompt=(text is None),
|
| 79 |
+
return_tensors=return_tensors,
|
| 80 |
)
|
| 81 |
|
| 82 |
+
if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 1:
|
| 83 |
+
input_ids = input_ids.unsqueeze(0)
|
| 84 |
|
| 85 |
+
result["input_ids"] = input_ids
|
| 86 |
+
result["attention_mask"] = torch.ones_like(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
return result
|
|
|
|
|
|
|
| 89 |
|
| 90 |
|
| 91 |
ASRProcessor.register_for_auto_class()
|
chat_template.jinja
CHANGED
|
@@ -1,6 +1,94 @@
|
|
| 1 |
-
{
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{# ───── defaults ───── #}
|
| 2 |
+
{%- if enable_thinking is not defined -%}
|
| 3 |
+
{%- set enable_thinking = true -%}
|
| 4 |
+
{%- endif -%}
|
| 5 |
+
|
| 6 |
+
{# ───── reasoning mode ───── #}
|
| 7 |
+
{%- if enable_thinking -%}
|
| 8 |
+
{%- set reasoning_mode = "/think" -%}
|
| 9 |
+
{%- else -%}
|
| 10 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 11 |
+
{%- endif -%}
|
| 12 |
+
|
| 13 |
+
{# ───── header (system message) ───── #}
|
| 14 |
+
{{- "<|im_start|>system\n" -}}
|
| 15 |
+
|
| 16 |
+
{%- if messages[0].role == "system" -%}
|
| 17 |
+
{%- set system_message = messages[0].content -%}
|
| 18 |
+
{%- if "/no_think" in system_message -%}
|
| 19 |
+
{%- set reasoning_mode = "/no_think" -%}
|
| 20 |
+
{%- elif "/think" in system_message -%}
|
| 21 |
+
{%- set reasoning_mode = "/think" -%}
|
| 22 |
+
{%- endif -%}
|
| 23 |
+
{%- set custom_instructions = system_message.replace("/no_think", "").replace("/think", "").rstrip() -%}
|
| 24 |
+
{%- endif -%}
|
| 25 |
+
|
| 26 |
+
{%- if "/system_override" in system_message -%}
|
| 27 |
+
{{- custom_instructions.replace("/system_override", "").rstrip() -}}
|
| 28 |
+
{{- "<|im_end|>\n" -}}
|
| 29 |
+
{%- else -%}
|
| 30 |
+
{{- "## Metadata\n\n" -}}
|
| 31 |
+
{{- "Knowledge Cutoff Date: June 2025\n" -}}
|
| 32 |
+
{%- set today = strftime_now("%d %B %Y") -%}
|
| 33 |
+
{{- "Today Date: " ~ today ~ "\n" -}}
|
| 34 |
+
{{- "Reasoning Mode: " + reasoning_mode + "\n\n" -}}
|
| 35 |
+
|
| 36 |
+
{{- "## Custom Instructions\n\n" -}}
|
| 37 |
+
{%- if custom_instructions -%}
|
| 38 |
+
{{- custom_instructions + "\n\n" -}}
|
| 39 |
+
{%- elif reasoning_mode == "/think" -%}
|
| 40 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face. Your role as an assistant involves thoroughly exploring questions through a systematic thinking process before providing the final precise and accurate solutions. This requires engaging in a comprehensive cycle of analysis, summarizing, exploration, reassessment, reflection, backtracking, and iteration to develop well-considered thinking process. Please structure your response into two main sections: Thought and Solution using the specified format: <think> Thought section </think> Solution section. In the Thought section, detail your reasoning process in steps. Each step should include detailed considerations such as analysing questions, summarizing relevant findings, brainstorming new ideas, verifying the accuracy of the current steps, refining any errors, and revisiting previous steps. In the Solution section, based on various attempts, explorations, and reflections from the Thought section, systematically present the final solution that you deem correct. The Solution section should be logical, accurate, and concise and detail necessary steps needed to reach the conclusion.\n\n" -}}
|
| 41 |
+
{%- else -%}
|
| 42 |
+
{{- "You are a helpful AI assistant named SmolLM, trained by Hugging Face.\n\n" -}}
|
| 43 |
+
{%- endif -%}
|
| 44 |
+
|
| 45 |
+
{%- if xml_tools or python_tools or tools -%}
|
| 46 |
+
{{- "### Tools\n\n" -}}
|
| 47 |
+
{%- if xml_tools or tools -%}
|
| 48 |
+
{%- if tools -%}
|
| 49 |
+
{%- set xml_tools = tools -%}
|
| 50 |
+
{%- endif -%}
|
| 51 |
+
{%- set ns = namespace(xml_tool_string="You may call one or more functions to assist with the user query.\nYou are provided with function signatures within <tools></tools> XML tags:\n\n<tools>\n") -%}
|
| 52 |
+
{%- for tool in xml_tools[:] -%} {# The slicing makes sure that xml_tools is a list #}
|
| 53 |
+
{%- set ns.xml_tool_string = ns.xml_tool_string ~ (tool | string) ~ "\n" -%}
|
| 54 |
+
{%- endfor -%}
|
| 55 |
+
{%- set xml_tool_string = ns.xml_tool_string + "</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call>" -%}
|
| 56 |
+
{{- xml_tool_string -}}
|
| 57 |
+
{%- endif -%}
|
| 58 |
+
{%- if python_tools -%}
|
| 59 |
+
{%- set ns = namespace(python_tool_string="When you send a message containing Python code between '<code>' and '</code>' tags, it will be executed in a stateful Jupyter notebook environment, and you will then be given the output to continued reasoning in an agentic loop.\n\nYou can use the following tools in your python code like regular functions:\n<tools>\n") -%}
|
| 60 |
+
{%- for tool in python_tools[:] -%} {# The slicing makes sure that python_tools is a list #}
|
| 61 |
+
{%- set ns.python_tool_string = ns.python_tool_string ~ (tool | string) ~ "\n" -%}
|
| 62 |
+
{%- endfor -%}
|
| 63 |
+
{%- set python_tool_string = ns.python_tool_string + "</tools>\n\nThe state persists between code executions: so variables that you define in one step are still available thereafter." -%}
|
| 64 |
+
{{- python_tool_string -}}
|
| 65 |
+
{%- endif -%}
|
| 66 |
+
{{- "\n\n" -}}
|
| 67 |
+
{{- "<|im_end|>\n" -}}
|
| 68 |
+
{%- endif -%}
|
| 69 |
+
{%- endif -%}
|
| 70 |
+
{# ───── main loop ───── #}
|
| 71 |
+
{%- for message in messages -%}
|
| 72 |
+
{%- set content = message.content if message.content is string else "" -%}
|
| 73 |
+
{%- if message.role == "user" -%}
|
| 74 |
+
{{ "<|im_start|>" + message.role + "\n" + content + "<|im_end|>\n" }}
|
| 75 |
+
{%- elif message.role == "assistant" -%}
|
| 76 |
+
{% generation %}
|
| 77 |
+
{%- if reasoning_mode == "/think" -%}
|
| 78 |
+
{{ "<|im_start|>assistant\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 79 |
+
{%- else -%}
|
| 80 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" + content.lstrip("\n") + "<|im_end|>\n" }}
|
| 81 |
+
{%- endif -%}
|
| 82 |
+
{% endgeneration %}
|
| 83 |
+
{%- elif message.role == "tool" -%}
|
| 84 |
+
{{ "<|im_start|>" + "user\n" + content + "<|im_end|>\n" }}
|
| 85 |
+
{%- endif -%}
|
| 86 |
+
{%- endfor -%}
|
| 87 |
+
{# ───── generation prompt ───── #}
|
| 88 |
+
{%- if add_generation_prompt -%}
|
| 89 |
+
{%- if reasoning_mode == "/think" -%}
|
| 90 |
+
{{ "<|im_start|>assistant\n" }}
|
| 91 |
+
{%- else -%}
|
| 92 |
+
{{ "<|im_start|>assistant\n" + "<think>\n\n</think>\n" }}
|
| 93 |
+
{%- endif -%}
|
| 94 |
+
{%- endif -%}
|
mlp_projector.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class MLPAudioProjector(nn.Module):
|
| 5 |
+
"""2-layer MLP projector with Qwen-style 2x temporal downsampling."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
encoder_dim = getattr(config, "encoder_dim", 768)
|
| 11 |
+
llm_dim = getattr(config, "llm_dim", 2048)
|
| 12 |
+
|
| 13 |
+
self.downsample = nn.Conv1d(
|
| 14 |
+
encoder_dim, encoder_dim, kernel_size=3, stride=2, padding=1, bias=False
|
| 15 |
+
)
|
| 16 |
+
self.linear_1 = nn.Linear(encoder_dim, llm_dim, bias=False)
|
| 17 |
+
self.act = nn.GELU()
|
| 18 |
+
self.linear_2 = nn.Linear(llm_dim, llm_dim, bias=False)
|
| 19 |
+
|
| 20 |
+
self.apply(self._init_weights)
|
| 21 |
+
|
| 22 |
+
def _init_weights(self, module):
|
| 23 |
+
if isinstance(module, nn.Linear):
|
| 24 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 25 |
+
elif isinstance(module, nn.Conv1d):
|
| 26 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 27 |
+
if module.bias is not None:
|
| 28 |
+
nn.init.zeros_(module.bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
"""
|
| 32 |
+
x: [Batch, Seq_Len, Dim]
|
| 33 |
+
Returns: [Batch, Seq_Len // 2, llm_dim]
|
| 34 |
+
"""
|
| 35 |
+
# Conv1d expects [Batch, Channels, Seq_Len]
|
| 36 |
+
x = x.transpose(1, 2)
|
| 37 |
+
x = self.downsample(x)
|
| 38 |
+
x = x.transpose(1, 2)
|
| 39 |
+
|
| 40 |
+
x = self.linear_1(x)
|
| 41 |
+
x = self.act(x)
|
| 42 |
+
return self.linear_2(x)
|
moe_projector.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F # noqa: N812
|
| 4 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SimpleAdapter(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
MOSA Section III-B:
|
| 10 |
+
"consists of two linear layers with a ReLU activation in between,
|
| 11 |
+
projecting the hidden dimension from 3072 to 4096 and back to 3072."
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, in_features, hidden_features, out_features, dropout=0.0):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 17 |
+
self.relu = nn.ReLU()
|
| 18 |
+
self.dropout = nn.Dropout(dropout)
|
| 19 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = self.fc1(x)
|
| 23 |
+
x = self.relu(x)
|
| 24 |
+
x = self.dropout(x)
|
| 25 |
+
return self.fc2(x)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MoEAudioProjector(nn.Module):
|
| 29 |
+
"""
|
| 30 |
+
MOSA-style projector: Mixture of Simple Adapters.
|
| 31 |
+
|
| 32 |
+
From paper (arXiv:2508.18998):
|
| 33 |
+
- Dense mixture (softmax over ALL experts) instead of sparse Top-K
|
| 34 |
+
- Simple Linear->ReLU->Linear adapters (3072->4096->3072)
|
| 35 |
+
- No auxiliary losses - just cross-entropy on transcripts
|
| 36 |
+
- Conv downsampling: stride 4 total (two conv layers, stride 2 each)
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, config):
|
| 40 |
+
super().__init__()
|
| 41 |
+
|
| 42 |
+
# Dimensions:
|
| 43 |
+
# Whisper-large-v3 encoder_dim = 1280
|
| 44 |
+
# SmolLM3-3B hidden_size = 2048
|
| 45 |
+
self.encoder_dim = config.encoder_dim # 1280
|
| 46 |
+
self.llm_dim = config.llm_dim # 2048
|
| 47 |
+
|
| 48 |
+
# Number of experts: Base=4, Large=8
|
| 49 |
+
self.num_experts = getattr(config, "num_experts", 4)
|
| 50 |
+
|
| 51 |
+
# Adapter hidden dim: paper uses 4096
|
| 52 |
+
adapter_hidden = getattr(config, "projector_hidden_dim", None) or 4096
|
| 53 |
+
|
| 54 |
+
# Dropout rate for experts (not applied to router)
|
| 55 |
+
self.dropout_rate = getattr(config, "projector_dropout", 0.1)
|
| 56 |
+
|
| 57 |
+
# --- Convolutional Subsampling (Section III-B) ---
|
| 58 |
+
# "two convolutional layers, each with a kernel size of 3 and a stride of 2"
|
| 59 |
+
# Maps encoder_dim (1280) -> llm_dim (3072), total stride=4
|
| 60 |
+
self.conv = nn.Sequential(
|
| 61 |
+
nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 62 |
+
nn.ReLU(),
|
| 63 |
+
nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 64 |
+
nn.ReLU(),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# --- Router (Section III-B) ---
|
| 68 |
+
# Base: "two linear layers... mapping from 1280 to 512 and finally to 4"
|
| 69 |
+
router_hidden = 512
|
| 70 |
+
self.router = nn.Sequential(
|
| 71 |
+
nn.Linear(self.encoder_dim, router_hidden),
|
| 72 |
+
nn.ReLU(),
|
| 73 |
+
nn.Linear(router_hidden, self.num_experts),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# --- Experts / Adapters (Section III-B) ---
|
| 77 |
+
# "projecting the hidden dimension from 3072 to 4096 and back to 3072"
|
| 78 |
+
self.experts = nn.ModuleList(
|
| 79 |
+
[
|
| 80 |
+
SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim, dropout=self.dropout_rate)
|
| 81 |
+
for _ in range(self.num_experts)
|
| 82 |
+
]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Normalization for stability (not in original MOSA but prevents FPE)
|
| 86 |
+
self.ln_post = LlamaRMSNorm(self.llm_dim, eps=1e-6)
|
| 87 |
+
|
| 88 |
+
# Initialize weights
|
| 89 |
+
self._init_weights()
|
| 90 |
+
|
| 91 |
+
def _init_weights(self):
|
| 92 |
+
"""Initialize weights for stable training."""
|
| 93 |
+
std = 0.02
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
# Conv layers
|
| 96 |
+
for module in self.conv:
|
| 97 |
+
if isinstance(module, nn.Conv1d):
|
| 98 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 99 |
+
if module.bias is not None:
|
| 100 |
+
nn.init.zeros_(module.bias)
|
| 101 |
+
|
| 102 |
+
# Router
|
| 103 |
+
for module in self.router:
|
| 104 |
+
if isinstance(module, nn.Linear):
|
| 105 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 106 |
+
if module.bias is not None:
|
| 107 |
+
nn.init.zeros_(module.bias)
|
| 108 |
+
|
| 109 |
+
# Experts
|
| 110 |
+
for expert in self.experts:
|
| 111 |
+
nn.init.normal_(expert.fc1.weight, mean=0.0, std=std)
|
| 112 |
+
nn.init.normal_(expert.fc2.weight, mean=0.0, std=std)
|
| 113 |
+
if expert.fc1.bias is not None:
|
| 114 |
+
nn.init.zeros_(expert.fc1.bias)
|
| 115 |
+
if expert.fc2.bias is not None:
|
| 116 |
+
nn.init.zeros_(expert.fc2.bias)
|
| 117 |
+
|
| 118 |
+
# LayerNorm
|
| 119 |
+
self.ln_post.weight.data.fill_(1.0)
|
| 120 |
+
|
| 121 |
+
def forward(self, x):
|
| 122 |
+
"""
|
| 123 |
+
Args:
|
| 124 |
+
x: [batch_size, seq_len, encoder_dim] from Whisper encoder (1280)
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
output: [batch_size, seq_len // 4, llm_dim] (3072)
|
| 128 |
+
"""
|
| 129 |
+
batch_size, seq_len, _ = x.shape
|
| 130 |
+
|
| 131 |
+
# Pad to be divisible by stride (4)
|
| 132 |
+
pad_amt = (4 - (seq_len % 4)) % 4
|
| 133 |
+
if pad_amt > 0:
|
| 134 |
+
x = F.pad(x, (0, 0, 0, pad_amt))
|
| 135 |
+
seq_len = x.shape[1]
|
| 136 |
+
|
| 137 |
+
# 1. Convolutional Downsampling
|
| 138 |
+
# (B, T, C) -> (B, C, T) -> conv -> (B, C, T//4) -> (B, T//4, C)
|
| 139 |
+
h_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1)
|
| 140 |
+
|
| 141 |
+
# 2. Router on high-res input, then downsample weights
|
| 142 |
+
router_logits = self.router(x) # [B, T, num_experts]
|
| 143 |
+
# Average over stride window to match conv output
|
| 144 |
+
router_logits = router_logits.view(batch_size, seq_len // 4, 4, self.num_experts).mean(
|
| 145 |
+
dim=2
|
| 146 |
+
)
|
| 147 |
+
# Dense softmax
|
| 148 |
+
routing_weights = F.softmax(router_logits, dim=-1) # [B, T//4, num_experts]
|
| 149 |
+
|
| 150 |
+
# 3. Weighted sum of expert outputs (Eq. 2: y = sum(w_i * E_i(x)))
|
| 151 |
+
# Use in-place add to reduce memory allocations
|
| 152 |
+
final_out = torch.zeros_like(h_conv)
|
| 153 |
+
for i, expert in enumerate(self.experts):
|
| 154 |
+
expert_out = expert(h_conv)
|
| 155 |
+
expert_weight = routing_weights[:, :, i : i + 1]
|
| 156 |
+
final_out.add_(expert_out * expert_weight)
|
| 157 |
+
|
| 158 |
+
return self.ln_post(final_out)
|
| 159 |
+
|
| 160 |
+
def get_aux_loss(self) -> torch.Tensor:
|
| 161 |
+
"""Return auxiliary loss (none for dense MoE - all experts always used)."""
|
| 162 |
+
return torch.tensor(0.0)
|
preprocessor_config.json
CHANGED
|
@@ -7,14 +7,11 @@
|
|
| 7 |
"n_fft": 400,
|
| 8 |
"n_samples": 480000,
|
| 9 |
"nb_max_frames": 3000,
|
| 10 |
-
"num_mel_bins": 128,
|
| 11 |
"padding_side": "right",
|
| 12 |
"padding_value": 0.0,
|
| 13 |
"processor_class": "ASRProcessor",
|
| 14 |
"return_attention_mask": false,
|
| 15 |
"sampling_rate": 16000,
|
| 16 |
-
"feature_extractor_class": "AutoFeatureExtractor",
|
| 17 |
-
"tokenizer_class": "AutoTokenizer",
|
| 18 |
"auto_map": {
|
| 19 |
"AutoProcessor": "asr_processing.ASRProcessor"
|
| 20 |
}
|
|
|
|
| 7 |
"n_fft": 400,
|
| 8 |
"n_samples": 480000,
|
| 9 |
"nb_max_frames": 3000,
|
|
|
|
| 10 |
"padding_side": "right",
|
| 11 |
"padding_value": 0.0,
|
| 12 |
"processor_class": "ASRProcessor",
|
| 13 |
"return_attention_mask": false,
|
| 14 |
"sampling_rate": 16000,
|
|
|
|
|
|
|
| 15 |
"auto_map": {
|
| 16 |
"AutoProcessor": "asr_processing.ASRProcessor"
|
| 17 |
}
|
residual_projector.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Residual MLP projector for Whisper → LLM feature space translation.
|
| 2 |
+
|
| 3 |
+
Philosophy: Whisper features are already information-complete. The projector
|
| 4 |
+
learns a nonlinear correction/refinement to align them with the LLM's expected
|
| 5 |
+
input distribution, rather than replacing them entirely.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F # noqa: N812
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ResidualMLP(nn.Module):
|
| 14 |
+
"""MLP block with residual connection.
|
| 15 |
+
|
| 16 |
+
Output = x + MLP(x)
|
| 17 |
+
|
| 18 |
+
At initialization (weights near zero), output ≈ input, providing a stable
|
| 19 |
+
starting point. The network learns to add nonlinear corrections as needed.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, dim, hidden_dim, dropout=0.0):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
| 25 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
| 26 |
+
self.act = nn.GELU()
|
| 27 |
+
self.dropout = nn.Dropout(dropout)
|
| 28 |
+
|
| 29 |
+
def forward(self, x):
|
| 30 |
+
residual = x
|
| 31 |
+
x = self.fc1(x)
|
| 32 |
+
x = self.act(x)
|
| 33 |
+
x = self.dropout(x)
|
| 34 |
+
x = self.fc2(x)
|
| 35 |
+
x = self.dropout(x)
|
| 36 |
+
return residual + x
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ResidualAudioProjector(nn.Module):
|
| 40 |
+
"""Residual MLP projector for audio-to-LLM feature translation.
|
| 41 |
+
|
| 42 |
+
Architecture:
|
| 43 |
+
1. Temporal pooling (concatenate k consecutive frames)
|
| 44 |
+
2. Linear projection to LLM dimension
|
| 45 |
+
3. N residual MLP blocks for nonlinear refinement
|
| 46 |
+
4. Final layer norm
|
| 47 |
+
|
| 48 |
+
The linear projection handles dimension matching, while residual MLPs
|
| 49 |
+
learn the nonlinear corrections needed to align acoustic features
|
| 50 |
+
with semantic embedding space.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, config):
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
# Temporal downsampling factor
|
| 57 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 58 |
+
|
| 59 |
+
# Dimensions
|
| 60 |
+
in_dim = config.encoder_dim * self.k # After concatenating k frames
|
| 61 |
+
out_dim = config.llm_dim
|
| 62 |
+
hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4
|
| 63 |
+
|
| 64 |
+
# Number of residual blocks
|
| 65 |
+
self.num_layers = getattr(config, "projector_num_layers", 2)
|
| 66 |
+
|
| 67 |
+
dropout_rate = getattr(config, "projector_dropout", 0.0)
|
| 68 |
+
|
| 69 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 70 |
+
|
| 71 |
+
# Initial projection: encoder_dim * k → llm_dim
|
| 72 |
+
self.input_proj = nn.Linear(in_dim, out_dim)
|
| 73 |
+
self.ln_input = LlamaRMSNorm(out_dim, eps=1e-6)
|
| 74 |
+
|
| 75 |
+
# Residual MLP blocks for nonlinear refinement
|
| 76 |
+
self.layers = nn.ModuleList(
|
| 77 |
+
[ResidualMLP(out_dim, hidden_dim, dropout=dropout_rate) for _ in range(self.num_layers)]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Per-layer norms (applied after each residual block)
|
| 81 |
+
self.layer_norms = nn.ModuleList(
|
| 82 |
+
[LlamaRMSNorm(out_dim, eps=1e-6) for _ in range(self.num_layers)]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.output_dropout = nn.Dropout(dropout_rate)
|
| 86 |
+
|
| 87 |
+
# Initialize for stable training
|
| 88 |
+
self._init_weights(config)
|
| 89 |
+
|
| 90 |
+
def _init_weights(self, config):
|
| 91 |
+
"""Initialize weights for stable residual learning.
|
| 92 |
+
|
| 93 |
+
Key insight: Initialize fc2 of each residual block to near-zero
|
| 94 |
+
so that initially output ≈ input (identity function).
|
| 95 |
+
"""
|
| 96 |
+
std = getattr(config, "projector_init_std", 0.02)
|
| 97 |
+
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
# Input projection: standard init
|
| 100 |
+
nn.init.normal_(self.input_proj.weight, mean=0.0, std=std)
|
| 101 |
+
if self.input_proj.bias is not None:
|
| 102 |
+
nn.init.zeros_(self.input_proj.bias)
|
| 103 |
+
|
| 104 |
+
# Layer norms
|
| 105 |
+
self.ln_input.weight.data.fill_(1.0)
|
| 106 |
+
for ln in self.layer_norms:
|
| 107 |
+
ln.weight.data.fill_(1.0)
|
| 108 |
+
|
| 109 |
+
# Residual blocks: small init on output projection
|
| 110 |
+
for layer in self.layers:
|
| 111 |
+
nn.init.normal_(layer.fc1.weight, mean=0.0, std=std)
|
| 112 |
+
# Initialize fc2 smaller so residual starts near identity
|
| 113 |
+
nn.init.normal_(layer.fc2.weight, mean=0.0, std=std * 0.1)
|
| 114 |
+
if layer.fc1.bias is not None:
|
| 115 |
+
nn.init.zeros_(layer.fc1.bias)
|
| 116 |
+
if layer.fc2.bias is not None:
|
| 117 |
+
nn.init.zeros_(layer.fc2.bias)
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
"""
|
| 121 |
+
Args:
|
| 122 |
+
x: [batch_size, seq_len, encoder_dim] from Whisper encoder
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
[batch_size, seq_len // k, llm_dim] projected features
|
| 126 |
+
"""
|
| 127 |
+
batch_size, seq_len, dim = x.size()
|
| 128 |
+
|
| 129 |
+
# Ensure correct dtype
|
| 130 |
+
target_dtype = self.input_proj.weight.dtype
|
| 131 |
+
if x.dtype != target_dtype:
|
| 132 |
+
x = x.to(target_dtype)
|
| 133 |
+
|
| 134 |
+
# Pad sequence to be divisible by k
|
| 135 |
+
remainder = seq_len % self.k
|
| 136 |
+
if remainder:
|
| 137 |
+
pad_len = self.k - remainder
|
| 138 |
+
x = F.pad(x, (0, 0, 0, pad_len))
|
| 139 |
+
|
| 140 |
+
# Temporal pooling: concatenate k consecutive frames
|
| 141 |
+
# [B, T, D] → [B, T//k, D*k]
|
| 142 |
+
x = x.contiguous().view(batch_size, -1, dim * self.k)
|
| 143 |
+
|
| 144 |
+
# Project to LLM dimension
|
| 145 |
+
x = self.input_proj(x)
|
| 146 |
+
x = self.ln_input(x)
|
| 147 |
+
|
| 148 |
+
# Apply residual MLP blocks
|
| 149 |
+
for layer, ln in zip(self.layers, self.layer_norms):
|
| 150 |
+
x = layer(x)
|
| 151 |
+
x = ln(x)
|
| 152 |
+
|
| 153 |
+
return self.output_dropout(x)
|
shared_moe_projector.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F # noqa: N812
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SwiGLUExpert(nn.Module):
|
| 7 |
+
"""SwiGLU expert MLP (used for both shared and routed experts)."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 12 |
+
self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
|
| 13 |
+
self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
|
| 14 |
+
self.act = nn.SiLU()
|
| 15 |
+
|
| 16 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SharedMoEBlock(nn.Module):
|
| 21 |
+
"""MoE block with shared expert + sparse routed experts."""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
input_dim: int,
|
| 26 |
+
hidden_dim: int,
|
| 27 |
+
output_dim: int,
|
| 28 |
+
num_experts: int = 4,
|
| 29 |
+
top_k: int = 2,
|
| 30 |
+
):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.num_experts = num_experts
|
| 33 |
+
self.top_k = top_k
|
| 34 |
+
self.output_dim = output_dim
|
| 35 |
+
|
| 36 |
+
# Router: zero-initialized for natural learning
|
| 37 |
+
self.router = nn.Linear(input_dim, num_experts, bias=False)
|
| 38 |
+
nn.init.zeros_(self.router.weight)
|
| 39 |
+
|
| 40 |
+
# Shared expert (always active)
|
| 41 |
+
self.shared_expert = SwiGLUExpert(input_dim, hidden_dim, output_dim)
|
| 42 |
+
|
| 43 |
+
# Routed experts (sparse)
|
| 44 |
+
self.experts = nn.ModuleList(
|
| 45 |
+
[SwiGLUExpert(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# For auxiliary loss (cached to avoid recomputation)
|
| 49 |
+
self.last_router_logits = None
|
| 50 |
+
self.last_router_probs = None
|
| 51 |
+
|
| 52 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
batch_size, seq_len, dim = hidden_states.shape
|
| 54 |
+
|
| 55 |
+
# Shared expert output (all tokens)
|
| 56 |
+
shared_out = self.shared_expert(hidden_states)
|
| 57 |
+
|
| 58 |
+
# Routing
|
| 59 |
+
flat_hidden = hidden_states.view(-1, dim)
|
| 60 |
+
router_logits = self.router(flat_hidden)
|
| 61 |
+
router_probs = F.softmax(router_logits.float(), dim=-1)
|
| 62 |
+
|
| 63 |
+
# Cache for aux loss
|
| 64 |
+
self.last_router_logits = router_logits
|
| 65 |
+
self.last_router_probs = router_probs
|
| 66 |
+
|
| 67 |
+
# Top-k selection and renormalization
|
| 68 |
+
top_k_weights, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
|
| 69 |
+
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
|
| 70 |
+
top_k_weights = top_k_weights.to(hidden_states.dtype)
|
| 71 |
+
|
| 72 |
+
# Routed expert output via token dispatch
|
| 73 |
+
routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
|
| 74 |
+
routed_out = routed_out.view(batch_size, seq_len, -1)
|
| 75 |
+
|
| 76 |
+
# Combine: shared expert baseline + routed experts (grow in via zero-init down_proj)
|
| 77 |
+
return shared_out + routed_out
|
| 78 |
+
|
| 79 |
+
def _dispatch_experts(
|
| 80 |
+
self,
|
| 81 |
+
hidden_states: torch.Tensor,
|
| 82 |
+
top_k_indices: torch.Tensor,
|
| 83 |
+
top_k_weights: torch.Tensor,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
"""Token dispatch - gather tokens per expert, process, scatter back."""
|
| 86 |
+
num_tokens = hidden_states.shape[0]
|
| 87 |
+
output = torch.zeros(
|
| 88 |
+
num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
for expert_idx, expert in enumerate(self.experts):
|
| 92 |
+
expert_mask = top_k_indices == expert_idx
|
| 93 |
+
if not expert_mask.any():
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
+
token_indices, slot_indices = torch.where(expert_mask)
|
| 97 |
+
expert_input = hidden_states[token_indices]
|
| 98 |
+
expert_output = expert(expert_input)
|
| 99 |
+
weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
|
| 100 |
+
output.index_add_(0, token_indices, expert_output * weights)
|
| 101 |
+
|
| 102 |
+
return output
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
|
| 106 |
+
"""Auxiliary loss to encourage balanced expert usage."""
|
| 107 |
+
_, selected = torch.topk(router_probs, top_k, dim=-1)
|
| 108 |
+
expert_mask = F.one_hot(selected, num_experts).float()
|
| 109 |
+
tokens_per_expert = expert_mask.mean(dim=(0, 1))
|
| 110 |
+
prob_per_expert = router_probs.mean(dim=0)
|
| 111 |
+
return (tokens_per_expert * prob_per_expert).sum() * num_experts
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
|
| 115 |
+
"""Z-loss to prevent router logits from growing too large."""
|
| 116 |
+
return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class SharedMoEAudioProjector(nn.Module):
|
| 120 |
+
def __init__(self, config):
|
| 121 |
+
super().__init__()
|
| 122 |
+
|
| 123 |
+
# Temporal downsampling
|
| 124 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 125 |
+
|
| 126 |
+
# Dimensions
|
| 127 |
+
encoder_dim = config.encoder_dim
|
| 128 |
+
in_dim = encoder_dim * self.k
|
| 129 |
+
out_dim = config.llm_dim
|
| 130 |
+
hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
|
| 131 |
+
|
| 132 |
+
# MoE config
|
| 133 |
+
self.num_experts = getattr(config, "num_experts", 4)
|
| 134 |
+
self.top_k = getattr(config, "num_experts_per_tok", 2)
|
| 135 |
+
self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
|
| 136 |
+
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
|
| 137 |
+
|
| 138 |
+
# Layers
|
| 139 |
+
self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
|
| 140 |
+
|
| 141 |
+
# Init
|
| 142 |
+
self._init_weights(in_dim)
|
| 143 |
+
|
| 144 |
+
def _init_weights(self, in_dim: int):
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
# Shared expert - orthogonal init for stable condition numbers
|
| 147 |
+
nn.init.orthogonal_(self.moe.shared_expert.gate_proj.weight)
|
| 148 |
+
nn.init.orthogonal_(self.moe.shared_expert.up_proj.weight)
|
| 149 |
+
nn.init.orthogonal_(self.moe.shared_expert.down_proj.weight, gain=0.5)
|
| 150 |
+
|
| 151 |
+
# Routed experts - orthogonal for gate/up, tiny orthogonal for down (grow-in)
|
| 152 |
+
# gain=0.01 gives ~1% initial contribution while maintaining good conditioning
|
| 153 |
+
for expert in self.moe.experts:
|
| 154 |
+
nn.init.orthogonal_(expert.gate_proj.weight)
|
| 155 |
+
nn.init.orthogonal_(expert.up_proj.weight)
|
| 156 |
+
nn.init.orthogonal_(expert.down_proj.weight, gain=0.01)
|
| 157 |
+
|
| 158 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
batch_size, seq_len, dim = x.size()
|
| 160 |
+
|
| 161 |
+
target_dtype = self.moe.shared_expert.gate_proj.weight.dtype
|
| 162 |
+
if x.dtype != target_dtype:
|
| 163 |
+
x = x.to(target_dtype)
|
| 164 |
+
|
| 165 |
+
# Pad for pooling (at most k-1 frames -> 1 extra token, negligible impact)
|
| 166 |
+
if seq_len % self.k:
|
| 167 |
+
x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
|
| 168 |
+
|
| 169 |
+
# Temporal pooling
|
| 170 |
+
x = x.view(batch_size, -1, dim * self.k)
|
| 171 |
+
|
| 172 |
+
return self.moe(x)
|
| 173 |
+
|
| 174 |
+
def get_aux_loss(self) -> torch.Tensor:
|
| 175 |
+
"""Get auxiliary losses (call after forward)."""
|
| 176 |
+
if self.moe.last_router_logits is None:
|
| 177 |
+
return torch.tensor(0.0, device=self.moe.router.weight.device)
|
| 178 |
+
|
| 179 |
+
balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
|
| 180 |
+
z = z_loss(self.moe.last_router_logits)
|
| 181 |
+
|
| 182 |
+
return self.aux_loss_coef * balance + self.z_loss_coef * z
|
special_tokens_map.json
CHANGED
|
@@ -1,15 +1,13 @@
|
|
| 1 |
{
|
| 2 |
"additional_special_tokens": [
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
],
|
| 6 |
-
"bos_token": {
|
| 7 |
-
"content": "<|im_start|>",
|
| 8 |
-
"lstrip": false,
|
| 9 |
-
"normalized": false,
|
| 10 |
-
"rstrip": false,
|
| 11 |
-
"single_word": false
|
| 12 |
-
},
|
| 13 |
"eos_token": {
|
| 14 |
"content": "<|im_end|>",
|
| 15 |
"lstrip": false,
|
|
@@ -17,18 +15,5 @@
|
|
| 17 |
"rstrip": false,
|
| 18 |
"single_word": false
|
| 19 |
},
|
| 20 |
-
"pad_token":
|
| 21 |
-
"content": "<|im_end|>",
|
| 22 |
-
"lstrip": false,
|
| 23 |
-
"normalized": false,
|
| 24 |
-
"rstrip": false,
|
| 25 |
-
"single_word": false
|
| 26 |
-
},
|
| 27 |
-
"unk_token": {
|
| 28 |
-
"content": "<|endoftext|>",
|
| 29 |
-
"lstrip": false,
|
| 30 |
-
"normalized": false,
|
| 31 |
-
"rstrip": false,
|
| 32 |
-
"single_word": false
|
| 33 |
-
}
|
| 34 |
}
|
|
|
|
| 1 |
{
|
| 2 |
"additional_special_tokens": [
|
| 3 |
+
{
|
| 4 |
+
"content": "<audio>",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false
|
| 9 |
+
}
|
| 10 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"eos_token": {
|
| 12 |
"content": "<|im_end|>",
|
| 13 |
"lstrip": false,
|
|
|
|
| 15 |
"rstrip": false,
|
| 16 |
"single_word": false
|
| 17 |
},
|
| 18 |
+
"pad_token": "<|finetune_right_pad_id|>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
}
|
swiglu_projector.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Simple SwiGLU-based audio projector."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F # noqa: N812
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SwiGLU(nn.Module):
|
| 9 |
+
def __init__(self, in_features, hidden_features, out_features, bias=False, dropout=0.0):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.w1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 12 |
+
self.w2 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 13 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 14 |
+
self.act = nn.SiLU()
|
| 15 |
+
self.dropout = nn.Dropout(dropout)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x_gate = self.act(self.w1(x))
|
| 19 |
+
x_val = self.w2(x)
|
| 20 |
+
x = x_gate * x_val
|
| 21 |
+
x = self.dropout(x)
|
| 22 |
+
return self.w3(x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AudioProjector(nn.Module):
|
| 26 |
+
def __init__(self, config):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 29 |
+
in_dim = config.encoder_dim * self.k
|
| 30 |
+
out_dim = config.llm_dim
|
| 31 |
+
hidden_dim = config.projector_hidden_dim
|
| 32 |
+
if hidden_dim is None:
|
| 33 |
+
hidden_dim = config.encoder_dim * 2
|
| 34 |
+
|
| 35 |
+
dropout_rate = getattr(config, "projector_dropout", 0.0)
|
| 36 |
+
|
| 37 |
+
self.proj1 = SwiGLU(in_dim, hidden_dim, hidden_dim, dropout=dropout_rate)
|
| 38 |
+
self.proj2 = SwiGLU(hidden_dim, hidden_dim, out_dim, dropout=dropout_rate)
|
| 39 |
+
self.output_dropout = nn.Dropout(dropout_rate)
|
| 40 |
+
|
| 41 |
+
with torch.no_grad():
|
| 42 |
+
std = getattr(config, "projector_init_std", 0.02)
|
| 43 |
+
# Initialize first layer
|
| 44 |
+
nn.init.normal_(self.proj1.w1.weight, mean=0.0, std=std)
|
| 45 |
+
nn.init.normal_(self.proj1.w2.weight, mean=0.0, std=std)
|
| 46 |
+
nn.init.normal_(self.proj1.w3.weight, mean=0.0, std=std)
|
| 47 |
+
# Initialize second layer
|
| 48 |
+
nn.init.normal_(self.proj2.w1.weight, mean=0.0, std=std)
|
| 49 |
+
nn.init.normal_(self.proj2.w2.weight, mean=0.0, std=std)
|
| 50 |
+
nn.init.normal_(self.proj2.w3.weight, mean=0.0, std=std)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
batch_size, seq_len, dim = x.size()
|
| 54 |
+
|
| 55 |
+
target_dtype = self.proj1.w1.weight.dtype
|
| 56 |
+
if x.dtype != target_dtype:
|
| 57 |
+
x = x.to(target_dtype)
|
| 58 |
+
|
| 59 |
+
remainder = seq_len % self.k
|
| 60 |
+
if remainder:
|
| 61 |
+
pad_len = self.k - remainder
|
| 62 |
+
x = F.pad(x, (0, 0, 0, pad_len))
|
| 63 |
+
|
| 64 |
+
x = x.contiguous().view(batch_size, -1, dim * self.k)
|
| 65 |
+
x = self.proj1(x)
|
| 66 |
+
x = self.proj2(x)
|
| 67 |
+
|
| 68 |
+
return self.output_dropout(x)
|
tokenizer.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
|
| 3 |
+
size 17209003
|
tokenizer_config.json
CHANGED
|
Binary files a/tokenizer_config.json and b/tokenizer_config.json differ
|
|
|