""" Tab & Chord Generation Module for TouchGrass. Generates guitar tabs, chord diagrams, and validates musical correctness. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List, Dict class TabChordModule(nn.Module): """ Generates and validates guitar tabs and chord diagrams. Features: - Generates ASCII tablature for guitar, bass, ukulele - Creates chord diagrams in standard format - Validates musical correctness (fret ranges, string counts) - Difficulty-aware: suggests easier voicings for beginners - Supports multiple tunings """ # Standard tunings STANDARD_TUNING = ["E2", "A2", "D3", "G3", "B3", "E4"] # Guitar BASS_TUNING = ["E1", "A1", "D2", "G2"] UKULELE_TUNING = ["G4", "C4", "E4", "A4"] DROP_D_TUNING = ["D2", "A2", "D3", "G3", "B3", "E4"] OPEN_G_TUNING = ["D2", "G2", "D3", "G3", "B3", "D4"] # Fretboard limits MAX_FRET = 24 OPEN_FRET = 0 MUTED_FRET = -1 def __init__(self, d_model: int, num_strings: int = 6, num_frets: int = 24): """ Initialize TabChordModule. Args: d_model: Hidden dimension from base model num_strings: Number of strings (6 for guitar, 4 for bass) num_frets: Number of frets (typically 24) """ super().__init__() self.d_model = d_model self.num_strings = num_strings self.num_frets = num_frets # Embeddings self.string_embed = nn.Embedding(num_strings, 64) self.fret_embed = nn.Embedding(num_frets + 2, 64) # +2 for open/muted # Tab validator head self.tab_validator = nn.Sequential( nn.Linear(d_model, 128), nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid() ) # Difficulty classifier (beginner/intermediate/advanced) self.difficulty_head = nn.Linear(d_model, 3) # Instrument type embedder self.instrument_embed = nn.Embedding(8, 64) # guitar/bass/ukulele/piano/etc # Fret position predictor for tab generation self.fret_predictor = nn.Linear(d_model + 128, num_frets + 2) # Tab sequence generator (for multi-token tab output) self.tab_generator = nn.GRU( input_size=d_model + 64, # hidden + string embedding hidden_size=d_model, num_layers=1, batch_first=True, ) # Chord quality classifier (major, minor, dim, aug, etc.) self.chord_quality_head = nn.Linear(d_model, 8) # Root note predictor (12 chromatic notes) self.root_note_head = nn.Linear(d_model, 12) def forward( self, hidden_states: torch.Tensor, instrument: str = "guitar", skill_level: str = "intermediate", generate_tab: bool = False, ) -> Dict[str, torch.Tensor]: """ Forward pass through TabChordModule. Args: hidden_states: Base model hidden states [batch, seq_len, d_model] instrument: Instrument type ("guitar", "bass", "ukulele") skill_level: "beginner", "intermediate", or "advanced" generate_tab: Whether to generate tab sequences Returns: Dictionary with tab_validity, difficulty_logits, fret_predictions, etc. """ batch_size, seq_len, _ = hidden_states.shape # Pool hidden states pooled = hidden_states.mean(dim=1) # [batch, d_model] # Validate tab tab_validity = self.tab_validator(pooled) # [batch, 1] # Predict difficulty difficulty_logits = self.difficulty_head(pooled) # [batch, 3] # Predict chord quality and root note chord_quality_logits = self.chord_quality_head(pooled) # [batch, 8] root_note_logits = self.root_note_head(pooled) # [batch, 12] outputs = { "tab_validity": tab_validity, "difficulty_logits": difficulty_logits, "chord_quality_logits": chord_quality_logits, "root_note_logits": root_note_logits, } if generate_tab: # Generate tab sequence tab_seq = self._generate_tab_sequence(hidden_states, instrument) outputs["tab_sequence"] = tab_seq return outputs def _generate_tab_sequence( self, hidden_states: torch.Tensor, instrument: str, max_length: int = 100, ) -> torch.Tensor: """ Generate tab sequence using GRU decoder. Args: hidden_states: Base model hidden states instrument: Instrument type max_length: Maximum tab sequence length Returns: Generated tab token sequence """ batch_size, seq_len, d_model = hidden_states.shape # Get instrument embedding instrument_idx = self._instrument_to_idx(instrument) instrument_emb = self.instrument_embed( torch.tensor([instrument_idx], device=hidden_states.device) ).unsqueeze(0).expand(batch_size, -1) # [batch, 64] # Initialize GRU hidden state h0 = hidden_states.mean(dim=1, keepdim=True).transpose(0, 1) # [1, batch, d_model] # Generate tokens auto-regressively generated = [] input_emb = hidden_states[:, 0:1, :] # Start with first token for _ in range(max_length): # Concatenate instrument embedding input_with_instr = torch.cat([input_emb, instrument_emb.unsqueeze(1)], dim=2) # GRU step output, h0 = self.tab_generator(input_with_instr, h0) # Predict fret positions fret_logits = self.fret_predictor(output) # [batch, 1, num_frets+2] next_token = fret_logits.argmax(dim=-1) # [batch, 1] generated.append(next_token.squeeze(1)) # Next input is predicted token embedding input_emb = self.fret_embed(next_token) return torch.stack(generated, dim=1) # [batch, max_length] def _instrument_to_idx(self, instrument: str) -> int: """Convert instrument name to index.""" mapping = { "guitar": 0, "bass": 1, "ukulele": 2, "piano": 3, "drums": 4, "vocals": 5, "theory": 6, "dj": 7, } return mapping.get(instrument, 0) def validate_tab( self, tab_strings: List[List[str]], instrument: str = "guitar", ) -> Tuple[bool, List[str]]: """ Validate ASCII tab for musical correctness. Args: tab_strings: List of tab rows (6 strings for guitar) instrument: Instrument type Returns: (is_valid, error_messages) """ errors = [] # Check number of strings expected_strings = self._get_expected_strings(instrument) if len(tab_strings) != expected_strings: errors.append(f"Expected {expected_strings} strings, got {len(tab_strings)}") # Validate each string for i, string_row in enumerate(tab_strings): # Check format (e.g., "e|--3--|") if not self._validate_tab_row(string_row, i, instrument): errors.append(f"Invalid format on string {i}: {string_row}") # Check for musical consistency if not self._check_musical_consistency(tab_strings): errors.append("Tab has musical inconsistencies (impossible fingering)") return len(errors) == 0, errors def _get_expected_strings(self, instrument: str) -> int: """Get expected number of strings for instrument.""" mapping = { "guitar": 6, "bass": 4, "ukulele": 4, } return mapping.get(instrument, 6) def _validate_tab_row(self, row: str, string_idx: int, instrument: str) -> bool: """Validate a single tab row.""" # Basic format check: should have string label and pipe separators if "|" not in row: return False # Extract fret numbers parts = row.split("|") if len(parts) < 2: return False # Check fret values are in valid range for part in parts[1:-1]: # Skip string label and last pipe if part.strip(): try: fret = int(part.strip().replace("-", "")) if fret < 0 or fret > self.MAX_FRET: return False except ValueError: # Could be 'x' for muted if part.strip().lower() != "x": return False return True def _check_musical_consistency(self, tab_strings: List[List[str]]) -> bool: """ Check if tab is musically possible (basic checks). - No impossible stretches - Open strings are marked as 0 """ # Simplified check: ensure all fret numbers are within range for string_row in tab_strings: for part in string_row.split("|")[1:-1]: fret_str = part.strip().replace("-", "") if fret_str and fret_str.lower() != "x": try: fret = int(fret_str) if fret < 0 or fret > self.MAX_FRET: return False except ValueError: return False return True def format_tab( self, frets: List[List[int]], instrument: str = "guitar", tuning: List[str] = None, ) -> List[str]: """ Format fret positions into ASCII tab. Args: frets: List of [num_strings] lists with fret numbers (0=open, -1=muted) instrument: Instrument type tuning: Optional custom tuning labels Returns: List of formatted tab strings """ if tuning is None: tuning = self.STANDARD_TUNING tab_strings = [] string_labels = ["e", "B", "G", "D", "A", "E"] # High to low for i, (label, fret_row) in enumerate(zip(string_labels, frets)): # Build tab row: "e|--3--|" row = f"{label}|" for fret in fret_row: if fret == -1: row += "x-" elif fret == 0: row += "0-" else: row += f"{fret}-" row += "|" tab_strings.append(row) return tab_strings def format_chord( self, frets: List[int], instrument: str = "guitar", ) -> str: """ Format chord as compact diagram. Args: frets: List of fret numbers for each string (low to high) instrument: Instrument type Returns: Chord string (e.g., "320003" for G major) """ # Format as: 320003 (from low E to high e) return "".join(str(fret) if fret >= 0 else "x" for fret in frets) def parse_chord(self, chord_str: str) -> List[int]: """ Parse chord string to fret positions. Args: chord_str: Chord string like "320003" or "x32010" Returns: List of fret positions """ frets = [] for char in chord_str: if char.lower() == "x": frets.append(-1) else: frets.append(int(char)) return frets def suggest_easier_voicing( self, chord_frets: List[int], skill_level: str = "beginner", ) -> List[int]: """ Suggest easier chord voicing for beginners. Args: chord_frets: Original chord frets skill_level: Target skill level Returns: Simplified chord frets """ if skill_level != "beginner": return chord_frets # Simplify: reduce barre chords, avoid wide stretches simplified = chord_frets.copy() # Count barre (same fret on multiple strings) fret_counts = {} for fret in chord_frets: if fret > 0: fret_counts[fret] = fret_counts.get(fret, 0) + 1 # If barre detected (3+ strings on same fret), try to simplify for fret, count in fret_counts.items(): if count >= 3: # Replace some with open strings if possible for i, f in enumerate(simplified): if f == fret and i % 2 == 0: # Every other string simplified[i] = 0 # Open string return simplified def test_tab_chord_module(): """Test the TabChordModule.""" import torch # Create module module = TabChordModule(d_model=4096, num_strings=6, num_frets=24) # Test input batch_size = 2 seq_len = 10 d_model = 4096 hidden_states = torch.randn(batch_size, seq_len, d_model) # Forward pass outputs = module.forward( hidden_states, instrument="guitar", skill_level="beginner", generate_tab=True, ) print("Outputs:") for key, value in outputs.items(): if isinstance(value, torch.Tensor): print(f" {key}: {value.shape}") else: print(f" {key}: {value}") # Test tab formatting frets = [[3, 3, 0, 0, 2, 3]] # G chord tab = module.format_tab(frets, instrument="guitar") print("\nFormatted tab:") for line in tab: print(f" {line}") # Test chord formatting chord = module.format_chord([3, 2, 0, 0, 3, 3]) print(f"\nChord: {chord}") # Test validation is_valid, errors = module.validate_tab(tab, instrument="guitar") print(f"\nTab valid: {is_valid}") if errors: print(f"Errors: {errors}") print("\nTabChordModule test complete!") if __name__ == "__main__": test_tab_chord_module()