Delete delta-iris/src/models/tokenizer
Browse files- delta-iris/src/models/tokenizer/__init__.py +0 -0
- delta-iris/src/models/tokenizer/__pycache__/__init__.cpython-310.pyc +0 -0
- delta-iris/src/models/tokenizer/__pycache__/quantizer.cpython-310.pyc +0 -0
- delta-iris/src/models/tokenizer/__pycache__/tokenizer.cpython-310.pyc +0 -0
- delta-iris/src/models/tokenizer/quantizer.py +0 -90
delta-iris/src/models/tokenizer/__init__.py
DELETED
|
File without changes
|
delta-iris/src/models/tokenizer/__pycache__/__init__.cpython-310.pyc
DELETED
|
Binary file (257 Bytes)
|
|
|
delta-iris/src/models/tokenizer/__pycache__/quantizer.cpython-310.pyc
DELETED
|
Binary file (3.74 kB)
|
|
|
delta-iris/src/models/tokenizer/__pycache__/tokenizer.cpython-310.pyc
DELETED
|
Binary file (4.8 kB)
|
|
|
delta-iris/src/models/tokenizer/quantizer.py
DELETED
|
@@ -1,90 +0,0 @@
|
|
| 1 |
-
from dataclasses import dataclass
|
| 2 |
-
import math
|
| 3 |
-
from typing import Dict, Optional
|
| 4 |
-
|
| 5 |
-
from einops import rearrange
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
|
| 10 |
-
class Quantizer(nn.Module):
|
| 11 |
-
def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
|
| 12 |
-
super().__init__()
|
| 13 |
-
assert math.log2(codebook_size).is_integer()
|
| 14 |
-
self.revival_entropy_threshold = int(math.log2(codebook_size)) - 2
|
| 15 |
-
self.max_codebook_updates_with_revival = max_codebook_updates_with_revival
|
| 16 |
-
self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
|
| 17 |
-
self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
|
| 18 |
-
codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
|
| 19 |
-
self.register_buffer('num_codebook_updates', torch.tensor(0))
|
| 20 |
-
self.register_buffer('codebook', codebook)
|
| 21 |
-
self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))
|
| 22 |
-
|
| 23 |
-
def forward(self, z: torch.Tensor) -> dict:
|
| 24 |
-
z = self.pre_quant_proj(z)
|
| 25 |
-
z = F.normalize(z, dim=-1)
|
| 26 |
-
b, k = z.size(0), z.size(2)
|
| 27 |
-
z = rearrange(z, 'b t k e -> (b t k) e')
|
| 28 |
-
|
| 29 |
-
cosine_similarity = torch.einsum('n e, c e -> n c', z, self.codebook)
|
| 30 |
-
tokens = cosine_similarity.argmax(dim=-1)
|
| 31 |
-
q = self.codebook[tokens]
|
| 32 |
-
|
| 33 |
-
losses = {'commitment_loss': 0.02 * (z - q.detach()).pow(2).mean()}
|
| 34 |
-
|
| 35 |
-
if self.training:
|
| 36 |
-
metrics = {**self.update_codebook(z, tokens), 'codebook_entropy': self.compute_codebook_entropy()}
|
| 37 |
-
else:
|
| 38 |
-
metrics = {}
|
| 39 |
-
|
| 40 |
-
q = z + (q - z).detach()
|
| 41 |
-
q = self.post_quant_proj(q)
|
| 42 |
-
|
| 43 |
-
q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
|
| 44 |
-
tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
|
| 45 |
-
return {
|
| 46 |
-
"q": q,
|
| 47 |
-
"tokens": tokens,
|
| 48 |
-
"losses": losses,
|
| 49 |
-
"metrics": metrics
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
@torch.no_grad()
|
| 53 |
-
def update_codebook(self, z: torch.Tensor, tokens: torch.LongTensor) -> None:
|
| 54 |
-
tokens_one_hot = F.one_hot(tokens, self.codebook.size(0)).float() # (N, C)
|
| 55 |
-
|
| 56 |
-
# Update codebook
|
| 57 |
-
counts = tokens_one_hot.sum(dim=0)
|
| 58 |
-
codebook_update = torch.einsum('n e, n c -> c e', z, tokens_one_hot) / torch.clamp(counts.unsqueeze(-1), min=1)
|
| 59 |
-
codebook_update = F.normalize(codebook_update, dim=-1)
|
| 60 |
-
self.codebook.lerp_(codebook_update, 1 - 0.99)
|
| 61 |
-
|
| 62 |
-
# Update counts and revive dead codewords
|
| 63 |
-
freqs = counts / tokens_one_hot.size(0)
|
| 64 |
-
self.codewords_freqs.lerp_(freqs, 1 - 0.98)
|
| 65 |
-
|
| 66 |
-
can_revive = (self.compute_codebook_entropy() < 1) or (self.max_codebook_updates_with_revival is None) or (self.num_codebook_updates.item() < self.max_codebook_updates_with_revival)
|
| 67 |
-
if can_revive and (self.compute_codebook_entropy() < self.revival_entropy_threshold):
|
| 68 |
-
expired = torch.where(self.codewords_freqs < 1 / (10 * self.codewords_freqs.size(0)))[0]
|
| 69 |
-
num_expired = expired.size(0)
|
| 70 |
-
expired = expired[torch.randperm(num_expired)[:z.size(0)]]
|
| 71 |
-
idx_revived = torch.randperm(z.size(0), device=z.device)[:expired.size(0)]
|
| 72 |
-
self.codebook[expired] = z[idx_revived]
|
| 73 |
-
self.codewords_freqs[expired] = 1 / self.codewords_freqs.size(0)
|
| 74 |
-
else:
|
| 75 |
-
num_expired = 0
|
| 76 |
-
|
| 77 |
-
self.codebook = F.normalize(self.codebook, dim=-1)
|
| 78 |
-
|
| 79 |
-
self.num_codebook_updates += 1
|
| 80 |
-
metrics = {'codewords_revived': num_expired}
|
| 81 |
-
|
| 82 |
-
return metrics
|
| 83 |
-
|
| 84 |
-
def compute_codebook_entropy(self) -> float:
|
| 85 |
-
probs = self.codewords_freqs[self.codewords_freqs != 0]
|
| 86 |
-
return -(torch.log2(probs) * probs).sum().item()
|
| 87 |
-
|
| 88 |
-
@torch.no_grad()
|
| 89 |
-
def embed_tokens(self, tokens: torch.LongTensor) -> torch.FloatTensor:
|
| 90 |
-
return self.post_quant_proj(self.codebook[tokens])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|