Instructions to use yxdu/ESRT-4B with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use yxdu/ESRT-4B with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("automatic-speech-recognition", model="yxdu/ESRT-4B", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("yxdu/ESRT-4B", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import time | |
| import asyncio | |
| import httpx | |
| import torch | |
| import torch.nn as nn | |
| from typing import List, Optional | |
| from transformers import ( | |
| PretrainedConfig, | |
| PreTrainedModel, | |
| WhisperModel, | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| AutoConfig, | |
| AutoModel, | |
| Blip2QFormerConfig, | |
| Blip2QFormerModel | |
| ) | |
| # --- 1. 定义配置类 --- | |
| class CustomSLMConfig(PretrainedConfig): | |
| model_type = "custom_slm" | |
| def __init__( | |
| self, | |
| encoder_path: str = "", | |
| llm_path: str = "", | |
| tokenizer_path: str = "", | |
| llm_dim: int = 3840, | |
| query_len: int = 80, | |
| encoder_dim: int = 1280, | |
| qformer_layers: int = 8, | |
| vllm_path: str = "", | |
| use_vllm: bool = False, # 新增传入参数 | |
| encoder_only: bool = False, # 新增传入参数 | |
| **kwargs | |
| ): | |
| # 存储所有初始化需要的路径和超参数 | |
| self.encoder_path = encoder_path | |
| self.llm_path = llm_path | |
| self.tokenizer_path = tokenizer_path | |
| self.llm_dim = llm_dim | |
| self.query_len = query_len | |
| self.encoder_dim = encoder_dim | |
| self.qformer_layers = qformer_layers | |
| self.vllm_path = vllm_path | |
| self.use_vllm = use_vllm | |
| self.encoder_only = encoder_only | |
| super().__init__(**kwargs) | |
| # --- 2. 辅助组件 --- | |
| def compute_accuracy(pad_outputs, pad_targets, ignore_label): | |
| mask = pad_targets != ignore_label | |
| if mask.sum() == 0: | |
| return torch.tensor(0.0).to(pad_outputs.device) | |
| numerator = torch.sum( | |
| pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) | |
| ) | |
| denominator = torch.sum(mask) | |
| return (numerator.float() / denominator.float()) * 100 | |
| class MLPProjector(nn.Module): | |
| def __init__(self, encoder_dim=1280, llm_dim=3840): | |
| super().__init__() | |
| self.llm_dim = llm_dim | |
| if self.llm_dim <= 1536: | |
| self.linear = nn.Linear(encoder_dim, self.llm_dim) | |
| self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) | |
| elif self.llm_dim <= 3072: | |
| self.linear1 = nn.Linear(encoder_dim, 1536) | |
| self.relu = nn.ReLU() | |
| self.linear2 = nn.Linear(1536, self.llm_dim) | |
| self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) | |
| else: | |
| self.linear1 = nn.Linear(encoder_dim, 2560) | |
| self.relu = nn.ReLU() | |
| self.linear2 = nn.Linear(2560, self.llm_dim) | |
| self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) | |
| def forward(self, x): | |
| if self.llm_dim <= 1536: | |
| return self.norm(self.linear(x)) | |
| else: | |
| x = self.linear1(x) | |
| x = self.relu(x) | |
| x = self.linear2(x) | |
| return self.norm(x) | |
| class QFormerModule(nn.Module): | |
| def __init__(self, encoder_dim=1280, query_len=80, num_layers=8): | |
| super().__init__() | |
| self.query_len = query_len | |
| configuration = Blip2QFormerConfig() | |
| configuration.encoder_hidden_size = encoder_dim | |
| configuration.num_hidden_layers = num_layers | |
| self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size)) | |
| self.query.data.normal_(mean=0.0, std=1.0) | |
| self.qformer = Blip2QFormerModel(configuration) | |
| def forward(self, x, atts): | |
| query = self.query.expand(x.shape[0], -1, -1) | |
| query_output = self.qformer( | |
| query_embeds=query, | |
| encoder_hidden_states=x, | |
| encoder_attention_mask=atts, | |
| return_dict=True, | |
| ) | |
| return query_output.last_hidden_state | |
| # --- 3. 主模型类 --- | |
| class CustomSLM(PreTrainedModel): | |
| config_class = CustomSLMConfig | |
| _auto_class = "AutoModel" | |
| def __init__(self, config: CustomSLMConfig, **kwargs): | |
| super().__init__(config, **kwargs) | |
| # 从 kwargs 或 config 提取 use_vllm 与 vllm_path | |
| self.vllm_path = kwargs.pop("vllm_path", getattr(config, "vllm_path", None)) | |
| self.use_vllm = kwargs.pop("use_vllm", getattr(config, "use_vllm", False)) | |
| self.encoder_only = kwargs.pop("encoder_only", getattr(config, "encoder_only", False)) | |
| encoder_config = AutoConfig.from_pretrained(config.encoder_path) | |
| self.encoder = WhisperModel(encoder_config).encoder | |
| # LLM 部分:构造空壳 | |
| llm_config = AutoConfig.from_pretrained(config.llm_path) | |
| if self.encoder_only: | |
| pass | |
| else: | |
| self.llm = AutoModelForCausalLM.from_config( | |
| llm_config, | |
| dtype=torch.bfloat16, | |
| attn_implementation="sdpa" | |
| ) | |
| if self.use_vllm: | |
| print("⚠️ [Info] 检测到 use_vllm=True: 正在释放原生 LLM 解码器参数,仅保留 Embedding 层以防止 vLLM OOM!") | |
| if hasattr(self.llm, "lm_head"): | |
| self.llm.lm_head = nn.Identity() | |
| if hasattr(self.llm, "model") and hasattr(self.llm.model, "layers"): | |
| self.llm.model.layers = nn.ModuleList() | |
| elif hasattr(self.llm, "transformer") and hasattr(self.llm.transformer, "h"): | |
| self.llm.transformer.h = nn.ModuleList() | |
| if hasattr(self.llm, "model") and hasattr(self.llm.model, "layers"): | |
| self.llm.model.layers = nn.ModuleList() | |
| # 兼容老版本 Qwen, ChatGLM 等 | |
| elif hasattr(self.llm, "transformer") and hasattr(self.llm.transformer, "h"): | |
| self.llm.transformer.h = nn.ModuleList() | |
| import gc | |
| gc.collect() | |
| # Q-Former 部分 | |
| self.q_former = QFormerModule( | |
| encoder_dim=config.encoder_dim, | |
| query_len=config.query_len, | |
| num_layers=config.qformer_layers | |
| ) | |
| # MLP 投影部分 | |
| qformer_hidden_size = Blip2QFormerConfig().hidden_size | |
| self.mlp = MLPProjector( | |
| encoder_dim=qformer_hidden_size, | |
| llm_dim=config.llm_dim | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path) | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| if hasattr(self, 'llm'): | |
| return self.llm.get_input_embeddings() | |
| return None | |
| def forward(self, | |
| audio_mel: torch.Tensor = None, | |
| audio_mel_post_mask: torch.Tensor = None, | |
| modality_mask: torch.Tensor = None, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| inference_mode: bool = False, | |
| **kwargs): | |
| if self.use_vllm and not inference_mode: | |
| raise RuntimeError("模型初始化开启了 use_vllm,已释放深层参数,无法进行标准 forward 计算。") | |
| # 1. 语音特征提取与投影 | |
| encoder_outs = self.encoder(audio_mel.permute(0, 2, 1)).last_hidden_state | |
| encoder_outs = self.q_former(encoder_outs, audio_mel_post_mask) | |
| encoder_outs = self.mlp(encoder_outs) | |
| # 2. 获取文本 Embeddings | |
| if input_ids is not None: | |
| input_ids_cleaned = input_ids.clone() | |
| input_ids_cleaned[input_ids_cleaned == -1] = 0 | |
| inputs_embeds = self.get_input_embeddings()(input_ids_cleaned) | |
| # 3. 多模态融合 (Modality Mask 逻辑) | |
| if modality_mask is not None: | |
| modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1) | |
| modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist() | |
| encoder_outs_pad = torch.zeros_like(inputs_embeds) | |
| for i in range(encoder_outs.shape[0]): | |
| length = int(modality_lengths[i]) | |
| start = int(modality_mask_start_indices[i]) | |
| encoder_outs_pad[i, start : start + length] = encoder_outs[i, :length] | |
| inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None]) | |
| if inference_mode: | |
| return inputs_embeds, attention_mask | |
| # 4. LLM 前向传播 | |
| model_outputs = self.llm( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| labels=labels | |
| ) | |
| # 5. 计算准确率 (仅用于训练/验证) | |
| with torch.no_grad(): | |
| preds = torch.argmax(model_outputs.logits, dim=-1) | |
| acc = compute_accuracy(preds[:, :-1], labels[:, 1:], ignore_label=-100) | |
| return model_outputs, acc | |
| def generate(self, *args, **kwargs): | |
| if getattr(self, "use_vllm", False): | |
| raise RuntimeError("当前为 use_vllm=True,LLM 解码器未加载,无法执行原生 generate。请调用 translate_batch_* 方法。") | |
| kwargs["inference_mode"] = True | |
| inputs_embeds, attention_mask = self.forward(*args, **kwargs) | |
| model_outputs = self.llm.generate( | |
| inputs_embeds=inputs_embeds, | |
| max_new_tokens=kwargs.get("max_new_tokens", 400), | |
| num_beams=kwargs.get("num_beams", 1), | |
| do_sample=kwargs.get("do_sample", False), | |
| min_length=kwargs.get("min_length", 1), | |
| top_p=kwargs.get("top_p", 1.0), | |
| repetition_penalty=kwargs.get("repetition_penalty", 1.0), | |
| length_penalty=kwargs.get("length_penalty", 1.0), | |
| temperature=kwargs.get("temperature", 1.0), | |
| no_repeat_ngram_size=5, | |
| attention_mask=attention_mask, | |
| bos_token_id=self.tokenizer.bos_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id | |
| ) | |
| return model_outputs | |
| def translate_encode(self, mels): | |
| device = self.device | |
| dtype = self.dtype | |
| mels = mels.to(device).to(dtype) | |
| enc_out = self.encoder(mels).last_hidden_state | |
| adapter_embeds = self.q_former(enc_out, None) | |
| return adapter_embeds | |
| def translate_batch_embeds(self, beam_search, embeds, prompts, max_new_tokens=300, use_vllm=False): | |
| use_vllm_flag = use_vllm or getattr(self, "use_vllm", False) | |
| if use_vllm_flag and getattr(self, "use_vllm", False) is False: | |
| print("⚠️ 警告: 您通过参数强制要求 use_vllm,但初始化未设置 use_vllm=True。") | |
| if use_vllm_flag: | |
| import importlib | |
| vllm_module = importlib.import_module("vllm") | |
| LLM = vllm_module.LLM | |
| SamplingParams = vllm_module.SamplingParams | |
| def mark(): | |
| if self.device.type == 'mps': torch.mps.synchronize() | |
| elif self.device.type in ['cuda', 'npu']: torch.cuda.synchronize() | |
| return time.perf_counter() | |
| if (not hasattr(self, '_vllm_llm') or self._vllm_llm is None) and use_vllm_flag: | |
| self._vllm_llm = LLM( | |
| self.vllm_path, | |
| dtype="bfloat16", | |
| trust_remote_code=True, | |
| tensor_parallel_size=1, | |
| gpu_memory_utilization=0.7, | |
| enable_prompt_embeds=True, | |
| max_model_len=1024, | |
| ) | |
| device = self.device | |
| dtype = self.dtype | |
| t0 = mark() | |
| adapter_embeds = embeds.to(device).to(dtype) | |
| adapter_embeds = self.mlp(adapter_embeds) | |
| text_inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(device) | |
| text_embeds = self.get_input_embeddings()(text_inputs.input_ids) | |
| inputs_embeds = torch.cat((adapter_embeds, text_embeds), dim=1) | |
| t1 = mark() | |
| if use_vllm_flag: | |
| sampling_params = SamplingParams( | |
| max_tokens=max_new_tokens, | |
| temperature=0.0, | |
| ) | |
| prompt_embeds_list = [{"prompt_embeds": inputs_embeds[i]} for i in range(inputs_embeds.size(0))] | |
| outputs = self._vllm_llm.generate( | |
| prompt_embeds_list, | |
| sampling_params=sampling_params, | |
| ) | |
| t2 = mark() | |
| print(f"⏱️ Embeds Prep: {t1-t0:.2f}s | LLM: {t2-t1:.2f}s | Total: {t2-t0:.2f}s") | |
| return [output.outputs[0].text for output in outputs] | |
| else: | |
| output_ids = self.llm.generate( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=None, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| num_beams=beam_search, | |
| no_repeat_ngram_size=5, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| ) | |
| t2 = mark() | |
| print(f"⏱️ Embeds Prep: {t1-t0:.2f}s | LLM: {t2-t1:.2f}s | Total: {t2-t0:.2f}s") | |
| return [self.tokenizer.decode(g, skip_special_tokens=True) for g in output_ids] | |
| # --- 4. 注册与保存脚本 --- | |
| def register_and_save_model(model, tokenizer, args): | |
| AutoConfig.register("custom_slm", CustomSLMConfig) | |
| AutoModel.register(CustomSLMConfig, CustomSLM) | |
| if hasattr(model, "use_lora") and model.use_lora: | |
| print("Merging LoRA weights...") | |
| model.llm = model.llm.merge_and_unload() | |
| model.config.auto_map = { | |
| "AutoConfig": "srt_model.CustomSLMConfig", | |
| "AutoModel": "srt_model.CustomSLM" | |
| } | |
| model.save_pretrained(args.merge_model) | |
| tokenizer.save_pretrained(args.merge_model) | |
| print(f"Model saved to {args.merge_model}") |