import numpy as np from torchcodec.decoders import AudioDecoder as _AudioDecoder class AudioDecoder(_AudioDecoder): def __getitem__(self, key: str): if key == "array": y = self.get_all_samples().data.cpu().numpy() return np.mean(y, axis=tuple(range(y.ndim - 1))) if y.ndim > 1 else y elif key == "sampling_rate": return self.get_samples_played_in_range(0, 0).sample_rate elif hasattr(super(), "__getitem__"): return super().__getitem__(key) else: raise TypeError("'torchcodec.decoders.AudioDecoder' object is not subscriptable")