Image-Text-to-Text
Transformers
Safetensors
English
Helium1_VL_2B
custom_code
File size: 3,839 Bytes
1126ea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# pylint: disable=protected-access
"""Utils to handle CASA layers construction"""

from contextlib import contextmanager
from dataclasses import dataclass, fields
from typing import Any, Callable, Generic, TypeVar

import torch


def delta_w_factory(
    org_lin: torch.nn.Linear, new_lin: torch.nn.Linear
) -> Callable[[torch.Tensor], torch.Tensor]:
    """Factory for building linear op where the weights are the sum of two layers' weights"""

    def _delta_w_fwd(input: torch.Tensor) -> torch.Tensor:
        nonlocal org_lin, new_lin
        bias = None if org_lin.bias is None else org_lin.bias + new_lin.bias
        return torch.nn.functional.linear(input, org_lin.weight + new_lin.weight, bias)

    return _delta_w_fwd


@dataclass
class StreamingState:
    """Streaming State used by CASA layers at inference to save
    e.g. the offset, the KV Cache and other persistent states"""

    offset: int = 0

    def _is_valid_field(self, key: str) -> bool:
        return key in {x.name for x in fields(self)}

    def _init_field(self, key: str) -> None:
        """Init function for non-arggment dependent defauls"""
        assert self._is_valid_field(key)
        if key == "offset":
            self.offset = 0
        else:
            # for fields which should be set explicitly and cannot be auto-initialized
            setattr(self, key, None)

    def init(self) -> None:
        for key in [x.name for x in fields(self)]:
            self._init_field(key)

    def _reset_field(self, name: str) -> None:
        """Resets the given field"""
        self._init_field(name)

    def reset(self) -> None:
        for f in fields(self):
            self._reset_field(f.name)

    def _get_field(self, f: str) -> Any:
        """Get field and init if not"""
        assert self._is_valid_field(f)
        if getattr(self, f) is None:
            self._init_field(f)
        return getattr(self, f)

    def _set_field(self, f: str, value: Any) -> None:
        assert self._is_valid_field(f)
        setattr(self, f, value)


StreamingStateT = TypeVar("StreamingStateT", bound=StreamingState)


class StreamingModule(torch.nn.Module, Generic[StreamingStateT]):  # pylint: disable=abstract-method
    """Overrides Audiocraft's Streaming modules with additional small utils"""

    def __init__(self, state_class: type) -> None:
        torch.nn.Module.__init__(self)
        self.is_streaming: bool = False
        self.enable_viz: tuple[str, ...] = ()
        self._streaming_state: StreamingStateT = state_class()

    @property
    def streaming_state(self) -> StreamingStateT:
        return self._streaming_state

    def _apply_named_streaming(self, fn: Callable):
        """Apply function to all streaming modules"""
        for name, module in self.named_modules():
            if isinstance(module, StreamingModule):
                fn(name, module)

    def reset_streaming(self):
        """Reset the streaming state."""

        def _reset(_: str, module: StreamingModule):
            module._streaming_state.reset()

        self._apply_named_streaming(_reset)

    def _set_streaming(self, streaming: bool, viz: tuple[str, ...] = ()):
        """Set all streaming modules in streaming mode"""

        def _set_streaming(_, module: StreamingModule) -> None:
            module.is_streaming = streaming
            module.enable_viz = viz
            if streaming:
                module.streaming_state.init()

        self._apply_named_streaming(_set_streaming)

    @contextmanager
    def streaming(self, stream: bool = True, viz: tuple[str, ...] = ()):
        """Context manager to enter streaming mode. Reset streaming state on exit."""
        self._set_streaming(stream, viz)
        try:
            yield
        finally:
            self._set_streaming(False, ())
            self.reset_streaming()