File size: 7,290 Bytes
7bef20f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from typing import Mapping, Text, Tuple
from einops import rearrange
from torch.cuda.amp import autocast


class SoftVectorQuantizer(torch.nn.Module):
    def __init__(self,
                 codebook_size: int = 1024,
                 token_size: int = 256,
                 commitment_cost: float = 0.25,
                 use_l2_norm: bool = False,
                 clustering_vq: bool = False,
                 entropy_loss_ratio: float = 0.01,
                 tau: float = 0.07,
                 num_codebooks: int = 1,
                 show_usage: bool = False
                 ):
        super().__init__()
        # Map new parameter names to internal names for compatibility
        self.codebook_size = codebook_size
        self.token_size = token_size
        self.commitment_cost = commitment_cost
        self.use_l2_norm = use_l2_norm
        self.clustering_vq = clustering_vq
        
        # Keep soft quantization specific parameters
        self.num_codebooks = num_codebooks
        self.n_e = codebook_size
        self.e_dim = token_size
        self.entropy_loss_ratio = entropy_loss_ratio
        self.l2_norm = use_l2_norm
        self.show_usage = show_usage
        self.tau = tau
        
        # Single embedding layer for all codebooks
        self.embedding = nn.Parameter(torch.randn(num_codebooks, codebook_size, token_size))
        self.embedding.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
        
        if self.l2_norm:
            self.embedding.data = F.normalize(self.embedding.data, p=2, dim=-1)
        
        if self.show_usage:
            self.register_buffer("codebook_used", torch.zeros(num_codebooks, 65536))

    # Ensure quantization is performed using f32
    @autocast(enabled=False)
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
        z = z.float()
        original_shape = z.shape
        
        # Handle input reshaping to match VectorQuantizer format
        z = rearrange(z, 'b c h w -> b h w c').contiguous()
        z = z.view(z.size(0), -1, z.size(-1))
        
        batch_size, seq_length, _ = z.shape
        
        # Ensure sequence length is divisible by number of codebooks
        assert seq_length % self.num_codebooks == 0, \
            f"Sequence length ({seq_length}) must be divisible by number of codebooks ({self.num_codebooks})"
        
        segment_length = seq_length // self.num_codebooks
        z_segments = z.view(batch_size, self.num_codebooks, segment_length, self.e_dim)
        
        # Apply L2 norm if needed
        embedding = F.normalize(self.embedding, p=2, dim=-1) if self.l2_norm else self.embedding
        if self.l2_norm:
            z_segments = F.normalize(z_segments, p=2, dim=-1)
            
        z_flat = z_segments.permute(1, 0, 2, 3).contiguous().view(self.num_codebooks, -1, self.e_dim)
        
        logits = torch.einsum('nbe, nke -> nbk', z_flat, embedding.detach())
        
        # Calculate probabilities (soft quantization)
        probs = F.softmax(logits / self.tau, dim=-1)  
        
        # Soft quantize
        z_q = torch.einsum('nbk, nke -> nbe', probs, embedding)
        
        # Reshape back
        z_q = z_q.view(self.num_codebooks, batch_size, segment_length, self.e_dim).permute(1, 0, 2, 3).contiguous()
        
        # Calculate cosine similarity
        with torch.no_grad():
            zq_z_cos = F.cosine_similarity(
                z_segments.view(-1, self.e_dim),
                z_q.view(-1, self.e_dim),
                dim=-1
            ).mean()
        
        # Get indices for usage tracking
        indices = torch.argmax(probs, dim=-1)  # (num_codebooks, batch_size * segment_length)
        indices = indices.transpose(0, 1).contiguous()  # (batch_size * segment_length, num_codebooks)
        
        # Track codebook usage
        if self.show_usage and self.training:
            for k in range(self.num_codebooks):
                cur_len = indices.size(0)
                self.codebook_used[k, :-cur_len].copy_(self.codebook_used[k, cur_len:].clone())
                self.codebook_used[k, -cur_len:].copy_(indices[:, k])
        
        # Calculate losses if training
        if self.training:
            # Soft quantization doesn't have traditional commitment/codebook loss
            # Map entropy loss to quantizer_loss for compatibility
            entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(logits.view(-1, self.n_e))
            quantizer_loss = entropy_loss
            commitment_loss = torch.tensor(0.0, device=z.device)
            codebook_loss = torch.tensor(0.0, device=z.device)
        else:
            quantizer_loss = torch.tensor(0.0, device=z.device)
            commitment_loss = torch.tensor(0.0, device=z.device)
            codebook_loss = torch.tensor(0.0, device=z.device)
        
        # Calculate codebook usage
        codebook_usage = torch.tensor([
            len(torch.unique(self.codebook_used[k])) / self.n_e 
            for k in range(self.num_codebooks)
        ]).mean() if self.show_usage else 0

        z_q = z_q.view(batch_size, -1, self.e_dim)
        
        # Reshape back to original input shape to match VectorQuantizer
        z_q = z_q.view(batch_size, original_shape[2], original_shape[3], original_shape[1])
        z_quantized = rearrange(z_q, 'b h w c -> b c h w').contiguous()
        
        # Calculate average probabilities
        avg_probs = torch.mean(torch.mean(probs, dim=-1))
        max_probs = torch.mean(torch.max(probs, dim=-1)[0])
        
        # Return format matching VectorQuantizer
        result_dict = dict(
            quantizer_loss=quantizer_loss,
            commitment_loss=commitment_loss,
            codebook_loss=codebook_loss,
            min_encoding_indices=indices.view(batch_size, self.num_codebooks, segment_length).view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
        )
        
        return z_quantized, result_dict

    def get_codebook_entry(self, indices):
        """Added for compatibility with VectorQuantizer API"""
        if len(indices.shape) == 1:
            # For single codebook case
            z_quantized = self.embedding[0][indices]
        elif len(indices.shape) == 2:
            z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding[0])
        else:
            raise NotImplementedError
        if self.use_l2_norm:
            z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
        return z_quantized


def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
    flat_affinity = affinity.reshape(-1, affinity.shape[-1])
    flat_affinity /= temperature
    probs = F.softmax(flat_affinity, dim=-1)
    log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
    if loss_type == "softmax":
        target_probs = probs
    else:
        raise ValueError("Entropy loss {} not supported".format(loss_type))
    avg_probs = torch.mean(target_probs, dim=0)
    avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-6))
    sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
    loss = sample_entropy - avg_entropy
    return loss