karagmercola commited on
Commit
fe1f6dc
·
verified ·
1 Parent(s): e859897

Create generator.py

Browse files
Files changed (1) hide show
  1. generator.py +195 -0
generator.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from huggingface_hub import hf_hub_download, login
7
+ from models import Model
8
+ from moshi.models import loaders
9
+ from tokenizers.processors import TemplateProcessing
10
+ from transformers import AutoTokenizer
11
+ from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
12
+
13
+ # Disable torch compile feature to avoid triton error
14
+ torch._dynamo.config.suppress_errors = True
15
+
16
+ @dataclass
17
+ class Segment:
18
+ speaker: int
19
+ text: str
20
+ # (num_samples,), sample_rate = 24_000
21
+ audio: torch.Tensor
22
+
23
+
24
+ def load_llama3_tokenizer():
25
+ """
26
+ https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
27
+ """
28
+ tokenizer_name = "meta-llama/Llama-3.2-1B"
29
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
30
+ bos = tokenizer.bos_token
31
+ eos = tokenizer.eos_token
32
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
33
+ single=f"{bos}:0 $A:0 {eos}:0",
34
+ pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
35
+ special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
36
+ )
37
+
38
+ return tokenizer
39
+
40
+
41
+ class Generator:
42
+ def __init__(
43
+ self,
44
+ model: Model,
45
+ ):
46
+ self._model = model
47
+ self._model.setup_caches(1)
48
+
49
+ self._text_tokenizer = load_llama3_tokenizer()
50
+
51
+ device = next(model.parameters()).device
52
+ mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
53
+ mimi = loaders.get_mimi(mimi_weight, device=device)
54
+ mimi.set_num_codebooks(32)
55
+ self._audio_tokenizer = mimi
56
+
57
+ self._watermarker = load_watermarker(device=device)
58
+
59
+ self.sample_rate = mimi.sample_rate
60
+ self.device = device
61
+
62
+ def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
63
+ frame_tokens = []
64
+ frame_masks = []
65
+
66
+ text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
67
+ text_frame = torch.zeros(len(text_tokens), 33).long()
68
+ text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
69
+ text_frame[:, -1] = torch.tensor(text_tokens)
70
+ text_frame_mask[:, -1] = True
71
+
72
+ frame_tokens.append(text_frame.to(self.device))
73
+ frame_masks.append(text_frame_mask.to(self.device))
74
+
75
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
76
+
77
+ def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
78
+ frame_tokens = []
79
+ frame_masks = []
80
+
81
+ # (K, T)
82
+ audio = audio.to(self.device)
83
+ audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
84
+ # add EOS frame
85
+ eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
86
+ audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
87
+
88
+ audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
89
+ audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
90
+ audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
91
+ audio_frame_mask[:, :-1] = True
92
+
93
+ frame_tokens.append(audio_frame)
94
+ frame_masks.append(audio_frame_mask)
95
+
96
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
97
+
98
+ def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
99
+ """
100
+ Returns:
101
+ (seq_len, 33), (seq_len, 33)
102
+ """
103
+ text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
104
+ audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
105
+
106
+ return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
107
+
108
+ @torch.inference_mode()
109
+ def generate(
110
+ self,
111
+ text: str,
112
+ speaker: int,
113
+ context: List[Segment],
114
+ max_audio_length_ms: float = 90_000,
115
+ temperature: float = 0.9,
116
+ topk: int = 50,
117
+ ) -> torch.Tensor:
118
+ self._model.reset_caches()
119
+
120
+ max_audio_frames = int(max_audio_length_ms / 80)
121
+ tokens, tokens_mask = [], []
122
+ for segment in context:
123
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
124
+ tokens.append(segment_tokens)
125
+ tokens_mask.append(segment_tokens_mask)
126
+
127
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
128
+ tokens.append(gen_segment_tokens)
129
+ tokens_mask.append(gen_segment_tokens_mask)
130
+
131
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
132
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
133
+
134
+ samples = []
135
+ curr_tokens = prompt_tokens.unsqueeze(0)
136
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
137
+ curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
138
+
139
+ max_seq_len = 2048 - max_audio_frames
140
+ if curr_tokens.size(1) >= max_seq_len:
141
+ raise ValueError(f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}")
142
+
143
+ for _ in range(max_audio_frames):
144
+ sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
145
+ if torch.all(sample == 0):
146
+ break # eos
147
+
148
+ samples.append(sample)
149
+
150
+ curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
151
+ curr_tokens_mask = torch.cat(
152
+ [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
153
+ ).unsqueeze(1)
154
+ curr_pos = curr_pos[:, -1:] + 1
155
+
156
+ audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
157
+
158
+ # This applies an imperceptible watermark to identify audio as AI-generated.
159
+ # Watermarking ensures transparency, dissuades misuse, and enables traceability.
160
+ # Please be a responsible AI citizen and keep the watermarking in place.
161
+ # If using CSM 1B in another application, use your own private key and keep it secret.
162
+ audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
163
+ audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
164
+
165
+ return audio
166
+
167
+
168
+ def load_csm_1b(device: str = "cuda") -> Generator:
169
+ """
170
+ Load the CSM-1B model from Hugging Face Hub.
171
+
172
+ Args:
173
+ device: Device to run the model on (cuda or cpu)
174
+
175
+ Returns:
176
+ Generator: Generator object to create audio from text
177
+ """
178
+ try:
179
+ # In ZeroGPU, CUDA should not be initialized in the main process
180
+ # Only move the model to GPU when called in a function with the @spaces.GPU decorator
181
+ print(f"Loading model on {device}")
182
+ if 'cuda' in device and not torch.cuda.is_initialized():
183
+ # Use CPU for the main process
184
+ model = Model.from_pretrained("sesame/csm-1b")
185
+ else:
186
+ model = Model.from_pretrained("sesame/csm-1b")
187
+ model.to(device=device, dtype=torch.bfloat16)
188
+
189
+ generator = Generator(model)
190
+ return generator
191
+ except Exception as e:
192
+ print(f"Error loading model: {e}")
193
+ print("Please check if you are logged in to Hugging Face Hub.")
194
+ print("You may need to request access to the model at: https://huggingface.co/sesame/csm-1b")
195
+ raise e