File size: 3,539 Bytes
68f155a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Initialize Deeplm model with config and BitNet quantization, save to safetensors.
"""
import sys
import os
import json
import torch

# Add deeplm to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "deeplm"))

from deeplm.config import DeeplmConfig
from deeplm.model.deeplm import DeeplmModel
from deeplm.quantization.bitnet_quantize import apply_bitnet_quantization

def main():
    print("Building DeeplmConfig...")
    config = DeeplmConfig(
        vocab_size=32000,
        max_seq_length=4096,
        dtype="float32",
    )
    config.architecture.num_layers = 10
    config.architecture.hidden_size = 512
    config.architecture.intermediate_size = 2048
    config.architecture.num_attention_heads = 8
    config.architecture.num_key_value_heads = 1
    config.architecture.head_dim = 128
    config.architecture.rope_head_dim = 64
    config.architecture.nope_head_dim = 64
    config.architecture.max_seq_length = 4096
    config.architecture.rope_theta = 50000.0

    config.mla.q_lora_rank = 192
    config.mla.kv_lora_rank = 64
    config.mla.qk_rope_head_dim = 64
    config.mla.qk_nope_head_dim = 64
    config.mla.v_head_dim = 128
    config.mla.num_heads = 8
    config.mla.kv_heads = 1

    config.moe.num_routed_experts = 4
    config.moe.num_shared_experts = 1
    config.moe.top_k = 2

    config.mtp.num_mtp_layers = 2
    config.mtp.mtp_depth = 2
    config.mtp.mtp_hidden_size = 512

    config.output_heads.lm_head.type = "tied"
    config.output_heads.lm_head.bias = False

    print(f"Creating DeeplmModel...")
    model = DeeplmModel(config)

    total_params = model.num_parameters()
    print(f"Total parameters: {total_params:,}")

    print("Applying BitNet b1.58 ternary quantization (absmean)...")
    stats = apply_bitnet_quantization(model, scale="absmean", verbose=True)
    print(f"Quantized {stats['quantized']}/{stats['total_linear']} linear layers")

    print("Saving to model.safetensors...")
    from safetensors.torch import save_file
    state_dict = model.state_dict()
    save_file(state_dict, "model.safetensors")

    # Save config.json
    config_json = {
        "architectures": ["DeeplmModel"],
        "model_type": "deeplm",
        "vocab_size": 32000,
        "hidden_size": 512,
        "intermediate_size": 2048,
        "num_hidden_layers": 10,
        "num_attention_heads": 8,
        "num_key_value_heads": 1,
        "max_position_embeddings": 4096,
        "rms_norm_eps": 1e-06,
        "rope_theta": 50000.0,
        "rope_dim": 64,
        "tie_word_embeddings": True,
        "num_routed_experts": 4,
        "num_shared_experts": 1,
        "expert_topk": 2,
        "q_lora_rank": 192,
        "kv_lora_rank": 64,
        "qk_rope_head_dim": 64,
        "qk_nope_head_dim": 64,
        "v_head_dim": 128,
        "mtp_depth": 2,
        "mtp_num_layers": 2,
        "bitnet_quantized": True,
        "bitnet_scale": "absmean",
    }
    with open("config.json", "w") as f:
        json.dump(config_json, f, indent=2)
    print("Saved config.json")

    # Save generation_config.json
    gen_config = {
        "max_new_tokens": 512,
        "do_sample": True,
        "temperature": 0.7,
        "top_p": 0.9,
        "top_k": 50,
        "repetition_penalty": 1.1,
        "pad_token_id": 0,
        "eos_token_id": 2,
        "bos_token_id": 1,
    }
    with open("generation_config.json", "w") as f:
        json.dump(gen_config, f, indent=2)
    print("Saved generation_config.json")

    print("Done!")

if __name__ == "__main__":
    main()