File size: 5,671 Bytes
b4bbb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from pathlib import Path
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.torch_utilities import (
    load_pretrained_model, merge_matched_keys, create_mask_from_length,
    loss_with_mask, create_alignment_path
)


class LoadPretrainedBase(nn.Module):
    def process_state_dict(
        self, model_dict: dict[str, torch.Tensor],
        state_dict: dict[str, torch.Tensor]
    ):
        """
        Custom processing functions of each model that transforms `state_dict` loaded from 
        checkpoints to the state that can be used in `load_state_dict`.
        Use `merge_mathced_keys` to update parameters with matched names and shapes by 
        default.  

        Args
            model_dict:
                The state dict of the current model, which is going to load pretrained parameters
            state_dict:
                A dictionary of parameters from a pre-trained model.

            Returns:
                dict[str, torch.Tensor]:
                    The updated state dict, where parameters with matched keys and shape are 
                    updated with values in `state_dict`.      
        """
        state_dict = merge_matched_keys(model_dict, state_dict)
        return state_dict

    def load_pretrained(self, ckpt_path: str | Path):
        load_pretrained_model(
            self, ckpt_path, state_dict_process_fn=self.process_state_dict
        )


class CountParamsBase(nn.Module):
    def count_params(self):
        num_params = 0
        trainable_params = 0
        for param in self.parameters():
            num_params += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        return num_params, trainable_params


class SaveTrainableParamsBase(nn.Module):
    @property
    def param_names_to_save(self):
        names = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                names.append(name)
        for name, _ in self.named_buffers():
            names.append(name)
        return names

    def load_state_dict(self, state_dict, strict=True):
        for key in self.param_names_to_save:
            if key not in state_dict:
                raise Exception(
                    f"{key} not found in either pre-trained models (e.g. BERT)"
                    " or resumed checkpoints (e.g. epoch_40/model.pt)"
                )
        return super().load_state_dict(state_dict, strict)


class DurationAdapterMixin:
    def __init__(
        self,
        latent_token_rate: int,
        offset: float = 1.0,
        frame_resolution: float | None = None
    ):
        self.latent_token_rate = latent_token_rate
        self.offset = offset
        self.frame_resolution = frame_resolution

    def get_global_duration_loss(
        self,
        pred: torch.Tensor,
        latent_mask: torch.Tensor,
        reduce: bool = True,
    ):
        target = torch.log(
            latent_mask.sum(1) / self.latent_token_rate + self.offset
        )
        loss = F.mse_loss(target, pred, reduction="mean" if reduce else "none")
        return loss

    def get_local_duration_loss(
        self, ground_truth: torch.Tensor, pred: torch.Tensor,
        mask: torch.Tensor, is_time_aligned: Sequence[bool], reduce: bool
    ):
        n_frames = torch.round(ground_truth / self.frame_resolution)
        target = torch.log(n_frames + self.offset)
        loss = loss_with_mask(
            (target - pred)**2,
            mask,
            reduce=False,
        )
        loss *= is_time_aligned
        if reduce:
            if is_time_aligned.sum().item() == 0:
                loss *= 0.0
                loss = loss.mean()
            else:
                loss = loss.sum() / is_time_aligned.sum()

        return loss

    def prepare_local_duration(self, pred: torch.Tensor, mask: torch.Tensor):
        pred = torch.exp(pred) * mask
        pred = torch.ceil(pred) - self.offset
        pred *= self.frame_resolution
        return pred

    def prepare_global_duration(
        self,
        global_pred: torch.Tensor,
        local_pred: torch.Tensor,
        is_time_aligned: Sequence[bool],
        use_local: bool = True,
    ):
        """
        global_pred: predicted duration value, processed by logarithmic and offset
        local_pred: predicted latent length 
        """
        global_pred = torch.exp(global_pred) - self.offset
        result = global_pred
        # avoid error accumulation for each frame
        if use_local:
            pred_from_local = torch.round(local_pred * self.latent_token_rate)
            pred_from_local = pred_from_local.sum(1) / self.latent_token_rate
            result[is_time_aligned] = pred_from_local[is_time_aligned]

        return result

    def expand_by_duration(
        self,
        x: torch.Tensor,
        content_mask: torch.Tensor,
        local_duration: torch.Tensor,
        global_duration: torch.Tensor | None = None,
    ):
        n_latents = torch.round(local_duration * self.latent_token_rate)
        if global_duration is not None:
            latent_length = torch.round(
                global_duration * self.latent_token_rate
            )
        else:
            latent_length = n_latents.sum(1)
        latent_mask = create_mask_from_length(latent_length).to(
            content_mask.device
        )
        attn_mask = content_mask.unsqueeze(-1) * latent_mask.unsqueeze(1)
        align_path = create_alignment_path(n_latents, attn_mask)
        expanded_x = torch.matmul(align_path.transpose(1, 2).to(x.dtype), x)
        return expanded_x, latent_mask