|
|
|
|
|
"""Utils to handle CASA layers construction""" |
|
|
|
|
|
from contextlib import contextmanager |
|
|
from dataclasses import dataclass, fields |
|
|
from typing import Any, Callable, Generic, TypeVar |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def delta_w_factory( |
|
|
org_lin: torch.nn.Linear, new_lin: torch.nn.Linear |
|
|
) -> Callable[[torch.Tensor], torch.Tensor]: |
|
|
"""Factory for building linear op where the weights are the sum of two layers' weights""" |
|
|
|
|
|
def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor: |
|
|
nonlocal org_lin, new_lin |
|
|
bias = None if org_lin.bias is None else org_lin.bias + new_lin.bias |
|
|
return torch.nn.functional.linear(input, org_lin.weight + new_lin.weight, bias) |
|
|
|
|
|
return _delta_w_fwd |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class StreamingState: |
|
|
"""Streaming State used by CASA layers at inference to save |
|
|
e.g. the offset, the KV Cache and other persistent states""" |
|
|
|
|
|
offset: int = 0 |
|
|
|
|
|
def _is_valid_field(self, key: str) -> bool: |
|
|
return key in {x.name for x in fields(self)} |
|
|
|
|
|
def _init_field(self, key: str) -> None: |
|
|
"""Init function for non-arggment dependent defauls""" |
|
|
assert self._is_valid_field(key) |
|
|
if key == "offset": |
|
|
self.offset = 0 |
|
|
else: |
|
|
|
|
|
setattr(self, key, None) |
|
|
|
|
|
def init(self) -> None: |
|
|
for key in [x.name for x in fields(self)]: |
|
|
self._init_field(key) |
|
|
|
|
|
def _reset_field(self, name: str) -> None: |
|
|
"""Resets the given field""" |
|
|
self._init_field(name) |
|
|
|
|
|
def reset(self) -> None: |
|
|
for f in fields(self): |
|
|
self._reset_field(f.name) |
|
|
|
|
|
def _get_field(self, f: str) -> Any: |
|
|
"""Get field and init if not""" |
|
|
assert self._is_valid_field(f) |
|
|
if getattr(self, f) is None: |
|
|
self._init_field(f) |
|
|
return getattr(self, f) |
|
|
|
|
|
def _set_field(self, f: str, value: Any) -> None: |
|
|
assert self._is_valid_field(f) |
|
|
setattr(self, f, value) |
|
|
|
|
|
|
|
|
StreamingStateT = TypeVar("StreamingStateT", bound=StreamingState) |
|
|
|
|
|
|
|
|
class StreamingModule(torch.nn.Module, Generic[StreamingStateT]): |
|
|
"""Overrides Audiocraft's Streaming modules with additional small utils""" |
|
|
|
|
|
def __init__(self, state_class: type) -> None: |
|
|
torch.nn.Module.__init__(self) |
|
|
self.is_streaming: bool = False |
|
|
self.enable_viz: tuple[str, ...] = () |
|
|
self._streaming_state: StreamingStateT = state_class() |
|
|
|
|
|
@property |
|
|
def streaming_state(self) -> StreamingStateT: |
|
|
return self._streaming_state |
|
|
|
|
|
def _apply_named_streaming(self, fn: Callable): |
|
|
"""Apply function to all streaming modules""" |
|
|
for name, module in self.named_modules(): |
|
|
if isinstance(module, StreamingModule): |
|
|
fn(name, module) |
|
|
|
|
|
def reset_streaming(self): |
|
|
"""Reset the streaming state.""" |
|
|
|
|
|
def _reset(_: str, module: StreamingModule): |
|
|
module._streaming_state.reset() |
|
|
|
|
|
self._apply_named_streaming(_reset) |
|
|
|
|
|
def _set_streaming(self, streaming: bool, viz: tuple[str, ...] = ()): |
|
|
"""Set all streaming modules in streaming mode""" |
|
|
|
|
|
def _set_streaming(_, module: StreamingModule) -> None: |
|
|
module.is_streaming = streaming |
|
|
module.enable_viz = viz |
|
|
if streaming: |
|
|
module.streaming_state.init() |
|
|
|
|
|
self._apply_named_streaming(_set_streaming) |
|
|
|
|
|
@contextmanager |
|
|
def streaming(self, stream: bool = True, viz: tuple[str, ...] = ()): |
|
|
"""Context manager to enter streaming mode. Reset streaming state on exit.""" |
|
|
self._set_streaming(stream, viz) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
self._set_streaming(False, ()) |
|
|
self.reset_streaming() |
|
|
|