TouchGrass-7b / models /tab_chord_module.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Tab & Chord Generation Module for TouchGrass.
Generates guitar tabs, chord diagrams, and validates musical correctness.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Dict
class TabChordModule(nn.Module):
"""
Generates and validates guitar tabs and chord diagrams.
Features:
- Generates ASCII tablature for guitar, bass, ukulele
- Creates chord diagrams in standard format
- Validates musical correctness (fret ranges, string counts)
- Difficulty-aware: suggests easier voicings for beginners
- Supports multiple tunings
"""
# Standard tunings
STANDARD_TUNING = ["E2", "A2", "D3", "G3", "B3", "E4"] # Guitar
BASS_TUNING = ["E1", "A1", "D2", "G2"]
UKULELE_TUNING = ["G4", "C4", "E4", "A4"]
DROP_D_TUNING = ["D2", "A2", "D3", "G3", "B3", "E4"]
OPEN_G_TUNING = ["D2", "G2", "D3", "G3", "B3", "D4"]
# Fretboard limits
MAX_FRET = 24
OPEN_FRET = 0
MUTED_FRET = -1
def __init__(self, d_model: int, num_strings: int = 6, num_frets: int = 24):
"""
Initialize TabChordModule.
Args:
d_model: Hidden dimension from base model
num_strings: Number of strings (6 for guitar, 4 for bass)
num_frets: Number of frets (typically 24)
"""
super().__init__()
self.d_model = d_model
self.num_strings = num_strings
self.num_frets = num_frets
# Embeddings
self.string_embed = nn.Embedding(num_strings, 64)
self.fret_embed = nn.Embedding(num_frets + 2, 64) # +2 for open/muted
# Tab validator head
self.tab_validator = nn.Sequential(
nn.Linear(d_model, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
# Difficulty classifier (beginner/intermediate/advanced)
self.difficulty_head = nn.Linear(d_model, 3)
# Instrument type embedder
self.instrument_embed = nn.Embedding(8, 64) # guitar/bass/ukulele/piano/etc
# Fret position predictor for tab generation
self.fret_predictor = nn.Linear(d_model + 128, num_frets + 2)
# Tab sequence generator (for multi-token tab output)
self.tab_generator = nn.GRU(
input_size=d_model + 64, # hidden + string embedding
hidden_size=d_model,
num_layers=1,
batch_first=True,
)
# Chord quality classifier (major, minor, dim, aug, etc.)
self.chord_quality_head = nn.Linear(d_model, 8)
# Root note predictor (12 chromatic notes)
self.root_note_head = nn.Linear(d_model, 12)
def forward(
self,
hidden_states: torch.Tensor,
instrument: str = "guitar",
skill_level: str = "intermediate",
generate_tab: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Forward pass through TabChordModule.
Args:
hidden_states: Base model hidden states [batch, seq_len, d_model]
instrument: Instrument type ("guitar", "bass", "ukulele")
skill_level: "beginner", "intermediate", or "advanced"
generate_tab: Whether to generate tab sequences
Returns:
Dictionary with tab_validity, difficulty_logits, fret_predictions, etc.
"""
batch_size, seq_len, _ = hidden_states.shape
# Pool hidden states
pooled = hidden_states.mean(dim=1) # [batch, d_model]
# Validate tab
tab_validity = self.tab_validator(pooled) # [batch, 1]
# Predict difficulty
difficulty_logits = self.difficulty_head(pooled) # [batch, 3]
# Predict chord quality and root note
chord_quality_logits = self.chord_quality_head(pooled) # [batch, 8]
root_note_logits = self.root_note_head(pooled) # [batch, 12]
outputs = {
"tab_validity": tab_validity,
"difficulty_logits": difficulty_logits,
"chord_quality_logits": chord_quality_logits,
"root_note_logits": root_note_logits,
}
if generate_tab:
# Generate tab sequence
tab_seq = self._generate_tab_sequence(hidden_states, instrument)
outputs["tab_sequence"] = tab_seq
return outputs
def _generate_tab_sequence(
self,
hidden_states: torch.Tensor,
instrument: str,
max_length: int = 100,
) -> torch.Tensor:
"""
Generate tab sequence using GRU decoder.
Args:
hidden_states: Base model hidden states
instrument: Instrument type
max_length: Maximum tab sequence length
Returns:
Generated tab token sequence
"""
batch_size, seq_len, d_model = hidden_states.shape
# Get instrument embedding
instrument_idx = self._instrument_to_idx(instrument)
instrument_emb = self.instrument_embed(
torch.tensor([instrument_idx], device=hidden_states.device)
).unsqueeze(0).expand(batch_size, -1) # [batch, 64]
# Initialize GRU hidden state
h0 = hidden_states.mean(dim=1, keepdim=True).transpose(0, 1) # [1, batch, d_model]
# Generate tokens auto-regressively
generated = []
input_emb = hidden_states[:, 0:1, :] # Start with first token
for _ in range(max_length):
# Concatenate instrument embedding
input_with_instr = torch.cat([input_emb, instrument_emb.unsqueeze(1)], dim=2)
# GRU step
output, h0 = self.tab_generator(input_with_instr, h0)
# Predict fret positions
fret_logits = self.fret_predictor(output) # [batch, 1, num_frets+2]
next_token = fret_logits.argmax(dim=-1) # [batch, 1]
generated.append(next_token.squeeze(1))
# Next input is predicted token embedding
input_emb = self.fret_embed(next_token)
return torch.stack(generated, dim=1) # [batch, max_length]
def _instrument_to_idx(self, instrument: str) -> int:
"""Convert instrument name to index."""
mapping = {
"guitar": 0,
"bass": 1,
"ukulele": 2,
"piano": 3,
"drums": 4,
"vocals": 5,
"theory": 6,
"dj": 7,
}
return mapping.get(instrument, 0)
def validate_tab(
self,
tab_strings: List[List[str]],
instrument: str = "guitar",
) -> Tuple[bool, List[str]]:
"""
Validate ASCII tab for musical correctness.
Args:
tab_strings: List of tab rows (6 strings for guitar)
instrument: Instrument type
Returns:
(is_valid, error_messages)
"""
errors = []
# Check number of strings
expected_strings = self._get_expected_strings(instrument)
if len(tab_strings) != expected_strings:
errors.append(f"Expected {expected_strings} strings, got {len(tab_strings)}")
# Validate each string
for i, string_row in enumerate(tab_strings):
# Check format (e.g., "e|--3--|")
if not self._validate_tab_row(string_row, i, instrument):
errors.append(f"Invalid format on string {i}: {string_row}")
# Check for musical consistency
if not self._check_musical_consistency(tab_strings):
errors.append("Tab has musical inconsistencies (impossible fingering)")
return len(errors) == 0, errors
def _get_expected_strings(self, instrument: str) -> int:
"""Get expected number of strings for instrument."""
mapping = {
"guitar": 6,
"bass": 4,
"ukulele": 4,
}
return mapping.get(instrument, 6)
def _validate_tab_row(self, row: str, string_idx: int, instrument: str) -> bool:
"""Validate a single tab row."""
# Basic format check: should have string label and pipe separators
if "|" not in row:
return False
# Extract fret numbers
parts = row.split("|")
if len(parts) < 2:
return False
# Check fret values are in valid range
for part in parts[1:-1]: # Skip string label and last pipe
if part.strip():
try:
fret = int(part.strip().replace("-", ""))
if fret < 0 or fret > self.MAX_FRET:
return False
except ValueError:
# Could be 'x' for muted
if part.strip().lower() != "x":
return False
return True
def _check_musical_consistency(self, tab_strings: List[List[str]]) -> bool:
"""
Check if tab is musically possible (basic checks).
- No impossible stretches
- Open strings are marked as 0
"""
# Simplified check: ensure all fret numbers are within range
for string_row in tab_strings:
for part in string_row.split("|")[1:-1]:
fret_str = part.strip().replace("-", "")
if fret_str and fret_str.lower() != "x":
try:
fret = int(fret_str)
if fret < 0 or fret > self.MAX_FRET:
return False
except ValueError:
return False
return True
def format_tab(
self,
frets: List[List[int]],
instrument: str = "guitar",
tuning: List[str] = None,
) -> List[str]:
"""
Format fret positions into ASCII tab.
Args:
frets: List of [num_strings] lists with fret numbers (0=open, -1=muted)
instrument: Instrument type
tuning: Optional custom tuning labels
Returns:
List of formatted tab strings
"""
if tuning is None:
tuning = self.STANDARD_TUNING
tab_strings = []
string_labels = ["e", "B", "G", "D", "A", "E"] # High to low
for i, (label, fret_row) in enumerate(zip(string_labels, frets)):
# Build tab row: "e|--3--|"
row = f"{label}|"
for fret in fret_row:
if fret == -1:
row += "x-"
elif fret == 0:
row += "0-"
else:
row += f"{fret}-"
row += "|"
tab_strings.append(row)
return tab_strings
def format_chord(
self,
frets: List[int],
instrument: str = "guitar",
) -> str:
"""
Format chord as compact diagram.
Args:
frets: List of fret numbers for each string (low to high)
instrument: Instrument type
Returns:
Chord string (e.g., "320003" for G major)
"""
# Format as: 320003 (from low E to high e)
return "".join(str(fret) if fret >= 0 else "x" for fret in frets)
def parse_chord(self, chord_str: str) -> List[int]:
"""
Parse chord string to fret positions.
Args:
chord_str: Chord string like "320003" or "x32010"
Returns:
List of fret positions
"""
frets = []
for char in chord_str:
if char.lower() == "x":
frets.append(-1)
else:
frets.append(int(char))
return frets
def suggest_easier_voicing(
self,
chord_frets: List[int],
skill_level: str = "beginner",
) -> List[int]:
"""
Suggest easier chord voicing for beginners.
Args:
chord_frets: Original chord frets
skill_level: Target skill level
Returns:
Simplified chord frets
"""
if skill_level != "beginner":
return chord_frets
# Simplify: reduce barre chords, avoid wide stretches
simplified = chord_frets.copy()
# Count barre (same fret on multiple strings)
fret_counts = {}
for fret in chord_frets:
if fret > 0:
fret_counts[fret] = fret_counts.get(fret, 0) + 1
# If barre detected (3+ strings on same fret), try to simplify
for fret, count in fret_counts.items():
if count >= 3:
# Replace some with open strings if possible
for i, f in enumerate(simplified):
if f == fret and i % 2 == 0: # Every other string
simplified[i] = 0 # Open string
return simplified
def test_tab_chord_module():
"""Test the TabChordModule."""
import torch
# Create module
module = TabChordModule(d_model=4096, num_strings=6, num_frets=24)
# Test input
batch_size = 2
seq_len = 10
d_model = 4096
hidden_states = torch.randn(batch_size, seq_len, d_model)
# Forward pass
outputs = module.forward(
hidden_states,
instrument="guitar",
skill_level="beginner",
generate_tab=True,
)
print("Outputs:")
for key, value in outputs.items():
if isinstance(value, torch.Tensor):
print(f" {key}: {value.shape}")
else:
print(f" {key}: {value}")
# Test tab formatting
frets = [[3, 3, 0, 0, 2, 3]] # G chord
tab = module.format_tab(frets, instrument="guitar")
print("\nFormatted tab:")
for line in tab:
print(f" {line}")
# Test chord formatting
chord = module.format_chord([3, 2, 0, 0, 3, 3])
print(f"\nChord: {chord}")
# Test validation
is_valid, errors = module.validate_tab(tab, instrument="guitar")
print(f"\nTab valid: {is_valid}")
if errors:
print(f"Errors: {errors}")
print("\nTabChordModule test complete!")
if __name__ == "__main__":
test_tab_chord_module()