ESRT-4B / srt_model.py
yxdu's picture
Update srt_model.py
afbfdfd verified
Raw
History Blame Contribute Delete
13.6 kB
import os
import time
import asyncio
import httpx
import torch
import torch.nn as nn
from typing import List, Optional
from transformers import (
PretrainedConfig,
PreTrainedModel,
WhisperModel,
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
AutoModel,
Blip2QFormerConfig,
Blip2QFormerModel
)
# --- 1. 定义配置类 ---
class CustomSLMConfig(PretrainedConfig):
model_type = "custom_slm"
def __init__(
self,
encoder_path: str = "",
llm_path: str = "",
tokenizer_path: str = "",
llm_dim: int = 3840,
query_len: int = 80,
encoder_dim: int = 1280,
qformer_layers: int = 8,
vllm_path: str = "",
use_vllm: bool = False, # 新增传入参数
encoder_only: bool = False, # 新增传入参数
**kwargs
):
# 存储所有初始化需要的路径和超参数
self.encoder_path = encoder_path
self.llm_path = llm_path
self.tokenizer_path = tokenizer_path
self.llm_dim = llm_dim
self.query_len = query_len
self.encoder_dim = encoder_dim
self.qformer_layers = qformer_layers
self.vllm_path = vllm_path
self.use_vllm = use_vllm
self.encoder_only = encoder_only
super().__init__(**kwargs)
# --- 2. 辅助组件 ---
def compute_accuracy(pad_outputs, pad_targets, ignore_label):
mask = pad_targets != ignore_label
if mask.sum() == 0:
return torch.tensor(0.0).to(pad_outputs.device)
numerator = torch.sum(
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
)
denominator = torch.sum(mask)
return (numerator.float() / denominator.float()) * 100
class MLPProjector(nn.Module):
def __init__(self, encoder_dim=1280, llm_dim=3840):
super().__init__()
self.llm_dim = llm_dim
if self.llm_dim <= 1536:
self.linear = nn.Linear(encoder_dim, self.llm_dim)
self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
elif self.llm_dim <= 3072:
self.linear1 = nn.Linear(encoder_dim, 1536)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(1536, self.llm_dim)
self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
else:
self.linear1 = nn.Linear(encoder_dim, 2560)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(2560, self.llm_dim)
self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5)
def forward(self, x):
if self.llm_dim <= 1536:
return self.norm(self.linear(x))
else:
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return self.norm(x)
class QFormerModule(nn.Module):
def __init__(self, encoder_dim=1280, query_len=80, num_layers=8):
super().__init__()
self.query_len = query_len
configuration = Blip2QFormerConfig()
configuration.encoder_hidden_size = encoder_dim
configuration.num_hidden_layers = num_layers
self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size))
self.query.data.normal_(mean=0.0, std=1.0)
self.qformer = Blip2QFormerModel(configuration)
def forward(self, x, atts):
query = self.query.expand(x.shape[0], -1, -1)
query_output = self.qformer(
query_embeds=query,
encoder_hidden_states=x,
encoder_attention_mask=atts,
return_dict=True,
)
return query_output.last_hidden_state
# --- 3. 主模型类 ---
class CustomSLM(PreTrainedModel):
config_class = CustomSLMConfig
_auto_class = "AutoModel"
def __init__(self, config: CustomSLMConfig, **kwargs):
super().__init__(config, **kwargs)
# 从 kwargs 或 config 提取 use_vllm 与 vllm_path
self.vllm_path = kwargs.pop("vllm_path", getattr(config, "vllm_path", None))
self.use_vllm = kwargs.pop("use_vllm", getattr(config, "use_vllm", False))
self.encoder_only = kwargs.pop("encoder_only", getattr(config, "encoder_only", False))
encoder_config = AutoConfig.from_pretrained(config.encoder_path)
self.encoder = WhisperModel(encoder_config).encoder
# LLM 部分:构造空壳
llm_config = AutoConfig.from_pretrained(config.llm_path)
if self.encoder_only:
pass
else:
self.llm = AutoModelForCausalLM.from_config(
llm_config,
dtype=torch.bfloat16,
attn_implementation="sdpa"
)
if self.use_vllm:
print("⚠️ [Info] 检测到 use_vllm=True: 正在释放原生 LLM 解码器参数,仅保留 Embedding 层以防止 vLLM OOM!")
if hasattr(self.llm, "lm_head"):
self.llm.lm_head = nn.Identity()
if hasattr(self.llm, "model") and hasattr(self.llm.model, "layers"):
self.llm.model.layers = nn.ModuleList()
elif hasattr(self.llm, "transformer") and hasattr(self.llm.transformer, "h"):
self.llm.transformer.h = nn.ModuleList()
if hasattr(self.llm, "model") and hasattr(self.llm.model, "layers"):
self.llm.model.layers = nn.ModuleList()
# 兼容老版本 Qwen, ChatGLM 等
elif hasattr(self.llm, "transformer") and hasattr(self.llm.transformer, "h"):
self.llm.transformer.h = nn.ModuleList()
import gc
gc.collect()
# Q-Former 部分
self.q_former = QFormerModule(
encoder_dim=config.encoder_dim,
query_len=config.query_len,
num_layers=config.qformer_layers
)
# MLP 投影部分
qformer_hidden_size = Blip2QFormerConfig().hidden_size
self.mlp = MLPProjector(
encoder_dim=qformer_hidden_size,
llm_dim=config.llm_dim
)
self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path)
self.post_init()
def get_input_embeddings(self):
if hasattr(self, 'llm'):
return self.llm.get_input_embeddings()
return None
def forward(self,
audio_mel: torch.Tensor = None,
audio_mel_post_mask: torch.Tensor = None,
modality_mask: torch.Tensor = None,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
inference_mode: bool = False,
**kwargs):
if self.use_vllm and not inference_mode:
raise RuntimeError("模型初始化开启了 use_vllm,已释放深层参数,无法进行标准 forward 计算。")
# 1. 语音特征提取与投影
encoder_outs = self.encoder(audio_mel.permute(0, 2, 1)).last_hidden_state
encoder_outs = self.q_former(encoder_outs, audio_mel_post_mask)
encoder_outs = self.mlp(encoder_outs)
# 2. 获取文本 Embeddings
if input_ids is not None:
input_ids_cleaned = input_ids.clone()
input_ids_cleaned[input_ids_cleaned == -1] = 0
inputs_embeds = self.get_input_embeddings()(input_ids_cleaned)
# 3. 多模态融合 (Modality Mask 逻辑)
if modality_mask is not None:
modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1)
modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist()
encoder_outs_pad = torch.zeros_like(inputs_embeds)
for i in range(encoder_outs.shape[0]):
length = int(modality_lengths[i])
start = int(modality_mask_start_indices[i])
encoder_outs_pad[i, start : start + length] = encoder_outs[i, :length]
inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None])
if inference_mode:
return inputs_embeds, attention_mask
# 4. LLM 前向传播
model_outputs = self.llm(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels
)
# 5. 计算准确率 (仅用于训练/验证)
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, dim=-1)
acc = compute_accuracy(preds[:, :-1], labels[:, 1:], ignore_label=-100)
return model_outputs, acc
@torch.no_grad()
def generate(self, *args, **kwargs):
if getattr(self, "use_vllm", False):
raise RuntimeError("当前为 use_vllm=True,LLM 解码器未加载,无法执行原生 generate。请调用 translate_batch_* 方法。")
kwargs["inference_mode"] = True
inputs_embeds, attention_mask = self.forward(*args, **kwargs)
model_outputs = self.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 400),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", False),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
no_repeat_ngram_size=5,
attention_mask=attention_mask,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
)
return model_outputs
@torch.no_grad()
def translate_encode(self, mels):
device = self.device
dtype = self.dtype
mels = mels.to(device).to(dtype)
enc_out = self.encoder(mels).last_hidden_state
adapter_embeds = self.q_former(enc_out, None)
return adapter_embeds
def translate_batch_embeds(self, beam_search, embeds, prompts, max_new_tokens=300, use_vllm=False):
use_vllm_flag = use_vllm or getattr(self, "use_vllm", False)
if use_vllm_flag and getattr(self, "use_vllm", False) is False:
print("⚠️ 警告: 您通过参数强制要求 use_vllm,但初始化未设置 use_vllm=True。")
if use_vllm_flag:
import importlib
vllm_module = importlib.import_module("vllm")
LLM = vllm_module.LLM
SamplingParams = vllm_module.SamplingParams
def mark():
if self.device.type == 'mps': torch.mps.synchronize()
elif self.device.type in ['cuda', 'npu']: torch.cuda.synchronize()
return time.perf_counter()
if (not hasattr(self, '_vllm_llm') or self._vllm_llm is None) and use_vllm_flag:
self._vllm_llm = LLM(
self.vllm_path,
dtype="bfloat16",
trust_remote_code=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
enable_prompt_embeds=True,
max_model_len=1024,
)
device = self.device
dtype = self.dtype
t0 = mark()
adapter_embeds = embeds.to(device).to(dtype)
adapter_embeds = self.mlp(adapter_embeds)
text_inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(device)
text_embeds = self.get_input_embeddings()(text_inputs.input_ids)
inputs_embeds = torch.cat((adapter_embeds, text_embeds), dim=1)
t1 = mark()
if use_vllm_flag:
sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0.0,
)
prompt_embeds_list = [{"prompt_embeds": inputs_embeds[i]} for i in range(inputs_embeds.size(0))]
outputs = self._vllm_llm.generate(
prompt_embeds_list,
sampling_params=sampling_params,
)
t2 = mark()
print(f"⏱️ Embeds Prep: {t1-t0:.2f}s | LLM: {t2-t1:.2f}s | Total: {t2-t0:.2f}s")
return [output.outputs[0].text for output in outputs]
else:
output_ids = self.llm.generate(
inputs_embeds=inputs_embeds,
attention_mask=None,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=beam_search,
no_repeat_ngram_size=5,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
t2 = mark()
print(f"⏱️ Embeds Prep: {t1-t0:.2f}s | LLM: {t2-t1:.2f}s | Total: {t2-t0:.2f}s")
return [self.tokenizer.decode(g, skip_special_tokens=True) for g in output_ids]
# --- 4. 注册与保存脚本 ---
def register_and_save_model(model, tokenizer, args):
AutoConfig.register("custom_slm", CustomSLMConfig)
AutoModel.register(CustomSLMConfig, CustomSLM)
if hasattr(model, "use_lora") and model.use_lora:
print("Merging LoRA weights...")
model.llm = model.llm.merge_and_unload()
model.config.auto_map = {
"AutoConfig": "srt_model.CustomSLMConfig",
"AutoModel": "srt_model.CustomSLM"
}
model.save_pretrained(args.merge_model)
tokenizer.save_pretrained(args.merge_model)
print(f"Model saved to {args.merge_model}")