File size: 5,275 Bytes
ddb382a
 
 
 
 
 
 
 
 
8031e67
 
 
ddb382a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from typing import Literal, Optional
import json
# import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# from open_clip import create_model_from_pretrained
from torchvision.transforms import Normalize
from PrismAudio.models.factory import create_model_from_config
from PrismAudio.models.utils import load_ckpt_state_dict
from PrismAudio.training.utils import copy_state_dict
from transformers import AutoModel
from transformers import AutoProcessor
from transformers import T5EncoderModel, AutoTokenizer
import logging


from data_utils.ext.synchformer import Synchformer

log = logging.getLogger()

def patch_clip(clip_model):
    # a hack to make it output last hidden states
    # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
    def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        last_hidden_state = text_outputs[0]
        pooled_output = text_outputs[1]
        text_features = self.text_projection(pooled_output)

        return text_features, last_hidden_state

    clip_model.get_text_features = new_get_text_features.__get__(clip_model)
    return clip_model


class FeaturesUtils(nn.Module):
 
    def __init__(
        self,
        *, 
        vae_ckpt: Optional[str] = None,
        vae_config: Optional[str] = None,
        synchformer_ckpt: Optional[str] = None,
        enable_conditions: bool = True,
        need_vae_encoder: bool = True,
    ):
        super().__init__()

        if enable_conditions:
            self.clip_model = AutoModel.from_pretrained("metaclip-h14-fullcc2.5b")
            self.clip_model = patch_clip(self.clip_model)
            self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl")
            self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl")
            self.clip_processor = AutoProcessor.from_pretrained("metaclip-h14-fullcc2.5b")
            # self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
            #                                  std=[0.26862954, 0.26130258, 0.27577711])
            # self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')  # same as 'ViT-H-14'
        else:
            self.clip_model = None
            self.synchformer = None
            self.tokenizer = None

        if vae_ckpt is not None:
            with open(vae_config) as f:
                vae_config = json.load(f)
            self.vae = create_model_from_config(vae_config)
            print(f"Loading model checkpoint from {vae_ckpt}")
            # Load checkpoint
            copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.'
        else:
            self.tod = None

    def compile(self):
        if self.clip_model is not None:
            self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
            self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
        if self.synchformer is not None:
            self.synchformer = torch.compile(self.synchformer)


    def train(self, mode: bool) -> None:
        return super().train(False)

    @torch.inference_mode()
    def encode_text(self, text: list[str]) -> torch.Tensor:
        assert self.clip_model is not None, 'CLIP is not loaded'
        # assert self.tokenizer is not None, 'Tokenizer is not loaded'
        # x: (B, L)
        tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device)
        return self.clip_model.get_text_features(**tokens)

    @torch.inference_mode()
    def encode_t5_text(self, text: list[str]) -> torch.Tensor:
        assert self.t5_model is not None, 'T5 model is not loaded'
        assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded'
        # x: (B, L)
        inputs = self.t5_tokenizer(text,
            truncation=True,
            max_length=77,
            padding="max_length",
            return_tensors="pt").to(self.device)
        return self.t5_model(**inputs).last_hidden_state

    @torch.inference_mode()
    def encode_audio(self, x) -> torch.Tensor:
        x = self.vae.encode(x)
        return x

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def dtype(self):
        return next(self.parameters()).dtype