Update delta-iris/src/models/tokenizer/quantizer.py
Browse files
delta-iris/src/models/tokenizer/quantizer.py
CHANGED
|
@@ -7,15 +7,6 @@ import torch
|
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
|
| 10 |
-
|
| 11 |
-
@dataclass
|
| 12 |
-
class QuantizerOutput:
|
| 13 |
-
q: torch.FloatTensor
|
| 14 |
-
tokens: torch.LongTensor
|
| 15 |
-
loss: Dict[str, torch.FloatTensor]
|
| 16 |
-
metrics: Dict[str, float]
|
| 17 |
-
|
| 18 |
-
|
| 19 |
class Quantizer(nn.Module):
|
| 20 |
def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
|
| 21 |
super().__init__()
|
|
@@ -51,8 +42,12 @@ class Quantizer(nn.Module):
|
|
| 51 |
|
| 52 |
q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
|
| 53 |
tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
@torch.no_grad()
|
| 58 |
def update_codebook(self, z: torch.Tensor, tokens: torch.LongTensor) -> None:
|
|
|
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
class Quantizer(nn.Module):
|
| 11 |
def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
|
| 12 |
super().__init__()
|
|
|
|
| 42 |
|
| 43 |
q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
|
| 44 |
tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
|
| 45 |
+
return {
|
| 46 |
+
"q": q,
|
| 47 |
+
"tokens": tokens,
|
| 48 |
+
"losses": losses,
|
| 49 |
+
"metrics": metrics
|
| 50 |
+
}
|
| 51 |
|
| 52 |
@torch.no_grad()
|
| 53 |
def update_codebook(self, z: torch.Tensor, tokens: torch.LongTensor) -> None:
|