| |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional, Union, Tuple, Dict, Any |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import CrossEntropyLoss |
|
|
| from transformers import ( |
| AutoConfig, |
| AutoModelForCausalLM, |
| PretrainedConfig, |
| PreTrainedModel, |
| GenerationMixin, |
| ) |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
| class SpeakerProjector(nn.Module): |
| def __init__( |
| self, |
| input_dim: int, |
| output_dim: int, |
| hidden_dim: Optional[int] = None, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
|
|
| if hidden_dim is None: |
| self.net = nn.Sequential( |
| nn.Linear(input_dim, output_dim), |
| nn.Tanh(), |
| ) |
| else: |
| self.net = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, output_dim), |
| nn.Tanh(), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.net(x) |
|
|
|
|
| class SpeakerConditionedCausalLMConfig(PretrainedConfig): |
| model_type = "speaker_conditioned_causal_lm" |
|
|
| def __init__( |
| self, |
| base_model_name_or_path: Optional[str] = None, |
| speaker_embedding_dim: int = 512, |
| speaker_hidden_dim: Optional[int] = None, |
| speaker_dropout: float = 0.0, |
| speaker_token_id: Optional[int] = None, |
| freeze_base_model: bool = True, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.base_model_name_or_path = base_model_name_or_path |
| self.speaker_embedding_dim = speaker_embedding_dim |
| self.speaker_hidden_dim = speaker_hidden_dim |
| self.speaker_dropout = speaker_dropout |
| self.speaker_token_id = speaker_token_id |
| self.freeze_base_model = freeze_base_model |
|
|
|
|
| class SpeakerConditionedCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = SpeakerConditionedCausalLMConfig |
| base_model_prefix = "model" |
| main_input_name = "input_ids" |
| _no_split_modules = [] |
|
|
| def __init__( |
| self, |
| config: SpeakerConditionedCausalLMConfig, |
| base_lm: Optional[PreTrainedModel] = None, |
| ): |
| super().__init__(config) |
|
|
| if config.base_model_name_or_path is None and base_lm is None: |
| raise ValueError( |
| "You must provide either config.base_model_name_or_path or a preloaded base_lm." |
| ) |
|
|
| if base_lm is None: |
| self.model = AutoModelForCausalLM.from_pretrained( |
| config.base_model_name_or_path |
| ) |
| else: |
| self.model = base_lm |
|
|
| hidden_size = self.model.config.hidden_size |
|
|
| self.speaker_projector = SpeakerProjector( |
| input_dim=config.speaker_embedding_dim, |
| output_dim=hidden_size, |
| hidden_dim=config.speaker_hidden_dim, |
| dropout=config.speaker_dropout, |
| ) |
|
|
| |
| for attr in [ |
| "pad_token_id", |
| "bos_token_id", |
| "eos_token_id", |
| "vocab_size", |
| "tie_word_embeddings", |
| ]: |
| if hasattr(self.model.config, attr): |
| setattr(self.config, attr, getattr(self.model.config, attr)) |
|
|
| if config.freeze_base_model: |
| self.freeze_base_model() |
|
|
| |
| self.post_init() |
|
|
| def freeze_base_model(self) -> None: |
| for param in self.model.parameters(): |
| param.requires_grad = False |
|
|
| def unfreeze_base_model(self) -> None: |
| for param in self.model.parameters(): |
| param.requires_grad = True |
|
|
| def get_input_embeddings(self): |
| return self.model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value): |
| self.model.set_input_embeddings(value) |
|
|
| def get_output_embeddings(self): |
| return self.model.get_output_embeddings() |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.model.set_output_embeddings(new_embeddings) |
|
|
| def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None): |
| return self.model.resize_token_embeddings( |
| new_num_tokens=new_num_tokens, |
| pad_to_multiple_of=pad_to_multiple_of, |
| ) |
|
|
| def _inject_speaker_embeddings( |
| self, |
| input_ids: torch.LongTensor, |
| inputs_embeds: torch.FloatTensor, |
| speaker_embedding: Optional[torch.FloatTensor], |
| ) -> torch.FloatTensor: |
| """ |
| Replace the embedding at speaker_token_id with projected speaker_embedding. |
| |
| input_ids: [B, T] |
| inputs_embeds: [B, T, H] |
| speaker_embedding:[B, D_s] |
| """ |
| if speaker_embedding is None: |
| return inputs_embeds |
|
|
| if self.config.speaker_token_id is None: |
| raise ValueError("config.speaker_token_id must be set.") |
|
|
| |
| speaker_mask = input_ids.eq(self.config.speaker_token_id) |
|
|
| if not speaker_mask.any(): |
| |
| return inputs_embeds |
|
|
| |
| projected = self.speaker_projector(speaker_embedding) |
| projected = projected.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) |
|
|
| |
| inputs_embeds = inputs_embeds.clone() |
|
|
| |
| num_markers_per_seq = speaker_mask.sum(dim=1) |
| if not torch.all(num_markers_per_seq == 1): |
| raise ValueError( |
| "Each sequence must contain exactly one speaker token marker. " |
| f"Got counts: {num_markers_per_seq.tolist()}" |
| ) |
|
|
| batch_idx, time_idx = torch.where(speaker_mask) |
| inputs_embeds[batch_idx, time_idx] = projected[batch_idx] |
|
|
| return inputs_embeds |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor, ...], ...]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| speaker_embedding: Optional[torch.FloatTensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| """ |
| Standard CausalLM forward, but inject projected speaker_embedding at |
| the <|SPEAKER_TOKEN_POS|> position. |
| |
| Notes: |
| - input_ids is required whenever speaker injection is needed. |
| - during generation with cache, later decoding steps may not include the |
| speaker token anymore; that is fine because the prompt cache already |
| contains the injected information from the first step. |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| if inputs_embeds is not None and input_ids is None: |
| raise ValueError( |
| "This wrapper currently requires input_ids whenever inputs_embeds is passed, " |
| "because speaker injection is located via input_ids." |
| ) |
|
|
| if inputs_embeds is None: |
| if input_ids is None: |
| raise ValueError("You must provide input_ids or inputs_embeds.") |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
|
|
| if input_ids is not None: |
| inputs_embeds = self._inject_speaker_embeddings( |
| input_ids=input_ids, |
| inputs_embeds=inputs_embeds, |
| speaker_embedding=speaker_embedding, |
| ) |
|
|
| outputs = self.model( |
| input_ids=None, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| **kwargs, |
| ) |
|
|
| return outputs |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor, ...], ...]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| speaker_embedding: Optional[torch.FloatTensor] = None, |
| **kwargs, |
| ) -> Dict[str, Any]: |
| """ |
| Called internally by .generate(). |
| |
| We must preserve speaker_embedding across generation steps. |
| On the first step, the full prompt is passed and the speaker marker exists. |
| On later steps with cache, usually only the newest token is passed. |
| """ |
| if past_key_values is not None: |
| |
| input_ids = input_ids[:, -1:] |
|
|
| model_inputs = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache", True), |
| "speaker_embedding": speaker_embedding, |
| } |
|
|
| return model_inputs |
|
|
| @classmethod |
| def from_pretrained_base( |
| cls, |
| base_model_name_or_path: str, |
| speaker_embedding_dim: int, |
| speaker_token_id: int, |
| speaker_hidden_dim: Optional[int] = None, |
| speaker_dropout: float = 0.0, |
| freeze_base_model: bool = True, |
| **kwargs, |
| ) -> "SpeakerConditionedCausalLM": |
| """ |
| Convenience constructor for creating a fresh wrapper around a pretrained LM. |
| |
| Example: |
| model = SpeakerConditionedCausalLM.from_pretrained_base( |
| "meta-llama/Llama-3.2-1B-Instruct", |
| speaker_embedding_dim=256, |
| speaker_token_id=tokenizer.convert_tokens_to_ids("<|SPEAKER_TOKEN_POS|>") |
| ) |
| """ |
| config = SpeakerConditionedCausalLMConfig( |
| base_model_name_or_path=base_model_name_or_path, |
| speaker_embedding_dim=speaker_embedding_dim, |
| speaker_hidden_dim=speaker_hidden_dim, |
| speaker_dropout=speaker_dropout, |
| speaker_token_id=speaker_token_id, |
| freeze_base_model=freeze_base_model, |
| ) |
| return cls(config=config, **kwargs) |
|
|
| def save_pretrained(self, save_directory: str, **kwargs): |
| """ |
| Save: |
| - wrapper config |
| - wrapper state dict (includes projector + underlying model weights as currently attached) |
| |
| If base model is frozen and unchanged, this is still fine; checkpoint size will include it. |
| If you want smaller checkpoints, I mention a lighter alternative below. |
| """ |
| return super().save_pretrained(save_directory, **kwargs) |
| |
|
|
| |
| AutoConfig.register(SpeakerConditionedCausalLMConfig.model_type, SpeakerConditionedCausalLMConfig) |
| AutoModelForCausalLM.register(SpeakerConditionedCausalLMConfig, SpeakerConditionedCausalLM) |