File size: 14,484 Bytes
4f0238f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 | """
Tab & Chord Generation Module for TouchGrass.
Generates guitar tabs, chord diagrams, and validates musical correctness.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
class TabChordModule(nn.Module):
"""
Generates and validates guitar tabs and chord diagrams.
Features:
- Generates ASCII tablature for guitar, bass, ukulele
- Creates chord diagrams in standard format
- Validates musical correctness (fret ranges, string counts)
- Difficulty-aware: suggests easier voicings for beginners
- Supports multiple tunings
"""
# Standard tunings
STANDARD_TUNING = ["E2", "A2", "D3", "G3", "B3", "E4"] # Guitar
BASS_TUNING = ["E1", "A1", "D2", "G2"]
UKULELE_TUNING = ["G4", "C4", "E4", "A4"]
DROP_D_TUNING = ["D2", "A2", "D3", "G3", "B3", "E4"]
OPEN_G_TUNING = ["D2", "G2", "D3", "G3", "B3", "D4"]
# Fretboard limits
MAX_FRET = 24
OPEN_FRET = 0
MUTED_FRET = -1
def __init__(self, d_model: int, num_strings: int = 6, num_frets: int = 24):
"""
Initialize TabChordModule.
Args:
d_model: Hidden dimension from base model
num_strings: Number of strings (6 for guitar, 4 for bass)
num_frets: Number of frets (typically 24)
"""
super().__init__()
self.d_model = d_model
self.num_strings = num_strings
self.num_frets = num_frets
# Embeddings
self.string_embed = nn.Embedding(num_strings, 64)
self.fret_embed = nn.Embedding(num_frets + 2, 64) # +2 for open/muted
# Tab validator head
self.tab_validator = nn.Sequential(
nn.Linear(d_model, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
# Difficulty classifier (beginner/intermediate/advanced)
self.difficulty_head = nn.Linear(d_model, 3)
# Instrument type embedder
self.instrument_embed = nn.Embedding(8, 64) # guitar/bass/ukulele/piano/etc
# Fret position predictor for tab generation
self.fret_predictor = nn.Linear(d_model + 128, num_frets + 2)
# Tab sequence generator (for multi-token tab output)
self.tab_generator = nn.GRU(
input_size=d_model + 64, # hidden + string embedding
hidden_size=d_model,
num_layers=1,
batch_first=True,
)
# Chord quality classifier (major, minor, dim, aug, etc.)
self.chord_quality_head = nn.Linear(d_model, 8)
# Root note predictor (12 chromatic notes)
self.root_note_head = nn.Linear(d_model, 12)
def forward(
self,
hidden_states: torch.Tensor,
instrument: str = "guitar",
skill_level: str = "intermediate",
generate_tab: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Forward pass through TabChordModule.
Args:
hidden_states: Base model hidden states [batch, seq_len, d_model]
instrument: Instrument type ("guitar", "bass", "ukulele")
skill_level: "beginner", "intermediate", or "advanced"
generate_tab: Whether to generate tab sequences
Returns:
Dictionary with tab_validity, difficulty_logits, fret_predictions, etc.
"""
batch_size, seq_len, _ = hidden_states.shape
# Pool hidden states
pooled = hidden_states.mean(dim=1) # [batch, d_model]
# Validate tab
tab_validity = self.tab_validator(pooled) # [batch, 1]
# Predict difficulty
difficulty_logits = self.difficulty_head(pooled) # [batch, 3]
# Predict chord quality and root note
chord_quality_logits = self.chord_quality_head(pooled) # [batch, 8]
root_note_logits = self.root_note_head(pooled) # [batch, 12]
outputs = {
"tab_validity": tab_validity,
"difficulty_logits": difficulty_logits,
"chord_quality_logits": chord_quality_logits,
"root_note_logits": root_note_logits,
}
if generate_tab:
# Generate tab sequence
tab_seq = self._generate_tab_sequence(hidden_states, instrument)
outputs["tab_sequence"] = tab_seq
return outputs
def _generate_tab_sequence(
self,
hidden_states: torch.Tensor,
instrument: str,
max_length: int = 100,
) -> torch.Tensor:
"""
Generate tab sequence using GRU decoder.
Args:
hidden_states: Base model hidden states
instrument: Instrument type
max_length: Maximum tab sequence length
Returns:
Generated tab token sequence
"""
batch_size, seq_len, d_model = hidden_states.shape
# Get instrument embedding
instrument_idx = self._instrument_to_idx(instrument)
instrument_emb = self.instrument_embed(
torch.tensor([instrument_idx], device=hidden_states.device)
).unsqueeze(0).expand(batch_size, -1) # [batch, 64]
# Initialize GRU hidden state
h0 = hidden_states.mean(dim=1, keepdim=True).transpose(0, 1) # [1, batch, d_model]
# Generate tokens auto-regressively
generated = []
input_emb = hidden_states[:, 0:1, :] # Start with first token
for _ in range(max_length):
# Concatenate instrument embedding
input_with_instr = torch.cat([input_emb, instrument_emb.unsqueeze(1)], dim=2)
# GRU step
output, h0 = self.tab_generator(input_with_instr, h0)
# Predict fret positions
fret_logits = self.fret_predictor(output) # [batch, 1, num_frets+2]
next_token = fret_logits.argmax(dim=-1) # [batch, 1]
generated.append(next_token.squeeze(1))
# Next input is predicted token embedding
input_emb = self.fret_embed(next_token)
return torch.stack(generated, dim=1) # [batch, max_length]
def _instrument_to_idx(self, instrument: str) -> int:
"""Convert instrument name to index."""
mapping = {
"guitar": 0,
"bass": 1,
"ukulele": 2,
"piano": 3,
"drums": 4,
"vocals": 5,
"theory": 6,
"dj": 7,
}
return mapping.get(instrument, 0)
def validate_tab(
self,
tab_strings: List[List[str]],
instrument: str = "guitar",
) -> Tuple[bool, List[str]]:
"""
Validate ASCII tab for musical correctness.
Args:
tab_strings: List of tab rows (6 strings for guitar)
instrument: Instrument type
Returns:
(is_valid, error_messages)
"""
errors = []
# Check number of strings
expected_strings = self._get_expected_strings(instrument)
if len(tab_strings) != expected_strings:
errors.append(f"Expected {expected_strings} strings, got {len(tab_strings)}")
# Validate each string
for i, string_row in enumerate(tab_strings):
# Check format (e.g., "e|--3--|")
if not self._validate_tab_row(string_row, i, instrument):
errors.append(f"Invalid format on string {i}: {string_row}")
# Check for musical consistency
if not self._check_musical_consistency(tab_strings):
errors.append("Tab has musical inconsistencies (impossible fingering)")
return len(errors) == 0, errors
def _get_expected_strings(self, instrument: str) -> int:
"""Get expected number of strings for instrument."""
mapping = {
"guitar": 6,
"bass": 4,
"ukulele": 4,
}
return mapping.get(instrument, 6)
def _validate_tab_row(self, row: str, string_idx: int, instrument: str) -> bool:
"""Validate a single tab row."""
# Basic format check: should have string label and pipe separators
if "|" not in row:
return False
# Extract fret numbers
parts = row.split("|")
if len(parts) < 2:
return False
# Check fret values are in valid range
for part in parts[1:-1]: # Skip string label and last pipe
if part.strip():
try:
fret = int(part.strip().replace("-", ""))
if fret < 0 or fret > self.MAX_FRET:
return False
except ValueError:
# Could be 'x' for muted
if part.strip().lower() != "x":
return False
return True
def _check_musical_consistency(self, tab_strings: List[List[str]]) -> bool:
"""
Check if tab is musically possible (basic checks).
- No impossible stretches
- Open strings are marked as 0
"""
# Simplified check: ensure all fret numbers are within range
for string_row in tab_strings:
for part in string_row.split("|")[1:-1]:
fret_str = part.strip().replace("-", "")
if fret_str and fret_str.lower() != "x":
try:
fret = int(fret_str)
if fret < 0 or fret > self.MAX_FRET:
return False
except ValueError:
return False
return True
def format_tab(
self,
frets: List[List[int]],
instrument: str = "guitar",
tuning: List[str] = None,
) -> List[str]:
"""
Format fret positions into ASCII tab.
Args:
frets: List of [num_strings] lists with fret numbers (0=open, -1=muted)
instrument: Instrument type
tuning: Optional custom tuning labels
Returns:
List of formatted tab strings
"""
if tuning is None:
tuning = self.STANDARD_TUNING
tab_strings = []
string_labels = ["e", "B", "G", "D", "A", "E"] # High to low
for i, (label, fret_row) in enumerate(zip(string_labels, frets)):
# Build tab row: "e|--3--|"
row = f"{label}|"
for fret in fret_row:
if fret == -1:
row += "x-"
elif fret == 0:
row += "0-"
else:
row += f"{fret}-"
row += "|"
tab_strings.append(row)
return tab_strings
def format_chord(
self,
frets: List[int],
instrument: str = "guitar",
) -> str:
"""
Format chord as compact diagram.
Args:
frets: List of fret numbers for each string (low to high)
instrument: Instrument type
Returns:
Chord string (e.g., "320003" for G major)
"""
# Format as: 320003 (from low E to high e)
return "".join(str(fret) if fret >= 0 else "x" for fret in frets)
def parse_chord(self, chord_str: str) -> List[int]:
"""
Parse chord string to fret positions.
Args:
chord_str: Chord string like "320003" or "x32010"
Returns:
List of fret positions
"""
frets = []
for char in chord_str:
if char.lower() == "x":
frets.append(-1)
else:
frets.append(int(char))
return frets
def suggest_easier_voicing(
self,
chord_frets: List[int],
skill_level: str = "beginner",
) -> List[int]:
"""
Suggest easier chord voicing for beginners.
Args:
chord_frets: Original chord frets
skill_level: Target skill level
Returns:
Simplified chord frets
"""
if skill_level != "beginner":
return chord_frets
# Simplify: reduce barre chords, avoid wide stretches
simplified = chord_frets.copy()
# Count barre (same fret on multiple strings)
fret_counts = {}
for fret in chord_frets:
if fret > 0:
fret_counts[fret] = fret_counts.get(fret, 0) + 1
# If barre detected (3+ strings on same fret), try to simplify
for fret, count in fret_counts.items():
if count >= 3:
# Replace some with open strings if possible
for i, f in enumerate(simplified):
if f == fret and i % 2 == 0: # Every other string
simplified[i] = 0 # Open string
return simplified
def test_tab_chord_module():
"""Test the TabChordModule."""
import torch
# Create module
module = TabChordModule(d_model=4096, num_strings=6, num_frets=24)
# Test input
batch_size = 2
seq_len = 10
d_model = 4096
hidden_states = torch.randn(batch_size, seq_len, d_model)
# Forward pass
outputs = module.forward(
hidden_states,
instrument="guitar",
skill_level="beginner",
generate_tab=True,
)
print("Outputs:")
for key, value in outputs.items():
if isinstance(value, torch.Tensor):
print(f" {key}: {value.shape}")
else:
print(f" {key}: {value}")
# Test tab formatting
frets = [[3, 3, 0, 0, 2, 3]] # G chord
tab = module.format_tab(frets, instrument="guitar")
print("\nFormatted tab:")
for line in tab:
print(f" {line}")
# Test chord formatting
chord = module.format_chord([3, 2, 0, 0, 3, 3])
print(f"\nChord: {chord}")
# Test validation
is_valid, errors = module.validate_tab(tab, instrument="guitar")
print(f"\nTab valid: {is_valid}")
if errors:
print(f"Errors: {errors}")
print("\nTabChordModule test complete!")
if __name__ == "__main__":
test_tab_chord_module() |