Image-Text-to-Text
Transformers
Safetensors
English
Helium1_VL_2B
custom_code
Helium1-VL-2B / utils.py
ameroyer's picture
Super-squash branch 'main' using huggingface_hub
1126ea7 verified
# pylint: disable=protected-access
"""Utils to handle CASA layers construction"""
from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import Any, Callable, Generic, TypeVar
import torch
def delta_w_factory(
org_lin: torch.nn.Linear, new_lin: torch.nn.Linear
) -> Callable[[torch.Tensor], torch.Tensor]:
"""Factory for building linear op where the weights are the sum of two layers' weights"""
def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
nonlocal org_lin, new_lin
bias = None if org_lin.bias is None else org_lin.bias + new_lin.bias
return torch.nn.functional.linear(input, org_lin.weight + new_lin.weight, bias)
return _delta_w_fwd
@dataclass
class StreamingState:
"""Streaming State used by CASA layers at inference to save
e.g. the offset, the KV Cache and other persistent states"""
offset: int = 0
def _is_valid_field(self, key: str) -> bool:
return key in {x.name for x in fields(self)}
def _init_field(self, key: str) -> None:
"""Init function for non-arggment dependent defauls"""
assert self._is_valid_field(key)
if key == "offset":
self.offset = 0
else:
# for fields which should be set explicitly and cannot be auto-initialized
setattr(self, key, None)
def init(self) -> None:
for key in [x.name for x in fields(self)]:
self._init_field(key)
def _reset_field(self, name: str) -> None:
"""Resets the given field"""
self._init_field(name)
def reset(self) -> None:
for f in fields(self):
self._reset_field(f.name)
def _get_field(self, f: str) -> Any:
"""Get field and init if not"""
assert self._is_valid_field(f)
if getattr(self, f) is None:
self._init_field(f)
return getattr(self, f)
def _set_field(self, f: str, value: Any) -> None:
assert self._is_valid_field(f)
setattr(self, f, value)
StreamingStateT = TypeVar("StreamingStateT", bound=StreamingState)
class StreamingModule(torch.nn.Module, Generic[StreamingStateT]): # pylint: disable=abstract-method
"""Overrides Audiocraft's Streaming modules with additional small utils"""
def __init__(self, state_class: type) -> None:
torch.nn.Module.__init__(self)
self.is_streaming: bool = False
self.enable_viz: tuple[str, ...] = ()
self._streaming_state: StreamingStateT = state_class()
@property
def streaming_state(self) -> StreamingStateT:
return self._streaming_state
def _apply_named_streaming(self, fn: Callable):
"""Apply function to all streaming modules"""
for name, module in self.named_modules():
if isinstance(module, StreamingModule):
fn(name, module)
def reset_streaming(self):
"""Reset the streaming state."""
def _reset(_: str, module: StreamingModule):
module._streaming_state.reset()
self._apply_named_streaming(_reset)
def _set_streaming(self, streaming: bool, viz: tuple[str, ...] = ()):
"""Set all streaming modules in streaming mode"""
def _set_streaming(_, module: StreamingModule) -> None:
module.is_streaming = streaming
module.enable_viz = viz
if streaming:
module.streaming_state.init()
self._apply_named_streaming(_set_streaming)
@contextmanager
def streaming(self, stream: bool = True, viz: tuple[str, ...] = ()):
"""Context manager to enter streaming mode. Reset streaming state on exit."""
self._set_streaming(stream, viz)
try:
yield
finally:
self._set_streaming(False, ())
self.reset_streaming()