File size: 4,024 Bytes
23bc32f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from dataclasses import dataclass
import math
from typing import Dict, Optional

from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class QuantizerOutput:
    q: torch.FloatTensor
    tokens: torch.LongTensor
    loss: Dict[str, torch.FloatTensor]
    metrics: Dict[str, float]


class Quantizer(nn.Module):
    def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
        super().__init__()
        assert math.log2(codebook_size).is_integer()
        self.revival_entropy_threshold = int(math.log2(codebook_size)) - 2
        self.max_codebook_updates_with_revival = max_codebook_updates_with_revival
        self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
        self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
        codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
        self.register_buffer('num_codebook_updates', torch.tensor(0))
        self.register_buffer('codebook', codebook)
        self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))

    def forward(self, z: torch.Tensor) -> QuantizerOutput:
        z = self.pre_quant_proj(z)
        z = F.normalize(z, dim=-1)
        b, k = z.size(0), z.size(2)
        z = rearrange(z, 'b t k e -> (b t k) e')

        cosine_similarity = torch.einsum('n e, c e -> n c', z, self.codebook)
        tokens = cosine_similarity.argmax(dim=-1)
        q = self.codebook[tokens]

        losses = {'commitment_loss': 0.02 * (z - q.detach()).pow(2).mean()}

        if self.training:
            metrics = {**self.update_codebook(z, tokens), 'codebook_entropy': self.compute_codebook_entropy()}
        else:
            metrics = {}

        q = z + (q - z).detach()
        q = self.post_quant_proj(q)

        q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
        tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)

        return QuantizerOutput(q, tokens, losses, metrics)

    @torch.no_grad()
    def update_codebook(self, z: torch.Tensor, tokens: torch.LongTensor) -> None:
        tokens_one_hot = F.one_hot(tokens, self.codebook.size(0)).float()  # (N, C)

        # Update codebook
        counts = tokens_one_hot.sum(dim=0) 
        codebook_update = torch.einsum('n e, n c -> c e', z, tokens_one_hot) / torch.clamp(counts.unsqueeze(-1), min=1)
        codebook_update = F.normalize(codebook_update, dim=-1)
        self.codebook.lerp_(codebook_update, 1 - 0.99)

        # Update counts and revive dead codewords
        freqs = counts / tokens_one_hot.size(0)
        self.codewords_freqs.lerp_(freqs, 1 - 0.98)

        can_revive = (self.compute_codebook_entropy() < 1) or (self.max_codebook_updates_with_revival is None) or (self.num_codebook_updates.item() < self.max_codebook_updates_with_revival) 
        if can_revive and (self.compute_codebook_entropy() < self.revival_entropy_threshold):
            expired = torch.where(self.codewords_freqs < 1 / (10 * self.codewords_freqs.size(0)))[0]
            num_expired = expired.size(0)
            expired = expired[torch.randperm(num_expired)[:z.size(0)]]
            idx_revived = torch.randperm(z.size(0), device=z.device)[:expired.size(0)]
            self.codebook[expired] = z[idx_revived]
            self.codewords_freqs[expired] = 1 / self.codewords_freqs.size(0)
        else:
            num_expired = 0

        self.codebook = F.normalize(self.codebook, dim=-1)

        self.num_codebook_updates += 1
        metrics = {'codewords_revived': num_expired}

        return metrics

    def compute_codebook_entropy(self) -> float:
        probs = self.codewords_freqs[self.codewords_freqs != 0]
        return -(torch.log2(probs) * probs).sum().item()

    @torch.no_grad()
    def embed_tokens(self, tokens: torch.LongTensor) -> torch.FloatTensor:
        return self.post_quant_proj(self.codebook[tokens])