""" Injects the style vector into the model via soft prompt conditioning. The style vector is projected to the model's hidden dimension and prepended to the input token embeddings as virtual tokens. This technique is called "prefix tuning" / "style prefix injection". It biases the model's attention toward the desired output style without modifying the base model weights. For Flan-T5: injects into encoder input embeddings For BART: injects into encoder input embeddings For Llama: prepends to the full input context """ import torch import torch.nn as nn class StyleConditioner(nn.Module): """ Projects a 512-dim style vector to n_prefix_tokens virtual tokens in the model's embedding space. """ def __init__( self, style_dim: int = 512, model_hidden_dim: int = 512, # T5-Small=512, Base=768, Large=1024, XL=2048 n_prefix_tokens: int = 10, # Number of virtual prefix tokens ): super().__init__() self.style_dim = style_dim self.model_hidden_dim = model_hidden_dim self.n_prefix_tokens = n_prefix_tokens # Project style vector to prefix embeddings # style_dim → n_prefix_tokens * model_hidden_dim total_output_dim = n_prefix_tokens * model_hidden_dim self.projection = nn.Sequential( nn.Linear(style_dim, total_output_dim), nn.Tanh(), ) def forward(self, style_vector: torch.Tensor) -> torch.Tensor: """ Args: style_vector: [batch_size, 512] Returns: prefix_embeddings: [batch_size, n_prefix_tokens, model_hidden_dim] """ # Project: [batch, 512] → [batch, n_prefix * hidden_dim] projected = self.projection(style_vector) # Reshape: [batch, n_prefix * hidden_dim] → [batch, n_prefix, hidden_dim] batch_size = style_vector.size(0) prefix_embeddings = projected.view(batch_size, self.n_prefix_tokens, self.model_hidden_dim) return prefix_embeddings def prepend_style_prefix( input_embeddings: torch.Tensor, style_prefix: torch.Tensor, ) -> torch.Tensor: """ Concatenates style prefix to input embeddings along sequence dimension. Args: input_embeddings: [batch, seq_len, hidden_dim] style_prefix: [batch, n_prefix, hidden_dim] Returns: [batch, n_prefix + seq_len, hidden_dim] """ return torch.cat([style_prefix, input_embeddings], dim=1)