| import torch |
| import torch.nn as nn |
| import numpy as np |
| import whisper |
| import librosa |
| from transformers import ( |
| PreTrainedModel, |
| PretrainedConfig, |
| WhisperModel, |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| Blip2QFormerConfig, |
| Blip2QFormerModel |
| ) |
| from peft import LoraConfig, get_peft_model, PeftType |
|
|
| |
| class SMTConfig(PretrainedConfig): |
| model_type = "smt_model" |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
| |
| class EncoderProjectorQFormer(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.encoder_dim = 1280 |
| self.llm_dim = 3584 |
| configuration = Blip2QFormerConfig() |
| configuration.encoder_hidden_size = self.encoder_dim |
| configuration.num_hidden_layers = 8 |
| |
| self.query_len = 80 |
| self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size)) |
| self.query.data.normal_(mean=0.0, std=1.0) |
| self.qformer = Blip2QFormerModel(configuration) |
|
|
| self.linear1 = nn.Linear(configuration.hidden_size, 2560) |
| self.relu = nn.ReLU() |
| self.linear2 = nn.Linear(2560, self.llm_dim) |
| self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) |
| |
| def forward(self, x, atts): |
| query = self.query.expand(x.shape[0], -1, -1) |
| query_output = self.qformer(query_embeds=query, encoder_hidden_states=x, |
| encoder_attention_mask=atts, return_dict=True) |
| x = self.linear2(self.relu(self.linear1(query_output.last_hidden_state))) |
| return self.norm(x) |
|
|
| |
| class SMTModel(PreTrainedModel): |
| config_class = SMTConfig |
| base_model_prefix = "model" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| self.encoder = WhisperModel.from_pretrained( |
| "yxdu/whisper-large-v3-encoder", |
| torch_dtype=torch.bfloat16 |
| ).encoder |
| |
| |
| base_llm = AutoModelForCausalLM.from_pretrained( |
| "ModelSpace/GemmaX2-28-9B-v0.1", |
| torch_dtype=torch.bfloat16 |
| ) |
| peft_config = LoraConfig( |
| peft_type=PeftType.LORA, |
| task_type="CAUSAL_LM", |
| r=16, |
| target_modules=["q_proj", "v_proj"], |
| lora_alpha=32, |
| lora_dropout=0.05, |
| ) |
| self.llm = get_peft_model(base_llm, peft_config) |
| self.tokenizer = AutoTokenizer.from_pretrained("ModelSpace/GemmaX2-28-9B-v0.1") |
| self.encoder_projector = EncoderProjectorQFormer() |
|
|
| def _prepare_audio(self, audio_input): |
| """兼容路径 (str) 和 HF Audio (dict)""" |
| if isinstance(audio_input, str): |
| return whisper.load_audio(audio_input) |
| |
| if isinstance(audio_input, dict): |
| |
| audio_array = audio_input["array"] |
| sr = audio_input["sampling_rate"] |
| |
| if sr != 16000: |
| audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000) |
| return audio_array.astype(np.float32) |
| |
| raise ValueError("输入应为文件路径或 HF Audio 字典") |
|
|
| @torch.no_grad() |
| def translate_batch(self, audio_inputs, prompts, max_new_tokens=300): |
| """ |
| 批量推理核心逻辑 |
| audio_inputs: List[str] 或 List[dict] |
| """ |
| device = next(self.parameters()).device |
| dtype = next(self.parameters()).dtype |
|
|
| |
| mels = [] |
| valid_lengths = [] |
| for audio_in in audio_inputs: |
| audio = self._prepare_audio(audio_in) |
| |
| post_len = (len(audio) // 160 + 1) // 2 |
| valid_lengths.append(min(post_len, 1500)) |
|
|
| audio = whisper.pad_or_trim(audio) |
| mel = whisper.log_mel_spectrogram(audio, n_mels=128).unsqueeze(0) |
| mels.append(mel) |
| |
| mels = torch.cat(mels, dim=0).to(device).to(dtype) |
| |
| |
| enc_out = self.encoder(mels).last_hidden_state |
| enc_mask = torch.zeros(enc_out.size()[:-1], dtype=torch.long, device=device) |
| for i, length in enumerate(valid_lengths): |
| enc_mask[i, :length] = 1 |
| |
| audio_embeds = self.encoder_projector(enc_out, enc_mask) |
|
|
| |
| self.tokenizer.padding_side = "left" |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| text_inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(device) |
| text_embeds = self.llm.model.model.embed_tokens(text_inputs.input_ids) |
|
|
| |
| inputs_embeds = torch.cat((audio_embeds, text_embeds), dim=1) |
| audio_mask = torch.ones(audio_embeds.size()[:-1], dtype=torch.long, device=device) |
| attention_mask = torch.cat((audio_mask, text_inputs.attention_mask), dim=1) |
|
|
| |
| output_ids = self.llm.generate( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| no_repeat_ngram_size=5, |
| eos_token_id=self.tokenizer.eos_token_id, |
| pad_token_id=self.tokenizer.pad_token_id, |
| ) |
| |
| |
| results = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| return [res.strip() for res in results] |
|
|
| def forward(self, mel, prompt_ids): |
| """训练用前向传播""" |
| enc_out = self.encoder(mel).last_hidden_state |
| mask = torch.ones(enc_out.size()[:-1], dtype=torch.long, device=enc_out.device) |
| audio_embeds = self.encoder_projector(enc_out, mask) |
| text_embeds = self.llm.model.model.embed_tokens(prompt_ids) |
| inputs_embeds = torch.cat((audio_embeds, text_embeds), dim=1) |
| return self.llm(inputs_embeds=inputs_embeds, return_dict=True).logits |
|
|
| |
| SMTConfig.register_for_auto_class() |
| SMTModel.register_for_auto_class("AutoModel") |