Spaces:
Paused
Paused
File size: 9,035 Bytes
81bf056 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
"""
Enhanced Text-to-Image Generation via Padding Token Orthogonalization
This module implements the padding token orthogonalization method described in the poster
"Enhanced Text-to-Image Generation via Padding Token Orthogonalization" by Jiafeng Mao,
Qianru Qiu, Xueting Wang from CyberAgent AI Lab.
The core idea is to use padding tokens as registers that collect, store, and redistribute
features across layers via attention pathways through Gram-Schmidt orthogonalization.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple
import logging
logger = logging.getLogger(__name__)
def orthogonalize_rows(X: torch.Tensor) -> torch.Tensor:
"""
Orthogonalize rows of matrix X using QR decomposition.
This is the core function from the poster: Q, _ = torch.linalg.qr(X.T) return Q.T
Args:
X: Input tensor of shape (..., n_rows, n_cols)
Returns:
Orthogonalized tensor of the same shape
"""
# Save original dtype and convert to float32 for QR decomposition
original_dtype = X.dtype
original_shape = X.shape
# Convert to float32 if needed (QR doesn't support bfloat16)
if X.dtype == torch.bfloat16:
X = X.to(torch.float32)
# Handle batch dimensions by flattening
if X.dim() > 2:
# Reshape to (batch_size, n_rows, n_cols)
X_flat = X.view(-1, original_shape[-2], original_shape[-1])
results = []
for i in range(X_flat.shape[0]):
# Apply QR decomposition: Q, _ = torch.linalg.qr(X.T)
Q, _ = torch.linalg.qr(X_flat[i].T)
# Return Q.T to get orthogonalized rows
results.append(Q.T)
result = torch.stack(results, dim=0)
# Reshape back to original shape
result = result.view(original_shape)
else:
# Simple 2D case
Q, _ = torch.linalg.qr(X.T)
result = Q.T
# Convert back to original dtype
if original_dtype == torch.bfloat16:
result = result.to(original_dtype)
return result
class PaddingTokenOrthogonalizer(nn.Module):
"""
A module that applies padding token orthogonalization to text embeddings.
Based on the poster's method, this enhances text-image alignment by:
1. Identifying padding tokens in the sequence
2. Orthogonalizing their representations using QR decomposition
3. Maintaining feature diversity and preventing biased attention
"""
def __init__(
self,
enabled: bool = True,
preserve_norm: bool = True,
orthogonalize_all: bool = False,
):
"""
Args:
enabled: Whether to apply orthogonalization
preserve_norm: Whether to preserve the original norm of tokens
orthogonalize_all: If True, orthogonalize all tokens; if False, only padding tokens
"""
super().__init__()
self.enabled = enabled
self.preserve_norm = preserve_norm
self.orthogonalize_all = orthogonalize_all
def identify_padding_tokens(
self,
embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
pad_token_id: Optional[int] = None,
input_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Identify padding token positions in the sequence.
Args:
embeddings: Token embeddings [batch, seq_len, hidden_size]
attention_mask: Attention mask where 0 indicates padding
pad_token_id: ID of the padding token
input_ids: Input token IDs
Returns:
Boolean mask indicating padding positions [batch, seq_len]
"""
batch_size, seq_len = embeddings.shape[:2]
if attention_mask is not None:
# Attention mask: 1 for real tokens, 0 for padding
return ~attention_mask.bool()
elif pad_token_id is not None and input_ids is not None:
return input_ids == pad_token_id
else:
# Fallback: assume last 25% of sequence are padding tokens
# This is a heuristic based on common practice
padding_start = int(seq_len * 0.75)
mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=embeddings.device)
mask[:, padding_start:] = True
return mask
def forward(
self,
embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
pad_token_id: Optional[int] = None,
input_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Apply padding token orthogonalization.
Args:
embeddings: Token embeddings [batch, seq_len, hidden_size]
attention_mask: Attention mask where 1 indicates real tokens
pad_token_id: ID of the padding token
input_ids: Input token IDs
Returns:
Enhanced embeddings with orthogonalized padding tokens
"""
if not self.enabled:
return embeddings
# Store original norms if we need to preserve them
if self.preserve_norm:
original_norms = torch.norm(embeddings, dim=-1, keepdim=True)
if self.orthogonalize_all:
# Orthogonalize all tokens in the sequence
enhanced_embeddings = orthogonalize_rows(embeddings)
else:
# Only orthogonalize padding tokens
padding_mask = self.identify_padding_tokens(
embeddings, attention_mask, pad_token_id, input_ids
)
enhanced_embeddings = embeddings.clone()
# Process each sample in the batch
for batch_idx in range(embeddings.shape[0]):
padding_indices = torch.where(padding_mask[batch_idx])[0]
if len(padding_indices) > 1: # Need at least 2 tokens to orthogonalize
# Extract padding token embeddings
padding_embeddings = embeddings[batch_idx, padding_indices]
# Apply orthogonalization
orthogonalized = orthogonalize_rows(padding_embeddings)
# Put back orthogonalized embeddings
enhanced_embeddings[batch_idx, padding_indices] = orthogonalized
# Restore original norms if requested
if self.preserve_norm:
current_norms = torch.norm(enhanced_embeddings, dim=-1, keepdim=True)
enhanced_embeddings = enhanced_embeddings * (original_norms / (current_norms + 1e-8))
return enhanced_embeddings
def apply_padding_token_orthogonalization(
prompt_embeds: torch.Tensor,
text_attention_mask: Optional[torch.Tensor] = None,
config: Optional[dict] = None,
) -> torch.Tensor:
"""
Convenience function to apply padding token orthogonalization to prompt embeddings.
Args:
prompt_embeds: Text prompt embeddings [batch, seq_len, hidden_size]
text_attention_mask: Attention mask for text tokens
config: Configuration dictionary with orthogonalization settings
Returns:
Enhanced prompt embeddings
"""
if config is None:
config = {}
orthogonalizer = PaddingTokenOrthogonalizer(
enabled=config.get('padding_orthogonalization_enabled', True),
preserve_norm=config.get('preserve_norm', True),
orthogonalize_all=config.get('orthogonalize_all_tokens', False),
)
return orthogonalizer(
embeddings=prompt_embeds,
attention_mask=text_attention_mask,
)
# Gram-Schmidt orthogonalization alternative implementation
def gram_schmidt_orthogonalization(vectors: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
"""
Alternative implementation using explicit Gram-Schmidt process.
This provides more control but is generally slower than QR decomposition.
Args:
vectors: Input vectors to orthogonalize [n_vectors, dim]
eps: Small epsilon for numerical stability
Returns:
Orthogonalized vectors
"""
n_vectors = vectors.shape[0]
orthogonal_vectors = torch.zeros_like(vectors)
for i in range(n_vectors):
vector = vectors[i].clone()
# Subtract projections onto previous orthogonal vectors
for j in range(i):
projection = torch.dot(vector, orthogonal_vectors[j]) / (
torch.dot(orthogonal_vectors[j], orthogonal_vectors[j]) + eps
)
vector = vector - projection * orthogonal_vectors[j]
# Normalize
norm = torch.norm(vector)
if norm > eps:
orthogonal_vectors[i] = vector / norm
else:
# Handle zero vector case
orthogonal_vectors[i] = vector
return orthogonal_vectors
|