| from __future__ import annotations
|
|
|
| import torch
|
|
|
| import asyncio
|
| from queue import Queue
|
| from typing import TYPE_CHECKING, Optional
|
|
|
|
|
| from transformers.generation import BaseStreamer
|
|
|
|
|
| class AudioStreamer(BaseStreamer):
|
| """
|
| Audio streamer that stores audio chunks in queues for each sample in the batch.
|
| This allows streaming audio generation for multiple samples simultaneously.
|
|
|
| Parameters:
|
| batch_size (`int`):
|
| The batch size for generation
|
| stop_signal (`any`, *optional*):
|
| The signal to put in the queue when generation ends. Defaults to None.
|
| timeout (`float`, *optional*):
|
| The timeout for the audio queue. If `None`, the queue will block indefinitely.
|
| """
|
|
|
| def __init__(
|
| self,
|
| batch_size: int,
|
| stop_signal: Optional[any] = None,
|
| timeout: Optional[float] = None,
|
| ):
|
| self.batch_size = batch_size
|
| self.stop_signal = stop_signal
|
| self.timeout = timeout
|
|
|
|
|
| self.audio_queues = [Queue() for _ in range(batch_size)]
|
| self.finished_flags = [False for _ in range(batch_size)]
|
| self.sample_indices_map = {}
|
|
|
| def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
| """
|
| Receives audio chunks and puts them in the appropriate queues.
|
|
|
| Args:
|
| audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks
|
| sample_indices: Tensor indicating which samples these chunks belong to
|
| """
|
| for i, sample_idx in enumerate(sample_indices):
|
| idx = sample_idx.item()
|
| if idx < self.batch_size and not self.finished_flags[idx]:
|
|
|
| audio_chunk = audio_chunks[i].detach().cpu()
|
| self.audio_queues[idx].put(audio_chunk, timeout=self.timeout)
|
|
|
| def end(self, sample_indices: Optional[torch.Tensor] = None):
|
| """
|
| Signals the end of generation for specified samples or all samples.
|
|
|
| Args:
|
| sample_indices: Optional tensor of sample indices to end. If None, ends all.
|
| """
|
| if sample_indices is None:
|
|
|
| for idx in range(self.batch_size):
|
| if not self.finished_flags[idx]:
|
| self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
| self.finished_flags[idx] = True
|
| else:
|
|
|
| for sample_idx in sample_indices:
|
| idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx
|
| if idx < self.batch_size and not self.finished_flags[idx]:
|
| self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout)
|
| self.finished_flags[idx] = True
|
|
|
| def __iter__(self):
|
| """Returns an iterator over the batch of audio streams."""
|
| return AudioBatchIterator(self)
|
|
|
| def get_stream(self, sample_idx: int):
|
| """Get the audio stream for a specific sample."""
|
| if sample_idx >= self.batch_size:
|
| raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
| return AudioSampleIterator(self, sample_idx)
|
|
|
|
|
| class AudioSampleIterator:
|
| """Iterator for a single audio stream from the batch."""
|
|
|
| def __init__(self, streamer: AudioStreamer, sample_idx: int):
|
| self.streamer = streamer
|
| self.sample_idx = sample_idx
|
|
|
| def __iter__(self):
|
| return self
|
|
|
| def __next__(self):
|
| value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout)
|
| if value == self.streamer.stop_signal:
|
| raise StopIteration()
|
| return value
|
|
|
|
|
| class AudioBatchIterator:
|
| """Iterator that yields audio chunks for all samples in the batch."""
|
|
|
| def __init__(self, streamer: AudioStreamer):
|
| self.streamer = streamer
|
| self.active_samples = set(range(streamer.batch_size))
|
|
|
| def __iter__(self):
|
| return self
|
|
|
| def __next__(self):
|
| if not self.active_samples:
|
| raise StopIteration()
|
|
|
| batch_chunks = {}
|
| samples_to_remove = set()
|
|
|
|
|
| for idx in self.active_samples:
|
| try:
|
| value = self.streamer.audio_queues[idx].get(block=False)
|
| if value == self.streamer.stop_signal:
|
| samples_to_remove.add(idx)
|
| else:
|
| batch_chunks[idx] = value
|
| except:
|
|
|
| pass
|
|
|
|
|
| self.active_samples -= samples_to_remove
|
|
|
| if batch_chunks:
|
| return batch_chunks
|
| elif self.active_samples:
|
|
|
|
|
| import time
|
| time.sleep(0.01)
|
| return self.__next__()
|
| else:
|
| raise StopIteration()
|
|
|
|
|
| class AsyncAudioStreamer(AudioStreamer):
|
| """
|
| Async version of AudioStreamer for use in async contexts.
|
| """
|
|
|
| def __init__(
|
| self,
|
| batch_size: int,
|
| stop_signal: Optional[any] = None,
|
| timeout: Optional[float] = None,
|
| ):
|
| super().__init__(batch_size, stop_signal, timeout)
|
|
|
| self.audio_queues = [asyncio.Queue() for _ in range(batch_size)]
|
| self.loop = asyncio.get_running_loop()
|
|
|
| def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor):
|
| """Put audio chunks in the appropriate async queues."""
|
| for i, sample_idx in enumerate(sample_indices):
|
| idx = sample_idx.item()
|
| if idx < self.batch_size and not self.finished_flags[idx]:
|
| audio_chunk = audio_chunks[i].detach().cpu()
|
| self.loop.call_soon_threadsafe(
|
| self.audio_queues[idx].put_nowait, audio_chunk
|
| )
|
|
|
| def end(self, sample_indices: Optional[torch.Tensor] = None):
|
| """Signal the end of generation for specified samples."""
|
| if sample_indices is None:
|
| indices_to_end = range(self.batch_size)
|
| else:
|
| indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices]
|
|
|
| for idx in indices_to_end:
|
| if idx < self.batch_size and not self.finished_flags[idx]:
|
| self.loop.call_soon_threadsafe(
|
| self.audio_queues[idx].put_nowait, self.stop_signal
|
| )
|
| self.finished_flags[idx] = True
|
|
|
| async def get_stream(self, sample_idx: int):
|
| """Get async iterator for a specific sample's audio stream."""
|
| if sample_idx >= self.batch_size:
|
| raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}")
|
|
|
| while True:
|
| value = await self.audio_queues[sample_idx].get()
|
| if value == self.stop_signal:
|
| break
|
| yield value
|
|
|
| def __aiter__(self):
|
| """Returns an async iterator over all audio streams."""
|
| return AsyncAudioBatchIterator(self)
|
|
|
|
|
| class AsyncAudioBatchIterator:
|
| """Async iterator for batch audio streaming."""
|
|
|
| def __init__(self, streamer: AsyncAudioStreamer):
|
| self.streamer = streamer
|
| self.active_samples = set(range(streamer.batch_size))
|
|
|
| def __aiter__(self):
|
| return self
|
|
|
| async def __anext__(self):
|
| if not self.active_samples:
|
| raise StopAsyncIteration()
|
|
|
| batch_chunks = {}
|
| samples_to_remove = set()
|
|
|
|
|
| tasks = {
|
| idx: asyncio.create_task(self._get_chunk(idx))
|
| for idx in self.active_samples
|
| }
|
|
|
|
|
| done, pending = await asyncio.wait(
|
| tasks.values(),
|
| return_when=asyncio.FIRST_COMPLETED,
|
| timeout=self.streamer.timeout
|
| )
|
|
|
|
|
| for task in pending:
|
| task.cancel()
|
|
|
|
|
| for idx, task in tasks.items():
|
| if task in done:
|
| try:
|
| value = await task
|
| if value == self.streamer.stop_signal:
|
| samples_to_remove.add(idx)
|
| else:
|
| batch_chunks[idx] = value
|
| except asyncio.CancelledError:
|
| pass
|
|
|
| self.active_samples -= samples_to_remove
|
|
|
| if batch_chunks:
|
| return batch_chunks
|
| elif self.active_samples:
|
|
|
| return await self.__anext__()
|
| else:
|
| raise StopAsyncIteration()
|
|
|
| async def _get_chunk(self, idx):
|
| """Helper to get a chunk from a specific queue."""
|
| return await self.streamer.audio_queues[idx].get() |