| from transformers import ASTModel, ViTModel, PretrainedConfig, PreTrainedModel | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from einops import reduce | |
| class MuVis(nn.Module): | |
| def __init__(self, embed_dims=768, latent_dims=128, sampling_rate=16000): | |
| super(MuVis, self).__init__() | |
| self.sampling_rate = sampling_rate | |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
| self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", low_cpu_mem_usage=True) | |
| self.wav_lin = nn.Linear(embed_dims, latent_dims) | |
| self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", low_cpu_mem_usage=True) | |
| self.img_lin = nn.Linear(embed_dims, latent_dims) | |
| def forward(self, wav=None, img=None): | |
| wav_out = None | |
| img_out = None | |
| if wav is not None: | |
| wav_out = self.ast(**wav)["last_hidden_state"] | |
| wav_out = self.wav_lin(wav_out) | |
| wav_out = reduce(wav_out, "b n d -> b d", "mean") | |
| wav_out = wav_out / wav_out.norm(dim=-1, keepdim=True) | |
| if img is not None: | |
| img_out = self.vit(**img)["last_hidden_state"] | |
| img_out = self.img_lin(img_out) | |
| img_out = reduce(img_out, "b n d -> b d", "mean") | |
| img_out = img_out / img_out.norm(dim=-1, keepdim=True) | |
| assert wav_out is not None or img_out is not None | |
| if wav_out is None or img_out is None: | |
| return wav_out if img_out is None else img_out | |
| return (wav_out, img_out) | |
| class MuVisConfig(PretrainedConfig): | |
| model_type = "muvis" | |
| def __init__( | |
| self, | |
| embed_dims=768, | |
| latent_dims=128, | |
| sampling_rate=16000, | |
| **kwargs, | |
| ): | |
| self.embed_dims = embed_dims | |
| self.latent_dims = latent_dims | |
| self.sampling_rate = sampling_rate | |
| super().__init__(**kwargs) | |
| class MuVisModel(PreTrainedModel): | |
| config_class = MuVisConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = MuVis( | |
| embed_dims=config.embed_dims, | |
| latent_dims=config.latent_dims, | |
| sampling_rate=config.sampling_rate, | |
| ) | |
| def forward(self, wav=None, img=None): | |
| return self.model(wav=wav, img=img) |