| | from typing import Literal, Optional
|
| | import json
|
| | import open_clip
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from einops import rearrange
|
| | from open_clip import create_model_from_pretrained, create_model
|
| | from torchvision.transforms import Normalize
|
| |
|
| | from ...ext.autoencoder import AutoEncoderModule
|
| | from ...ext.mel_converter import get_mel_converter
|
| | from ...ext.synchformer.synchformer import Synchformer
|
| | from ...model.utils.distributions import DiagonalGaussianDistribution
|
| | from shared.utils import files_locator as fl
|
| |
|
| |
|
| | def patch_clip(clip_model):
|
| |
|
| |
|
| | def new_encode_text(self, text, normalize: bool = False):
|
| | cast_dtype = self.transformer.get_cast_dtype()
|
| |
|
| | x = self.token_embedding(text).to(cast_dtype)
|
| |
|
| | x = x + self.positional_embedding.to(cast_dtype)
|
| | x = self.transformer(x, attn_mask=self.attn_mask)
|
| | x = self.ln_final(x)
|
| | return F.normalize(x, dim=-1) if normalize else x
|
| |
|
| | clip_model.encode_text = new_encode_text.__get__(clip_model)
|
| | return clip_model
|
| |
|
| | def get_model_config(model_name):
|
| | with open( fl.locate_file("DFN5B-CLIP-ViT-H-14-378/open_clip_config.json"), 'r', encoding='utf-8') as f:
|
| | return json.load(f)["model_cfg"]
|
| |
|
| | class FeaturesUtils(nn.Module):
|
| |
|
| | def __init__(
|
| | self,
|
| | *,
|
| | tod_vae_ckpt: Optional[str] = None,
|
| | bigvgan_vocoder_ckpt: Optional[str] = None,
|
| | synchformer_ckpt: Optional[str] = None,
|
| | enable_conditions: bool = True,
|
| | mode=Literal['16k', '44k'],
|
| | need_vae_encoder: bool = True,
|
| | ):
|
| | super().__init__()
|
| | self.device ="cuda"
|
| | if enable_conditions:
|
| | old_get_model_config = open_clip.factory.get_model_config
|
| | open_clip.factory.get_model_config = get_model_config
|
| | with open( fl.locate_file("DFN5B-CLIP-ViT-H-14-378/open_clip_config.json"), 'r', encoding='utf-8') as f:
|
| | override_preprocess = json.load(f)["preprocess_cfg"]
|
| |
|
| | self.clip_model = create_model('DFN5B-CLIP-ViT-H-14-378', pretrained= fl.locate_file('DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin'), force_preprocess_cfg= override_preprocess)
|
| | open_clip.factory.get_model_config = old_get_model_config
|
| |
|
| |
|
| | self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
| | std=[0.26862954, 0.26130258, 0.27577711])
|
| | self.clip_model = patch_clip(self.clip_model)
|
| |
|
| | self.synchformer = Synchformer()
|
| | self.synchformer.load_state_dict(
|
| | torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
|
| |
|
| | self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu')
|
| | else:
|
| | self.clip_model = None
|
| | self.synchformer = None
|
| | self.tokenizer = None
|
| |
|
| | if tod_vae_ckpt is not None:
|
| | self.mel_converter = get_mel_converter(mode)
|
| | self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
|
| | vocoder_ckpt_path=bigvgan_vocoder_ckpt,
|
| | mode=mode,
|
| | need_vae_encoder=need_vae_encoder)
|
| | else:
|
| | self.tod = None
|
| |
|
| | def compile(self):
|
| | if self.clip_model is not None:
|
| | self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
|
| | self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
|
| | if self.synchformer is not None:
|
| | self.synchformer = torch.compile(self.synchformer)
|
| | self.decode = torch.compile(self.decode)
|
| | self.vocode = torch.compile(self.vocode)
|
| |
|
| | def train(self, mode: bool) -> None:
|
| | return super().train(False)
|
| |
|
| | @torch.inference_mode()
|
| | def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
| | assert self.clip_model is not None, 'CLIP is not loaded'
|
| |
|
| | b, t, c, h, w = x.shape
|
| | assert c == 3 and h == 384 and w == 384
|
| | x = self.clip_preprocess(x)
|
| | x = rearrange(x, 'b t c h w -> (b t) c h w')
|
| | outputs = []
|
| | if batch_size < 0:
|
| | batch_size = b * t
|
| | for i in range(0, b * t, batch_size):
|
| | outputs.append(self.clip_model.encode_image(x[i:i + batch_size], normalize=True))
|
| | x = torch.cat(outputs, dim=0)
|
| |
|
| | x = rearrange(x, '(b t) d -> b t d', b=b)
|
| | return x
|
| |
|
| | @torch.inference_mode()
|
| | def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
| | assert self.synchformer is not None, 'Synchformer is not loaded'
|
| |
|
| |
|
| | b, t, c, h, w = x.shape
|
| | assert c == 3 and h == 224 and w == 224
|
| |
|
| |
|
| | segment_size = 16
|
| | step_size = 8
|
| | num_segments = (t - segment_size) // step_size + 1
|
| | segments = []
|
| | for i in range(num_segments):
|
| | segments.append(x[:, i * step_size:i * step_size + segment_size])
|
| | x = torch.stack(segments, dim=1)
|
| |
|
| | outputs = []
|
| | if batch_size < 0:
|
| | batch_size = b
|
| | x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
|
| | for i in range(0, b * num_segments, batch_size):
|
| | outputs.append(self.synchformer(x[i:i + batch_size]))
|
| | x = torch.cat(outputs, dim=0)
|
| | x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
|
| | return x
|
| |
|
| | @torch.inference_mode()
|
| | def encode_text(self, text: list[str]) -> torch.Tensor:
|
| | assert self.clip_model is not None, 'CLIP is not loaded'
|
| | assert self.tokenizer is not None, 'Tokenizer is not loaded'
|
| |
|
| | tokens = self.tokenizer(text).to(self.device)
|
| | return self.clip_model.encode_text(tokens, normalize=True)
|
| |
|
| | @torch.inference_mode()
|
| | def encode_audio(self, x) -> DiagonalGaussianDistribution:
|
| | assert self.tod is not None, 'VAE is not loaded'
|
| |
|
| | mel = self.mel_converter(x)
|
| | dist = self.tod.encode(mel)
|
| |
|
| | return dist
|
| |
|
| | @torch.inference_mode()
|
| | def vocode(self, mel: torch.Tensor) -> torch.Tensor:
|
| | assert self.tod is not None, 'VAE is not loaded'
|
| | return self.tod.vocode(mel)
|
| |
|
| | @torch.inference_mode()
|
| | def decode(self, z: torch.Tensor) -> torch.Tensor:
|
| | assert self.tod is not None, 'VAE is not loaded'
|
| | return self.tod.decode(z.transpose(1, 2))
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | @property
|
| | def dtype(self):
|
| | return next(self.parameters()).dtype
|
| |
|