Safetensors
smt_model
custom_code
smt-9b-hf / models.py
yxdu's picture
Update models.py
3009f01 verified
import torch
import torch.nn as nn
import numpy as np
import whisper
import librosa
from transformers import (
PreTrainedModel,
PretrainedConfig,
WhisperModel,
AutoModelForCausalLM,
AutoTokenizer,
Blip2QFormerConfig,
Blip2QFormerModel
)
from peft import LoraConfig, get_peft_model, PeftType
# --- 1. 配置类 ---
class SMTConfig(PretrainedConfig):
model_type = "smt_model"
def __init__(self, **kwargs):
super().__init__(**kwargs)
# --- 2. Q-Former 投影层 ---
class EncoderProjectorQFormer(nn.Module):
def __init__(self):
super().__init__()
self.encoder_dim = 1280
self.llm_dim = 3584
configuration = Blip2QFormerConfig()
configuration.encoder_hidden_size = self.encoder_dim
configuration.num_hidden_layers = 8
self.query_len = 80
self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
self.query.data.normal_(mean=0.0, std=1.0)
self.qformer = Blip2QFormerModel(configuration)
self.linear1 = nn.Linear(configuration.hidden_size, 2560)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2560, self.llm_dim)
self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
def forward(self, x, atts):
query = self.query.expand(x.shape[0], -1, -1)
query_output = self.qformer(query_embeds=query, encoder_hidden_states=x,
encoder_attention_mask=atts, return_dict=True)
x = self.linear2(self.relu(self.linear1(query_output.last_hidden_state)))
return self.norm(x)
# --- 3. 主模型类 ---
class SMTModel(PreTrainedModel):
config_class = SMTConfig
base_model_prefix = "model"
def __init__(self, config):
super().__init__(config)
# 编码器
self.encoder = WhisperModel.from_pretrained(
"yxdu/whisper-large-v3-encoder",
torch_dtype=torch.bfloat16
).encoder
# LLM + LoRA
base_llm = AutoModelForCausalLM.from_pretrained(
"ModelSpace/GemmaX2-28-9B-v0.1",
torch_dtype=torch.bfloat16
)
peft_config = LoraConfig(
peft_type=PeftType.LORA,
task_type="CAUSAL_LM",
r=16,
target_modules=["q_proj", "v_proj"],
lora_alpha=32,
lora_dropout=0.05,
)
self.llm = get_peft_model(base_llm, peft_config)
self.tokenizer = AutoTokenizer.from_pretrained("ModelSpace/GemmaX2-28-9B-v0.1")
self.encoder_projector = EncoderProjectorQFormer()
def _prepare_audio(self, audio_input):
"""兼容路径 (str) 和 HF Audio (dict)"""
if isinstance(audio_input, str):
return whisper.load_audio(audio_input)
if isinstance(audio_input, dict):
# 提取数组与采样率
audio_array = audio_input["array"]
sr = audio_input["sampling_rate"]
# 强制重采样至 16kHz
if sr != 16000:
audio_array = librosa.resample(audio_array, orig_sr=sr, target_sr=16000)
return audio_array.astype(np.float32)
raise ValueError("输入应为文件路径或 HF Audio 字典")
@torch.no_grad()
def translate_batch(self, audio_inputs, prompts, max_new_tokens=300):
"""
批量推理核心逻辑
audio_inputs: List[str] 或 List[dict]
"""
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
# 1. 音频处理
mels = []
valid_lengths = []
for audio_in in audio_inputs:
audio = self._prepare_audio(audio_in)
# 计算 Whisper 降采样后的有效帧长
post_len = (len(audio) // 160 + 1) // 2
valid_lengths.append(min(post_len, 1500))
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio, n_mels=128).unsqueeze(0)
mels.append(mel)
mels = torch.cat(mels, dim=0).to(device).to(dtype)
# 2. 提取音频特征 + Mask 构造
enc_out = self.encoder(mels).last_hidden_state
enc_mask = torch.zeros(enc_out.size()[:-1], dtype=torch.long, device=device)
for i, length in enumerate(valid_lengths):
enc_mask[i, :length] = 1
audio_embeds = self.encoder_projector(enc_out, enc_mask)
# 3. 文本处理 (左填充)
self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
text_inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(device)
text_embeds = self.llm.model.model.embed_tokens(text_inputs.input_ids)
# 4. 拼接输入
inputs_embeds = torch.cat((audio_embeds, text_embeds), dim=1)
audio_mask = torch.ones(audio_embeds.size()[:-1], dtype=torch.long, device=device)
attention_mask = torch.cat((audio_mask, text_inputs.attention_mask), dim=1)
# 5. 执行推理
output_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
do_sample=False,
no_repeat_ngram_size=5,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
# 6. 后处理
results = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
return [res.strip() for res in results]
def forward(self, mel, prompt_ids):
"""训练用前向传播"""
enc_out = self.encoder(mel).last_hidden_state
mask = torch.ones(enc_out.size()[:-1], dtype=torch.long, device=enc_out.device)
audio_embeds = self.encoder_projector(enc_out, mask)
text_embeds = self.llm.model.model.embed_tokens(prompt_ids)
inputs_embeds = torch.cat((audio_embeds, text_embeds), dim=1)
return self.llm(inputs_embeds=inputs_embeds, return_dict=True).logits
# --- 4. 注册 ---
SMTConfig.register_for_auto_class()
SMTModel.register_for_auto_class("AutoModel")