rewrite / src /model /style_conditioner.py
morpheuslord's picture
Add files using upload-large-folder tool
12fd5f2 verified
"""
Injects the style vector into the model via soft prompt conditioning.
The style vector is projected to the model's hidden dimension and
prepended to the input token embeddings as virtual tokens.
This technique is called "prefix tuning" / "style prefix injection".
It biases the model's attention toward the desired output style
without modifying the base model weights.
For Flan-T5: injects into encoder input embeddings
For BART: injects into encoder input embeddings
For Llama: prepends to the full input context
"""
import torch
import torch.nn as nn
class StyleConditioner(nn.Module):
"""
Projects a 512-dim style vector to n_prefix_tokens virtual tokens
in the model's embedding space.
"""
def __init__(
self,
style_dim: int = 512,
model_hidden_dim: int = 512, # T5-Small=512, Base=768, Large=1024, XL=2048
n_prefix_tokens: int = 10, # Number of virtual prefix tokens
):
super().__init__()
self.style_dim = style_dim
self.model_hidden_dim = model_hidden_dim
self.n_prefix_tokens = n_prefix_tokens
# Project style vector to prefix embeddings
# style_dim → n_prefix_tokens * model_hidden_dim
total_output_dim = n_prefix_tokens * model_hidden_dim
self.projection = nn.Sequential(
nn.Linear(style_dim, total_output_dim),
nn.Tanh(),
)
def forward(self, style_vector: torch.Tensor) -> torch.Tensor:
"""
Args:
style_vector: [batch_size, 512]
Returns:
prefix_embeddings: [batch_size, n_prefix_tokens, model_hidden_dim]
"""
# Project: [batch, 512] → [batch, n_prefix * hidden_dim]
projected = self.projection(style_vector)
# Reshape: [batch, n_prefix * hidden_dim] → [batch, n_prefix, hidden_dim]
batch_size = style_vector.size(0)
prefix_embeddings = projected.view(batch_size, self.n_prefix_tokens, self.model_hidden_dim)
return prefix_embeddings
def prepend_style_prefix(
input_embeddings: torch.Tensor,
style_prefix: torch.Tensor,
) -> torch.Tensor:
"""
Concatenates style prefix to input embeddings along sequence dimension.
Args:
input_embeddings: [batch, seq_len, hidden_dim]
style_prefix: [batch, n_prefix, hidden_dim]
Returns:
[batch, n_prefix + seq_len, hidden_dim]
"""
return torch.cat([style_prefix, input_embeddings], dim=1)