File size: 4,437 Bytes
4f0238f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""

TouchGrass configuration for HuggingFace.

Integrates with transformers library.

"""

from typing import Optional, List, Dict, Any
from transformers import PretrainedConfig


class TouchGrassConfig(PretrainedConfig):
    """

    Configuration class for TouchGrass model.

    Compatible with HuggingFace transformers.

    """

    model_type = "touchgrass"
    tie_word_embeddings = True

    def __init__(

        self,

        base_model: str = "Qwen/Qwen3.5-3B-Instruct",

        model_type: str = "touchgrass",

        d_model: int = 2048,

        num_layers: int = 36,

        num_heads: int = 16,

        head_dim: int = 128,

        ffn_expansion: float = 2.67,

        vocab_size: int = 32000,

        max_seq_len: int = 4096,

        # Music modules

        enable_tab_chord_module: bool = True,

        enable_music_theory_module: bool = True,

        enable_ear_training_module: bool = True,

        enable_eq_adapter: bool = True,

        enable_songwriting_module: bool = True,

        eq_hidden_dim: int = 32,

        eq_loss_weight: float = 0.1,

        # Special tokens

        special_tokens: Optional[Dict[str, int]] = None,

        music_domains: Optional[List[str]] = None,

        skill_levels: Optional[List[str]] = None,

        notation_tags: Optional[List[str]] = None,

        initializer_range: float = 0.02,

        **kwargs

    ):
        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
        self.base_model = base_model
        self.model_type = model_type
        self.d_model = d_model
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.ffn_expansion = ffn_expansion
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.enable_tab_chord_module = enable_tab_chord_module
        self.enable_music_theory_module = enable_music_theory_module
        self.enable_ear_training_module = enable_ear_training_module
        self.enable_eq_adapter = enable_eq_adapter
        self.enable_songwriting_module = enable_songwriting_module
        self.eq_hidden_dim = eq_hidden_dim
        self.eq_loss_weight = eq_loss_weight
        self.special_tokens = special_tokens or {}
        self.music_domains = music_domains or ["[GUITAR]", "[PIANO]", "[DRUMS]", "[VOCALS]", "[THEORY]", "[DJ]"]
        self.skill_levels = skill_levels or ["[BEGINNER]", "[INTERMEDIATE]", "[ADVANCED]"]
        self.notation_tags = notation_tags or ["[TAB]", "[CHORD]", "[SHEET]", "[LYRICS]", "[PROGRESSION]"]
        self.initializer_range = initializer_range

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
        """Load config from pretrained model."""
        import json
        import os

        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
        if os.path.exists(config_path):
            with open(config_path, "r") as f:
                config_dict = json.load(f)
            config_dict.update(kwargs)
            return cls(**config_dict)
        else:
            # Return default config
            return cls(**kwargs)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return {
            "model_type": self.model_type,
            "base_model": self.base_model,
            "d_model": self.d_model,
            "num_layers": self.num_layers,
            "num_heads": self.num_heads,
            "head_dim": self.head_dim,
            "ffn_expansion": self.ffn_expansion,
            "vocab_size": self.vocab_size,
            "max_seq_len": self.max_seq_len,
            "enable_tab_chord_module": self.enable_tab_chord_module,
            "enable_music_theory_module": self.enable_music_theory_module,
            "enable_ear_training_module": self.enable_ear_training_module,
            "enable_eq_adapter": self.enable_eq_adapter,
            "enable_songwriting_module": self.enable_songwriting_module,
            "eq_hidden_dim": self.eq_hidden_dim,
            "eq_loss_weight": self.eq_loss_weight,
            "special_tokens": self.special_tokens,
            "music_domains": self.music_domains,
            "skill_levels": self.skill_levels,
            "notation_tags": self.notation_tags,
            "initializer_range": self.initializer_range,
        }