File size: 2,192 Bytes
0db0a20 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | from transformers import ASTModel, ViTModel, PretrainedConfig, PreTrainedModel
import numpy as np
import torch
import torch.nn as nn
from einops import reduce
class MuVis(nn.Module):
def __init__(self, embed_dims=768, latent_dims=128, sampling_rate=16000):
super(MuVis, self).__init__()
self.sampling_rate = sampling_rate
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", low_cpu_mem_usage=True)
self.wav_lin = nn.Linear(embed_dims, latent_dims)
self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", low_cpu_mem_usage=True)
self.img_lin = nn.Linear(embed_dims, latent_dims)
def forward(self, wav=None, img=None):
wav_out = None
img_out = None
if wav is not None:
wav_out = self.ast(**wav)["last_hidden_state"]
wav_out = self.wav_lin(wav_out)
wav_out = reduce(wav_out, "b n d -> b d", "mean")
wav_out = wav_out / wav_out.norm(dim=-1, keepdim=True)
if img is not None:
img_out = self.vit(**img)["last_hidden_state"]
img_out = self.img_lin(img_out)
img_out = reduce(img_out, "b n d -> b d", "mean")
img_out = img_out / img_out.norm(dim=-1, keepdim=True)
assert wav_out is not None or img_out is not None
if wav_out is None or img_out is None:
return wav_out if img_out is None else img_out
return (wav_out, img_out)
class MuVisConfig(PretrainedConfig):
model_type = "muvis"
def __init__(
self,
embed_dims=768,
latent_dims=128,
sampling_rate=16000,
**kwargs,
):
self.embed_dims = embed_dims
self.latent_dims = latent_dims
self.sampling_rate = sampling_rate
super().__init__(**kwargs)
class MuVisModel(PreTrainedModel):
config_class = MuVisConfig
def __init__(self, config):
super().__init__(config)
self.model = MuVis(
embed_dims=config.embed_dims,
latent_dims=config.latent_dims,
sampling_rate=config.sampling_rate,
)
def forward(self, wav=None, img=None):
return self.model(wav=wav, img=img) |