Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len=500): | |
| super(PositionalEncoding, self).__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0).transpose(0, 1) | |
| self.register_buffer('pe', pe) | |
| def forward(self, x): | |
| return x + self.pe[:x.size(0), :] | |
| class LoveLiveTransformer(nn.Module): | |
| def __init__(self, num_songs, num_artists, num_feedback_types, num_lives, d_model=64, nhead=4, num_layers=2): | |
| super(LoveLiveTransformer, self).__init__() | |
| # Use padding_idx=0 so embedding for 0 is always vector of zeros | |
| self.song_embedding = nn.Embedding(num_songs, d_model, padding_idx=0) | |
| self.artist_embedding = nn.Embedding(num_artists, d_model, padding_idx=0) | |
| self.feedback_embedding = nn.Embedding(num_feedback_types, d_model, padding_idx=0) | |
| self.pos_encoder = PositionalEncoding(d_model) | |
| encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=d_model*4) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers) | |
| self.fc_out = nn.Linear(d_model, num_lives) | |
| self.d_model = d_model | |
| def forward(self, song_seq, artist_seq, feedback_seq): | |
| # seq input: (seq_len, batch_size) | |
| # Create padding mask (batch_size, seq_len) | |
| # True where value is 0 (padding) | |
| src_key_padding_mask = (song_seq == 0).transpose(0, 1) | |
| # Embed inputs | |
| src = self.song_embedding(song_seq) + self.artist_embedding(artist_seq) + self.feedback_embedding(feedback_seq) | |
| src = src * math.sqrt(self.d_model) | |
| src = self.pos_encoder(src) | |
| # Transformer Encoder | |
| # output: (seq_len, batch_size, d_model) | |
| output = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask) | |
| # Pooling: Mean pooling excluding padding | |
| # Transpose to (batch_size, seq_len, d_model) | |
| output = output.transpose(0, 1) | |
| # Create mask for pooling (batch_size, seq_len, 1) | |
| # 1 for valid, 0 for padding | |
| mask = (~src_key_padding_mask).float().unsqueeze(2) | |
| # Sum valid outputs | |
| sum_output = torch.sum(output * mask, dim=1) | |
| # Count valid tokens | |
| count_valid = torch.sum(mask, dim=1) | |
| # Avoid division by zero | |
| count_valid = torch.clamp(count_valid, min=1.0) | |
| # Mean | |
| pooled_output = sum_output / count_valid | |
| # Classification | |
| logits = self.fc_out(pooled_output) | |
| return logits | |