""" Enhanced Text-to-Image Generation via Padding Token Orthogonalization This module implements the padding token orthogonalization method described in the poster "Enhanced Text-to-Image Generation via Padding Token Orthogonalization" by Jiafeng Mao, Qianru Qiu, Xueting Wang from CyberAgent AI Lab. The core idea is to use padding tokens as registers that collect, store, and redistribute features across layers via attention pathways through Gram-Schmidt orthogonalization. """ import torch import torch.nn as nn from typing import Optional, Tuple import logging logger = logging.getLogger(__name__) def orthogonalize_rows(X: torch.Tensor) -> torch.Tensor: """ Orthogonalize rows of matrix X using QR decomposition. This is the core function from the poster: Q, _ = torch.linalg.qr(X.T) return Q.T Args: X: Input tensor of shape (..., n_rows, n_cols) Returns: Orthogonalized tensor of the same shape """ # Save original dtype and convert to float32 for QR decomposition original_dtype = X.dtype original_shape = X.shape # Convert to float32 if needed (QR doesn't support bfloat16) if X.dtype == torch.bfloat16: X = X.to(torch.float32) # Handle batch dimensions by flattening if X.dim() > 2: # Reshape to (batch_size, n_rows, n_cols) X_flat = X.view(-1, original_shape[-2], original_shape[-1]) results = [] for i in range(X_flat.shape[0]): # Apply QR decomposition: Q, _ = torch.linalg.qr(X.T) Q, _ = torch.linalg.qr(X_flat[i].T) # Return Q.T to get orthogonalized rows results.append(Q.T) result = torch.stack(results, dim=0) # Reshape back to original shape result = result.view(original_shape) else: # Simple 2D case Q, _ = torch.linalg.qr(X.T) result = Q.T # Convert back to original dtype if original_dtype == torch.bfloat16: result = result.to(original_dtype) return result class PaddingTokenOrthogonalizer(nn.Module): """ A module that applies padding token orthogonalization to text embeddings. Based on the poster's method, this enhances text-image alignment by: 1. Identifying padding tokens in the sequence 2. Orthogonalizing their representations using QR decomposition 3. Maintaining feature diversity and preventing biased attention """ def __init__( self, enabled: bool = True, preserve_norm: bool = True, orthogonalize_all: bool = False, ): """ Args: enabled: Whether to apply orthogonalization preserve_norm: Whether to preserve the original norm of tokens orthogonalize_all: If True, orthogonalize all tokens; if False, only padding tokens """ super().__init__() self.enabled = enabled self.preserve_norm = preserve_norm self.orthogonalize_all = orthogonalize_all def identify_padding_tokens( self, embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, pad_token_id: Optional[int] = None, input_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Identify padding token positions in the sequence. Args: embeddings: Token embeddings [batch, seq_len, hidden_size] attention_mask: Attention mask where 0 indicates padding pad_token_id: ID of the padding token input_ids: Input token IDs Returns: Boolean mask indicating padding positions [batch, seq_len] """ batch_size, seq_len = embeddings.shape[:2] if attention_mask is not None: # Attention mask: 1 for real tokens, 0 for padding return ~attention_mask.bool() elif pad_token_id is not None and input_ids is not None: return input_ids == pad_token_id else: # Fallback: assume last 25% of sequence are padding tokens # This is a heuristic based on common practice padding_start = int(seq_len * 0.75) mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=embeddings.device) mask[:, padding_start:] = True return mask def forward( self, embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, pad_token_id: Optional[int] = None, input_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Apply padding token orthogonalization. Args: embeddings: Token embeddings [batch, seq_len, hidden_size] attention_mask: Attention mask where 1 indicates real tokens pad_token_id: ID of the padding token input_ids: Input token IDs Returns: Enhanced embeddings with orthogonalized padding tokens """ if not self.enabled: return embeddings # Store original norms if we need to preserve them if self.preserve_norm: original_norms = torch.norm(embeddings, dim=-1, keepdim=True) if self.orthogonalize_all: # Orthogonalize all tokens in the sequence enhanced_embeddings = orthogonalize_rows(embeddings) else: # Only orthogonalize padding tokens padding_mask = self.identify_padding_tokens( embeddings, attention_mask, pad_token_id, input_ids ) enhanced_embeddings = embeddings.clone() # Process each sample in the batch for batch_idx in range(embeddings.shape[0]): padding_indices = torch.where(padding_mask[batch_idx])[0] if len(padding_indices) > 1: # Need at least 2 tokens to orthogonalize # Extract padding token embeddings padding_embeddings = embeddings[batch_idx, padding_indices] # Apply orthogonalization orthogonalized = orthogonalize_rows(padding_embeddings) # Put back orthogonalized embeddings enhanced_embeddings[batch_idx, padding_indices] = orthogonalized # Restore original norms if requested if self.preserve_norm: current_norms = torch.norm(enhanced_embeddings, dim=-1, keepdim=True) enhanced_embeddings = enhanced_embeddings * (original_norms / (current_norms + 1e-8)) return enhanced_embeddings def apply_padding_token_orthogonalization( prompt_embeds: torch.Tensor, text_attention_mask: Optional[torch.Tensor] = None, config: Optional[dict] = None, ) -> torch.Tensor: """ Convenience function to apply padding token orthogonalization to prompt embeddings. Args: prompt_embeds: Text prompt embeddings [batch, seq_len, hidden_size] text_attention_mask: Attention mask for text tokens config: Configuration dictionary with orthogonalization settings Returns: Enhanced prompt embeddings """ if config is None: config = {} orthogonalizer = PaddingTokenOrthogonalizer( enabled=config.get('padding_orthogonalization_enabled', True), preserve_norm=config.get('preserve_norm', True), orthogonalize_all=config.get('orthogonalize_all_tokens', False), ) return orthogonalizer( embeddings=prompt_embeds, attention_mask=text_attention_mask, ) # Gram-Schmidt orthogonalization alternative implementation def gram_schmidt_orthogonalization(vectors: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: """ Alternative implementation using explicit Gram-Schmidt process. This provides more control but is generally slower than QR decomposition. Args: vectors: Input vectors to orthogonalize [n_vectors, dim] eps: Small epsilon for numerical stability Returns: Orthogonalized vectors """ n_vectors = vectors.shape[0] orthogonal_vectors = torch.zeros_like(vectors) for i in range(n_vectors): vector = vectors[i].clone() # Subtract projections onto previous orthogonal vectors for j in range(i): projection = torch.dot(vector, orthogonal_vectors[j]) / ( torch.dot(orthogonal_vectors[j], orthogonal_vectors[j]) + eps ) vector = vector - projection * orthogonal_vectors[j] # Normalize norm = torch.norm(vector) if norm > eps: orthogonal_vectors[i] = vector / norm else: # Handle zero vector case orthogonal_vectors[i] = vector return orthogonal_vectors