File size: 9,035 Bytes
81bf056
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""
Enhanced Text-to-Image Generation via Padding Token Orthogonalization

This module implements the padding token orthogonalization method described in the poster
"Enhanced Text-to-Image Generation via Padding Token Orthogonalization" by Jiafeng Mao, 
Qianru Qiu, Xueting Wang from CyberAgent AI Lab.

The core idea is to use padding tokens as registers that collect, store, and redistribute
features across layers via attention pathways through Gram-Schmidt orthogonalization.
"""

import torch
import torch.nn as nn
from typing import Optional, Tuple
import logging

logger = logging.getLogger(__name__)


def orthogonalize_rows(X: torch.Tensor) -> torch.Tensor:
    """
    Orthogonalize rows of matrix X using QR decomposition.
    
    This is the core function from the poster: Q, _ = torch.linalg.qr(X.T) return Q.T
    
    Args:
        X: Input tensor of shape (..., n_rows, n_cols)
        
    Returns:
        Orthogonalized tensor of the same shape
    """
    # Save original dtype and convert to float32 for QR decomposition
    original_dtype = X.dtype
    original_shape = X.shape
    
    # Convert to float32 if needed (QR doesn't support bfloat16)
    if X.dtype == torch.bfloat16:
        X = X.to(torch.float32)
    
    # Handle batch dimensions by flattening
    if X.dim() > 2:
        # Reshape to (batch_size, n_rows, n_cols)
        X_flat = X.view(-1, original_shape[-2], original_shape[-1])
        results = []
        
        for i in range(X_flat.shape[0]):
            # Apply QR decomposition: Q, _ = torch.linalg.qr(X.T)
            Q, _ = torch.linalg.qr(X_flat[i].T)
            # Return Q.T to get orthogonalized rows
            results.append(Q.T)
        
        result = torch.stack(results, dim=0)
        # Reshape back to original shape
        result = result.view(original_shape)
    else:
        # Simple 2D case
        Q, _ = torch.linalg.qr(X.T)
        result = Q.T
    
    # Convert back to original dtype
    if original_dtype == torch.bfloat16:
        result = result.to(original_dtype)
    
    return result


class PaddingTokenOrthogonalizer(nn.Module):
    """
    A module that applies padding token orthogonalization to text embeddings.
    
    Based on the poster's method, this enhances text-image alignment by:
    1. Identifying padding tokens in the sequence
    2. Orthogonalizing their representations using QR decomposition
    3. Maintaining feature diversity and preventing biased attention
    """
    
    def __init__(
        self,
        enabled: bool = True,
        preserve_norm: bool = True,
        orthogonalize_all: bool = False,
    ):
        """
        Args:
            enabled: Whether to apply orthogonalization
            preserve_norm: Whether to preserve the original norm of tokens
            orthogonalize_all: If True, orthogonalize all tokens; if False, only padding tokens
        """
        super().__init__()
        self.enabled = enabled
        self.preserve_norm = preserve_norm
        self.orthogonalize_all = orthogonalize_all
        
    def identify_padding_tokens(
        self, 
        embeddings: torch.Tensor, 
        attention_mask: Optional[torch.Tensor] = None,
        pad_token_id: Optional[int] = None,
        input_ids: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Identify padding token positions in the sequence.
        
        Args:
            embeddings: Token embeddings [batch, seq_len, hidden_size]
            attention_mask: Attention mask where 0 indicates padding
            pad_token_id: ID of the padding token
            input_ids: Input token IDs
            
        Returns:
            Boolean mask indicating padding positions [batch, seq_len]
        """
        batch_size, seq_len = embeddings.shape[:2]
        
        if attention_mask is not None:
            # Attention mask: 1 for real tokens, 0 for padding
            return ~attention_mask.bool()
        elif pad_token_id is not None and input_ids is not None:
            return input_ids == pad_token_id
        else:
            # Fallback: assume last 25% of sequence are padding tokens
            # This is a heuristic based on common practice
            padding_start = int(seq_len * 0.75)
            mask = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=embeddings.device)
            mask[:, padding_start:] = True
            return mask
    
    def forward(
        self,
        embeddings: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        pad_token_id: Optional[int] = None,
        input_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Apply padding token orthogonalization.
        
        Args:
            embeddings: Token embeddings [batch, seq_len, hidden_size]
            attention_mask: Attention mask where 1 indicates real tokens
            pad_token_id: ID of the padding token
            input_ids: Input token IDs
            
        Returns:
            Enhanced embeddings with orthogonalized padding tokens
        """
        if not self.enabled:
            return embeddings
        
        # Store original norms if we need to preserve them
        if self.preserve_norm:
            original_norms = torch.norm(embeddings, dim=-1, keepdim=True)
        
        if self.orthogonalize_all:
            # Orthogonalize all tokens in the sequence
            enhanced_embeddings = orthogonalize_rows(embeddings)
        else:
            # Only orthogonalize padding tokens
            padding_mask = self.identify_padding_tokens(
                embeddings, attention_mask, pad_token_id, input_ids
            )
            
            enhanced_embeddings = embeddings.clone()
            
            # Process each sample in the batch
            for batch_idx in range(embeddings.shape[0]):
                padding_indices = torch.where(padding_mask[batch_idx])[0]
                
                if len(padding_indices) > 1:  # Need at least 2 tokens to orthogonalize
                    # Extract padding token embeddings
                    padding_embeddings = embeddings[batch_idx, padding_indices]
                    
                    # Apply orthogonalization
                    orthogonalized = orthogonalize_rows(padding_embeddings)
                    
                    # Put back orthogonalized embeddings
                    enhanced_embeddings[batch_idx, padding_indices] = orthogonalized
        
        # Restore original norms if requested
        if self.preserve_norm:
            current_norms = torch.norm(enhanced_embeddings, dim=-1, keepdim=True)
            enhanced_embeddings = enhanced_embeddings * (original_norms / (current_norms + 1e-8))
        
        return enhanced_embeddings


def apply_padding_token_orthogonalization(
    prompt_embeds: torch.Tensor,
    text_attention_mask: Optional[torch.Tensor] = None,
    config: Optional[dict] = None,
) -> torch.Tensor:
    """
    Convenience function to apply padding token orthogonalization to prompt embeddings.
    
    Args:
        prompt_embeds: Text prompt embeddings [batch, seq_len, hidden_size]
        text_attention_mask: Attention mask for text tokens
        config: Configuration dictionary with orthogonalization settings
        
    Returns:
        Enhanced prompt embeddings
    """
    if config is None:
        config = {}
    
    orthogonalizer = PaddingTokenOrthogonalizer(
        enabled=config.get('padding_orthogonalization_enabled', True),
        preserve_norm=config.get('preserve_norm', True),
        orthogonalize_all=config.get('orthogonalize_all_tokens', False),
    )
    
    return orthogonalizer(
        embeddings=prompt_embeds,
        attention_mask=text_attention_mask,
    )


# Gram-Schmidt orthogonalization alternative implementation
def gram_schmidt_orthogonalization(vectors: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Alternative implementation using explicit Gram-Schmidt process.
    This provides more control but is generally slower than QR decomposition.
    
    Args:
        vectors: Input vectors to orthogonalize [n_vectors, dim]
        eps: Small epsilon for numerical stability
        
    Returns:
        Orthogonalized vectors
    """
    n_vectors = vectors.shape[0]
    orthogonal_vectors = torch.zeros_like(vectors)
    
    for i in range(n_vectors):
        vector = vectors[i].clone()
        
        # Subtract projections onto previous orthogonal vectors
        for j in range(i):
            projection = torch.dot(vector, orthogonal_vectors[j]) / (
                torch.dot(orthogonal_vectors[j], orthogonal_vectors[j]) + eps
            )
            vector = vector - projection * orthogonal_vectors[j]
        
        # Normalize
        norm = torch.norm(vector)
        if norm > eps:
            orthogonal_vectors[i] = vector / norm
        else:
            # Handle zero vector case
            orthogonal_vectors[i] = vector
    
    return orthogonal_vectors