File size: 2,402 Bytes
22cfe7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torch.nn as nn
import torch.nn.functional as F


from transformers import BertModel
from timeseries_vqvae_transformer import TransformerVQVAE



class SemanticSimilarityModel(nn.Module):
    def __init__(self):
        super(SemanticSimilarityModel, self).__init__()
        
        # Tranformer VQ-VAE and convolutional layers for time series music embeddings
        self.music_encoder = TransformerVQVAE()
        for param in self.music_encoder.parameters():
            param.requires_grad = False  # Freeze all parameters of music encoder

        self.conv1 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3)) # output shape: (14 x 24 x 4)
        self.conv2 = nn.Conv3d(128, 256, kernel_size=(3, 3, 3)) # output shape: (12 x 22 x 2)
        self.timeseries_fc = nn.Linear(141312, 512) 

        
        # BERT model and linear layer for text embeddings
        self.text_encoder = BertModel.from_pretrained('bert-base-cased')
        # self.text_encoder = RobertaModel.from_pretrained('roberta-base')
        total_params = len(list(self.text_encoder.parameters()))

        # Iterate over each parameter with its index
        for idx, param in enumerate(self.text_encoder.parameters()):
            param.requires_grad = False
        self.text_fc = nn.Linear(self.text_encoder.config.hidden_size, 512)

    def forward(self, time_series, input_ids, attention_mask):

        # Get time series embeddings
        melspec_embed = []
        for batch in time_series:
            current_batch = []
            for melspec in batch:
                z = self.music_encoder.encode(melspec.unsqueeze(0))
                z_q, *_ = self.music_encoder.vq(z)

                current_batch.append(z_q)
            current_batch = torch.cat(current_batch, dim = 0)
            melspec_embed.append(current_batch)
        
        
        
        melspec_embed = torch.stack(melspec_embed, dim=0)
        x_ts = melspec_embed.permute(0, 2, 1, 3, 4)
        x_ts = self.conv1(x_ts)
        x_ts = self.conv2(x_ts)
        batch_size, *_ = x_ts.shape
        x_ts = x_ts.view(batch_size, -1)
        x_ts = self.timeseries_fc(x_ts)


        # Get text embeddings
        text_embed = self.text_encoder(input_ids = input_ids, attention_mask = attention_mask)
        x_text = text_embed.last_hidden_state[:, 0, :]
        x_text = self.text_fc(x_text)
        
        return x_ts, x_text