ShaswatRobotics commited on
Commit
f7aba9e
·
verified ·
1 Parent(s): 709389a

Delete delta-iris/src/models/tokenizer

Browse files
delta-iris/src/models/tokenizer/__init__.py DELETED
File without changes
delta-iris/src/models/tokenizer/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (257 Bytes)
 
delta-iris/src/models/tokenizer/__pycache__/quantizer.cpython-310.pyc DELETED
Binary file (3.74 kB)
 
delta-iris/src/models/tokenizer/__pycache__/tokenizer.cpython-310.pyc DELETED
Binary file (4.8 kB)
 
delta-iris/src/models/tokenizer/quantizer.py DELETED
@@ -1,90 +0,0 @@
1
- from dataclasses import dataclass
2
- import math
3
- from typing import Dict, Optional
4
-
5
- from einops import rearrange
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- class Quantizer(nn.Module):
11
- def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
12
- super().__init__()
13
- assert math.log2(codebook_size).is_integer()
14
- self.revival_entropy_threshold = int(math.log2(codebook_size)) - 2
15
- self.max_codebook_updates_with_revival = max_codebook_updates_with_revival
16
- self.pre_quant_proj = nn.Linear(input_dim, codebook_dim)
17
- self.post_quant_proj = nn.Linear(codebook_dim, input_dim)
18
- codebook = torch.empty(codebook_size, codebook_dim, requires_grad=False).uniform_(-1.0 / codebook_size, 1.0 / codebook_size)
19
- self.register_buffer('num_codebook_updates', torch.tensor(0))
20
- self.register_buffer('codebook', codebook)
21
- self.register_buffer('codewords_freqs', torch.ones(codebook_size).div(codebook_size))
22
-
23
- def forward(self, z: torch.Tensor) -> dict:
24
- z = self.pre_quant_proj(z)
25
- z = F.normalize(z, dim=-1)
26
- b, k = z.size(0), z.size(2)
27
- z = rearrange(z, 'b t k e -> (b t k) e')
28
-
29
- cosine_similarity = torch.einsum('n e, c e -> n c', z, self.codebook)
30
- tokens = cosine_similarity.argmax(dim=-1)
31
- q = self.codebook[tokens]
32
-
33
- losses = {'commitment_loss': 0.02 * (z - q.detach()).pow(2).mean()}
34
-
35
- if self.training:
36
- metrics = {**self.update_codebook(z, tokens), 'codebook_entropy': self.compute_codebook_entropy()}
37
- else:
38
- metrics = {}
39
-
40
- q = z + (q - z).detach()
41
- q = self.post_quant_proj(q)
42
-
43
- q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
44
- tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
45
- return {
46
- "q": q,
47
- "tokens": tokens,
48
- "losses": losses,
49
- "metrics": metrics
50
- }
51
-
52
- @torch.no_grad()
53
- def update_codebook(self, z: torch.Tensor, tokens: torch.LongTensor) -> None:
54
- tokens_one_hot = F.one_hot(tokens, self.codebook.size(0)).float() # (N, C)
55
-
56
- # Update codebook
57
- counts = tokens_one_hot.sum(dim=0)
58
- codebook_update = torch.einsum('n e, n c -> c e', z, tokens_one_hot) / torch.clamp(counts.unsqueeze(-1), min=1)
59
- codebook_update = F.normalize(codebook_update, dim=-1)
60
- self.codebook.lerp_(codebook_update, 1 - 0.99)
61
-
62
- # Update counts and revive dead codewords
63
- freqs = counts / tokens_one_hot.size(0)
64
- self.codewords_freqs.lerp_(freqs, 1 - 0.98)
65
-
66
- can_revive = (self.compute_codebook_entropy() < 1) or (self.max_codebook_updates_with_revival is None) or (self.num_codebook_updates.item() < self.max_codebook_updates_with_revival)
67
- if can_revive and (self.compute_codebook_entropy() < self.revival_entropy_threshold):
68
- expired = torch.where(self.codewords_freqs < 1 / (10 * self.codewords_freqs.size(0)))[0]
69
- num_expired = expired.size(0)
70
- expired = expired[torch.randperm(num_expired)[:z.size(0)]]
71
- idx_revived = torch.randperm(z.size(0), device=z.device)[:expired.size(0)]
72
- self.codebook[expired] = z[idx_revived]
73
- self.codewords_freqs[expired] = 1 / self.codewords_freqs.size(0)
74
- else:
75
- num_expired = 0
76
-
77
- self.codebook = F.normalize(self.codebook, dim=-1)
78
-
79
- self.num_codebook_updates += 1
80
- metrics = {'codewords_revived': num_expired}
81
-
82
- return metrics
83
-
84
- def compute_codebook_entropy(self) -> float:
85
- probs = self.codewords_freqs[self.codewords_freqs != 0]
86
- return -(torch.log2(probs) * probs).sum().item()
87
-
88
- @torch.no_grad()
89
- def embed_tokens(self, tokens: torch.LongTensor) -> torch.FloatTensor:
90
- return self.post_quant_proj(self.codebook[tokens])