| | """
|
| | TouchGrass configuration for HuggingFace.
|
| | Integrates with transformers library.
|
| | """
|
| |
|
| | from typing import Optional, List, Dict, Any
|
| | from transformers import PretrainedConfig
|
| |
|
| |
|
| | class TouchGrassConfig(PretrainedConfig):
|
| | """
|
| | Configuration class for TouchGrass model.
|
| | Compatible with HuggingFace transformers.
|
| | """
|
| |
|
| | model_type = "touchgrass"
|
| | tie_word_embeddings = True
|
| |
|
| | def __init__(
|
| | self,
|
| | base_model: str = "Qwen/Qwen3.5-3B-Instruct",
|
| | model_type: str = "touchgrass",
|
| | d_model: int = 2048,
|
| | num_layers: int = 36,
|
| | num_heads: int = 16,
|
| | head_dim: int = 128,
|
| | ffn_expansion: float = 2.67,
|
| | vocab_size: int = 32000,
|
| | max_seq_len: int = 4096,
|
| |
|
| | enable_tab_chord_module: bool = True,
|
| | enable_music_theory_module: bool = True,
|
| | enable_ear_training_module: bool = True,
|
| | enable_eq_adapter: bool = True,
|
| | enable_songwriting_module: bool = True,
|
| | eq_hidden_dim: int = 32,
|
| | eq_loss_weight: float = 0.1,
|
| |
|
| | special_tokens: Optional[Dict[str, int]] = None,
|
| | music_domains: Optional[List[str]] = None,
|
| | skill_levels: Optional[List[str]] = None,
|
| | notation_tags: Optional[List[str]] = None,
|
| | initializer_range: float = 0.02,
|
| | **kwargs
|
| | ):
|
| | super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| | self.base_model = base_model
|
| | self.model_type = model_type
|
| | self.d_model = d_model
|
| | self.num_layers = num_layers
|
| | self.num_heads = num_heads
|
| | self.head_dim = head_dim
|
| | self.ffn_expansion = ffn_expansion
|
| | self.vocab_size = vocab_size
|
| | self.max_seq_len = max_seq_len
|
| | self.enable_tab_chord_module = enable_tab_chord_module
|
| | self.enable_music_theory_module = enable_music_theory_module
|
| | self.enable_ear_training_module = enable_ear_training_module
|
| | self.enable_eq_adapter = enable_eq_adapter
|
| | self.enable_songwriting_module = enable_songwriting_module
|
| | self.eq_hidden_dim = eq_hidden_dim
|
| | self.eq_loss_weight = eq_loss_weight
|
| | self.special_tokens = special_tokens or {}
|
| | self.music_domains = music_domains or ["[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[DJ]"]
|
| | self.skill_levels = skill_levels or ["[BEGINNER]", "[INTERMEDIATE]", "[ADVANCED]"]
|
| | self.notation_tags = notation_tags or ["[TAB]", "[CHORD]", "[SHEET]", "[LYRICS]", "[PROGRESSION]"]
|
| | self.initializer_range = initializer_range
|
| |
|
| | @classmethod
|
| | def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
| | """Load config from pretrained model."""
|
| | import json
|
| | import os
|
| |
|
| | config_path = os.path.join(pretrained_model_name_or_path, "config.json")
|
| | if os.path.exists(config_path):
|
| | with open(config_path, "r") as f:
|
| | config_dict = json.load(f)
|
| | config_dict.update(kwargs)
|
| | return cls(**config_dict)
|
| | else:
|
| |
|
| | return cls(**kwargs)
|
| |
|
| | def to_dict(self) -> Dict[str, Any]:
|
| | """Convert to dictionary."""
|
| | return {
|
| | "model_type": self.model_type,
|
| | "base_model": self.base_model,
|
| | "d_model": self.d_model,
|
| | "num_layers": self.num_layers,
|
| | "num_heads": self.num_heads,
|
| | "head_dim": self.head_dim,
|
| | "ffn_expansion": self.ffn_expansion,
|
| | "vocab_size": self.vocab_size,
|
| | "max_seq_len": self.max_seq_len,
|
| | "enable_tab_chord_module": self.enable_tab_chord_module,
|
| | "enable_music_theory_module": self.enable_music_theory_module,
|
| | "enable_ear_training_module": self.enable_ear_training_module,
|
| | "enable_eq_adapter": self.enable_eq_adapter,
|
| | "enable_songwriting_module": self.enable_songwriting_module,
|
| | "eq_hidden_dim": self.eq_hidden_dim,
|
| | "eq_loss_weight": self.eq_loss_weight,
|
| | "special_tokens": self.special_tokens,
|
| | "music_domains": self.music_domains,
|
| | "skill_levels": self.skill_levels,
|
| | "notation_tags": self.notation_tags,
|
| | "initializer_range": self.initializer_range,
|
| | } |