Spaces:
Paused
Paused
| """ | |
| Enhanced Text-to-Image Generation via Padding Token Orthogonalization | |
| This module implements the padding token orthogonalization method described in the poster | |
| "Enhanced Text-to-Image Generation via Padding Token Orthogonalization" by Jiafeng Mao, | |
| Qianru Qiu, Xueting Wang from CyberAgent AI Lab. | |
| The core idea is to use padding tokens as registers that collect, store, and redistribute | |
| features across layers via attention pathways through Gram-Schmidt orthogonalization. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| from typing import Optional, Tuple | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def orthogonalize_rows(X: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Orthogonalize rows of matrix X using QR decomposition. | |
| This is the core function from the poster: Q, _ = torch.linalg.qr(X.T) return Q.T | |
| Args: | |
| X: Input tensor of shape (..., n_rows, n_cols) | |
| Returns: | |
| Orthogonalized tensor of the same shape | |
| """ | |
| # Save original dtype and convert to float32 for QR decomposition | |
| original_dtype = X.dtype | |
| original_shape = X.shape | |
| # Convert to float32 if needed (QR doesn't support bfloat16) | |
| if X.dtype == torch.bfloat16: | |
| X = X.to(torch.float32) | |
| # Handle batch dimensions by flattening | |
| if X.dim() > 2: | |
| # Reshape to (batch_size, n_rows, n_cols) | |
| X_flat = X.view(-1, original_shape[-2], original_shape[-1]) | |
| results = [] | |
| for i in range(X_flat.shape[0]): | |
| # Apply QR decomposition: Q, _ = torch.linalg.qr(X.T) | |
| Q, _ = torch.linalg.qr(X_flat[i].T) | |
| # Return Q.T to get orthogonalized rows | |
| results.append(Q.T) | |
| result = torch.stack(results, dim=0) | |
| # Reshape back to original shape | |
| result = result.view(original_shape) | |
| else: | |
| # Simple 2D case | |
| Q, _ = torch.linalg.qr(X.T) | |
| result = Q.T | |
| # Convert back to original dtype | |
| if original_dtype == torch.bfloat16: | |
| result = result.to(original_dtype) | |
| return result | |
| class PaddingTokenOrthogonalizer(nn.Module): | |
| """ | |
| A module that applies padding token orthogonalization to text embeddings. | |
| Based on the poster's method, this enhances text-image alignment by: | |
| 1. Identifying padding tokens in the sequence | |
| 2. Orthogonalizing their representations using QR decomposition | |
| 3. Maintaining feature diversity and preventing biased attention | |
| """ | |
| def __init__( | |
| self, | |
| enabled: bool = True, | |
| preserve_norm: bool = True, | |
| orthogonalize_all: bool = False, | |
| ): | |
| """ | |
| Args: | |
| enabled: Whether to apply orthogonalization | |
| preserve_norm: Whether to preserve the original norm of tokens | |
| orthogonalize_all: If True, orthogonalize all tokens; if False, only padding tokens | |
| """ | |
| super().__init__() | |
| self.enabled = enabled | |
| self.preserve_norm = preserve_norm | |
| self.orthogonalize_all = orthogonalize_all | |
| def identify_padding_tokens( | |
| self, | |
| embeddings: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| pad_token_id: Optional[int] = None, | |
| input_ids: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| """ | |
| Identify padding token positions in the sequence. | |
| Args: | |
| embeddings: Token embeddings [batch, seq_len, hidden_size] | |
| attention_mask: Attention mask where 0 indicates padding | |
| pad_token_id: ID of the padding token | |
| input_ids: Input token IDs | |
| Returns: | |
| Boolean mask indicating padding positions [batch, seq_len] | |
| """ | |
| batch_size, seq_len = embeddings.shape[:2] | |
| if attention_mask is not None: | |
| # Attention mask: 1 for real tokens, 0 for padding | |
| return ~attention_mask.bool() | |
| elif pad_token_id is not None and input_ids is not None: | |
| return input_ids == pad_token_id | |
| else: | |
| # Fallback: assume last 25% of sequence are padding tokens | |
| # This is a heuristic based on common practice | |
| padding_start = int(seq_len * 0.75) | |
| mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=embeddings.device) | |
| mask[:, padding_start:] = True | |
| return mask | |
| def forward( | |
| self, | |
| embeddings: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| pad_token_id: Optional[int] = None, | |
| input_ids: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply padding token orthogonalization. | |
| Args: | |
| embeddings: Token embeddings [batch, seq_len, hidden_size] | |
| attention_mask: Attention mask where 1 indicates real tokens | |
| pad_token_id: ID of the padding token | |
| input_ids: Input token IDs | |
| Returns: | |
| Enhanced embeddings with orthogonalized padding tokens | |
| """ | |
| if not self.enabled: | |
| return embeddings | |
| # Store original norms if we need to preserve them | |
| if self.preserve_norm: | |
| original_norms = torch.norm(embeddings, dim=-1, keepdim=True) | |
| if self.orthogonalize_all: | |
| # Orthogonalize all tokens in the sequence | |
| enhanced_embeddings = orthogonalize_rows(embeddings) | |
| else: | |
| # Only orthogonalize padding tokens | |
| padding_mask = self.identify_padding_tokens( | |
| embeddings, attention_mask, pad_token_id, input_ids | |
| ) | |
| enhanced_embeddings = embeddings.clone() | |
| # Process each sample in the batch | |
| for batch_idx in range(embeddings.shape[0]): | |
| padding_indices = torch.where(padding_mask[batch_idx])[0] | |
| if len(padding_indices) > 1: # Need at least 2 tokens to orthogonalize | |
| # Extract padding token embeddings | |
| padding_embeddings = embeddings[batch_idx, padding_indices] | |
| # Apply orthogonalization | |
| orthogonalized = orthogonalize_rows(padding_embeddings) | |
| # Put back orthogonalized embeddings | |
| enhanced_embeddings[batch_idx, padding_indices] = orthogonalized | |
| # Restore original norms if requested | |
| if self.preserve_norm: | |
| current_norms = torch.norm(enhanced_embeddings, dim=-1, keepdim=True) | |
| enhanced_embeddings = enhanced_embeddings * (original_norms / (current_norms + 1e-8)) | |
| return enhanced_embeddings | |
| def apply_padding_token_orthogonalization( | |
| prompt_embeds: torch.Tensor, | |
| text_attention_mask: Optional[torch.Tensor] = None, | |
| config: Optional[dict] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Convenience function to apply padding token orthogonalization to prompt embeddings. | |
| Args: | |
| prompt_embeds: Text prompt embeddings [batch, seq_len, hidden_size] | |
| text_attention_mask: Attention mask for text tokens | |
| config: Configuration dictionary with orthogonalization settings | |
| Returns: | |
| Enhanced prompt embeddings | |
| """ | |
| if config is None: | |
| config = {} | |
| orthogonalizer = PaddingTokenOrthogonalizer( | |
| enabled=config.get('padding_orthogonalization_enabled', True), | |
| preserve_norm=config.get('preserve_norm', True), | |
| orthogonalize_all=config.get('orthogonalize_all_tokens', False), | |
| ) | |
| return orthogonalizer( | |
| embeddings=prompt_embeds, | |
| attention_mask=text_attention_mask, | |
| ) | |
| # Gram-Schmidt orthogonalization alternative implementation | |
| def gram_schmidt_orthogonalization(vectors: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: | |
| """ | |
| Alternative implementation using explicit Gram-Schmidt process. | |
| This provides more control but is generally slower than QR decomposition. | |
| Args: | |
| vectors: Input vectors to orthogonalize [n_vectors, dim] | |
| eps: Small epsilon for numerical stability | |
| Returns: | |
| Orthogonalized vectors | |
| """ | |
| n_vectors = vectors.shape[0] | |
| orthogonal_vectors = torch.zeros_like(vectors) | |
| for i in range(n_vectors): | |
| vector = vectors[i].clone() | |
| # Subtract projections onto previous orthogonal vectors | |
| for j in range(i): | |
| projection = torch.dot(vector, orthogonal_vectors[j]) / ( | |
| torch.dot(orthogonal_vectors[j], orthogonal_vectors[j]) + eps | |
| ) | |
| vector = vector - projection * orthogonal_vectors[j] | |
| # Normalize | |
| norm = torch.norm(vector) | |
| if norm > eps: | |
| orthogonal_vectors[i] = vector / norm | |
| else: | |
| # Handle zero vector case | |
| orthogonal_vectors[i] = vector | |
| return orthogonal_vectors | |