import os import time import asyncio import httpx import torch import torch.nn as nn from typing import List, Optional from transformers import ( PretrainedConfig, PreTrainedModel, WhisperModel, AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, Blip2QFormerConfig, Blip2QFormerModel ) # --- 1. 定义配置类 --- class CustomSLMConfig(PretrainedConfig): model_type = "custom_slm" def __init__( self, encoder_path: str = "", llm_path: str = "", tokenizer_path: str = "", llm_dim: int = 3840, query_len: int = 80, encoder_dim: int = 1280, qformer_layers: int = 8, vllm_path: str = "", use_vllm: bool = False, # 新增传入参数 encoder_only: bool = False, # 新增传入参数 **kwargs ): # 存储所有初始化需要的路径和超参数 self.encoder_path = encoder_path self.llm_path = llm_path self.tokenizer_path = tokenizer_path self.llm_dim = llm_dim self.query_len = query_len self.encoder_dim = encoder_dim self.qformer_layers = qformer_layers self.vllm_path = vllm_path self.use_vllm = use_vllm self.encoder_only = encoder_only super().__init__(**kwargs) # --- 2. 辅助组件 --- def compute_accuracy(pad_outputs, pad_targets, ignore_label): mask = pad_targets != ignore_label if mask.sum() == 0: return torch.tensor(0.0).to(pad_outputs.device) numerator = torch.sum( pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) ) denominator = torch.sum(mask) return (numerator.float() / denominator.float()) * 100 class MLPProjector(nn.Module): def __init__(self, encoder_dim=1280, llm_dim=3840): super().__init__() self.llm_dim = llm_dim if self.llm_dim <= 1536: self.linear = nn.Linear(encoder_dim, self.llm_dim) self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) elif self.llm_dim <= 3072: self.linear1 = nn.Linear(encoder_dim, 1536) self.relu = nn.ReLU() self.linear2 = nn.Linear(1536, self.llm_dim) self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) else: self.linear1 = nn.Linear(encoder_dim, 2560) self.relu = nn.ReLU() self.linear2 = nn.Linear(2560, self.llm_dim) self.norm = nn.LayerNorm(self.llm_dim, eps=1e-5) def forward(self, x): if self.llm_dim <= 1536: return self.norm(self.linear(x)) else: x = self.linear1(x) x = self.relu(x) x = self.linear2(x) return self.norm(x) class QFormerModule(nn.Module): def __init__(self, encoder_dim=1280, query_len=80, num_layers=8): super().__init__() self.query_len = query_len configuration = Blip2QFormerConfig() configuration.encoder_hidden_size = encoder_dim configuration.num_hidden_layers = num_layers self.query = nn.Parameter(torch.zeros(1, self.query_len, configuration.hidden_size)) self.query.data.normal_(mean=0.0, std=1.0) self.qformer = Blip2QFormerModel(configuration) def forward(self, x, atts): query = self.query.expand(x.shape[0], -1, -1) query_output = self.qformer( query_embeds=query, encoder_hidden_states=x, encoder_attention_mask=atts, return_dict=True, ) return query_output.last_hidden_state # --- 3. 主模型类 --- class CustomSLM(PreTrainedModel): config_class = CustomSLMConfig _auto_class = "AutoModel" def __init__(self, config: CustomSLMConfig, **kwargs): super().__init__(config, **kwargs) # 从 kwargs 或 config 提取 use_vllm 与 vllm_path self.vllm_path = kwargs.pop("vllm_path", getattr(config, "vllm_path", None)) self.use_vllm = kwargs.pop("use_vllm", getattr(config, "use_vllm", False)) self.encoder_only = kwargs.pop("encoder_only", getattr(config, "encoder_only", False)) encoder_config = AutoConfig.from_pretrained(config.encoder_path) self.encoder = WhisperModel(encoder_config).encoder # LLM 部分:构造空壳 llm_config = AutoConfig.from_pretrained(config.llm_path) if self.encoder_only: pass else: self.llm = AutoModelForCausalLM.from_config( llm_config, dtype=torch.bfloat16, attn_implementation="sdpa" ) if self.use_vllm: print("⚠️ [Info] 检测到 use_vllm=True: 正在释放原生 LLM 解码器参数,仅保留 Embedding 层以防止 vLLM OOM!") if hasattr(self.llm, "lm_head"): self.llm.lm_head = nn.Identity() if hasattr(self.llm, "model") and hasattr(self.llm.model, "layers"): self.llm.model.layers = nn.ModuleList() elif hasattr(self.llm, "transformer") and hasattr(self.llm.transformer, "h"): self.llm.transformer.h = nn.ModuleList() if hasattr(self.llm, "model") and hasattr(self.llm.model, "layers"): self.llm.model.layers = nn.ModuleList() # 兼容老版本 Qwen, ChatGLM 等 elif hasattr(self.llm, "transformer") and hasattr(self.llm.transformer, "h"): self.llm.transformer.h = nn.ModuleList() import gc gc.collect() # Q-Former 部分 self.q_former = QFormerModule( encoder_dim=config.encoder_dim, query_len=config.query_len, num_layers=config.qformer_layers ) # MLP 投影部分 qformer_hidden_size = Blip2QFormerConfig().hidden_size self.mlp = MLPProjector( encoder_dim=qformer_hidden_size, llm_dim=config.llm_dim ) self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_path) self.post_init() def get_input_embeddings(self): if hasattr(self, 'llm'): return self.llm.get_input_embeddings() return None def forward(self, audio_mel: torch.Tensor = None, audio_mel_post_mask: torch.Tensor = None, modality_mask: torch.Tensor = None, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, inference_mode: bool = False, **kwargs): if self.use_vllm and not inference_mode: raise RuntimeError("模型初始化开启了 use_vllm,已释放深层参数,无法进行标准 forward 计算。") # 1. 语音特征提取与投影 encoder_outs = self.encoder(audio_mel.permute(0, 2, 1)).last_hidden_state encoder_outs = self.q_former(encoder_outs, audio_mel_post_mask) encoder_outs = self.mlp(encoder_outs) # 2. 获取文本 Embeddings if input_ids is not None: input_ids_cleaned = input_ids.clone() input_ids_cleaned[input_ids_cleaned == -1] = 0 inputs_embeds = self.get_input_embeddings()(input_ids_cleaned) # 3. 多模态融合 (Modality Mask 逻辑) if modality_mask is not None: modality_mask_start_indices = (modality_mask == True).float().argmax(dim=1) modality_lengths = torch.clamp(modality_mask.sum(dim=1), max=encoder_outs.shape[1]).tolist() encoder_outs_pad = torch.zeros_like(inputs_embeds) for i in range(encoder_outs.shape[0]): length = int(modality_lengths[i]) start = int(modality_mask_start_indices[i]) encoder_outs_pad[i, start : start + length] = encoder_outs[i, :length] inputs_embeds = encoder_outs_pad + inputs_embeds * (~modality_mask[:, :, None]) if inference_mode: return inputs_embeds, attention_mask # 4. LLM 前向传播 model_outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels ) # 5. 计算准确率 (仅用于训练/验证) with torch.no_grad(): preds = torch.argmax(model_outputs.logits, dim=-1) acc = compute_accuracy(preds[:, :-1], labels[:, 1:], ignore_label=-100) return model_outputs, acc @torch.no_grad() def generate(self, *args, **kwargs): if getattr(self, "use_vllm", False): raise RuntimeError("当前为 use_vllm=True,LLM 解码器未加载,无法执行原生 generate。请调用 translate_batch_* 方法。") kwargs["inference_mode"] = True inputs_embeds, attention_mask = self.forward(*args, **kwargs) model_outputs = self.llm.generate( inputs_embeds=inputs_embeds, max_new_tokens=kwargs.get("max_new_tokens", 400), num_beams=kwargs.get("num_beams", 1), do_sample=kwargs.get("do_sample", False), min_length=kwargs.get("min_length", 1), top_p=kwargs.get("top_p", 1.0), repetition_penalty=kwargs.get("repetition_penalty", 1.0), length_penalty=kwargs.get("length_penalty", 1.0), temperature=kwargs.get("temperature", 1.0), no_repeat_ngram_size=5, attention_mask=attention_mask, bos_token_id=self.tokenizer.bos_token_id, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id ) return model_outputs @torch.no_grad() def translate_encode(self, mels): device = self.device dtype = self.dtype mels = mels.to(device).to(dtype) enc_out = self.encoder(mels).last_hidden_state adapter_embeds = self.q_former(enc_out, None) return adapter_embeds def translate_batch_embeds(self, beam_search, embeds, prompts, max_new_tokens=300, use_vllm=False): use_vllm_flag = use_vllm or getattr(self, "use_vllm", False) if use_vllm_flag and getattr(self, "use_vllm", False) is False: print("⚠️ 警告: 您通过参数强制要求 use_vllm,但初始化未设置 use_vllm=True。") if use_vllm_flag: import importlib vllm_module = importlib.import_module("vllm") LLM = vllm_module.LLM SamplingParams = vllm_module.SamplingParams def mark(): if self.device.type == 'mps': torch.mps.synchronize() elif self.device.type in ['cuda', 'npu']: torch.cuda.synchronize() return time.perf_counter() if (not hasattr(self, '_vllm_llm') or self._vllm_llm is None) and use_vllm_flag: self._vllm_llm = LLM( self.vllm_path, dtype="bfloat16", trust_remote_code=True, tensor_parallel_size=1, gpu_memory_utilization=0.7, enable_prompt_embeds=True, max_model_len=1024, ) device = self.device dtype = self.dtype t0 = mark() adapter_embeds = embeds.to(device).to(dtype) adapter_embeds = self.mlp(adapter_embeds) text_inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(device) text_embeds = self.get_input_embeddings()(text_inputs.input_ids) inputs_embeds = torch.cat((adapter_embeds, text_embeds), dim=1) t1 = mark() if use_vllm_flag: sampling_params = SamplingParams( max_tokens=max_new_tokens, temperature=0.0, ) prompt_embeds_list = [{"prompt_embeds": inputs_embeds[i]} for i in range(inputs_embeds.size(0))] outputs = self._vllm_llm.generate( prompt_embeds_list, sampling_params=sampling_params, ) t2 = mark() print(f"⏱️ Embeds Prep: {t1-t0:.2f}s | LLM: {t2-t1:.2f}s | Total: {t2-t0:.2f}s") return [output.outputs[0].text for output in outputs] else: output_ids = self.llm.generate( inputs_embeds=inputs_embeds, attention_mask=None, max_new_tokens=max_new_tokens, do_sample=False, num_beams=beam_search, no_repeat_ngram_size=5, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, ) t2 = mark() print(f"⏱️ Embeds Prep: {t1-t0:.2f}s | LLM: {t2-t1:.2f}s | Total: {t2-t0:.2f}s") return [self.tokenizer.decode(g, skip_special_tokens=True) for g in output_ids] # --- 4. 注册与保存脚本 --- def register_and_save_model(model, tokenizer, args): AutoConfig.register("custom_slm", CustomSLMConfig) AutoModel.register(CustomSLMConfig, CustomSLM) if hasattr(model, "use_lora") and model.use_lora: print("Merging LoRA weights...") model.llm = model.llm.merge_and_unload() model.config.auto_map = { "AutoConfig": "srt_model.CustomSLMConfig", "AutoModel": "srt_model.CustomSLM" } model.save_pretrained(args.merge_model) tokenizer.save_pretrained(args.merge_model) print(f"Model saved to {args.merge_model}")