import json from pathlib import Path from threading import Thread from typing import Iterator, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 from transformers import ( AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, TextIteratorStreamer, ) from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast try: from .asr_config import ASRConfig, compute_encoder_output_length from .projectors import PROJECTOR_CLASSES except ImportError: from asr_config import ASRConfig, compute_encoder_output_length # type: ignore[no-redef] from projectors import PROJECTOR_CLASSES # type: ignore[no-redef] def _resolve_attn_implementation(requested: Optional[str]) -> Optional[str]: """Coerce flash_attention_2 to sdpa when CUDA isn't available. FA2 is CUDA-only. On MPS/CPU, requesting it either errors at load or silently falls back to a slower path; either way the user pays the FA2 install + import cost for no win. Coerce here so a saved config that pins flash_attention_2 still loads on Mac / CPU-only Linux boxes. """ if requested == "flash_attention_2" and not torch.cuda.is_available(): return "sdpa" return requested def _gather_audio_embeds(audio_embeds: torch.Tensor, token_counts: torch.Tensor) -> torch.Tensor: """Flatten per-sample audio embeddings into a packed tensor. For each row i, takes the first ``token_counts[i]`` rows of ``audio_embeds[i]`` and concatenates them. If any token count exceeds ``audio_embeds.shape[1]``, the deficit is zero-padded. Equivalent to a per-sample slice/cat loop but with O(1) host-device syncs per call (one ``max().item()``) instead of one per sample. """ _, max_len, _ = audio_embeds.shape needed = int(token_counts.max().item()) if needed > max_len: audio_embeds = F.pad(audio_embeds, (0, 0, 0, needed - max_len)) max_len = needed indices = torch.arange(max_len, device=audio_embeds.device).unsqueeze(0) mask = indices < token_counts.unsqueeze(1) return audio_embeds[mask] class ASRModel(PreTrainedModel, GenerationMixin): """Audio-to-text model combining an audio encoder, projector, and language model.""" config_class = ASRConfig base_model_prefix = "model" main_input_name = "input_features" _supports_flash_attn_2 = True supports_gradient_checkpointing = True _is_loading_from_pretrained: bool = False TRANSCRIBE_PROMPT = "Transcribe the speech to text" @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel": """Load model from pretrained, handling device placement correctly.""" from safetensors.torch import load_file from transformers.utils.hub import cached_file config = kwargs.pop("config", None) if config is None: config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) # Set flag to avoid device_map="auto" in sub-model loaders cls._is_loading_from_pretrained = True try: model = cls(config, **kwargs) # Load projector weights from safetensors subfolder = kwargs.get("subfolder") revision = kwargs.get("revision") cache_kwargs = {} if subfolder: cache_kwargs["subfolder"] = subfolder if revision: cache_kwargs["revision"] = revision model_file = cached_file( pretrained_model_name_or_path, "model.safetensors", _raise_exceptions_for_missing_entries=False, **cache_kwargs, ) if model_file is not None: state_dict = load_file(model_file) model.load_state_dict(state_dict, strict=False) # Load LoRA adapters if use_lora is enabled if getattr(config, "use_lora", False): # Check for adapter_config.json (required by PEFT to load adapters) adapter_config_file = cached_file( pretrained_model_name_or_path, "adapter_config.json", _raise_exceptions_for_missing_entries=False, **cache_kwargs, ) if adapter_config_file is not None: # Load saved adapter weights using the original repo_id/path # PEFT handles Hub downloads and caching internally from peft import PeftModel model.language_model = PeftModel.from_pretrained( model.language_model, pretrained_model_name_or_path, is_trainable=True, **cache_kwargs, ) else: # No saved adapters - initialize fresh LLM LoRA for training from peft import LoraConfig, get_peft_model lora_config = LoraConfig( r=config.lora_rank, lora_alpha=config.lora_alpha, target_modules=config.lora_target_modules, lora_dropout=config.lora_dropout, bias="none", task_type="CAUSAL_LM", ) model.language_model = get_peft_model(model.language_model, lora_config) return model finally: cls._is_loading_from_pretrained = False def __init__(self, config: ASRConfig, **kwargs) -> None: super().__init__(config) self.system_prompt = config.system_prompt target_dtype = getattr(torch, config.model_dtype) # Audio encoder (frozen) self.audio_tower = self._load_audio_encoder(config, target_dtype) # Language model (frozen) self.language_model = self._load_language_model(config, target_dtype) # Initialize tokenizer and special tokens self._init_tokenizer(config) # Set up generation config with greedy decoding defaults self.generation_config = self.language_model.generation_config self.generation_config.max_new_tokens = config.max_new_tokens self.generation_config.min_new_tokens = config.min_new_tokens self.generation_config.num_beams = config.num_beams self.generation_config.do_sample = config.do_sample # Set sampling params from config (None means use model defaults) self.generation_config.temperature = config.temperature self.generation_config.top_p = config.top_p self.generation_config.top_k = config.top_k self.generation_config.use_cache = config.use_cache self.generation_config.length_penalty = config.length_penalty self.generation_config.repetition_penalty = config.repetition_penalty self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size # Set EOS tokens, filtering out any that don't exist in the tokenizer eos_candidates = [ self.tokenizer.convert_tokens_to_ids("<|im_end|>"), self.tokenizer.convert_tokens_to_ids("<|endoftext|>"), ] self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None] self.generation_config.pad_token_id = self.tokenizer.pad_token_id # Feature extractor for audio preprocessing self.feature_extractor = self._create_feature_extractor(config) # Audio projector (trainable unless freeze_projector is set) self.projector = self._create_projector(config, target_dtype) # Setup LoRA if enabled (Stage 2 fine-tuning) # Skip if loading from pretrained - from_pretrained will handle adapter loading if getattr(config, "use_lora", False) and not getattr( self.__class__, "_is_loading_from_pretrained", False ): self._setup_lora(config) # Freeze projector if specified (for Stage 2 LoRA-only training) if getattr(config, "freeze_projector", False): self.projector.requires_grad_(False) # Freeze the text-vocab embedding table (preserves base Qwen3's # token→embedding mapping during joint fine-tune). With # tie_word_embeddings=True the same tensor backs lm_head, so this # also freezes the output projection. Audio tokens bypass this # table — they're scattered into inputs_embeds via masked_scatter # at