biosn2 commited on
Commit
6a652d2
·
verified ·
1 Parent(s): e072c2a

Upload indextts/gpt/conformer/embedding.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. indextts/gpt/conformer/embedding.py +163 -0
indextts/gpt/conformer/embedding.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Modified from ESPnet(https://github.com/espnet/espnet)
15
+
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+
24
+
25
+ class PositionalEncoding(torch.nn.Module):
26
+ """Positional encoding.
27
+
28
+ :param int d_model: embedding dim
29
+ :param float dropout_rate: dropout rate
30
+ :param int max_len: maximum input length
31
+
32
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
33
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
34
+ """
35
+ def __init__(self,
36
+ d_model: int,
37
+ dropout_rate: float,
38
+ max_len: int = 5000,
39
+ reverse: bool = False):
40
+ """Construct an PositionalEncoding object."""
41
+ super().__init__()
42
+ self.d_model = d_model
43
+ self.xscale = math.sqrt(self.d_model)
44
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
45
+ self.max_len = max_len
46
+
47
+ pe = torch.zeros(self.max_len, self.d_model)
48
+ position = torch.arange(0, self.max_len).unsqueeze(1)
49
+ div_term = torch.exp(
50
+ torch.arange(0, self.d_model, 2) *
51
+ -(math.log(10000.0) / self.d_model))
52
+ pe[:, 0::2] = torch.sin(position * div_term)
53
+ pe[:, 1::2] = torch.cos(position * div_term)
54
+ pe = pe.unsqueeze(0)
55
+ self.register_buffer('pe', pe)
56
+
57
+ def forward(self,
58
+ x: torch.Tensor,
59
+ offset: Union[int, torch.Tensor] = 0) \
60
+ -> Tuple[torch.Tensor, torch.Tensor]:
61
+ """Add positional encoding.
62
+
63
+ Args:
64
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
65
+ offset (int, torch.tensor): position offset
66
+
67
+ Returns:
68
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
69
+ torch.Tensor: for compatibility to RelPositionalEncoding
70
+ """
71
+
72
+ self.pe = self.pe.to(x.device)
73
+ pos_emb = self.position_encoding(offset, x.size(1), False)
74
+ x = x * self.xscale + pos_emb
75
+ return self.dropout(x), self.dropout(pos_emb)
76
+
77
+ def position_encoding(self, offset: Union[int, torch.Tensor], size: int,
78
+ apply_dropout: bool = True) -> torch.Tensor:
79
+ """ For getting encoding in a streaming fashion
80
+
81
+ Attention!!!!!
82
+ we apply dropout only once at the whole utterance level in a none
83
+ streaming way, but will call this function several times with
84
+ increasing input size in a streaming scenario, so the dropout will
85
+ be applied several times.
86
+
87
+ Args:
88
+ offset (int or torch.tensor): start offset
89
+ size (int): required size of position encoding
90
+
91
+ Returns:
92
+ torch.Tensor: Corresponding encoding
93
+ """
94
+ # How to subscript a Union type:
95
+ # https://github.com/pytorch/pytorch/issues/69434
96
+ if isinstance(offset, int):
97
+ assert offset + size < self.max_len
98
+ pos_emb = self.pe[:, offset:offset + size]
99
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
100
+ assert offset + size < self.max_len
101
+ pos_emb = self.pe[:, offset:offset + size]
102
+ else: # for batched streaming decoding on GPU
103
+ assert torch.max(offset) + size < self.max_len
104
+ index = offset.unsqueeze(1) + \
105
+ torch.arange(0, size).to(offset.device) # B X T
106
+ flag = index > 0
107
+ # remove negative offset
108
+ index = index * flag
109
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
110
+
111
+ if apply_dropout:
112
+ pos_emb = self.dropout(pos_emb)
113
+ return pos_emb
114
+
115
+ class RelPositionalEncoding(PositionalEncoding):
116
+ """Relative positional encoding module.
117
+ See : Appendix B in https://arxiv.org/abs/1901.02860
118
+ Args:
119
+ d_model (int): Embedding dimension.
120
+ dropout_rate (float): Dropout rate.
121
+ max_len (int): Maximum input length.
122
+ """
123
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
124
+ """Initialize class."""
125
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
126
+
127
+ def forward(self,
128
+ x: torch.Tensor,
129
+ offset: Union[int, torch.Tensor] = 0) \
130
+ -> Tuple[torch.Tensor, torch.Tensor]:
131
+ """Compute positional encoding.
132
+ Args:
133
+ x (torch.Tensor): Input tensor (batch, time, `*`).
134
+ Returns:
135
+ torch.Tensor: Encoded tensor (batch, time, `*`).
136
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
137
+ """
138
+ self.pe = self.pe.to(x.device)
139
+ x = x * self.xscale
140
+ pos_emb = self.position_encoding(offset, x.size(1), False)
141
+ return self.dropout(x), self.dropout(pos_emb)
142
+
143
+
144
+ class NoPositionalEncoding(torch.nn.Module):
145
+ """ No position encoding
146
+ """
147
+ def __init__(self, d_model: int, dropout_rate: float):
148
+ super().__init__()
149
+ self.d_model = d_model
150
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
151
+
152
+ def forward(self,
153
+ x: torch.Tensor,
154
+ offset: Union[int, torch.Tensor] = 0) \
155
+ -> Tuple[torch.Tensor, torch.Tensor]:
156
+ """ Just return zero vector for interface compatibility
157
+ """
158
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
159
+ return self.dropout(x), pos_emb
160
+
161
+ def position_encoding(
162
+ self, offset: Union[int, torch.Tensor], size: int) -> torch.Tensor:
163
+ return torch.zeros(1, size, self.d_model)