| | """
|
| | Tab & Chord Generation Module for TouchGrass.
|
| | Generates guitar tabs, chord diagrams, and validates musical correctness.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from typing import Optional, Tuple, List, Dict
|
| |
|
| |
|
| | class TabChordModule(nn.Module):
|
| | """
|
| | Generates and validates guitar tabs and chord diagrams.
|
| |
|
| | Features:
|
| | - Generates ASCII tablature for guitar, bass, ukulele
|
| | - Creates chord diagrams in standard format
|
| | - Validates musical correctness (fret ranges, string counts)
|
| | - Difficulty-aware: suggests easier voicings for beginners
|
| | - Supports multiple tunings
|
| | """
|
| |
|
| |
|
| | STANDARD_TUNING = ["E2", "A2", "D3", "G3", "B3", "E4"]
|
| | BASS_TUNING = ["E1", "A1", "D2", "G2"]
|
| | UKULELE_TUNING = ["G4", "C4", "E4", "A4"]
|
| | DROP_D_TUNING = ["D2", "A2", "D3", "G3", "B3", "E4"]
|
| | OPEN_G_TUNING = ["D2", "G2", "D3", "G3", "B3", "D4"]
|
| |
|
| |
|
| | MAX_FRET = 24
|
| | OPEN_FRET = 0
|
| | MUTED_FRET = -1
|
| |
|
| | def __init__(self, d_model: int, num_strings: int = 6, num_frets: int = 24):
|
| | """
|
| | Initialize TabChordModule.
|
| |
|
| | Args:
|
| | d_model: Hidden dimension from base model
|
| | num_strings: Number of strings (6 for guitar, 4 for bass)
|
| | num_frets: Number of frets (typically 24)
|
| | """
|
| | super().__init__()
|
| | self.d_model = d_model
|
| | self.num_strings = num_strings
|
| | self.num_frets = num_frets
|
| |
|
| |
|
| | self.string_embed = nn.Embedding(num_strings, 64)
|
| | self.fret_embed = nn.Embedding(num_frets + 2, 64)
|
| |
|
| |
|
| | self.tab_validator = nn.Sequential(
|
| | nn.Linear(d_model, 128),
|
| | nn.ReLU(),
|
| | nn.Linear(128, 1),
|
| | nn.Sigmoid()
|
| | )
|
| |
|
| |
|
| | self.difficulty_head = nn.Linear(d_model, 3)
|
| |
|
| |
|
| | self.instrument_embed = nn.Embedding(8, 64)
|
| |
|
| |
|
| | self.fret_predictor = nn.Linear(d_model + 128, num_frets + 2)
|
| |
|
| |
|
| | self.tab_generator = nn.GRU(
|
| | input_size=d_model + 64,
|
| | hidden_size=d_model,
|
| | num_layers=1,
|
| | batch_first=True,
|
| | )
|
| |
|
| |
|
| | self.chord_quality_head = nn.Linear(d_model, 8)
|
| |
|
| |
|
| | self.root_note_head = nn.Linear(d_model, 12)
|
| |
|
| | def forward(
|
| | self,
|
| | hidden_states: torch.Tensor,
|
| | instrument: str = "guitar",
|
| | skill_level: str = "intermediate",
|
| | generate_tab: bool = False,
|
| | ) -> Dict[str, torch.Tensor]:
|
| | """
|
| | Forward pass through TabChordModule.
|
| |
|
| | Args:
|
| | hidden_states: Base model hidden states [batch, seq_len, d_model]
|
| | instrument: Instrument type ("guitar", "bass", "ukulele")
|
| | skill_level: "beginner", "intermediate", or "advanced"
|
| | generate_tab: Whether to generate tab sequences
|
| |
|
| | Returns:
|
| | Dictionary with tab_validity, difficulty_logits, fret_predictions, etc.
|
| | """
|
| | batch_size, seq_len, _ = hidden_states.shape
|
| |
|
| |
|
| | pooled = hidden_states.mean(dim=1)
|
| |
|
| |
|
| | tab_validity = self.tab_validator(pooled)
|
| |
|
| |
|
| | difficulty_logits = self.difficulty_head(pooled)
|
| |
|
| |
|
| | chord_quality_logits = self.chord_quality_head(pooled)
|
| | root_note_logits = self.root_note_head(pooled)
|
| |
|
| | outputs = {
|
| | "tab_validity": tab_validity,
|
| | "difficulty_logits": difficulty_logits,
|
| | "chord_quality_logits": chord_quality_logits,
|
| | "root_note_logits": root_note_logits,
|
| | }
|
| |
|
| | if generate_tab:
|
| |
|
| | tab_seq = self._generate_tab_sequence(hidden_states, instrument)
|
| | outputs["tab_sequence"] = tab_seq
|
| |
|
| | return outputs
|
| |
|
| | def _generate_tab_sequence(
|
| | self,
|
| | hidden_states: torch.Tensor,
|
| | instrument: str,
|
| | max_length: int = 100,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Generate tab sequence using GRU decoder.
|
| |
|
| | Args:
|
| | hidden_states: Base model hidden states
|
| | instrument: Instrument type
|
| | max_length: Maximum tab sequence length
|
| |
|
| | Returns:
|
| | Generated tab token sequence
|
| | """
|
| | batch_size, seq_len, d_model = hidden_states.shape
|
| |
|
| |
|
| | instrument_idx = self._instrument_to_idx(instrument)
|
| | instrument_emb = self.instrument_embed(
|
| | torch.tensor([instrument_idx], device=hidden_states.device)
|
| | ).unsqueeze(0).expand(batch_size, -1)
|
| |
|
| |
|
| | h0 = hidden_states.mean(dim=1, keepdim=True).transpose(0, 1)
|
| |
|
| |
|
| | generated = []
|
| | input_emb = hidden_states[:, 0:1, :]
|
| |
|
| | for _ in range(max_length):
|
| |
|
| | input_with_instr = torch.cat([input_emb, instrument_emb.unsqueeze(1)], dim=2)
|
| |
|
| |
|
| | output, h0 = self.tab_generator(input_with_instr, h0)
|
| |
|
| |
|
| | fret_logits = self.fret_predictor(output)
|
| | next_token = fret_logits.argmax(dim=-1)
|
| |
|
| | generated.append(next_token.squeeze(1))
|
| |
|
| |
|
| | input_emb = self.fret_embed(next_token)
|
| |
|
| | return torch.stack(generated, dim=1)
|
| |
|
| | def _instrument_to_idx(self, instrument: str) -> int:
|
| | """Convert instrument name to index."""
|
| | mapping = {
|
| | "guitar": 0,
|
| | "bass": 1,
|
| | "ukulele": 2,
|
| | "piano": 3,
|
| | "drums": 4,
|
| | "vocals": 5,
|
| | "theory": 6,
|
| | "dj": 7,
|
| | }
|
| | return mapping.get(instrument, 0)
|
| |
|
| | def validate_tab(
|
| | self,
|
| | tab_strings: List[List[str]],
|
| | instrument: str = "guitar",
|
| | ) -> Tuple[bool, List[str]]:
|
| | """
|
| | Validate ASCII tab for musical correctness.
|
| |
|
| | Args:
|
| | tab_strings: List of tab rows (6 strings for guitar)
|
| | instrument: Instrument type
|
| |
|
| | Returns:
|
| | (is_valid, error_messages)
|
| | """
|
| | errors = []
|
| |
|
| |
|
| | expected_strings = self._get_expected_strings(instrument)
|
| | if len(tab_strings) != expected_strings:
|
| | errors.append(f"Expected {expected_strings} strings, got {len(tab_strings)}")
|
| |
|
| |
|
| | for i, string_row in enumerate(tab_strings):
|
| |
|
| | if not self._validate_tab_row(string_row, i, instrument):
|
| | errors.append(f"Invalid format on string {i}: {string_row}")
|
| |
|
| |
|
| | if not self._check_musical_consistency(tab_strings):
|
| | errors.append("Tab has musical inconsistencies (impossible fingering)")
|
| |
|
| | return len(errors) == 0, errors
|
| |
|
| | def _get_expected_strings(self, instrument: str) -> int:
|
| | """Get expected number of strings for instrument."""
|
| | mapping = {
|
| | "guitar": 6,
|
| | "bass": 4,
|
| | "ukulele": 4,
|
| | }
|
| | return mapping.get(instrument, 6)
|
| |
|
| | def _validate_tab_row(self, row: str, string_idx: int, instrument: str) -> bool:
|
| | """Validate a single tab row."""
|
| |
|
| | if "|" not in row:
|
| | return False
|
| |
|
| |
|
| | parts = row.split("|")
|
| | if len(parts) < 2:
|
| | return False
|
| |
|
| |
|
| | for part in parts[1:-1]:
|
| | if part.strip():
|
| | try:
|
| | fret = int(part.strip().replace("-", ""))
|
| | if fret < 0 or fret > self.MAX_FRET:
|
| | return False
|
| | except ValueError:
|
| |
|
| | if part.strip().lower() != "x":
|
| | return False
|
| |
|
| | return True
|
| |
|
| | def _check_musical_consistency(self, tab_strings: List[List[str]]) -> bool:
|
| | """
|
| | Check if tab is musically possible (basic checks).
|
| | - No impossible stretches
|
| | - Open strings are marked as 0
|
| | """
|
| |
|
| | for string_row in tab_strings:
|
| | for part in string_row.split("|")[1:-1]:
|
| | fret_str = part.strip().replace("-", "")
|
| | if fret_str and fret_str.lower() != "x":
|
| | try:
|
| | fret = int(fret_str)
|
| | if fret < 0 or fret > self.MAX_FRET:
|
| | return False
|
| | except ValueError:
|
| | return False
|
| | return True
|
| |
|
| | def format_tab(
|
| | self,
|
| | frets: List[List[int]],
|
| | instrument: str = "guitar",
|
| | tuning: List[str] = None,
|
| | ) -> List[str]:
|
| | """
|
| | Format fret positions into ASCII tab.
|
| |
|
| | Args:
|
| | frets: List of [num_strings] lists with fret numbers (0=open, -1=muted)
|
| | instrument: Instrument type
|
| | tuning: Optional custom tuning labels
|
| |
|
| | Returns:
|
| | List of formatted tab strings
|
| | """
|
| | if tuning is None:
|
| | tuning = self.STANDARD_TUNING
|
| |
|
| | tab_strings = []
|
| | string_labels = ["e", "B", "G", "D", "A", "E"]
|
| |
|
| | for i, (label, fret_row) in enumerate(zip(string_labels, frets)):
|
| |
|
| | row = f"{label}|"
|
| | for fret in fret_row:
|
| | if fret == -1:
|
| | row += "x-"
|
| | elif fret == 0:
|
| | row += "0-"
|
| | else:
|
| | row += f"{fret}-"
|
| | row += "|"
|
| | tab_strings.append(row)
|
| |
|
| | return tab_strings
|
| |
|
| | def format_chord(
|
| | self,
|
| | frets: List[int],
|
| | instrument: str = "guitar",
|
| | ) -> str:
|
| | """
|
| | Format chord as compact diagram.
|
| |
|
| | Args:
|
| | frets: List of fret numbers for each string (low to high)
|
| | instrument: Instrument type
|
| |
|
| | Returns:
|
| | Chord string (e.g., "320003" for G major)
|
| | """
|
| |
|
| | return "".join(str(fret) if fret >= 0 else "x" for fret in frets)
|
| |
|
| | def parse_chord(self, chord_str: str) -> List[int]:
|
| | """
|
| | Parse chord string to fret positions.
|
| |
|
| | Args:
|
| | chord_str: Chord string like "320003" or "x32010"
|
| |
|
| | Returns:
|
| | List of fret positions
|
| | """
|
| | frets = []
|
| | for char in chord_str:
|
| | if char.lower() == "x":
|
| | frets.append(-1)
|
| | else:
|
| | frets.append(int(char))
|
| | return frets
|
| |
|
| | def suggest_easier_voicing(
|
| | self,
|
| | chord_frets: List[int],
|
| | skill_level: str = "beginner",
|
| | ) -> List[int]:
|
| | """
|
| | Suggest easier chord voicing for beginners.
|
| |
|
| | Args:
|
| | chord_frets: Original chord frets
|
| | skill_level: Target skill level
|
| |
|
| | Returns:
|
| | Simplified chord frets
|
| | """
|
| | if skill_level != "beginner":
|
| | return chord_frets
|
| |
|
| |
|
| | simplified = chord_frets.copy()
|
| |
|
| |
|
| | fret_counts = {}
|
| | for fret in chord_frets:
|
| | if fret > 0:
|
| | fret_counts[fret] = fret_counts.get(fret, 0) + 1
|
| |
|
| |
|
| | for fret, count in fret_counts.items():
|
| | if count >= 3:
|
| |
|
| | for i, f in enumerate(simplified):
|
| | if f == fret and i % 2 == 0:
|
| | simplified[i] = 0
|
| |
|
| | return simplified
|
| |
|
| |
|
| | def test_tab_chord_module():
|
| | """Test the TabChordModule."""
|
| | import torch
|
| |
|
| |
|
| | module = TabChordModule(d_model=4096, num_strings=6, num_frets=24)
|
| |
|
| |
|
| | batch_size = 2
|
| | seq_len = 10
|
| | d_model = 4096
|
| | hidden_states = torch.randn(batch_size, seq_len, d_model)
|
| |
|
| |
|
| | outputs = module.forward(
|
| | hidden_states,
|
| | instrument="guitar",
|
| | skill_level="beginner",
|
| | generate_tab=True,
|
| | )
|
| |
|
| | print("Outputs:")
|
| | for key, value in outputs.items():
|
| | if isinstance(value, torch.Tensor):
|
| | print(f" {key}: {value.shape}")
|
| | else:
|
| | print(f" {key}: {value}")
|
| |
|
| |
|
| | frets = [[3, 3, 0, 0, 2, 3]]
|
| | tab = module.format_tab(frets, instrument="guitar")
|
| | print("\nFormatted tab:")
|
| | for line in tab:
|
| | print(f" {line}")
|
| |
|
| |
|
| | chord = module.format_chord([3, 2, 0, 0, 3, 3])
|
| | print(f"\nChord: {chord}")
|
| |
|
| |
|
| | is_valid, errors = module.validate_tab(tab, instrument="guitar")
|
| | print(f"\nTab valid: {is_valid}")
|
| | if errors:
|
| | print(f"Errors: {errors}")
|
| |
|
| | print("\nTabChordModule test complete!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_tab_chord_module() |