LexaLCM_Pre0 / lcm /nn /incremental_state.py
Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
#
from typing import Dict, Optional, final
from fairseq2.nn.incremental_state import IncrementalState, IncrementalStateBag
from fairseq2.nn.transformer import FullAttentionState
from torch import Tensor
from torch.nn import Module
@final
class LCMIncrementalStateBag(IncrementalStateBag): # type: ignore
"""Holds the module states during incremental decoding."""
_module_states: Dict[Module, FullAttentionState] # type: ignore
def __init__(
self, max_num_steps: int, *, capacity_increment: Optional[int] = 16
) -> None:
super().__init__(
max_num_steps=max_num_steps, capacity_increment=capacity_increment
)
def reorder(self, new_order: Tensor) -> None:
"""Reorder the module states.
See :meth:`IncrementalState.reorder` for more information.
"""
# FIXME Deal with reordering diffusion state bags here
for state in self._module_states.values():
state.reorder(new_order)
def set_state(self, m: Module, state: IncrementalState) -> None:
"""Set the state of ``m``.
:param m: The module.
:param state: The state to store.
There is no current call to `set_state` when the bag
is frozen, but it's implemented here for completeness
"""
super().set_state(m, state)