File size: 3,839 Bytes
1126ea7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
# pylint: disable=protected-access
"""Utils to handle CASA layers construction"""
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import Any, Callable, Generic, TypeVar
import torch
def delta_w_factory(
org_lin: torch.nn.Linear, new_lin: torch.nn.Linear
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Factory for building linear op where the weights are the sum of two layers' weights"""
def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
nonlocal org_lin, new_lin
bias = None if org_lin.bias is None else org_lin.bias + new_lin.bias
return torch.nn.functional.linear(input, org_lin.weight + new_lin.weight, bias)
return _delta_w_fwd
@dataclass
class StreamingState:
"""Streaming State used by CASA layers at inference to save
e.g. the offset, the KV Cache and other persistent states"""
offset: int = 0
def _is_valid_field(self, key: str) -> bool:
return key in {x.name for x in fields(self)}
def _init_field(self, key: str) -> None:
"""Init function for non-arggment dependent defauls"""
assert self._is_valid_field(key)
if key == "offset":
self.offset = 0
else:
# for fields which should be set explicitly and cannot be auto-initialized
setattr(self, key, None)
def init(self) -> None:
for key in [x.name for x in fields(self)]:
self._init_field(key)
def _reset_field(self, name: str) -> None:
"""Resets the given field"""
self._init_field(name)
def reset(self) -> None:
for f in fields(self):
self._reset_field(f.name)
def _get_field(self, f: str) -> Any:
"""Get field and init if not"""
assert self._is_valid_field(f)
if getattr(self, f) is None:
self._init_field(f)
return getattr(self, f)
def _set_field(self, f: str, value: Any) -> None:
assert self._is_valid_field(f)
setattr(self, f, value)
StreamingStateT = TypeVar("StreamingStateT", bound=StreamingState)
class StreamingModule(torch.nn.Module, Generic[StreamingStateT]): # pylint: disable=abstract-method
"""Overrides Audiocraft's Streaming modules with additional small utils"""
def __init__(self, state_class: type) -> None:
torch.nn.Module.__init__(self)
self.is_streaming: bool = False
self.enable_viz: tuple[str, ...] = ()
self._streaming_state: StreamingStateT = state_class()
@property
def streaming_state(self) -> StreamingStateT:
return self._streaming_state
def _apply_named_streaming(self, fn: Callable):
"""Apply function to all streaming modules"""
for name, module in self.named_modules():
if isinstance(module, StreamingModule):
fn(name, module)
def reset_streaming(self):
"""Reset the streaming state."""
def _reset(_: str, module: StreamingModule):
module._streaming_state.reset()
self._apply_named_streaming(_reset)
def _set_streaming(self, streaming: bool, viz: tuple[str, ...] = ()):
"""Set all streaming modules in streaming mode"""
def _set_streaming(_, module: StreamingModule) -> None:
module.is_streaming = streaming
module.enable_viz = viz
if streaming:
module.streaming_state.init()
self._apply_named_streaming(_set_streaming)
@contextmanager
def streaming(self, stream: bool = True, viz: tuple[str, ...] = ()):
"""Context manager to enter streaming mode. Reset streaming state on exit."""
self._set_streaming(stream, viz)
try:
yield
finally:
self._set_streaming(False, ())
self.reset_streaming()
|