# pylint: disable=protected-access """Utils to handle CASA layers construction""" from contextlib import contextmanager from dataclasses import dataclass, fields from typing import Any, Callable, Generic, TypeVar import torch def delta_w_factory( org_lin: torch.nn.Linear, new_lin: torch.nn.Linear ) -> Callable[[torch.Tensor], torch.Tensor]: """Factory for building linear op where the weights are the sum of two layers' weights""" def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor: nonlocal org_lin, new_lin bias = None if org_lin.bias is None else org_lin.bias + new_lin.bias return torch.nn.functional.linear(input, org_lin.weight + new_lin.weight, bias) return _delta_w_fwd @dataclass class StreamingState: """Streaming State used by CASA layers at inference to save e.g. the offset, the KV Cache and other persistent states""" offset: int = 0 def _is_valid_field(self, key: str) -> bool: return key in {x.name for x in fields(self)} def _init_field(self, key: str) -> None: """Init function for non-arggment dependent defauls""" assert self._is_valid_field(key) if key == "offset": self.offset = 0 else: # for fields which should be set explicitly and cannot be auto-initialized setattr(self, key, None) def init(self) -> None: for key in [x.name for x in fields(self)]: self._init_field(key) def _reset_field(self, name: str) -> None: """Resets the given field""" self._init_field(name) def reset(self) -> None: for f in fields(self): self._reset_field(f.name) def _get_field(self, f: str) -> Any: """Get field and init if not""" assert self._is_valid_field(f) if getattr(self, f) is None: self._init_field(f) return getattr(self, f) def _set_field(self, f: str, value: Any) -> None: assert self._is_valid_field(f) setattr(self, f, value) StreamingStateT = TypeVar("StreamingStateT", bound=StreamingState) class StreamingModule(torch.nn.Module, Generic[StreamingStateT]): # pylint: disable=abstract-method """Overrides Audiocraft's Streaming modules with additional small utils""" def __init__(self, state_class: type) -> None: torch.nn.Module.__init__(self) self.is_streaming: bool = False self.enable_viz: tuple[str, ...] = () self._streaming_state: StreamingStateT = state_class() @property def streaming_state(self) -> StreamingStateT: return self._streaming_state def _apply_named_streaming(self, fn: Callable): """Apply function to all streaming modules""" for name, module in self.named_modules(): if isinstance(module, StreamingModule): fn(name, module) def reset_streaming(self): """Reset the streaming state.""" def _reset(_: str, module: StreamingModule): module._streaming_state.reset() self._apply_named_streaming(_reset) def _set_streaming(self, streaming: bool, viz: tuple[str, ...] = ()): """Set all streaming modules in streaming mode""" def _set_streaming(_, module: StreamingModule) -> None: module.is_streaming = streaming module.enable_viz = viz if streaming: module.streaming_state.init() self._apply_named_streaming(_set_streaming) @contextmanager def streaming(self, stream: bool = True, viz: tuple[str, ...] = ()): """Context manager to enter streaming mode. Reset streaming state on exit.""" self._set_streaming(stream, viz) try: yield finally: self._set_streaming(False, ()) self.reset_streaming()