TouchGrass-7b / tokenizer /music_token_extension.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Music Tokenizer Extension for Qwen3.5
Extends Qwen's tokenizer with music-specific tokens without replacing the base tokenizer.
"""
from transformers import AutoTokenizer
from typing import Dict, List, Optional
import json
import os
class MusicTokenizerExtension:
"""
Extends a base tokenizer with music-specific special tokens.
Does NOT replace the base tokenizer vocabulary — adds tokens on top.
"""
def __init__(
self,
base_tokenizer_name: str = "Qwen/Qwen3.5-3B-Instruct",
special_tokens: Optional[Dict[str, int]] = None,
music_vocab_extensions: Optional[List[str]] = None,
):
"""
Initialize music tokenizer extension.
Args:
base_tokenizer_name: HuggingFace tokenizer to extend
special_tokens: Dict mapping token strings to IDs (must not conflict with base vocab)
music_vocab_extensions: Additional music notation tokens to add
"""
# Load base tokenizer
print(f"Loading base tokenizer: {base_tokenizer_name}")
self.base_tokenizer = AutoTokenizer.from_pretrained(
base_tokenizer_name,
trust_remote_code=True,
)
# Store original vocab size
self.base_vocab_size = self.base_tokenizer.vocab_size
print(f"Base tokenizer vocab size: {self.base_vocab_size}")
# Define special tokens if not provided
if special_tokens is None:
special_tokens = self._default_special_tokens()
self.special_tokens = special_tokens
self.music_vocab_extensions = music_vocab_extensions or self._default_music_extensions()
# Verify token IDs don't conflict
self._validate_token_ids()
# Add special tokens to tokenizer
self._extend_tokenizer()
print(f"Extended tokenizer vocab size: {self.base_tokenizer.vocab_size}")
def _default_special_tokens(self) -> Dict[str, int]:
"""Default music special tokens."""
return {
# Music domain tokens
"[GUITAR]": 32000,
"[PIANO]": 32001,
"[DRUMS]": 32002,
"[VOCALS]": 32003,
"[THEORY]": 32004,
"[DJ]": 32005,
# Notation tokens
"[TAB]": 32006,
"[/TAB]": 32007,
"[CHORD]": 32008,
"[/CHORD]": 32009,
"[SHEET]": 32010,
"[/SHEET]": 32011,
"[LYRICS]": 32012,
"[/LYRICS]": 32013,
"[PROGRESSION]": 32014,
"[/PROGRESSION]": 32015,
# Skill level tokens
"[BEGINNER]": 32016,
"[INTERMEDIATE]": 32017,
"[ADVANCED]": 32018,
# EQ tokens
"[FRUSTRATED]": 32019,
"[ENCOURAGED]": 32020,
}
def _default_music_extensions(self) -> List[str]:
"""Default music notation tokens to add to vocabulary."""
return [
# Notes
"C#", "Db", "D#", "Eb", "F#", "Gb", "G#", "Ab", "A#", "Bb",
# Chord types
"maj7", "min7", "dom7", "dim7", "aug7", "sus2", "sus4", "add9",
"maj9", "min9", "11th", "13th",
# Guitar-specific
"barre", "capo", "hammer-on", "pull-off", "bend", "vibrato", "tremolo",
# Rhythm
"4/4", "3/4", "6/8", "12/8", "5/4", "7/8",
# Tempo markings
"allegro", "andante", "adagio", "presto", "moderato", "ritardando",
# Music theory
"pentatonic", "diatonic", "chromatic", "arpeggio", "ostinato",
"counterpoint", "modulation", "cadence", "interval", "tritone",
# Scales
"dorian", "phrygian", "lydian", "mixolydian", "locrian", "aeolian",
# Production
"BPM", "DAW", "MIDI", "reverb", "delay", "compression", "EQ",
"sidechain", "quantize", "automation", "synthesizer", "sequencer",
# ABC notation support
"|:", ":|", "||", "|]",
]
def _validate_token_ids(self):
"""Ensure token IDs don't conflict with base vocabulary."""
for token, token_id in self.special_tokens.items():
if token_id < self.base_vocab_size:
raise ValueError(
f"Special token '{token}' ID {token_id} conflicts with base vocab. "
f"Use IDs >= {self.base_vocab_size}"
)
def _extend_tokenizer(self):
"""Add special tokens to the tokenizer."""
# Add special tokens
num_added = self.base_tokenizer.add_special_tokens({
"additional_special_tokens": list(self.special_tokens.keys())
})
# Add music vocabulary extensions
if self.music_vocab_extensions:
self.base_tokenizer.add_tokens(self.music_vocab_extensions)
print(f"Added {num_added} special tokens")
print(f"Total vocabulary size: {self.base_tokenizer.vocab_size}")
def get_tokenizer(self):
"""Get the extended tokenizer."""
return self.base_tokenizer
def get_music_token_id(self, token: str) -> int:
"""Get token ID for a music special token."""
return self.base_tokenizer.convert_tokens_to_ids(token)
def is_music_token(self, token_id: int) -> bool:
"""Check if a token ID is a music special token."""
token = self.base_tokenizer.convert_ids_to_tokens(token_id)
return token in self.special_tokens
def save_pretrained(self, save_directory: str):
"""Save extended tokenizer to directory."""
os.makedirs(save_directory, exist_ok=True)
# Save base tokenizer
self.base_tokenizer.save_pretrained(save_directory)
# Save extension metadata
metadata = {
"base_tokenizer": self.base_tokenizer.name_or_path,
"base_vocab_size": self.base_vocab_size,
"special_tokens": self.special_tokens,
"music_vocab_extensions": self.music_vocab_extensions,
}
metadata_path = os.path.join(save_directory, "music_tokenizer_metadata.json")
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
print(f"Music tokenizer saved to {save_directory}")
@classmethod
def from_pretrained(cls, model_path: str):
"""Load music tokenizer extension from saved directory."""
metadata_path = os.path.join(model_path, "music_tokenizer_metadata.json")
if not os.path.exists(metadata_path):
raise FileNotFoundError(f"Music tokenizer metadata not found at {metadata_path}")
with open(metadata_path, "r") as f:
metadata = json.load(f)
# Load base tokenizer
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
)
# Create instance
instance = cls.__new__(cls)
instance.base_tokenizer = base_tokenizer
instance.base_vocab_size = metadata["base_vocab_size"]
instance.special_tokens = metadata["special_tokens"]
instance.music_vocab_extensions = metadata.get("music_vocab_extensions", [])
return instance
def extend_qwen_tokenizer(
base_model_name: str = "Qwen/Qwen3.5-3B-Instruct",
save_dir: Optional[str] = None,
) -> MusicTokenizerExtension:
"""
Convenience function to extend Qwen tokenizer with music tokens.
Args:
base_model_name: Qwen model name (3B or 7B)
save_dir: Optional directory to save the extended tokenizer
Returns:
MusicTokenizerExtension instance
"""
ext = MusicTokenizerExtension(base_tokenizer_name=base_model_name)
if save_dir:
ext.save_pretrained(save_dir)
return ext
if __name__ == "__main__":
# Example usage
print("Extending Qwen3.5-3B tokenizer with music tokens...")
tokenizer_ext = extend_qwen_tokenizer(
base_model_name="Qwen/Qwen3.5-3B-Instruct",
save_dir="./touchgrass_tokenizer",
)
# Test encoding
test_text = "[GUITAR][BEGINNER] How do I play a G chord?"
tokens = tokenizer_ext.get_tokenizer().encode(test_text)
print(f"\nTest encoding: {test_text}")
print(f"Token IDs: {tokens[:20]}...")
print(f"Decoded: {tokenizer_ext.get_tokenizer().decode(tokens)}")