import json from pathlib import Path from threading import Thread from typing import Iterator, Optional, Union import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TextIteratorStreamer, ) from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast try: from .asr_config import ASRConfig from .projectors import PROJECTOR_CLASSES except ImportError: from asr_config import ASRConfig # type: ignore[no-redef] from projectors import PROJECTOR_CLASSES # type: ignore[no-redef] def _compute_mask_indices( shape: tuple[int, int], mask_prob: float, mask_length: int, min_masks: int = 0, device: torch.device = None, ) -> torch.Tensor: """Compute random mask spans for SpecAugment. Based on transformers' _compute_mask_indices for Wav2Vec2/Whisper. Args: shape: (batch_size, sequence_length) mask_prob: Probability for each token to be chosen as start of mask span mask_length: Maximum length of mask span min_masks: Minimum number of masks per sample device: Device to create tensor on Returns: Boolean mask tensor of shape (batch_size, sequence_length) """ batch_size, sequence_length = shape if mask_length < 1: raise ValueError(f"mask_length must be >= 1, got {mask_length}") if mask_length > sequence_length: raise ValueError(f"mask_length {mask_length} must be <= sequence_length {sequence_length}") # Compute number of masked spans per sample num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand(1).item()) num_masked_spans = max(num_masked_spans, min_masks) # Clamp to ensure we don't exceed sequence length if num_masked_spans * mask_length > sequence_length: num_masked_spans = sequence_length // mask_length if num_masked_spans == 0: return torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device) # Uniformly sample span start indices mask = torch.zeros((batch_size, sequence_length), dtype=torch.bool, device=device) for i in range(batch_size): # Random start indices for this sample spec_aug_start_indices = torch.randint( 0, sequence_length - mask_length + 1, (num_masked_spans,), device=device ) # Create mask spans for start_idx in spec_aug_start_indices: mask[i, start_idx : start_idx + mask_length] = True return mask def apply_specaugment( input_features: torch.Tensor, mask_time_prob: float = 0.05, mask_time_length: int = 10, mask_time_min_masks: int = 2, mask_feature_prob: float = 0.0, mask_feature_length: int = 10, mask_feature_min_masks: int = 0, ) -> torch.Tensor: """Apply SpecAugment to mel spectrogram features. Args: input_features: Mel spectrogram of shape (batch, n_mels, time) mask_time_prob: Probability of masking time steps mask_time_length: Max length of time mask mask_time_min_masks: Min number of time masks mask_feature_prob: Probability of masking frequency bins mask_feature_length: Max length of frequency mask mask_feature_min_masks: Min number of frequency masks Returns: Augmented mel spectrogram with same shape """ batch_size, n_mels, time_steps = input_features.shape device = input_features.device # Clone to avoid modifying original augmented = input_features.clone() # Time masking (along time dimension) # Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0) if mask_time_prob > 0 or mask_time_min_masks > 0: time_mask = _compute_mask_indices( shape=(batch_size, time_steps), mask_prob=mask_time_prob, mask_length=mask_time_length, min_masks=mask_time_min_masks, device=device, ) # Expand to (batch, 1, time) for broadcasting time_mask = time_mask.unsqueeze(1) augmented = augmented.masked_fill(time_mask, 0.0) # Frequency masking (along mel dimension) # Apply if prob > 0 OR min_masks > 0 (to support fixed mask count with prob=0) if mask_feature_prob > 0 or mask_feature_min_masks > 0: feature_mask = _compute_mask_indices( shape=(batch_size, n_mels), mask_prob=mask_feature_prob, mask_length=mask_feature_length, min_masks=mask_feature_min_masks, device=device, ) # Expand to (batch, n_mels, 1) for broadcasting feature_mask = feature_mask.unsqueeze(2) augmented = augmented.masked_fill(feature_mask, 0.0) return augmented class ASRModel(PreTrainedModel, GenerationMixin): """Audio-to-text model combining an audio encoder, projector, and language model.""" config_class = ASRConfig base_model_prefix = "model" main_input_name = "input_features" _supports_flash_attn_2 = True supports_gradient_checkpointing = True _is_loading_from_pretrained: bool = False _pretrained_model_path: Optional[str] = None TRANSCRIBE_PROMPT = "Transcribe: " @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """Load model from pretrained, handling device placement correctly.""" from safetensors.torch import load_file from transformers.utils.hub import cached_file config = kwargs.pop("config", None) if config is None: config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) # Set flag to avoid device_map="auto" in sub-model loaders cls._is_loading_from_pretrained = True cls._pretrained_model_path = pretrained_model_name_or_path try: model = cls(config, **kwargs) # Load projector weights from safetensors subfolder = kwargs.get("subfolder") revision = kwargs.get("revision") cache_kwargs = {} if subfolder: cache_kwargs["subfolder"] = subfolder if revision: cache_kwargs["revision"] = revision model_file = cached_file( pretrained_model_name_or_path, "model.safetensors", _raise_exceptions_for_missing_entries=False, **cache_kwargs, ) if model_file is not None: state_dict = load_file(model_file) model.load_state_dict(state_dict, strict=False) return model finally: cls._is_loading_from_pretrained = False cls._pretrained_model_path = None def __init__(self, config: ASRConfig, **kwargs): super().__init__(config) self.system_prompt = config.system_prompt target_dtype = getattr(torch, config.model_dtype) # Audio encoder (frozen) self.audio_tower = self._load_audio_encoder(config, target_dtype) # Language model (frozen) self.language_model = self._load_language_model(config, target_dtype) # Initialize tokenizer and special tokens self._init_tokenizer(config) # Set up generation config with greedy decoding defaults self.generation_config = self.language_model.generation_config self.generation_config.max_new_tokens = config.max_new_tokens self.generation_config.min_new_tokens = config.min_new_tokens self.generation_config.num_beams = config.num_beams self.generation_config.do_sample = False # Clear sampling params (inherited from LLM) since we use greedy decoding self.generation_config.temperature = None self.generation_config.top_p = None self.generation_config.top_k = None self.generation_config.use_cache = config.use_cache self.generation_config.length_penalty = config.length_penalty self.generation_config.repetition_penalty = config.repetition_penalty self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size self.generation_config.eos_token_id = [ self.tokenizer.convert_tokens_to_ids("<|im_end|>"), self.tokenizer.convert_tokens_to_ids("<|endoftext|>"), ] self.generation_config.pad_token_id = self.tokenizer.pad_token_id # Feature extractor for audio preprocessing self.feature_extractor = self._create_feature_extractor(config) # Audio projector (trainable) self.projector = self._create_projector(config, target_dtype) # For model parallelism self._no_split_modules = getattr(self.language_model, "_no_split_modules", []) def _create_feature_extractor(self, config: ASRConfig): """Create the appropriate feature extractor for the audio encoder.""" from transformers import AutoFeatureExtractor return AutoFeatureExtractor.from_pretrained(config.audio_model_id) @classmethod def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module: """Load and freeze the audio encoder.""" encoder_kwargs = { "attn_implementation": config.attn_implementation, "low_cpu_mem_usage": True, "dtype": dtype, } if "whisper" in config.audio_model_id.lower(): from transformers import WhisperModel full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs) encoder = full_model.encoder del full_model elif "glm" in config.audio_model_id.lower(): # GLM-ASR models use audio_tower as the encoder # Requires transformers >= 5.x or installed from source from transformers import AutoModelForSeq2SeqLM full_model = AutoModelForSeq2SeqLM.from_pretrained( config.audio_model_id, trust_remote_code=True, **encoder_kwargs ) # GLM stores encoder at audio_tower (GlmAsrEncoder) encoder = full_model.audio_tower # Clear references to free VRAM from the LLM decoder full_model.language_model = None full_model.multi_modal_projector = None del full_model if torch.cuda.is_available(): torch.cuda.empty_cache() else: encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs) encoder.requires_grad_(False) encoder.eval() return encoder @classmethod def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel: """Load and freeze the language model.""" decoder_kwargs = { "attn_implementation": config.attn_implementation, "trust_remote_code": True, "tie_word_embeddings": False, "low_cpu_mem_usage": True, "dtype": dtype, } decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs) decoder.config.use_cache = getattr(config, "use_cache", True) decoder.requires_grad_(False) decoder.eval() return decoder def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module: """Create the trainable audio projector.""" # Auto-detect dimensions if not specified if config.encoder_dim is None: enc_cfg = self.audio_tower.config config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr( enc_cfg, "d_model", None ) if config.encoder_dim is None: raise ValueError("Could not auto-detect encoder_dim. Please specify in config.") if config.llm_dim is None: dec_cfg = self.language_model.config config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr( dec_cfg, "d_model", None ) if config.llm_dim is None: raise ValueError("Could not auto-detect llm_dim. Please specify in config.") # Select projector type based on config projector_type = getattr(config, "projector_type", "mlp") projector_class = PROJECTOR_CLASSES.get(projector_type) if projector_class is None: raise ValueError( f"Unknown projector_type: {projector_type}. " f"Valid options: {list(PROJECTOR_CLASSES.keys())}" ) projector = projector_class(config) # Move projector to same device as language model (important when using quantization) device = next(self.language_model.parameters()).device return projector.to(device=device, dtype=dtype) def _init_tokenizer(self, config: ASRConfig): """Initialize tokenizer with audio token.""" self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True) # Set pad token if ( self.tokenizer.pad_token is None or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab(): self.tokenizer.pad_token = "<|finetune_right_pad_id|>" # Add audio token existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or [] if "