ShaswatRobotics commited on
Commit
c42256a
·
verified ·
1 Parent(s): f67d483

Update delta-iris/src/models/tokenizer/quantizer.py

Browse files
delta-iris/src/models/tokenizer/quantizer.py CHANGED
@@ -7,15 +7,6 @@ import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
 
10
-
11
- @dataclass
12
- class QuantizerOutput:
13
- q: torch.FloatTensor
14
- tokens: torch.LongTensor
15
- loss: Dict[str, torch.FloatTensor]
16
- metrics: Dict[str, float]
17
-
18
-
19
  class Quantizer(nn.Module):
20
  def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
21
  super().__init__()
@@ -51,8 +42,12 @@ class Quantizer(nn.Module):
51
 
52
  q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
53
  tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
54
-
55
- return QuantizerOutput(q, tokens, losses, metrics)
 
 
 
 
56
 
57
  @torch.no_grad()
58
  def update_codebook(self, z: torch.Tensor, tokens: torch.LongTensor) -> None:
 
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
 
 
 
 
 
 
 
 
 
 
10
  class Quantizer(nn.Module):
11
  def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int, max_codebook_updates_with_revival: Optional[int] = None) -> None:
12
  super().__init__()
 
42
 
43
  q = rearrange(q, '(b t k) e -> b t k e', b=b, k=k)
44
  tokens = rearrange(tokens, '(b t k) -> b t k', b=b, k=k)
45
+ return {
46
+ "q": q,
47
+ "tokens": tokens,
48
+ "losses": losses,
49
+ "metrics": metrics
50
+ }
51
 
52
  @torch.no_grad()
53
  def update_codebook(self, z: torch.Tensor, tokens: torch.LongTensor) -> None: