| from transformers.generation.streamers import BaseStreamer |
| from typing import Optional |
| from transformers import MusicgenForConditionalGeneration, set_seed |
| from queue import Queue |
| import numpy as np |
| import torch |
| import time |
| class MusicgenStreamer(BaseStreamer): |
| def __init__( |
| self, |
| model: MusicgenForConditionalGeneration, |
| device: Optional[str] = None, |
| play_steps: Optional[int] = 10, |
| duration: Optional[float] = 30, |
| stride: Optional[int] = None, |
| timeout: Optional[float] = None, |
| initial_streamer: Optional[bool] = False, |
| audio_context: Optional[list] = None, |
| ): |
| """ |
| Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is |
| useful for applications that benefit from accessing the generated audio in a non-blocking way (e.g. in an interactive |
| Gradio demo). |
| |
| Parameters: |
| model (`MusicgenForConditionalGeneration`): |
| The MusicGen model used to generate the audio waveform. |
| device (`str`, *optional*): |
| The torch device on which to run the computation. If `None`, will default to the device of the model. |
| play_steps (`int`, *optional*, defaults to 10): |
| The number of generation steps with which to return the generated audio array. Using fewer steps will |
| mean the first chunk is ready faster, but will require more codec decoding steps overall. This value |
| should be tuned to your device and latency requirements. |
| stride (`int`, *optional*): |
| The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces |
| the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to |
| play_steps // 6 in the audio space. |
| timeout (`int`, *optional*): |
| The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions |
| in `.generate()`, when it is called in a separate thread. |
| """ |
| self.decoder = model.decoder |
| self.audio_encoder = model.audio_encoder |
| self.generation_config = model.generation_config |
| self.device = device if device is not None else model.device |
|
|
| |
| self.play_steps = play_steps |
| if stride is not None: |
| self.stride = stride |
| else: |
| hop_length = np.prod(self.audio_encoder.config.upsampling_ratios) |
| self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6 |
| self.token_cache = None |
| self.to_yield = 0 |
| self.audio_context = audio_context |
| if self.audio_context is not None: |
| self.to_yield += len(self.audio_context) |
| |
| self.audio_queue = Queue() |
| self.stop_signal = None |
| self.timeout = timeout |
| self.initial_streamer = initial_streamer |
| self.streamer_start = time.time() |
| print(f"duration {duration}, playsteps {play_steps}, {int(duration*0.25)}") |
| required_buffer_estimate = int(duration*0.55+(play_steps/self.audio_encoder.config.frame_rate)) |
| self.buffer_length = min(13, required_buffer_estimate) |
| |
| if self.initial_streamer: |
| print(f"Initial Streamer: {self.initial_streamer}, Buffer Estimate: {required_buffer_estimate}, buffer length: {self.buffer_length}s") |
| def apply_delay_pattern_mask(self, input_ids): |
| |
| _, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask( |
| input_ids[:, :1], |
| pad_token_id=self.generation_config.decoder_start_token_id, |
| max_length=input_ids.shape[-1], |
| ) |
| |
| input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask) |
|
|
| |
| input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape( |
| 1, self.decoder.num_codebooks, -1 |
| ) |
|
|
| |
| input_ids = input_ids[None, ...] |
|
|
| |
| input_ids = input_ids.to(self.audio_encoder.device) |
|
|
| output_values = self.audio_encoder.decode( |
| input_ids, |
| audio_scales=[None], |
| ) |
| audio_values = output_values.audio_values[0, 0] |
| return audio_values.cpu().float().numpy() |
|
|
| def put(self, value): |
| batch_size = value.shape[0] // self.decoder.num_codebooks |
| |
| if batch_size > 1: |
| raise ValueError("MusicgenStreamer only supports batch size 1") |
|
|
| if self.token_cache is None: |
| self.token_cache = value |
| else: |
| self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1) |
|
|
| |
|
|
| if self.token_cache.shape[-1] % self.play_steps == 0: |
| audio_values = self.apply_delay_pattern_mask(self.token_cache) |
| self.on_finalized_audio(audio_values[self.to_yield : -self.stride]) |
| self.to_yield += len(audio_values) - self.to_yield - self.stride |
|
|
| def end(self): |
| """Flushes any remaining cache and appends the stop symbol.""" |
| if self.token_cache is not None: |
| audio_values = self.apply_delay_pattern_mask(self.token_cache) |
| else: |
| audio_values = np.zeros(self.to_yield) |
|
|
| self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True) |
|
|
| def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False): |
| """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue.""" |
| self.audio_queue.put(audio, timeout=self.timeout) |
| if stream_end: |
| self.audio_queue.put(self.stop_signal, timeout=self.timeout) |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| if self.initial_streamer == True: |
| while True: |
| delay = time.time() - self.streamer_start |
| if delay >= self.buffer_length: |
| value = self.audio_queue.get(timeout=self.timeout) |
| if not isinstance(value, np.ndarray) and value == self.stop_signal: |
| raise StopIteration() |
| else: |
| return value |
| else: |
| time.sleep(0.01) |
|
|
| else: |
| value = self.audio_queue.get(timeout=self.timeout) |
| if not isinstance(value, np.ndarray) and value == self.stop_signal: |
| raise StopIteration() |
| else: |
| return value |