| | import json |
| | from pathlib import Path |
| | from typing import Optional, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import ( |
| | AutoConfig, |
| | AutoModel, |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | PreTrainedModel, |
| | ) |
| | from transformers.generation import GenerationMixin |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| | try: |
| | from .asr_config import ASRConfig |
| | from .projectors import PROJECTOR_CLASSES |
| | except ImportError: |
| | from asr_config import ASRConfig |
| | from projectors import PROJECTOR_CLASSES |
| |
|
| |
|
| | from torchaudio.transforms import SpecAugment |
| |
|
| |
|
| | class ASRModel(PreTrainedModel, GenerationMixin): |
| | """Audio-to-text model combining an audio encoder, projector, and language model.""" |
| |
|
| | config_class = ASRConfig |
| | base_model_prefix = "model" |
| | main_input_name = "input_features" |
| | _supports_flash_attn_2 = True |
| | supports_gradient_checkpointing = True |
| | _is_loading_from_pretrained: bool = False |
| | _pretrained_model_path: Optional[str] = None |
| |
|
| | TRANSCRIBE_PROMPT = "" |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel": |
| | """Load model from pretrained, handling device placement correctly.""" |
| | from safetensors.torch import load_file |
| | from transformers.utils.hub import cached_file |
| |
|
| | config = kwargs.pop("config", None) |
| | if config is None: |
| | config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) |
| |
|
| | |
| | cls._is_loading_from_pretrained = True |
| | cls._pretrained_model_path = pretrained_model_name_or_path |
| |
|
| | try: |
| | model = cls(config, **kwargs) |
| |
|
| | |
| | subfolder = kwargs.get("subfolder") |
| | revision = kwargs.get("revision") |
| | cache_kwargs = {} |
| | if subfolder: |
| | cache_kwargs["subfolder"] = subfolder |
| | if revision: |
| | cache_kwargs["revision"] = revision |
| |
|
| | model_file = cached_file( |
| | pretrained_model_name_or_path, |
| | "model.safetensors", |
| | _raise_exceptions_for_missing_entries=False, |
| | **cache_kwargs, |
| | ) |
| |
|
| | if model_file is not None: |
| | state_dict = load_file(model_file) |
| | model.load_state_dict(state_dict, strict=False) |
| |
|
| | return model |
| | finally: |
| | cls._is_loading_from_pretrained = False |
| | cls._pretrained_model_path = None |
| |
|
| | def __init__(self, config: ASRConfig, **kwargs) -> None: |
| | super().__init__(config) |
| |
|
| | self.system_prompt = config.system_prompt |
| | target_dtype = getattr(torch, config.model_dtype) |
| |
|
| | |
| | self.audio_tower = self._load_audio_encoder(config, target_dtype) |
| |
|
| | |
| | self.language_model = self._load_language_model(config, target_dtype) |
| |
|
| | |
| | self._init_tokenizer(config) |
| |
|
| | |
| | self.generation_config = self.language_model.generation_config |
| | self.generation_config.max_new_tokens = config.max_new_tokens |
| | self.generation_config.min_new_tokens = config.min_new_tokens |
| | self.generation_config.num_beams = config.num_beams |
| | self.generation_config.do_sample = config.do_sample |
| | |
| | self.generation_config.temperature = config.temperature |
| | self.generation_config.top_p = config.top_p |
| | self.generation_config.top_k = config.top_k |
| | self.generation_config.use_cache = config.use_cache |
| | self.generation_config.length_penalty = config.length_penalty |
| | self.generation_config.repetition_penalty = config.repetition_penalty |
| | self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size |
| | |
| | eos_candidates = [ |
| | self.tokenizer.convert_tokens_to_ids("<|im_end|>"), |
| | self.tokenizer.convert_tokens_to_ids("<|endoftext|>"), |
| | ] |
| | self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None] |
| | self.generation_config.pad_token_id = self.tokenizer.pad_token_id |
| |
|
| | |
| | self.feature_extractor = self._create_feature_extractor(config) |
| |
|
| | |
| | self.projector = self._create_projector(config, target_dtype) |
| |
|
| | |
| | |
| | self.audio_pad_embedding = nn.Parameter(torch.randn(1, config.llm_dim) * 0.02) |
| |
|
| | |
| | if getattr(config, "freeze_projector", False): |
| | self.projector.requires_grad_(False) |
| |
|
| | |
| | if getattr(config, "use_specaugment", False): |
| | self.spec_augment = SpecAugment( |
| | n_time_masks=config.num_time_masks, |
| | time_mask_param=config.time_mask_length, |
| | n_freq_masks=config.num_freq_masks, |
| | freq_mask_param=config.freq_mask_length, |
| | ) |
| | else: |
| | self.spec_augment = None |
| |
|
| | |
| | if getattr(config, "use_audio_head", False): |
| | from .audio_head import AudioHead, AudioHeadConfig |
| |
|
| | device = next(self.language_model.parameters()).device |
| |
|
| | audio_head_config = AudioHeadConfig( |
| | tts_model_id=getattr(config, "tts_model_id", "neuphonic/neutts-nano"), |
| | llm_model_id=config.text_model_id, |
| | projector_hidden=getattr(config, "audio_head_projector_hidden", 1024), |
| | max_audio_tokens=config.max_audio_tokens, |
| | neucodec_model_id=getattr(config, "neucodec_model_id", "neuphonic/neucodec"), |
| | temperature=getattr(config, "audio_head_temperature", 1.0), |
| | top_k=getattr(config, "audio_head_top_k", 50), |
| | ) |
| | self.audio_head = AudioHead(audio_head_config).to(device=device, dtype=target_dtype) |
| |
|
| | |
| | |
| | import gc |
| |
|
| | del self.audio_head.llm |
| | self.audio_head.llm = None |
| | gc.collect() |
| |
|
| | if getattr(config, "freeze_audio_head", False): |
| | self.audio_head.requires_grad_(False) |
| | else: |
| | self.audio_head = None |
| |
|
| | |
| | |
| | self._vad_model = None |
| | self._vad_utils = None |
| |
|
| | |
| | self._no_split_modules = getattr(self.language_model, "_no_split_modules", []) |
| |
|
| | def _tie_weights(self): |
| | """No-op: AudioHead manages its own embeddings.""" |
| | pass |
| |
|
| | def _create_feature_extractor(self, config: ASRConfig): |
| | """Create the appropriate feature extractor for the audio encoder.""" |
| | from transformers import AutoFeatureExtractor |
| |
|
| | feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id) |
| | |
| | feature_extractor.padding = False |
| | return feature_extractor |
| |
|
| | @classmethod |
| | def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module: |
| | """Load and freeze the audio encoder.""" |
| | encoder_kwargs = { |
| | "attn_implementation": config.attn_implementation, |
| | "low_cpu_mem_usage": True, |
| | "torch_dtype": dtype, |
| | } |
| |
|
| | if "whisper" in config.audio_model_id.lower(): |
| | from transformers import WhisperModel |
| |
|
| | full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs) |
| | encoder = full_model.encoder |
| | del full_model |
| | elif "glm" in config.audio_model_id.lower(): |
| | |
| | |
| | from transformers import AutoModelForSeq2SeqLM |
| |
|
| | full_model = AutoModelForSeq2SeqLM.from_pretrained( |
| | config.audio_model_id, trust_remote_code=True, **encoder_kwargs |
| | ) |
| | |
| | encoder = full_model.audio_tower |
| | |
| | full_model.language_model = None |
| | full_model.multi_modal_projector = None |
| | del full_model |
| | else: |
| | encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs) |
| |
|
| | encoder.requires_grad_(False) |
| | encoder.eval() |
| | return encoder |
| |
|
| | @classmethod |
| | def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel: |
| | """Load and freeze the language model.""" |
| | decoder_kwargs = { |
| | "attn_implementation": config.attn_implementation, |
| | "trust_remote_code": True, |
| | "low_cpu_mem_usage": True, |
| | "dtype": dtype, |
| | } |
| |
|
| | decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs) |
| | decoder.config.use_cache = getattr(config, "use_cache", True) |
| | decoder.requires_grad_(False) |
| | decoder.eval() |
| | return decoder |
| |
|
| | def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module: |
| | """Create the trainable audio projector.""" |
| | |
| | if config.encoder_dim is None: |
| | enc_cfg = self.audio_tower.config |
| | config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr( |
| | enc_cfg, "d_model", None |
| | ) |
| | if config.encoder_dim is None: |
| | raise ValueError("Could not auto-detect encoder_dim. Please specify in config.") |
| |
|
| | if config.llm_dim is None: |
| | dec_cfg = self.language_model.config |
| | config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr( |
| | dec_cfg, "d_model", None |
| | ) |
| | if config.llm_dim is None: |
| | raise ValueError("Could not auto-detect llm_dim. Please specify in config.") |
| |
|
| | |
| | projector_type = getattr(config, "projector_type", "mlp") |
| | projector_class = PROJECTOR_CLASSES.get(projector_type) |
| | if projector_class is None: |
| | raise ValueError( |
| | f"Unknown projector_type: {projector_type}. " |
| | f"Valid options: {list(PROJECTOR_CLASSES.keys())}" |
| | ) |
| | projector = projector_class(config) |
| |
|
| | |
| | device = next(self.language_model.parameters()).device |
| | return projector.to(device=device, dtype=dtype) |
| |
|
| | def _init_tokenizer(self, config: ASRConfig): |
| | """Initialize tokenizer with audio token.""" |
| | self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True) |
| |
|
| | |
| | if ( |
| | self.tokenizer.pad_token is None |
| | or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id |
| | ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab(): |
| | self.tokenizer.pad_token = "<|finetune_right_pad_id|>" |
| |
|
| | |
| | existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or [] |
| | if "<audio>" not in existing_special: |
| | self.tokenizer.add_special_tokens( |
| | {"additional_special_tokens": existing_special + ["<audio>"]} |
| | ) |
| | self.language_model.resize_token_embeddings( |
| | len(self.tokenizer), mean_resizing=False, pad_to_multiple_of=64 |
| | ) |
| |
|
| | self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>") |
| | self.tokenizer.padding_side = "right" |
| |
|
| | |
| | for cfg in [self.config.text_config, self.language_model.config, self.generation_config]: |
| | if cfg is not None: |
| | cfg.pad_token_id = self.tokenizer.pad_token_id |
| | cfg.eos_token_id = self.tokenizer.eos_token_id |
| | cfg.bos_token_id = self.tokenizer.bos_token_id |
| |
|
| | def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None): |
| | """Enable/disable gradient checkpointing for the language model.""" |
| | |
| | |
| | if hasattr(self.language_model, "_set_gradient_checkpointing"): |
| | self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func) |
| | elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable: |
| | self.language_model.gradient_checkpointing_enable( |
| | gradient_checkpointing_kwargs={"use_reentrant": False} |
| | ) |
| | elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable: |
| | self.language_model.gradient_checkpointing_disable() |
| |
|
| | def get_input_embeddings(self) -> nn.Module: |
| | return self.language_model.get_input_embeddings() |
| |
|
| | def set_input_embeddings(self, value: nn.Module) -> None: |
| | self.language_model.set_input_embeddings(value) |
| |
|
| | def get_output_embeddings(self) -> nn.Module: |
| | return self.language_model.get_output_embeddings() |
| |
|
| | def set_output_embeddings(self, value: nn.Module) -> None: |
| | self.language_model.set_output_embeddings(value) |
| |
|
| | def get_processor(self): |
| | """Get the processor for this model.""" |
| | try: |
| | from .asr_processing import ASRProcessor |
| | except ImportError: |
| | from asr_processing import ASRProcessor |
| |
|
| | return ASRProcessor( |
| | feature_extractor=self.feature_extractor, |
| | tokenizer=self.tokenizer, |
| | projector=self.projector, |
| | encoder_conv_layers=self.config.encoder_conv_layers, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | def load_vad(self, force_reload: bool = False) -> None: |
| | """Load Silero VAD model for interruption detection. |
| | |
| | Silero VAD is a lightweight (~2MB) voice activity detector that runs |
| | in real-time. Used as the first layer of interruption detection. |
| | |
| | Args: |
| | force_reload: Force reload even if already loaded |
| | """ |
| | if self._vad_model is not None and not force_reload: |
| | return |
| |
|
| | model, utils = torch.hub.load( |
| | repo_or_dir="snakers4/silero-vad", |
| | model="silero_vad", |
| | force_reload=force_reload, |
| | trust_repo=True, |
| | ) |
| |
|
| | self._vad_model = model |
| | self._vad_utils = utils |
| |
|
| | |
| | self._vad_model.eval() |
| | for param in self._vad_model.parameters(): |
| | param.requires_grad = False |
| |
|
| | def detect_speech( |
| | self, |
| | audio_chunk: torch.Tensor, |
| | sample_rate: int = 16000, |
| | threshold: float = 0.5, |
| | ) -> tuple[bool, float]: |
| | """Detect speech in an audio chunk using Silero VAD. |
| | |
| | Args: |
| | audio_chunk: Audio waveform [samples] or [1, samples] at sample_rate |
| | sample_rate: Audio sample rate (default 16kHz) |
| | threshold: Speech probability threshold (default 0.5) |
| | |
| | Returns: |
| | Tuple of (is_speech, probability) |
| | """ |
| | if self._vad_model is None: |
| | self.load_vad() |
| |
|
| | |
| | if audio_chunk.dim() > 1: |
| | audio_chunk = audio_chunk.squeeze() |
| |
|
| | |
| | if sample_rate not in (8000, 16000): |
| | import torchaudio.functional as audio_functional |
| |
|
| | audio_chunk = audio_functional.resample(audio_chunk, sample_rate, 16000) |
| | sample_rate = 16000 |
| |
|
| | |
| | with torch.no_grad(): |
| | speech_prob = self._vad_model(audio_chunk, sample_rate).item() |
| |
|
| | return speech_prob > threshold, speech_prob |
| |
|
| | def reset_vad_state(self) -> None: |
| | """Reset VAD internal state between utterances.""" |
| | if self._vad_model is not None: |
| | self._vad_model.reset_states() |
| |
|
| | def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]: |
| | """Save trainable weights (projector + audio_head if present).""" |
| | state = {f"projector.{k}": v for k, v in self.projector.state_dict().items()} |
| | if self.audio_head is not None: |
| | state.update({f"audio_head.{k}": v for k, v in self.audio_head.state_dict().items()}) |
| | return state |
| |
|
| | def _compute_encoder_output_lengths( |
| | self, |
| | audio_attention_mask: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Compute per-sample encoder output lengths using conv layer formulas. |
| | |
| | Args: |
| | audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len) |
| | |
| | Returns: |
| | Tensor of encoder output lengths per sample (batch,) |
| | """ |
| | |
| | lengths = audio_attention_mask.sum(dim=-1) |
| |
|
| | |
| | for padding, kernel_size, stride in self.config.encoder_conv_layers: |
| | lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1 |
| |
|
| | return lengths |
| |
|
| | def _encode_audio( |
| | self, |
| | audio_features: torch.Tensor, |
| | audio_attention_mask: torch.Tensor, |
| | expected_token_counts: torch.Tensor | None = None, |
| | ) -> torch.Tensor: |
| | """Encode audio and project to LLM embedding space. |
| | |
| | Args: |
| | audio_features: Mel spectrogram features (batch, n_mels, mel_len) |
| | audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len) |
| | expected_token_counts: Expected number of audio tokens per sample from input_ids. |
| | If provided, output will match these counts exactly (padding/truncating as needed). |
| | |
| | Returns: |
| | Flattened audio embeddings of shape (total_audio_tokens, hidden_dim). |
| | """ |
| | with torch.no_grad(): |
| | encoder_out = self.audio_tower(input_features=audio_features) |
| | hidden_states = encoder_out.last_hidden_state |
| |
|
| | |
| | audio_embeds = self.projector(hidden_states) |
| |
|
| | |
| | if expected_token_counts is not None: |
| | token_counts = expected_token_counts |
| | else: |
| | |
| | encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask) |
| | token_counts = torch.tensor( |
| | [ |
| | self.projector.get_output_length(int(length.item())) |
| | for length in encoder_lengths |
| | ], |
| | device=audio_embeds.device, |
| | ) |
| |
|
| | |
| | batch_size = audio_embeds.shape[0] |
| |
|
| | result_embeds = [] |
| | for i in range(batch_size): |
| | count = int(token_counts[i].item()) |
| | sample_embeds = audio_embeds[i, :count, :] |
| | |
| | if sample_embeds.shape[0] < count: |
| | pad_count = count - sample_embeds.shape[0] |
| | padding = self.audio_pad_embedding.expand(pad_count, -1).to( |
| | device=audio_embeds.device, dtype=audio_embeds.dtype |
| | ) |
| | sample_embeds = torch.cat([sample_embeds, padding], dim=0) |
| | result_embeds.append(sample_embeds) |
| |
|
| | return torch.cat(result_embeds, dim=0) |
| |
|
| | def forward( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | input_features: Optional[torch.Tensor] = None, |
| | audio_attention_mask: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.Tensor] = None, |
| | past_key_values: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.Tensor] = None, |
| | use_cache: Optional[bool] = None, |
| | cache_position: Optional[torch.Tensor] = None, |
| | **kwargs, |
| | ) -> CausalLMOutputWithPast: |
| | """Forward pass for training and inference.""" |
| | |
| | if inputs_embeds is None: |
| | inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
| |
|
| | if input_features is not None and input_ids is not None: |
| | |
| | if self.training and self.spec_augment is not None: |
| | input_features = self.spec_augment(input_features) |
| |
|
| | |
| | audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1) |
| |
|
| | |
| | audio_embeds = self._encode_audio( |
| | input_features, audio_attention_mask, audio_token_counts |
| | ) |
| |
|
| | |
| | audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1) |
| |
|
| | inputs_embeds = inputs_embeds.masked_scatter( |
| | audio_token_mask.to(inputs_embeds.device), |
| | audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype), |
| | ) |
| |
|
| | |
| | kwargs.pop("prompts", None) |
| | kwargs.pop("prompt_attention_mask", None) |
| |
|
| | |
| | return self.language_model( |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | labels=labels, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| |
|
| | def prepare_inputs_for_generation(self, *args, **kwargs): |
| | """Prepare inputs for generation, handling audio features for cached decoding.""" |
| | input_features = kwargs.pop("input_features", None) |
| | cache_position = kwargs.get("cache_position") |
| |
|
| | model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs) |
| |
|
| | |
| | if cache_position is not None and cache_position[0] == 0 and input_features is not None: |
| | model_inputs["input_features"] = input_features |
| |
|
| | return model_inputs |
| |
|
| | def _get_num_audio_tokens( |
| | self, |
| | audio_attention_mask: torch.Tensor, |
| | ) -> int: |
| | """Calculate number of audio tokens based on actual audio length. |
| | |
| | Uses attention mask to get real audio length, then computes: |
| | mel_frames -> encoder_frames (via conv formulas) -> projector output tokens |
| | """ |
| | encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask) |
| | |
| | encoder_output_len = int(encoder_lengths.max().item()) |
| | return int(self.projector.get_output_length(encoder_output_len)) |
| |
|
| | def _build_audio_prompt( |
| | self, |
| | audio_attention_mask: torch.Tensor, |
| | batch_size: int, |
| | device: torch.device, |
| | system_prompt: Optional[str] = None, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Build input_ids and attention_mask for audio-conditioned generation. |
| | |
| | Args: |
| | audio_attention_mask: Mask for real vs padded mel frames |
| | batch_size: Batch size for expanding single prompts |
| | device: Device to place tensors on |
| | system_prompt: Optional system prompt override |
| | |
| | Returns: |
| | Tuple of (input_ids, attention_mask) tensors |
| | """ |
| | num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask) |
| | audio_placeholder = "<audio>" * num_audio_tokens |
| |
|
| | system_prompt = system_prompt or self.system_prompt |
| |
|
| | messages: list[dict[str, str]] = [] |
| | if system_prompt: |
| | messages.append({"role": "system", "content": system_prompt}) |
| | user_content = audio_placeholder |
| | if self.TRANSCRIBE_PROMPT: |
| | user_content += " " + self.TRANSCRIBE_PROMPT |
| | messages.append({"role": "user", "content": user_content}) |
| |
|
| | chat_result = self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=True, |
| | add_generation_prompt=True, |
| | return_tensors="pt", |
| | enable_thinking=getattr(self.config, "enable_thinking", False), |
| | ) |
| | input_ids = chat_result.input_ids.to(device) |
| |
|
| | if input_ids.dim() == 1: |
| | input_ids = input_ids.unsqueeze(0) |
| | if input_ids.shape[0] == 1 and batch_size > 1: |
| | input_ids = input_ids.expand(batch_size, -1) |
| |
|
| | return input_ids, torch.ones_like(input_ids) |
| |
|
| | def _inject_audio_embeddings( |
| | self, |
| | input_ids: torch.Tensor, |
| | audio_embeds: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Replace audio token placeholders with actual audio embeddings. |
| | |
| | Args: |
| | input_ids: Token IDs containing <audio> placeholder tokens |
| | audio_embeds: Encoded audio embeddings to inject |
| | |
| | Returns: |
| | Input embeddings with audio tokens replaced by audio embeddings |
| | """ |
| | inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
| | audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1) |
| | return inputs_embeds.masked_scatter( |
| | audio_token_mask.to(inputs_embeds.device), |
| | audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype), |
| | ) |
| |
|
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | input_ids: Optional[torch.Tensor] = None, |
| | input_features: Optional[torch.Tensor] = None, |
| | audio_attention_mask: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | system_prompt: Optional[str] = None, |
| | **generate_kwargs, |
| | ) -> torch.Tensor: |
| | """Generate transcription from audio input. |
| | |
| | Can be called in two ways: |
| | 1. With input_ids containing <audio> tokens (from processor) |
| | 2. With just audio, and we build the prompt internally |
| | """ |
| | if input_features is None: |
| | raise ValueError("input_features required for generation") |
| | if audio_attention_mask is None: |
| | raise ValueError("audio_attention_mask required for generation") |
| |
|
| | device = input_features.device |
| | batch_size = input_features.shape[0] |
| |
|
| | |
| | audio_embeds = self._encode_audio(input_features, audio_attention_mask) |
| |
|
| | |
| | if input_ids is None: |
| | input_ids, attention_mask = self._build_audio_prompt( |
| | audio_attention_mask, batch_size, device, system_prompt |
| | ) |
| |
|
| | |
| | inputs_embeds = self._inject_audio_embeddings(input_ids, audio_embeds) |
| |
|
| | |
| | |
| | |
| | output = self.language_model.generate( |
| | input_ids=input_ids, |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | generation_config=self.generation_config, |
| | **generate_kwargs, |
| | ) |
| |
|
| | |
| | |
| | sequences = output if isinstance(output, torch.Tensor) else output.sequences |
| | input_len = input_ids.shape[1] |
| | return sequences[:, input_len:] |
| |
|
| | def _process_audio( |
| | self, |
| | audio, |
| | sampling_rate: int = 16000, |
| | ) -> dict[str, torch.Tensor]: |
| | """Process raw audio waveform to model inputs.""" |
| | |
| | if isinstance(audio, torch.Tensor): |
| | audio = audio.cpu().numpy() |
| |
|
| | |
| | inputs = self.feature_extractor( |
| | audio, |
| | sampling_rate=sampling_rate, |
| | return_attention_mask=True, |
| | return_tensors="pt", |
| | ) |
| |
|
| | device = next(self.language_model.parameters()).device |
| | return { |
| | "input_features": inputs["input_features"].to(device), |
| | "attention_mask": inputs["attention_mask"].to(device), |
| | } |
| |
|
| | def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None: |
| | """Save model, tokenizer, and processor.""" |
| | import shutil |
| | from pathlib import Path as PathlibPath |
| |
|
| | save_dir = PathlibPath(save_directory) |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | self.config.vocab_size = self.language_model.config.vocab_size |
| | self.config.text_config.vocab_size = self.language_model.config.vocab_size |
| |
|
| | if hasattr(self.audio_tower.config, "num_mel_bins"): |
| | self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins |
| |
|
| | |
| | self.config.save_pretrained(save_dir) |
| |
|
| | |
| | |
| | state_dict = self.state_dict() |
| | safe_serialization = kwargs.get("safe_serialization", True) |
| |
|
| | if safe_serialization: |
| | from safetensors.torch import save_file |
| |
|
| | save_file(state_dict, save_dir / "model.safetensors") |
| | else: |
| | import torch |
| |
|
| | torch.save(state_dict, save_dir / "pytorch_model.bin") |
| |
|
| | |
| | self.tokenizer.save_pretrained(save_dir) |
| | self.feature_extractor.save_pretrained(save_dir) |
| |
|
| | |
| | config_path = save_dir / "preprocessor_config.json" |
| | if config_path.exists(): |
| | with config_path.open() as f: |
| | processor_config = json.load(f) |
| | else: |
| | processor_config = {} |
| |
|
| | processor_config.update( |
| | { |
| | "processor_class": "ASRProcessor", |
| | "auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"}, |
| | } |
| | ) |
| |
|
| | with config_path.open("w") as f: |
| | json.dump(processor_config, f, indent=2) |
| |
|
| | |
| | src_dir = PathlibPath(__file__).parent |
| | for asr_file in src_dir.glob("asr_*.py"): |
| | shutil.copy(asr_file, save_dir / asr_file.name) |
| | |
| | shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py") |
| | |
| | shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py") |
| | |
| | shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py") |
| | |
| | audio_head_path = src_dir / "audio_head.py" |
| | if audio_head_path.exists(): |
| | shutil.copy(audio_head_path, save_dir / "audio_head.py") |
| | |
| | full_duplex_path = src_dir / "full_duplex.py" |
| | if full_duplex_path.exists(): |
| | shutil.copy(full_duplex_path, save_dir / "full_duplex.py") |
| |
|
| | def push_to_hub(self, repo_id: str, **kwargs) -> str: |
| | """Push model to HuggingFace Hub.""" |
| | self.config.pretrained_model_path = repo_id |
| | return super().push_to_hub(repo_id, **kwargs) |
| |
|
| | def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None: |
| | """No-op for model card creation - we use MODEL_CARD.md in repo instead.""" |
| | pass |
| |
|
| |
|
| | |
| | AutoConfig.register("asr_model", ASRConfig) |
| | AutoModel.register(ASRConfig, ASRModel) |
| |
|