karagmercola commited on
Commit
e859897
·
verified ·
1 Parent(s): 5f13ae1

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +201 -0
models.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchtune
6
+ from huggingface_hub import PyTorchModelHubMixin
7
+ from torchtune.models import llama3_2
8
+
9
+
10
+ def llama3_2_1B() -> torchtune.modules.transformer.TransformerDecoder:
11
+ return llama3_2.llama3_2(
12
+ vocab_size=128_256,
13
+ num_layers=16,
14
+ num_heads=32,
15
+ num_kv_heads=8,
16
+ embed_dim=2048,
17
+ max_seq_len=2048,
18
+ intermediate_dim=8192,
19
+ attn_dropout=0.0,
20
+ norm_eps=1e-5,
21
+ rope_base=500_000,
22
+ scale_factor=32,
23
+ )
24
+
25
+
26
+ def llama3_2_100M() -> torchtune.modules.transformer.TransformerDecoder:
27
+ return llama3_2.llama3_2(
28
+ vocab_size=128_256,
29
+ num_layers=4,
30
+ num_heads=8,
31
+ num_kv_heads=2,
32
+ embed_dim=1024,
33
+ max_seq_len=2048,
34
+ intermediate_dim=8192,
35
+ attn_dropout=0.0,
36
+ norm_eps=1e-5,
37
+ rope_base=500_000,
38
+ scale_factor=32,
39
+ )
40
+
41
+
42
+ FLAVORS = {
43
+ "llama-1B": llama3_2_1B,
44
+ "llama-100M": llama3_2_100M,
45
+ }
46
+
47
+
48
+ def _prepare_transformer(model):
49
+ embed_dim = model.tok_embeddings.embedding_dim
50
+ model.tok_embeddings = nn.Identity()
51
+ model.output = nn.Identity()
52
+ return model, embed_dim
53
+
54
+
55
+ def _create_causal_mask(seq_len: int, device: torch.device):
56
+ return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
57
+
58
+
59
+ def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
60
+ """
61
+ Args:
62
+ mask: (max_seq_len, max_seq_len)
63
+ input_pos: (batch_size, seq_len)
64
+ Returns:
65
+ (batch_size, seq_len, max_seq_len)
66
+ """
67
+ r = mask[input_pos, :]
68
+ return r
69
+
70
+
71
+ def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization
72
+ q = torch.empty_like(probs).exponential_(1)
73
+ return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int)
74
+
75
+
76
+ def sample_topk(logits: torch.Tensor, topk: int, temperature: float):
77
+ logits = logits / temperature
78
+
79
+ filter_value: float = -float("Inf")
80
+ indices_to_remove = logits < torch.topk(logits, topk)[0][..., -1, None]
81
+ scores_processed = logits.masked_fill(indices_to_remove, filter_value)
82
+ scores_processed = torch.nn.functional.log_softmax(scores_processed, dim=-1)
83
+ probs = torch.nn.functional.softmax(scores_processed, dim=-1)
84
+
85
+ sample_token = _multinomial_sample_one_no_sync(probs)
86
+ return sample_token
87
+
88
+
89
+ @dataclass
90
+ class ModelArgs:
91
+ backbone_flavor: str
92
+ decoder_flavor: str
93
+ text_vocab_size: int
94
+ audio_vocab_size: int
95
+ audio_num_codebooks: int
96
+
97
+
98
+ class Model(
99
+ nn.Module,
100
+ PyTorchModelHubMixin,
101
+ repo_url="https://github.com/SesameAILabs/csm",
102
+ pipeline_tag="text-to-speech",
103
+ license="apache-2.0",
104
+ ):
105
+ def __init__(self, config: ModelArgs):
106
+ super().__init__()
107
+ self.config = config
108
+
109
+ self.backbone, backbone_dim = _prepare_transformer(FLAVORS[config.backbone_flavor]())
110
+ self.decoder, decoder_dim = _prepare_transformer(FLAVORS[config.decoder_flavor]())
111
+
112
+ self.text_embeddings = nn.Embedding(config.text_vocab_size, backbone_dim)
113
+ self.audio_embeddings = nn.Embedding(config.audio_vocab_size * config.audio_num_codebooks, backbone_dim)
114
+
115
+ self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False)
116
+ self.codebook0_head = nn.Linear(backbone_dim, config.audio_vocab_size, bias=False)
117
+ self.audio_head = nn.Parameter(torch.empty(config.audio_num_codebooks - 1, decoder_dim, config.audio_vocab_size))
118
+
119
+ def setup_caches(self, max_batch_size: int) -> torch.Tensor:
120
+ """Setup KV caches and return a causal mask."""
121
+ dtype = next(self.parameters()).dtype
122
+ device = next(self.parameters()).device
123
+
124
+ with device:
125
+ self.backbone.setup_caches(max_batch_size, dtype)
126
+ self.decoder.setup_caches(max_batch_size, dtype, decoder_max_seq_len=self.config.audio_num_codebooks)
127
+
128
+ self.register_buffer("backbone_causal_mask", _create_causal_mask(self.backbone.max_seq_len, device))
129
+ self.register_buffer("decoder_causal_mask", _create_causal_mask(self.config.audio_num_codebooks, device))
130
+
131
+ def generate_frame(
132
+ self,
133
+ tokens: torch.Tensor,
134
+ tokens_mask: torch.Tensor,
135
+ input_pos: torch.Tensor,
136
+ temperature: float,
137
+ topk: int,
138
+ ) -> torch.Tensor:
139
+ """
140
+ Args:
141
+ tokens: (batch_size, seq_len, audio_num_codebooks+1)
142
+ tokens_mask: (batch_size, seq_len, audio_num_codebooks+1)
143
+ input_pos: (batch_size, seq_len) positions for each token
144
+ mask: (batch_size, seq_len, max_seq_len
145
+ Returns:
146
+ (batch_size, audio_num_codebooks) sampled tokens
147
+ """
148
+ dtype = next(self.parameters()).dtype
149
+ b, s, _ = tokens.size()
150
+
151
+ assert self.backbone.caches_are_enabled(), "backbone caches are not enabled"
152
+ curr_backbone_mask = _index_causal_mask(self.backbone_causal_mask, input_pos)
153
+ embeds = self._embed_tokens(tokens)
154
+ masked_embeds = embeds * tokens_mask.unsqueeze(-1)
155
+ h = masked_embeds.sum(dim=2)
156
+ h = self.backbone(h, input_pos=input_pos, mask=curr_backbone_mask).to(dtype=dtype)
157
+
158
+ last_h = h[:, -1, :]
159
+ c0_logits = self.codebook0_head(last_h)
160
+ c0_sample = sample_topk(c0_logits, topk, temperature)
161
+ c0_embed = self._embed_audio(0, c0_sample)
162
+
163
+ curr_h = torch.cat([last_h.unsqueeze(1), c0_embed], dim=1)
164
+ curr_sample = c0_sample.clone()
165
+ curr_pos = torch.arange(0, curr_h.size(1), device=curr_h.device).unsqueeze(0).repeat(curr_h.size(0), 1)
166
+
167
+ # Decoder caches must be reset every frame.
168
+ self.decoder.reset_caches()
169
+ for i in range(1, self.config.audio_num_codebooks):
170
+ curr_decoder_mask = _index_causal_mask(self.decoder_causal_mask, curr_pos)
171
+ decoder_h = self.decoder(self.projection(curr_h), input_pos=curr_pos, mask=curr_decoder_mask).to(
172
+ dtype=dtype
173
+ )
174
+ ci_logits = torch.mm(decoder_h[:, -1, :], self.audio_head[i - 1])
175
+ ci_sample = sample_topk(ci_logits, topk, temperature)
176
+ ci_embed = self._embed_audio(i, ci_sample)
177
+
178
+ curr_h = ci_embed
179
+ curr_sample = torch.cat([curr_sample, ci_sample], dim=1)
180
+ curr_pos = curr_pos[:, -1:] + 1
181
+
182
+ return curr_sample
183
+
184
+ def reset_caches(self):
185
+ self.backbone.reset_caches()
186
+ self.decoder.reset_caches()
187
+
188
+ def _embed_audio(self, codebook: int, tokens: torch.Tensor) -> torch.Tensor:
189
+ return self.audio_embeddings(tokens + codebook * self.config.audio_vocab_size)
190
+
191
+ def _embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
192
+ text_embeds = self.text_embeddings(tokens[:, :, -1]).unsqueeze(-2)
193
+
194
+ audio_tokens = tokens[:, :, :-1] + (
195
+ self.config.audio_vocab_size * torch.arange(self.config.audio_num_codebooks, device=tokens.device)
196
+ )
197
+ audio_embeds = self.audio_embeddings(audio_tokens.view(-1)).reshape(
198
+ tokens.size(0), tokens.size(1), self.config.audio_num_codebooks, -1
199
+ )
200
+
201
+ return torch.cat([audio_embeds, text_embeds], dim=-2)