import json from pathlib import Path from threading import Thread from typing import Iterator, Optional, Union import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TextIteratorStreamer, ) from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast try: from .asr_config import ASRConfig from .projectors import PROJECTOR_CLASSES except ImportError: from asr_config import ASRConfig # type: ignore[no-redef] from projectors import PROJECTOR_CLASSES # type: ignore[no-redef] from torchaudio.transforms import SpecAugment class ASRModel(PreTrainedModel, GenerationMixin): """Audio-to-text model combining an audio encoder, projector, and language model.""" config_class = ASRConfig base_model_prefix = "model" main_input_name = "input_features" _supports_flash_attn_2 = True supports_gradient_checkpointing = True _is_loading_from_pretrained: bool = False _pretrained_model_path: Optional[str] = None TRANSCRIBE_PROMPT = "Transcribe: " @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel": """Load model from pretrained, handling device placement correctly.""" from safetensors.torch import load_file from transformers.utils.hub import cached_file config = kwargs.pop("config", None) if config is None: config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) # Set flag to avoid device_map="auto" in sub-model loaders cls._is_loading_from_pretrained = True cls._pretrained_model_path = pretrained_model_name_or_path try: model = cls(config, **kwargs) # Load projector weights from safetensors subfolder = kwargs.get("subfolder") revision = kwargs.get("revision") cache_kwargs = {} if subfolder: cache_kwargs["subfolder"] = subfolder if revision: cache_kwargs["revision"] = revision model_file = cached_file( pretrained_model_name_or_path, "model.safetensors", _raise_exceptions_for_missing_entries=False, **cache_kwargs, ) if model_file is not None: state_dict = load_file(model_file) model.load_state_dict(state_dict, strict=False) # Load LoRA adapters if use_lora is enabled if getattr(config, "use_lora", False): # Check for adapter_config.json (required by PEFT to load adapters) adapter_config_file = cached_file( pretrained_model_name_or_path, "adapter_config.json", _raise_exceptions_for_missing_entries=False, **cache_kwargs, ) if adapter_config_file is not None: # Load saved adapter weights using the original repo_id/path # PEFT handles Hub downloads and caching internally from peft import LoraConfig, PeftModel # Pre-load and fix the adapter config to avoid str(None) -> "None" bug. # Some PEFT/transformers versions convert null to "None" string which # causes HF to try loading a model called "None". with open(adapter_config_file) as f: adapter_config_dict = json.load(f) # Fix base_model_name_or_path if it's None/null if adapter_config_dict.get("base_model_name_or_path") is None: adapter_config_dict["base_model_name_or_path"] = "" # Create LoraConfig from the fixed dict peft_config = LoraConfig(**adapter_config_dict) # language_model is bare (not PEFT-wrapped) since we skipped _setup_lora model.language_model = PeftModel.from_pretrained( model.language_model, pretrained_model_name_or_path, # Use original repo_id, not cache path is_trainable=True, config=peft_config, # Use our fixed config **cache_kwargs, ) else: # No saved adapters - initialize fresh LLM LoRA for training from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=config.lora_rank, lora_alpha=config.lora_alpha, target_modules=config.lora_target_modules, lora_dropout=config.lora_dropout, bias="none", task_type="CAUSAL_LM", ) model.language_model = get_peft_model(model.language_model, lora_config) # Clear base_model_name_or_path so PEFT doesn't save a reference # to the base LLM. Use empty string to avoid str(None) -> "None" bug. model.language_model.peft_config["default"].base_model_name_or_path = "" return model finally: cls._is_loading_from_pretrained = False cls._pretrained_model_path = None def __init__(self, config: ASRConfig, **kwargs) -> None: super().__init__(config) self.system_prompt = config.system_prompt target_dtype = getattr(torch, config.model_dtype) # Audio encoder (frozen) self.audio_tower = self._load_audio_encoder(config, target_dtype) # Language model (frozen) self.language_model = self._load_language_model(config, target_dtype) # Initialize tokenizer and special tokens self._init_tokenizer(config) # Set up generation config with greedy decoding defaults self.generation_config = self.language_model.generation_config self.generation_config.max_new_tokens = config.max_new_tokens self.generation_config.min_new_tokens = config.min_new_tokens self.generation_config.num_beams = config.num_beams self.generation_config.do_sample = False # Clear sampling params (inherited from LLM) since we use greedy decoding self.generation_config.temperature = None self.generation_config.top_p = None self.generation_config.top_k = None self.generation_config.use_cache = config.use_cache self.generation_config.length_penalty = config.length_penalty self.generation_config.repetition_penalty = config.repetition_penalty self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size self.generation_config.eos_token_id = [ self.tokenizer.convert_tokens_to_ids("<|im_end|>"), self.tokenizer.convert_tokens_to_ids("<|endoftext|>"), ] self.generation_config.pad_token_id = self.tokenizer.pad_token_id # Feature extractor for audio preprocessing self.feature_extractor = self._create_feature_extractor(config) # Audio projector (trainable unless freeze_projector is set) self.projector = self._create_projector(config, target_dtype) # Setup LoRA if enabled (Stage 2 fine-tuning) # Skip if loading from pretrained - from_pretrained will handle adapter loading if getattr(config, "use_lora", False) and not getattr( self.__class__, "_is_loading_from_pretrained", False ): self._setup_lora(config) # Freeze projector if specified (for Stage 2 LoRA-only training) if getattr(config, "freeze_projector", False): self.projector.requires_grad_(False) # SpecAugment for data augmentation during training if getattr(config, "use_specaugment", False): self.spec_augment = SpecAugment( n_time_masks=config.num_time_masks, time_mask_param=config.time_mask_length, n_freq_masks=config.num_freq_masks, freq_mask_param=config.freq_mask_length, ) else: self.spec_augment = None # For model parallelism self._no_split_modules = getattr(self.language_model, "_no_split_modules", []) def _create_feature_extractor(self, config: ASRConfig): """Create the appropriate feature extractor for the audio encoder.""" from transformers import AutoFeatureExtractor return AutoFeatureExtractor.from_pretrained(config.audio_model_id) @classmethod def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module: """Load and freeze the audio encoder.""" encoder_kwargs = { "attn_implementation": config.attn_implementation, "low_cpu_mem_usage": True, "dtype": dtype, } if "whisper" in config.audio_model_id.lower(): from transformers import WhisperModel full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs) encoder = full_model.encoder del full_model elif "glm" in config.audio_model_id.lower(): # GLM-ASR models use audio_tower as the encoder # Requires transformers >= 5.x or installed from source from transformers import AutoModelForSeq2SeqLM full_model = AutoModelForSeq2SeqLM.from_pretrained( config.audio_model_id, trust_remote_code=True, **encoder_kwargs ) # GLM stores encoder at audio_tower (GlmAsrEncoder) encoder = full_model.audio_tower # Clear references to free VRAM from the LLM decoder full_model.language_model = None full_model.multi_modal_projector = None del full_model else: encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs) encoder.requires_grad_(False) encoder.eval() return encoder @classmethod def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel: """Load and freeze the language model.""" decoder_kwargs = { "attn_implementation": config.attn_implementation, "trust_remote_code": True, "tie_word_embeddings": False, "low_cpu_mem_usage": True, "dtype": dtype, } decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs) decoder.config.use_cache = getattr(config, "use_cache", True) decoder.requires_grad_(False) decoder.eval() return decoder def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module: """Create the trainable audio projector.""" # Auto-detect dimensions if not specified if config.encoder_dim is None: enc_cfg = self.audio_tower.config config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr( enc_cfg, "d_model", None ) if config.encoder_dim is None: raise ValueError("Could not auto-detect encoder_dim. Please specify in config.") if config.llm_dim is None: dec_cfg = self.language_model.config config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr( dec_cfg, "d_model", None ) if config.llm_dim is None: raise ValueError("Could not auto-detect llm_dim. Please specify in config.") # Select projector type based on config projector_type = getattr(config, "projector_type", "mlp") projector_class = PROJECTOR_CLASSES.get(projector_type) if projector_class is None: raise ValueError( f"Unknown projector_type: {projector_type}. " f"Valid options: {list(PROJECTOR_CLASSES.keys())}" ) projector = projector_class(config) # Move projector to same device as language model (important when using quantization) device = next(self.language_model.parameters()).device return projector.to(device=device, dtype=dtype) def _setup_lora(self, config: ASRConfig): """Apply LoRA adapters to the language model for Stage 2 fine-tuning.""" from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=config.lora_rank, lora_alpha=config.lora_alpha, target_modules=config.lora_target_modules, lora_dropout=config.lora_dropout, bias="none", task_type="CAUSAL_LM", ) self.language_model = get_peft_model(self.language_model, lora_config) # Clear base_model_name_or_path so PEFT doesn't save a reference to the # base LLM (e.g. Qwen). This prevents pipeline() from redirecting to the # wrong model. Use empty string to avoid str(None) -> "None" bug. self.language_model.peft_config["default"].base_model_name_or_path = "" def _init_tokenizer(self, config: ASRConfig): """Initialize tokenizer with audio token.""" self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True) # Set pad token if ( self.tokenizer.pad_token is None or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab(): self.tokenizer.pad_token = "<|finetune_right_pad_id|>" # Add audio token existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or [] if "