Guess_the_Live / model.py
trioskosmos's picture
Upload 6 files
6d30012 verified
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=500):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(0), :]
class LoveLiveTransformer(nn.Module):
def __init__(self, num_songs, num_artists, num_feedback_types, num_lives, d_model=64, nhead=4, num_layers=2):
super(LoveLiveTransformer, self).__init__()
# Use padding_idx=0 so embedding for 0 is always vector of zeros
self.song_embedding = nn.Embedding(num_songs, d_model, padding_idx=0)
self.artist_embedding = nn.Embedding(num_artists, d_model, padding_idx=0)
self.feedback_embedding = nn.Embedding(num_feedback_types, d_model, padding_idx=0)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=d_model*4)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
self.fc_out = nn.Linear(d_model, num_lives)
self.d_model = d_model
def forward(self, song_seq, artist_seq, feedback_seq):
# seq input: (seq_len, batch_size)
# Create padding mask (batch_size, seq_len)
# True where value is 0 (padding)
src_key_padding_mask = (song_seq == 0).transpose(0, 1)
# Embed inputs
src = self.song_embedding(song_seq) + self.artist_embedding(artist_seq) + self.feedback_embedding(feedback_seq)
src = src * math.sqrt(self.d_model)
src = self.pos_encoder(src)
# Transformer Encoder
# output: (seq_len, batch_size, d_model)
output = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)
# Pooling: Mean pooling excluding padding
# Transpose to (batch_size, seq_len, d_model)
output = output.transpose(0, 1)
# Create mask for pooling (batch_size, seq_len, 1)
# 1 for valid, 0 for padding
mask = (~src_key_padding_mask).float().unsqueeze(2)
# Sum valid outputs
sum_output = torch.sum(output * mask, dim=1)
# Count valid tokens
count_valid = torch.sum(mask, dim=1)
# Avoid division by zero
count_valid = torch.clamp(count_valid, min=1.0)
# Mean
pooled_output = sum_output / count_valid
# Classification
logits = self.fc_out(pooled_output)
return logits