Ad_gen / cascade /padding_orthogonalization.py
Flulike99's picture
cascade
6df1cf2
"""
Enhanced Text-to-Image Generation via Padding Token Orthogonalization
This module implements the padding token orthogonalization method described in the poster
"Enhanced Text-to-Image Generation via Padding Token Orthogonalization" by Jiafeng Mao,
Qianru Qiu, Xueting Wang from CyberAgent AI Lab.
The core idea is to use padding tokens as registers that collect, store, and redistribute
features across layers via attention pathways through Gram-Schmidt orthogonalization.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
import logging
logger = logging.getLogger(__name__)
def orthogonalize_rows(X: torch.Tensor) -> torch.Tensor:
"""
Orthogonalize rows of matrix X using QR decomposition.
This is the core function from the poster: Q, _ = torch.linalg.qr(X.T) return Q.T
Args:
X: Input tensor of shape (..., n_rows, n_cols)
Returns:
Orthogonalized tensor of the same shape
"""
# Save original dtype and convert to float32 for QR decomposition
original_dtype = X.dtype
original_shape = X.shape
# Convert to float32 if needed (QR doesn't support bfloat16)
if X.dtype == torch.bfloat16:
X = X.to(torch.float32)
# Handle batch dimensions by flattening
if X.dim() > 2:
# Reshape to (batch_size, n_rows, n_cols)
X_flat = X.view(-1, original_shape[-2], original_shape[-1])
results = []
for i in range(X_flat.shape[0]):
# Apply QR decomposition: Q, _ = torch.linalg.qr(X.T)
Q, _ = torch.linalg.qr(X_flat[i].T)
# Return Q.T to get orthogonalized rows
results.append(Q.T)
result = torch.stack(results, dim=0)
# Reshape back to original shape
result = result.view(original_shape)
else:
# Simple 2D case
Q, _ = torch.linalg.qr(X.T)
result = Q.T
# Convert back to original dtype
if original_dtype == torch.bfloat16:
result = result.to(original_dtype)
return result
class PaddingTokenOrthogonalizer(nn.Module):
"""
A module that applies padding token orthogonalization to text embeddings.
Based on the poster's method, this enhances text-image alignment by:
1. Identifying padding tokens in the sequence
2. Orthogonalizing their representations using QR decomposition
3. Maintaining feature diversity and preventing biased attention
"""
def __init__(
self,
enabled: bool = True,
preserve_norm: bool = True,
orthogonalize_all: bool = False,
):
"""
Args:
enabled: Whether to apply orthogonalization
preserve_norm: Whether to preserve the original norm of tokens
orthogonalize_all: If True, orthogonalize all tokens; if False, only padding tokens
"""
super().__init__()
self.enabled = enabled
self.preserve_norm = preserve_norm
self.orthogonalize_all = orthogonalize_all
def identify_padding_tokens(
self,
embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
pad_token_id: Optional[int] = None,
input_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Identify padding token positions in the sequence.
Args:
embeddings: Token embeddings [batch, seq_len, hidden_size]
attention_mask: Attention mask where 0 indicates padding
pad_token_id: ID of the padding token
input_ids: Input token IDs
Returns:
Boolean mask indicating padding positions [batch, seq_len]
"""
batch_size, seq_len = embeddings.shape[:2]
if attention_mask is not None:
# Attention mask: 1 for real tokens, 0 for padding
return ~attention_mask.bool()
elif pad_token_id is not None and input_ids is not None:
return input_ids == pad_token_id
else:
# Fallback: assume last 25% of sequence are padding tokens
# This is a heuristic based on common practice
padding_start = int(seq_len * 0.75)
mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=embeddings.device)
mask[:, padding_start:] = True
return mask
def forward(
self,
embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
pad_token_id: Optional[int] = None,
input_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Apply padding token orthogonalization.
Args:
embeddings: Token embeddings [batch, seq_len, hidden_size]
attention_mask: Attention mask where 1 indicates real tokens
pad_token_id: ID of the padding token
input_ids: Input token IDs
Returns:
Enhanced embeddings with orthogonalized padding tokens
"""
if not self.enabled:
return embeddings
# Store original norms if we need to preserve them
if self.preserve_norm:
original_norms = torch.norm(embeddings, dim=-1, keepdim=True)
if self.orthogonalize_all:
# Orthogonalize all tokens in the sequence
enhanced_embeddings = orthogonalize_rows(embeddings)
else:
# Only orthogonalize padding tokens
padding_mask = self.identify_padding_tokens(
embeddings, attention_mask, pad_token_id, input_ids
)
enhanced_embeddings = embeddings.clone()
# Process each sample in the batch
for batch_idx in range(embeddings.shape[0]):
padding_indices = torch.where(padding_mask[batch_idx])[0]
if len(padding_indices) > 1: # Need at least 2 tokens to orthogonalize
# Extract padding token embeddings
padding_embeddings = embeddings[batch_idx, padding_indices]
# Apply orthogonalization
orthogonalized = orthogonalize_rows(padding_embeddings)
# Put back orthogonalized embeddings
enhanced_embeddings[batch_idx, padding_indices] = orthogonalized
# Restore original norms if requested
if self.preserve_norm:
current_norms = torch.norm(enhanced_embeddings, dim=-1, keepdim=True)
enhanced_embeddings = enhanced_embeddings * (original_norms / (current_norms + 1e-8))
return enhanced_embeddings
def apply_padding_token_orthogonalization(
prompt_embeds: torch.Tensor,
text_attention_mask: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
) -> torch.Tensor:
"""
Convenience function to apply padding token orthogonalization to prompt embeddings.
Args:
prompt_embeds: Text prompt embeddings [batch, seq_len, hidden_size]
text_attention_mask: Attention mask for text tokens
config: Configuration dictionary with orthogonalization settings
Returns:
Enhanced prompt embeddings
"""
if config is None:
config = {}
orthogonalizer = PaddingTokenOrthogonalizer(
enabled=config.get('padding_orthogonalization_enabled', True),
preserve_norm=config.get('preserve_norm', True),
orthogonalize_all=config.get('orthogonalize_all_tokens', False),
)
return orthogonalizer(
embeddings=prompt_embeds,
attention_mask=text_attention_mask,
)
# Gram-Schmidt orthogonalization alternative implementation
def gram_schmidt_orthogonalization(vectors: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""
Alternative implementation using explicit Gram-Schmidt process.
This provides more control but is generally slower than QR decomposition.
Args:
vectors: Input vectors to orthogonalize [n_vectors, dim]
eps: Small epsilon for numerical stability
Returns:
Orthogonalized vectors
"""
n_vectors = vectors.shape[0]
orthogonal_vectors = torch.zeros_like(vectors)
for i in range(n_vectors):
vector = vectors[i].clone()
# Subtract projections onto previous orthogonal vectors
for j in range(i):
projection = torch.dot(vector, orthogonal_vectors[j]) / (
torch.dot(orthogonal_vectors[j], orthogonal_vectors[j]) + eps
)
vector = vector - projection * orthogonal_vectors[j]
# Normalize
norm = torch.norm(vector)
if norm > eps:
orthogonal_vectors[i] = vector / norm
else:
# Handle zero vector case
orthogonal_vectors[i] = vector
return orthogonal_vectors