File size: 4,234 Bytes
bf64b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""

Vortex configuration for HuggingFace.

"""

from typing import Optional, List, Dict, Any
from transformers import PretrainedConfig


class VortexConfig(PretrainedConfig):
    """

    Configuration class for Vortex model.

    Compatible with HuggingFace transformers.

    """

    model_type = "vortex"
    tie_word_embeddings = True

    def __init__(

        self,

        d_model: int = 4096,

        num_layers: int = 32,

        num_heads: int = 32,

        d_state: int = 16,

        d_conv: int = 4,

        window_size: int = 512,

        ffn_expansion: int = 4,

        num_domains: int = 7,

        vocab_size: int = 50000,

        max_seq_len: int = 16384,

        ssm_ratio: float = 0.6,

        enable_equation_module: bool = True,

        enable_numerical_module: bool = True,

        enable_citation_module: bool = True,

        enable_molecular_module: bool = True,

        special_tokens: Optional[Dict[str, int]] = None,

        domain_tags: Optional[List[str]] = None,

        initializer_range: float = 0.02,

        tie_word_embeddings: bool = True,

        **kwargs

    ):
        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
        self.d_model = d_model
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.d_state = d_state
        self.d_conv = d_conv
        self.window_size = window_size
        self.ffn_expansion = ffn_expansion
        self.num_domains = num_domains
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.ssm_ratio = ssm_ratio
        self.enable_equation_module = enable_equation_module
        self.enable_numerical_module = enable_numerical_module
        self.enable_citation_module = enable_citation_module
        self.enable_molecular_module = enable_molecular_module
        self.special_tokens = special_tokens or {
            "[PAD]": 0, "[UNK]": 1, "[BOS]": 2, "[EOS]": 3,
            "[EQUATION]": 4, "[/EQUATION]": 5,
            "[CITATION]": 6, "[/CITATION]": 7,
            "[MOLECULE]": 8, "[/MOLECULE]": 9,
            "[FIGURE]": 10, "[TABLE]": 11,
            "[MATH]": 12, "[CHEM]": 13, "[BIO]": 14,
            "[PHYS]": 15, "[EARTH]": 16, "[SPACE]": 17, "[ZOO]": 18,
        }
        self.domain_tags = domain_tags or ["[MATH]", "[CHEM]", "[BIO]", "[PHYS]", "[EARTH]", "[SPACE]", "[ZOO]"]
        self.initializer_range = initializer_range
        # Compute derived attributes
        self.head_dim = self.d_model // self.num_heads

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
        """Load config from pretrained model."""
        import json
        import os

        config_path = os.path.join(pretrained_model_name_or_path, "config.json")
        if os.path.exists(config_path):
            with open(config_path, "r") as f:
                config_dict = json.load(f)
            config_dict.update(kwargs)
            return cls(**config_dict)
        else:
            # Return default config
            return cls(**kwargs)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return {
            "model_type": self.model_type,
            "d_model": self.d_model,
            "num_layers": self.num_layers,
            "num_heads": self.num_heads,
            "head_dim": self.head_dim,
            "d_state": self.d_state,
            "d_conv": self.d_conv,
            "window_size": self.window_size,
            "ffn_expansion": self.ffn_expansion,
            "num_domains": self.num_domains,
            "vocab_size": self.vocab_size,
            "max_seq_len": self.max_seq_len,
            "ssm_ratio": self.ssm_ratio,
            "enable_equation_module": self.enable_equation_module,
            "enable_numerical_module": self.enable_numerical_module,
            "enable_citation_module": self.enable_citation_module,
            "enable_molecular_module": self.enable_molecular_module,
            "special_tokens": self.special_tokens,
            "domain_tags": self.domain_tags,
            "initializer_range": self.initializer_range,
        }