File size: 7,941 Bytes
d171350 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 | """Chart prediction model architecture.
FiLM-conditioned masked transformer for Guitar Hero chart generation.
"""
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Utility layers
# ---------------------------------------------------------------------------
def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0):
x_glu, x_linear = x[..., ::2], x[..., 1::2]
x_glu = x_glu.clamp(max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
return x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = x.float()
t = t * torch.rsqrt(t.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return (t * self.scale).to(x.dtype)
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff, bias=False)
self.linear_out = nn.Linear(d_ff // 2, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_out(self.dropout(swiglu(self.linear1(x))))
# ---------------------------------------------------------------------------
# Rotary position embeddings
# ---------------------------------------------------------------------------
def apply_rotary_emb(
x: torch.Tensor, dim: int, base: float = 10000.0,
) -> torch.Tensor:
"""Apply RoPE to a tensor of shape [B, heads, T, head_dim]."""
seq_len = x.size(2)
device, dtype = x.device, x.dtype
theta = base ** (-torch.arange(0, dim, 2, device=device, dtype=dtype) / dim)
positions = torch.arange(seq_len, device=device, dtype=dtype).unsqueeze(1)
angles = positions * theta.unsqueeze(0)
sin, cos = angles.sin(), angles.cos()
sin = sin.unsqueeze(0).unsqueeze(0)
cos = cos.unsqueeze(0).unsqueeze(0)
x1 = x[..., : dim // 2]
x2 = x[..., dim // 2 : dim]
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
# ---------------------------------------------------------------------------
# Bidirectional multi-head self-attention
# ---------------------------------------------------------------------------
class BidirectionalAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1,
rope_base: float = 10000.0):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.rope_base = rope_base
self.w_q = nn.Linear(d_model, d_model, bias=False)
self.w_k = nn.Linear(d_model, d_model, bias=False)
self.w_v = nn.Linear(d_model, d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
B, T, _ = x.shape
Q = self.w_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.w_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.w_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
Q = apply_rotary_emb(Q, dim=self.d_k, base=self.rope_base)
K = apply_rotary_emb(K, dim=self.d_k, base=self.rope_base)
sdpa_mask = None
if attn_mask is not None:
sdpa_mask = attn_mask[:, None, None, :].bool()
out = F.scaled_dot_product_attention(
Q, K, V, attn_mask=sdpa_mask,
dropout_p=self.dropout.p if self.training else 0.0,
is_causal=False,
)
out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
return self.out_proj(out)
# ---------------------------------------------------------------------------
# FiLM-conditioned encoder block
# ---------------------------------------------------------------------------
class FiLMEncoderBlock(nn.Module):
"""Encoder block with FiLM difficulty conditioning.
After the feedforward, the output is modulated:
h = (1 + gamma) * h + beta
where gamma, beta are derived from the difficulty embedding.
"""
def __init__(self, d_model: int, d_ff: int, n_heads: int,
dropout: float = 0.1, rope_base: float = 10000.0):
super().__init__()
self.norm1 = RMSNorm(d_model)
self.attn = BidirectionalAttention(d_model, n_heads, dropout, rope_base)
self.norm2 = RMSNorm(d_model)
self.ff = FeedForward(d_model, d_ff, dropout)
self.dropout = nn.Dropout(dropout)
self.film_proj = nn.Linear(d_model, d_model * 2)
nn.init.zeros_(self.film_proj.weight)
nn.init.zeros_(self.film_proj.bias)
def forward(self, x: torch.Tensor, diff_emb: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.dropout(self.attn(self.norm1(x), attn_mask))
h = self.ff(self.norm2(x))
film = self.film_proj(diff_emb).unsqueeze(1)
gamma, beta = film.chunk(2, dim=-1)
h = (1 + gamma) * h + beta
x = x + self.dropout(h)
return x
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SILENCE_TOKEN = 32
MASK_TOKEN = 33
VOCAB_SIZE = 34
NUM_SUSTAIN_BUCKETS = 6
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
class ChartMaskPredictor(nn.Module):
"""Masked prediction chart model (v3).
Token vocabulary: 0-31 fret combos, 32 silence, 33 MASK.
"""
def __init__(self, config: "ChartMaskPredictorConfig"):
super().__init__()
self.config = config
d = config.d_model
self.audio_projection = nn.Linear(config.audio_dim, d, bias=False)
self.chart_embedding = nn.Embedding(VOCAB_SIZE, d)
self.input_dropout = nn.Dropout(config.dropout)
self.difficulty_embedding = nn.Embedding(4, d)
self.layers = nn.ModuleList([
FiLMEncoderBlock(
d_model=d, d_ff=config.d_ff, n_heads=config.n_heads,
dropout=config.dropout, rope_base=config.rope_base,
)
for _ in range(config.n_layers)
])
self.final_norm = RMSNorm(d)
self.token_head = nn.Linear(d, VOCAB_SIZE - 1) # 33 classes (no MASK)
self.sustain_head = nn.Linear(d, 1)
self.duration_head = nn.Linear(d, NUM_SUSTAIN_BUCKETS)
def forward(self, audio_features: torch.Tensor, chart_tokens: torch.Tensor,
difficulty: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None) -> dict[str, torch.Tensor]:
audio = self.audio_projection(audio_features)
chart = self.chart_embedding(chart_tokens)
x = audio + chart
x = self.input_dropout(x)
diff_emb = self.difficulty_embedding(difficulty)
for layer in self.layers:
x = layer(x, diff_emb, attn_mask=padding_mask)
x = self.final_norm(x)
return {
"token_logits": self.token_head(x),
"sustain_logits": self.sustain_head(x),
"duration_logits": self.duration_head(x),
}
@dataclass
class ChartMaskPredictorConfig:
audio_dim: int = 771
d_model: int = 512
n_heads: int = 8
n_layers: int = 6
d_ff: int = 2048
dropout: float = 0.15
rope_base: float = 10000.0
|