Spaces:
Runtime error
Runtime error
Add simulstreaming_whisper module, update requirements, improve Dockerfile and model handling
d860e14 | from .whisper.decoding import PyTorchInference | |
| # extention of PyTorchInference for beam search | |
| class BeamPyTorchInference(PyTorchInference): | |
| def _kv_modules(self): | |
| key_modules = [block.attn.key.cache_id for block in self.model.decoder.blocks] | |
| value_modules = [block.attn.value.cache_id for block in self.model.decoder.blocks] | |
| return key_modules + value_modules | |
| def rearrange_kv_cache(self, source_indices): | |
| if source_indices != list(range(len(source_indices))): | |
| for module_cache_id in self._kv_modules(): | |
| self.kv_cache[module_cache_id] = self.kv_cache[module_cache_id][source_indices].detach() | |
| from torch import Tensor | |
| def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor: | |
| return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache) |