Lab1806 commited on
Commit
bfa9a3d
·
verified ·
1 Parent(s): bcf5db1

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LlamaForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 1,
8
+ "eos_token_id": 2,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 4096,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 11008,
14
+ "max_position_embeddings": 4096,
15
+ "mlp_bias": false,
16
+ "model_type": "llama",
17
+ "num_attention_heads": 32,
18
+ "num_hidden_layers": 32,
19
+ "num_key_value_heads": 32,
20
+ "pretraining_tp": 1,
21
+ "rms_norm_eps": 1e-05,
22
+ "rope_scaling": null,
23
+ "rope_theta": 10000.0,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.52.4",
27
+ "use_cache": true,
28
+ "vocab_size": 32000
29
+ }
generation_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 1,
3
+ "do_sample": true,
4
+ "eos_token_id": 2,
5
+ "max_length": 4096,
6
+ "pad_token_id": 0,
7
+ "temperature": 0.6,
8
+ "top_p": 0.9,
9
+ "transformers_version": "4.52.4"
10
+ }
load_model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model loading script
3
+
4
+ Load Fairy2i-W2 model from Hugging Face repository.
5
+
6
+ Usage:
7
+ from load_model import load_model
8
+ model, tokenizer = load_model()
9
+ """
10
+
11
+ import torch
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer
13
+ from safetensors.torch import load_file
14
+ from huggingface_hub import hf_hub_download
15
+ import os
16
+ import sys
17
+
18
+ # Add current directory to path for importing qat_modules
19
+ current_dir = os.path.dirname(os.path.abspath(__file__))
20
+ if current_dir not in sys.path:
21
+ sys.path.insert(0, current_dir)
22
+
23
+ from qat_modules import replace_modules_for_qat
24
+
25
+
26
+ def load_model(
27
+ device="cuda" if torch.cuda.is_available() else "cpu",
28
+ torch_dtype=torch.bfloat16
29
+ ):
30
+ """
31
+ Load Fairy2i-W2 model: standard architecture + custom weights + QAT linear layer replacement
32
+
33
+ Load weights and tokenizer from Hugging Face repository.
34
+
35
+ Args:
36
+ device: Device, default auto-select
37
+ torch_dtype: Data type, default torch.bfloat16
38
+
39
+ Returns:
40
+ model, tokenizer
41
+ """
42
+ # Configuration parameters
43
+ base_model_id = "meta-llama/Llama-2-7b-hf"
44
+ weights_repo_id = "PKU-DS-LAB/Fairy2i-W2"
45
+ quant_method = "complex_phase_v2"
46
+ skip_lm_head = False
47
+ print("=" * 70)
48
+ print("Loading Fairy2i-W2 Model")
49
+ print("=" * 70)
50
+
51
+ # Step 1: Load standard model architecture
52
+ print(f"\n📥 Step 1/4: Loading standard model architecture: {base_model_id}")
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ base_model_id,
55
+ torch_dtype=torch_dtype,
56
+ device_map=device,
57
+ trust_remote_code=False
58
+ )
59
+ print("✅ Standard model architecture loaded")
60
+
61
+ # Step 2: Load custom weights
62
+ print(f"\n💾 Step 2/4: Loading weights from Hugging Face repository: {weights_repo_id}")
63
+
64
+ # Check for sharded weights
65
+ try:
66
+ index_path = hf_hub_download(
67
+ repo_id=weights_repo_id,
68
+ filename="model.safetensors.index.json",
69
+ local_dir=None
70
+ )
71
+
72
+ # Sharded weights
73
+ from safetensors import safe_open
74
+ import json
75
+
76
+ with open(index_path, 'r') as f:
77
+ weight_map = json.load(f)["weight_map"]
78
+
79
+ state_dict = {}
80
+ for weight_file in set(weight_map.values()):
81
+ file_path = hf_hub_download(
82
+ repo_id=weights_repo_id,
83
+ filename=weight_file,
84
+ local_dir=None
85
+ )
86
+ with safe_open(file_path, framework="pt", device="cpu") as f:
87
+ for key in f.keys():
88
+ state_dict[key] = f.get_tensor(key)
89
+
90
+ model.load_state_dict(state_dict, strict=False)
91
+ print(f"✅ Weights loaded (sharded)")
92
+ except Exception:
93
+ # Single weight file
94
+ try:
95
+ weights_path = hf_hub_download(
96
+ repo_id=weights_repo_id,
97
+ filename="model.safetensors",
98
+ local_dir=None
99
+ )
100
+ state_dict = load_file(weights_path)
101
+ model.load_state_dict(state_dict, strict=False)
102
+ print(f"✅ Weights loaded (single file)")
103
+ except Exception as e:
104
+ raise RuntimeError(f"Failed to load weights from Hugging Face: {e}")
105
+
106
+ # Step 3: Apply QAT replacement
107
+ print(f"\n🔧 Step 3/4: Applying QAT replacement ({quant_method})...")
108
+ replace_modules_for_qat(model, quant_method, skip_lm_head=skip_lm_head)
109
+ print("✅ QAT replacement completed")
110
+
111
+ # Step 4: Load tokenizer
112
+ print(f"\n📝 Step 4/4: Loading Tokenizer from Hugging Face repository: {weights_repo_id}")
113
+ tokenizer = AutoTokenizer.from_pretrained(weights_repo_id)
114
+ print("✅ Tokenizer loaded")
115
+
116
+ print("\n" + "=" * 70)
117
+ print("✅ Model loading completed!")
118
+ print("=" * 70)
119
+
120
+ return model, tokenizer
121
+
122
+
123
+ if __name__ == "__main__":
124
+ # Example: Load model
125
+ model, tokenizer = load_model()
126
+
127
+ # Test generation
128
+ print("\n🧪 Testing generation...")
129
+ prompt = "Hello, how are you?"
130
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
131
+
132
+ with torch.no_grad():
133
+ outputs = model.generate(
134
+ **inputs,
135
+ max_new_tokens=50,
136
+ do_sample=True,
137
+ temperature=0.7
138
+ )
139
+
140
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
141
+ print(f"Prompt: {prompt}")
142
+ print(f"Response: {response}")
143
+
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7faa5a86a5d9968018d5ec54582e4c8d953541fdda69dd3855cad02b0ba62f8c
3
+ size 4938985352
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c10a92f9aa154ead36f02411ed284ac5a4abcba038d43c3d3b4599b3c4fd132e
3
+ size 4947390880
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33494b29047ea5f2696e949935b46b7e9372ec5de81d9718656ae8237766bf2d
3
+ size 3590488816
model.safetensors.index.json ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 13476831232
4
+ },
5
+ "weight_map": {
6
+ "lm_head.weight": "model-00003-of-00003.safetensors",
7
+ "model.embed_tokens.weight": "model-00001-of-00003.safetensors",
8
+ "model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
9
+ "model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
10
+ "model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
11
+ "model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
12
+ "model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
13
+ "model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
14
+ "model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
15
+ "model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
16
+ "model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
17
+ "model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
18
+ "model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
19
+ "model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
20
+ "model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
21
+ "model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
22
+ "model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
23
+ "model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
24
+ "model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
25
+ "model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
26
+ "model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
27
+ "model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
28
+ "model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
29
+ "model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
30
+ "model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
31
+ "model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
32
+ "model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
33
+ "model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
34
+ "model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
35
+ "model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
36
+ "model.layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
37
+ "model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
38
+ "model.layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
39
+ "model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
40
+ "model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
41
+ "model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
42
+ "model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
43
+ "model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
44
+ "model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
45
+ "model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
46
+ "model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
47
+ "model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
48
+ "model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
49
+ "model.layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
50
+ "model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
51
+ "model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
52
+ "model.layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
53
+ "model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
54
+ "model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
55
+ "model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
56
+ "model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
57
+ "model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
58
+ "model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
59
+ "model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
60
+ "model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
61
+ "model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
62
+ "model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
63
+ "model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
64
+ "model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
65
+ "model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
66
+ "model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
67
+ "model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
68
+ "model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
69
+ "model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
70
+ "model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
71
+ "model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
72
+ "model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
73
+ "model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
74
+ "model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
75
+ "model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
76
+ "model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
77
+ "model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
78
+ "model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
79
+ "model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
80
+ "model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
81
+ "model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
82
+ "model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
83
+ "model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
84
+ "model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
85
+ "model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
86
+ "model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
87
+ "model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
88
+ "model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
89
+ "model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
90
+ "model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
91
+ "model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
92
+ "model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
93
+ "model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
94
+ "model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
95
+ "model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
96
+ "model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
97
+ "model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
98
+ "model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
99
+ "model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
100
+ "model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
101
+ "model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
102
+ "model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
103
+ "model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
104
+ "model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
105
+ "model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
106
+ "model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
107
+ "model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
108
+ "model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
109
+ "model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
110
+ "model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
111
+ "model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
112
+ "model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
113
+ "model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
114
+ "model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
115
+ "model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
116
+ "model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
117
+ "model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
118
+ "model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
119
+ "model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
120
+ "model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
121
+ "model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
122
+ "model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
123
+ "model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
124
+ "model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
125
+ "model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
126
+ "model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
127
+ "model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
128
+ "model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
129
+ "model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
130
+ "model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
131
+ "model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
132
+ "model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
133
+ "model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
134
+ "model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
135
+ "model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
136
+ "model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
137
+ "model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
138
+ "model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
139
+ "model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
140
+ "model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
141
+ "model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
142
+ "model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
143
+ "model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
144
+ "model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
145
+ "model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
146
+ "model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
147
+ "model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
148
+ "model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
149
+ "model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
150
+ "model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
151
+ "model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
152
+ "model.layers.23.input_layernorm.weight": "model-00003-of-00003.safetensors",
153
+ "model.layers.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
154
+ "model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
155
+ "model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
156
+ "model.layers.23.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
157
+ "model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
158
+ "model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
159
+ "model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
160
+ "model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
161
+ "model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
162
+ "model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
163
+ "model.layers.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
164
+ "model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
165
+ "model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
166
+ "model.layers.24.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
167
+ "model.layers.24.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
168
+ "model.layers.24.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
169
+ "model.layers.24.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
170
+ "model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
171
+ "model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
172
+ "model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
173
+ "model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
174
+ "model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
175
+ "model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
176
+ "model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
177
+ "model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
178
+ "model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
179
+ "model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
180
+ "model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
181
+ "model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
182
+ "model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
183
+ "model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
184
+ "model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
185
+ "model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
186
+ "model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
187
+ "model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
188
+ "model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
189
+ "model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
190
+ "model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
191
+ "model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
192
+ "model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
193
+ "model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
194
+ "model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
195
+ "model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
196
+ "model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
197
+ "model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
198
+ "model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
199
+ "model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
200
+ "model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
201
+ "model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
202
+ "model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
203
+ "model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
204
+ "model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
205
+ "model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
206
+ "model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
207
+ "model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
208
+ "model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
209
+ "model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
210
+ "model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
211
+ "model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
212
+ "model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
213
+ "model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
214
+ "model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
215
+ "model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
216
+ "model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
217
+ "model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
218
+ "model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
219
+ "model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
220
+ "model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
221
+ "model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
222
+ "model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
223
+ "model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
224
+ "model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
225
+ "model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
226
+ "model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
227
+ "model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
228
+ "model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
229
+ "model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
230
+ "model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
231
+ "model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
232
+ "model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
233
+ "model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
234
+ "model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
235
+ "model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
236
+ "model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
237
+ "model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
238
+ "model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
239
+ "model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
240
+ "model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
241
+ "model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
242
+ "model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
243
+ "model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
244
+ "model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
245
+ "model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
246
+ "model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
247
+ "model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
248
+ "model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
249
+ "model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
250
+ "model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
251
+ "model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
252
+ "model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
253
+ "model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
254
+ "model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
255
+ "model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
256
+ "model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
257
+ "model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
258
+ "model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
259
+ "model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
260
+ "model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
261
+ "model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
262
+ "model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
263
+ "model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
264
+ "model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
265
+ "model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
266
+ "model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
267
+ "model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
268
+ "model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
269
+ "model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
270
+ "model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
271
+ "model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
272
+ "model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
273
+ "model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
274
+ "model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
275
+ "model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
276
+ "model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
277
+ "model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
278
+ "model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
279
+ "model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
280
+ "model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
281
+ "model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
282
+ "model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
283
+ "model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
284
+ "model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
285
+ "model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
286
+ "model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
287
+ "model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
288
+ "model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
289
+ "model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
290
+ "model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
291
+ "model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
292
+ "model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
293
+ "model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
294
+ "model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
295
+ "model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
296
+ "model.norm.weight": "model-00003-of-00003.safetensors"
297
+ }
298
+ }
qat_modules.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from quantization import BitNetQuantSTE, PhaseQuantSTE, PhaseQuantSTE_V2, PhaseQuantSTE_V3, PhaseQuantSTE_V4
5
+ import math
6
+
7
+ class QATLinearBitNet(nn.Linear):
8
+ """BitNet QAT linear layer"""
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+
12
+ def forward(self, x):
13
+ quantized_weight = BitNetQuantSTE.apply(self.weight)
14
+ return F.linear(x, quantized_weight, self.bias)
15
+
16
+ class QATLinearComplexPhaseV1(nn.Linear):
17
+ """Complex-Phase V1 QAT linear layer"""
18
+ def __init__(self, *args, **kwargs):
19
+ super().__init__(*args, **kwargs)
20
+ if self.in_features % 2 != 0 or self.out_features % 2 != 0:
21
+ raise ValueError("Complex-Phase QAT requires even in/out features for Linear layers.")
22
+
23
+ def forward(self, x):
24
+ A = self.weight
25
+ n, m = A.shape[0] // 2, A.shape[1] // 2
26
+ A11, A12 = A[:n, :m], A[:n, m:]
27
+ A21, A22 = A[n:, :m], A[n:, m:]
28
+
29
+ U_re = 0.5 * (A11 + A22)
30
+ U_im = 0.5 * (A21 - A12)
31
+ W_re = 0.5 * (A11 - A22)
32
+ W_im = 0.5 * (A12 + A21)
33
+
34
+ U_re_q, U_im_q = PhaseQuantSTE.apply(U_re, U_im)
35
+ W_re_q, W_im_q = PhaseQuantSTE.apply(W_re, W_im)
36
+
37
+ A11_q = W_re_q + U_re_q
38
+ A12_q = W_im_q - U_im_q
39
+ A21_q = W_im_q + U_im_q
40
+ A22_q = -W_re_q + U_re_q
41
+
42
+ A_quant_top = torch.cat([A11_q, A12_q], dim=1)
43
+ A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
44
+ A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
45
+
46
+ return F.linear(x, A_quant, self.bias)
47
+
48
+ class QATLinearComplexPhaseV2(nn.Linear):
49
+ """Complex-Phase V2 QAT linear layer (1-step residual)"""
50
+ def __init__(self, *args, **kwargs):
51
+ super().__init__(*args, **kwargs)
52
+ if self.in_features % 2 != 0 or self.out_features % 2 != 0:
53
+ raise ValueError("Complex-Phase QAT requires even in/out features for Linear layers.")
54
+
55
+ def forward(self, x):
56
+ A = self.weight
57
+ n, m = A.shape[0] // 2, A.shape[1] // 2
58
+ A11, A12 = A[:n, :m], A[:n, m:]
59
+ A21, A22 = A[n:, :m], A[n:, m:]
60
+
61
+ U_re = 0.5 * (A11 + A22)
62
+ U_im = 0.5 * (A21 - A12)
63
+ W_re = 0.5 * (A11 - A22)
64
+ W_im = 0.5 * (A12 + A21)
65
+
66
+ U_re_q, U_im_q = PhaseQuantSTE_V2.apply(U_re, U_im)
67
+ W_re_q, W_im_q = PhaseQuantSTE_V2.apply(W_re, W_im)
68
+
69
+ A11_q = W_re_q + U_re_q
70
+ A12_q = W_im_q - U_im_q
71
+ A21_q = W_im_q + U_im_q
72
+ A22_q = -W_re_q + U_re_q
73
+
74
+ A_quant_top = torch.cat([A11_q, A12_q], dim=1)
75
+ A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
76
+ A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
77
+
78
+ return F.linear(x, A_quant, self.bias)
79
+
80
+ class QATLinearComplexPhaseV3(nn.Linear):
81
+ """Complex-Phase V3 QAT linear layer (2-step residual)"""
82
+ def __init__(self, *args, **kwargs):
83
+ super().__init__(*args, **kwargs)
84
+ if self.in_features % 2 != 0 or self.out_features % 2 != 0:
85
+ raise ValueError("Complex-Phase QAT requires even in/out features for Linear layers.")
86
+
87
+ def forward(self, x):
88
+ A = self.weight
89
+ n, m = A.shape[0] // 2, A.shape[1] // 2
90
+ A11, A12 = A[:n, :m], A[:n, m:]
91
+ A21, A22 = A[n:, :m], A[n:, m:]
92
+
93
+ U_re = 0.5 * (A11 + A22)
94
+ U_im = 0.5 * (A21 - A12)
95
+ W_re = 0.5 * (A11 - A22)
96
+ W_im = 0.5 * (A12 + A21)
97
+
98
+ U_re_q, U_im_q = PhaseQuantSTE_V3.apply(U_re, U_im)
99
+ W_re_q, W_im_q = PhaseQuantSTE_V3.apply(W_re, W_im)
100
+
101
+ A11_q = W_re_q + U_re_q
102
+ A12_q = W_im_q - U_im_q
103
+ A21_q = W_im_q + U_im_q
104
+ A22_q = -W_re_q + U_re_q
105
+
106
+ A_quant_top = torch.cat([A11_q, A12_q], dim=1)
107
+ A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
108
+ A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
109
+
110
+ return F.linear(x, A_quant, self.bias)
111
+
112
+ class QATLinearComplexPhaseV4(nn.Linear):
113
+ """Complex-Phase V4 QAT linear layer (3-step residual)"""
114
+ def __init__(self, *args, **kwargs):
115
+ super().__init__(*args, **kwargs)
116
+ if self.in_features % 2 != 0 or self.out_features % 2 != 0:
117
+ raise ValueError("Complex-Phase QAT requires even in/out features for Linear layers.")
118
+
119
+ def forward(self, x):
120
+ A = self.weight
121
+ n, m = A.shape[0] // 2, A.shape[1] // 2
122
+ A11, A12 = A[:n, :m], A[:n, m:]
123
+ A21, A22 = A[n:, :m], A[n:, m:]
124
+
125
+ U_re = 0.5 * (A11 + A22)
126
+ U_im = 0.5 * (A21 - A12)
127
+ W_re = 0.5 * (A11 - A22)
128
+ W_im = 0.5 * (A12 + A21)
129
+
130
+ U_re_q, U_im_q = PhaseQuantSTE_V4.apply(U_re, U_im)
131
+ W_re_q, W_im_q = PhaseQuantSTE_V4.apply(W_re, W_im)
132
+
133
+ A11_q = W_re_q + U_re_q
134
+ A12_q = W_im_q - U_im_q
135
+ A21_q = W_im_q + U_im_q
136
+ A22_q = -W_re_q + U_re_q
137
+
138
+ A_quant_top = torch.cat([A11_q, A12_q], dim=1)
139
+ A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
140
+ A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
141
+
142
+ return F.linear(x, A_quant, self.bias)
143
+
144
+ METHOD_MAP = {
145
+ 'bitnet': QATLinearBitNet,
146
+ 'complex_phase_v1': QATLinearComplexPhaseV1,
147
+ 'complex_phase_v2': QATLinearComplexPhaseV2,
148
+ 'complex_phase_v3': QATLinearComplexPhaseV3,
149
+ 'complex_phase_v4': QATLinearComplexPhaseV4,
150
+ }
151
+
152
+ def replace_modules_for_qat(model: nn.Module, method: str, skip_lm_head: bool = False):
153
+ """Recursively replace nn.Linear layers in the model with QAT layers"""
154
+ if method not in METHOD_MAP:
155
+ raise ValueError(f"Unknown method: {method}. Available methods: {list(METHOD_MAP.keys())}")
156
+
157
+ TargetQATClass = METHOD_MAP[method]
158
+
159
+ for name, module in model.named_children():
160
+ if len(list(module.children())) > 0:
161
+ replace_modules_for_qat(module, method, skip_lm_head)
162
+
163
+ if isinstance(module, nn.Linear):
164
+ if skip_lm_head and name == 'lm_head':
165
+ print(f" -> Skipping lm_head layer (skip_lm_head=True)")
166
+ continue
167
+
168
+ if 'complex_phase' in method:
169
+ if module.in_features % 2 != 0 or module.out_features % 2 != 0:
170
+ print(f" -> Skipping Complex-Phase replacement (non-even dimensions): {name} ({module.in_features}, {module.out_features})")
171
+ continue
172
+
173
+ print(f" -> Replacing layer: {name} with {TargetQATClass.__name__}")
174
+ new_module = TargetQATClass(
175
+ module.in_features,
176
+ module.out_features,
177
+ bias=module.bias is not None,
178
+ dtype=module.weight.dtype,
179
+ device=module.weight.device
180
+ )
181
+ new_module.weight.data.copy_(module.weight.data)
182
+ if module.bias is not None:
183
+ new_module.bias.data.copy_(module.bias.data)
184
+
185
+ setattr(model, name, new_module)
186
+
187
+ class InferenceOptimizedBitNet(nn.Linear):
188
+ """Inference-optimized BitNet linear layer, in-place weight replacement to save memory"""
189
+ def __init__(self, *args, **kwargs):
190
+ super().__init__(*args, **kwargs)
191
+ self._is_quantized = False
192
+
193
+ def _ensure_quantized(self):
194
+ """Ensure weights are quantized, executed only once"""
195
+ if not self._is_quantized:
196
+ with torch.no_grad():
197
+ w = self.weight
198
+ scale = w.abs().mean()
199
+ alpha = w.mean()
200
+ centered_w = w - alpha
201
+ binarized_w = torch.where(centered_w > 0, 1.0, -1.0).to(w.dtype)
202
+ quantized_w = binarized_w * scale
203
+ self.weight.data = quantized_w
204
+ self._is_quantized = True
205
+
206
+ def forward(self, x):
207
+ self._ensure_quantized()
208
+ return F.linear(x, self.weight, self.bias)
209
+
210
+ class InferenceOptimizedComplexPhase(nn.Linear):
211
+ """Inference-optimized Complex Phase linear layer, supports V1-V4"""
212
+ def __init__(self, version="v1", *args, **kwargs):
213
+ super().__init__(*args, **kwargs)
214
+ if self.in_features % 2 != 0 or self.out_features % 2 != 0:
215
+ raise ValueError("Complex-Phase requires even in/out features.")
216
+ self._is_quantized = False
217
+ self._version = version.lower()
218
+ if self._version not in ["v1", "v2", "v3", "v4"]:
219
+ raise ValueError(f"Unsupported version: {version}. Must be one of ['v1', 'v2', 'v3', 'v4']")
220
+
221
+ def _ensure_quantized(self):
222
+ """Ensure weights are quantized, executed only once"""
223
+ if not self._is_quantized:
224
+ with torch.no_grad():
225
+ A = self.weight
226
+ n, m = A.shape[0] // 2, A.shape[1] // 2
227
+ A11, A12 = A[:n, :m], A[:n, m:]
228
+ A21, A22 = A[n:, :m], A[n:, m:]
229
+
230
+ U_re = 0.5 * (A11 + A22)
231
+ U_im = 0.5 * (A21 - A12)
232
+ W_re = 0.5 * (A11 - A22)
233
+ W_im = 0.5 * (A12 + A21)
234
+
235
+ if self._version == "v1":
236
+ U_re_q, U_im_q = self._phase_quant_v1(U_re, U_im)
237
+ W_re_q, W_im_q = self._phase_quant_v1(W_re, W_im)
238
+ elif self._version == "v2":
239
+ U_re_q, U_im_q = self._phase_quant_v2(U_re, U_im)
240
+ W_re_q, W_im_q = self._phase_quant_v2(W_re, W_im)
241
+ elif self._version == "v3":
242
+ U_re_q, U_im_q = self._phase_quant_v3(U_re, U_im)
243
+ W_re_q, W_im_q = self._phase_quant_v3(W_re, W_im)
244
+ elif self._version == "v4":
245
+ U_re_q, U_im_q = self._phase_quant_v4(U_re, U_im)
246
+ W_re_q, W_im_q = self._phase_quant_v4(W_re, W_im)
247
+
248
+ A11_q = W_re_q + U_re_q
249
+ A12_q = W_im_q - U_im_q
250
+ A21_q = W_im_q + U_im_q
251
+ A22_q = -W_re_q + U_re_q
252
+
253
+ A_quant_top = torch.cat([A11_q, A12_q], dim=1)
254
+ A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
255
+ A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
256
+
257
+ self.weight.data = A_quant
258
+ self._is_quantized = True
259
+
260
+ def _phase_quant_v1(self, w_real, w_imag):
261
+ """V1: Basic PhaseQuant"""
262
+ phase = torch.angle(w_real + 1j * w_imag)
263
+
264
+ real_pos = (phase >= -math.pi / 4) & (phase < math.pi / 4)
265
+ real_neg = (phase >= 3 * math.pi / 4) | (phase < -3 * math.pi / 4)
266
+ imag_pos = (phase >= math.pi / 4) & (phase < 3 * math.pi / 4)
267
+ imag_neg = (phase >= -3 * math.pi / 4) & (phase < -math.pi / 4)
268
+
269
+ mask_real = real_pos | real_neg
270
+ mask_imag = imag_pos | imag_neg
271
+
272
+ s_re = w_real[mask_real].abs().mean() if mask_real.any() else torch.tensor(0.0, device=w_real.device)
273
+ s_im = w_imag[mask_imag].abs().mean() if mask_imag.any() else torch.tensor(0.0, device=w_imag.device)
274
+
275
+ s_re = torch.clamp(s_re, min=1e-6)
276
+ s_im = torch.clamp(s_im, min=1e-6)
277
+
278
+ qw_real = torch.zeros_like(w_real)
279
+ qw_imag = torch.zeros_like(w_imag)
280
+
281
+ qw_real[real_pos] = 1.0
282
+ qw_real[real_neg] = -1.0
283
+ qw_imag[imag_pos] = 1.0
284
+ qw_imag[imag_neg] = -1.0
285
+
286
+ return qw_real * s_re, qw_imag * s_im
287
+
288
+ def _phase_quant_v2(self, w_real, w_imag):
289
+ """V2: 1-step residual quantization"""
290
+ qw_real_o1, qw_imag_o1 = self._phase_quant_v1(w_real, w_imag)
291
+ error_real = w_real - qw_real_o1
292
+ error_imag = w_imag - qw_imag_o1
293
+ qw_real_o2, qw_imag_o2 = self._phase_quant_v1(error_real, error_imag)
294
+ qw_real = qw_real_o1 + qw_real_o2
295
+ qw_imag = qw_imag_o1 + qw_imag_o2
296
+ return qw_real, qw_imag
297
+
298
+ def _phase_quant_v3(self, w_real, w_imag):
299
+ """V3: 2-step residual quantization"""
300
+ qw_real_o1, qw_imag_o1 = self._phase_quant_v1(w_real, w_imag)
301
+ error_real_1 = w_real - qw_real_o1
302
+ error_imag_1 = w_imag - qw_imag_o1
303
+ qw_real_o2, qw_imag_o2 = self._phase_quant_v1(error_real_1, error_imag_1)
304
+ error_real_2 = error_real_1 - qw_real_o2
305
+ error_imag_2 = error_imag_1 - qw_imag_o2
306
+ qw_real_o3, qw_imag_o3 = self._phase_quant_v1(error_real_2, error_imag_2)
307
+ qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3
308
+ qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3
309
+ return qw_real, qw_imag
310
+
311
+ def _phase_quant_v4(self, w_real, w_imag):
312
+ """V4: 3-step residual quantization"""
313
+ qw_real_o1, qw_imag_o1 = self._phase_quant_v1(w_real, w_imag)
314
+ error_real_1 = w_real - qw_real_o1
315
+ error_imag_1 = w_imag - qw_imag_o1
316
+ qw_real_o2, qw_imag_o2 = self._phase_quant_v1(error_real_1, error_imag_1)
317
+ error_real_2 = error_real_1 - qw_real_o2
318
+ error_imag_2 = error_imag_1 - qw_imag_o2
319
+ qw_real_o3, qw_imag_o3 = self._phase_quant_v1(error_real_2, error_imag_2)
320
+ error_real_3 = error_real_2 - qw_real_o3
321
+ error_imag_3 = error_imag_2 - qw_imag_o3
322
+ qw_real_o4, qw_imag_o4 = self._phase_quant_v1(error_real_3, error_imag_3)
323
+ qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3 + qw_real_o4
324
+ qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3 + qw_imag_o4
325
+ return qw_real, qw_imag
326
+
327
+ def forward(self, x):
328
+ self._ensure_quantized()
329
+ return F.linear(x, self.weight, self.bias)
330
+
331
+ def convert_to_inference_mode(model):
332
+ """Convert QAT modules to inference-optimized version (permanently modifies model weights)"""
333
+ converted_count = 0
334
+
335
+ def _convert_module(module, name_path=""):
336
+ nonlocal converted_count
337
+
338
+ for name, child in list(module.named_children()):
339
+ full_name = f"{name_path}.{name}" if name_path else name
340
+
341
+ if isinstance(child, QATLinearBitNet):
342
+ new_module = InferenceOptimizedBitNet(
343
+ child.in_features,
344
+ child.out_features,
345
+ bias=child.bias is not None,
346
+ device=child.weight.device,
347
+ dtype=child.weight.dtype
348
+ )
349
+ new_module.weight.data.copy_(child.weight.data)
350
+ if child.bias is not None:
351
+ new_module.bias.data.copy_(child.bias.data)
352
+
353
+ setattr(module, name, new_module)
354
+ converted_count += 1
355
+ print(f" -> Converting BitNet layer: {full_name}")
356
+
357
+ elif isinstance(child, (QATLinearComplexPhaseV1, QATLinearComplexPhaseV2,
358
+ QATLinearComplexPhaseV3, QATLinearComplexPhaseV4)):
359
+ if isinstance(child, QATLinearComplexPhaseV1):
360
+ version = "v1"
361
+ elif isinstance(child, QATLinearComplexPhaseV2):
362
+ version = "v2"
363
+ elif isinstance(child, QATLinearComplexPhaseV3):
364
+ version = "v3"
365
+ elif isinstance(child, QATLinearComplexPhaseV4):
366
+ version = "v4"
367
+
368
+ new_module = InferenceOptimizedComplexPhase(
369
+ version=version,
370
+ in_features=child.in_features,
371
+ out_features=child.out_features,
372
+ bias=child.bias is not None,
373
+ device=child.weight.device,
374
+ dtype=child.weight.dtype
375
+ )
376
+ new_module.weight.data.copy_(child.weight.data)
377
+ if child.bias is not None:
378
+ new_module.bias.data.copy_(child.bias.data)
379
+
380
+ setattr(module, name, new_module)
381
+ converted_count += 1
382
+ print(f" -> Converting ComplexPhase{version.upper()} layer: {full_name}")
383
+ else:
384
+ _convert_module(child, full_name)
385
+
386
+ _convert_module(model)
387
+ print(f"Converted {converted_count} QAT layers to inference-optimized version")
388
+ return model
quantization.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ @torch.no_grad()
6
+ def quantize_complex_tensor(w_real: torch.Tensor, w_imag: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
7
+ """Apply PhaseQuant logic to complex weight tensors"""
8
+ phase = torch.angle(w_real + 1j * w_imag)
9
+
10
+ real_pos = (phase >= -math.pi / 4) & (phase < math.pi / 4)
11
+ real_neg = (phase >= 3 * math.pi / 4) | (phase < -3 * math.pi / 4)
12
+ imag_pos = (phase >= math.pi / 4) & (phase < 3 * math.pi / 4)
13
+ imag_neg = (phase >= -3 * math.pi / 4) & (phase < -math.pi / 4)
14
+
15
+ mask_real = real_pos | real_neg
16
+ mask_imag = imag_pos | imag_neg
17
+
18
+ s_re = w_real[mask_real].abs().mean() if mask_real.any() else torch.tensor(0.0, device=w_real.device)
19
+ s_im = w_imag[mask_imag].abs().mean() if mask_imag.any() else torch.tensor(0.0, device=w_imag.device)
20
+
21
+ s_re = torch.clamp(s_re, min=1e-6)
22
+ s_im = torch.clamp(s_im, min=1e-6)
23
+ if torch.isnan(s_re) or torch.isinf(s_re): s_re = torch.tensor(1e-6, device=w_real.device)
24
+ if torch.isnan(s_im) or torch.isinf(s_im): s_im = torch.tensor(1e-6, device=w_imag.device)
25
+
26
+ qw_real = torch.zeros_like(w_real)
27
+ qw_imag = torch.zeros_like(w_imag)
28
+
29
+ qw_real[real_pos] = 1.0
30
+ qw_real[real_neg] = -1.0
31
+ qw_imag[imag_pos] = 1.0
32
+ qw_imag[imag_neg] = -1.0
33
+
34
+ qw_real_scaled = qw_real * s_re
35
+ qw_imag_scaled = qw_imag * s_im
36
+ return qw_real_scaled.to(w_real.dtype), qw_imag_scaled.to(w_imag.dtype)
37
+
38
+ def apply_complex_inspired_quantization(model: nn.Module):
39
+ """Apply complex-inspired quantization to real-valued model"""
40
+ print("Applying complex-inspired quantization (PhaseQuant-based)...")
41
+
42
+ @torch.no_grad()
43
+ def quantize_linear_layer(module: nn.Linear):
44
+ A = module.weight.data
45
+ if A.shape[0] % 2 != 0 or A.shape[1] % 2 != 0:
46
+ print(f" -> Skipping layer (non-even dimensions): {A.shape}")
47
+ return
48
+
49
+ n, m = A.shape[0] // 2, A.shape[1] // 2
50
+ A11, A12 = A[:n, :m], A[:n, m:]
51
+ A21, A22 = A[n:, :m], A[n:, m:]
52
+
53
+ U_re = 0.5 * (A11 + A22)
54
+ U_im = 0.5 * (A21 - A12)
55
+ W_re = 0.5 * (A11 - A22)
56
+ W_im = 0.5 * (A12 + A21)
57
+
58
+ U_re_q, U_im_q = quantize_complex_tensor(U_re, U_im)
59
+ W_re_q, W_im_q = quantize_complex_tensor(W_re, W_im)
60
+
61
+ A11_q = W_re_q + U_re_q
62
+ A12_q = W_im_q - U_im_q
63
+ A21_q = W_im_q + U_im_q
64
+ A22_q = -W_re_q + U_re_q
65
+
66
+ A_quant_top = torch.cat([A11_q, A12_q], dim=1)
67
+ A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
68
+ A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
69
+
70
+ module.weight.data = A_quant.to(A.dtype)
71
+
72
+ model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
73
+ print("Complex-inspired quantization completed.")
74
+ return model
75
+
76
+ def apply_bitnet_quantization(model: nn.Module):
77
+ """Apply BitNet 1-bit quantization to real-valued model"""
78
+ print("Applying BitNet (true 1-bit, affine) quantization to real-valued model...")
79
+
80
+ @torch.no_grad()
81
+ def quantize_linear_layer(module: nn.Linear):
82
+ scale = module.weight.data.abs().mean()
83
+ alpha = module.weight.data.mean()
84
+ centered_weights = module.weight.data - alpha
85
+ binarized_weights = torch.where(centered_weights > 0, 1.0, -1.0)
86
+ module.weight.data = binarized_weights.to(module.weight.data.dtype) * scale
87
+
88
+ model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
89
+ print("BitNet quantization completed.")
90
+ return model
91
+
92
+ def apply_bitnet_1_58bit_quantization_standard(model: nn.Module):
93
+ """Apply BitNet 1.58-bit quantization to real-valued model (quantize to {-1, 0, +1})"""
94
+ print("Applying BitNet 1.58-bit (absmean threshold) quantization to real-valued model...")
95
+
96
+ @torch.no_grad()
97
+ def quantize_linear_layer(module: nn.Linear):
98
+ W = module.weight.data
99
+ gamma = W.abs().mean()
100
+ W_normalized = W / (gamma + 1e-5)
101
+ W_quantized = torch.clamp(torch.round(W_normalized), -1.0, 1.0)
102
+ module.weight.data = W_quantized.to(W.dtype) * gamma
103
+
104
+ model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
105
+ print("BitNet 1.58-bit (absmean threshold) quantization completed.")
106
+ return model
107
+
108
+ def apply_bitnet_1_58bit_quantization_variant(model: nn.Module, threshold: float = 0.5):
109
+ """Apply BitNet 1.58-bit quantization to real-valued model (quantize to {-1, 0, +1})"""
110
+ print("Applying BitNet 1.58-bit (ternary) quantization to real-valued model...")
111
+
112
+ @torch.no_grad()
113
+ def quantize_linear_layer(module: nn.Linear):
114
+ gamma = module.weight.data.abs().mean()
115
+ normalized_weights = module.weight.data / (gamma + 1e-5)
116
+ adaptive_threshold = threshold
117
+ ternary_weights = torch.zeros_like(normalized_weights)
118
+ ternary_weights[normalized_weights > adaptive_threshold] = 1.0
119
+ ternary_weights[normalized_weights < -adaptive_threshold] = -1.0
120
+ module.weight.data = ternary_weights.to(module.weight.data.dtype) * gamma
121
+
122
+ model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
123
+ print("BitNet 1.58-bit quantization completed.")
124
+ return model
125
+
126
+ def minmax_1bit_quantize_dequantize(w: torch.Tensor) -> torch.Tensor:
127
+ """Apply 1-bit Min-Max quantization and dequantization to weight tensor"""
128
+ min_val = w.min()
129
+ max_val = w.max()
130
+ scale = (max_val - min_val) / 1.0
131
+ zero_point = min_val
132
+
133
+ if abs(scale) < 1e-9:
134
+ return w
135
+
136
+ quantized_w = torch.round((w - zero_point) / scale)
137
+ dequantized_w = quantized_w * scale + zero_point
138
+
139
+ return dequantized_w.to(w.dtype)
140
+
141
+ def apply_minmax_1bit_quantization(model: nn.Module):
142
+ """Apply Min-Max 1-bit quantization to real-valued model"""
143
+ print("Applying Min-Max (1-bit) quantization to real-valued model...")
144
+
145
+ @torch.no_grad()
146
+ def quantize_linear_layer(module: nn.Linear):
147
+ module.weight.data = minmax_1bit_quantize_dequantize(module.weight.data)
148
+
149
+ model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
150
+
151
+ print("Min-Max 1-bit quantization completed.")
152
+ return model
153
+
154
+ def symmetric_minmax_1bit_quantize_dequantize(w: torch.Tensor) -> torch.Tensor:
155
+ """Apply symmetric 1-bit Min-Max quantization to weight tensor (quantize to {-1, 1})"""
156
+ max_abs = w.abs().max()
157
+ scale = max_abs
158
+
159
+ if scale < 1e-9:
160
+ return w
161
+
162
+ quantized_w = (w / scale).sign()
163
+ dequantized_w = quantized_w * scale
164
+
165
+ return dequantized_w.to(w.dtype)
166
+
167
+ def apply_symmetric_minmax_1bit_quantization(model: nn.Module):
168
+ """Apply symmetric Min-Max 1-bit quantization to real-valued model"""
169
+ print("Applying symmetric Min-Max (1-bit, to {-1, 1}) quantization to real-valued model...")
170
+
171
+ @torch.no_grad()
172
+ def quantize_linear_layer(module: nn.Linear):
173
+ module.weight.data = symmetric_minmax_1bit_quantize_dequantize(module.weight.data)
174
+
175
+ model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
176
+
177
+ print("Symmetric Min-Max 1-bit quantization completed.")
178
+ return model
179
+
180
+ class BitNetQuantSTE(torch.autograd.Function):
181
+ """BitNet STE: quantize in forward, pass gradients in backward"""
182
+ @staticmethod
183
+ def forward(ctx, w):
184
+ scale = w.abs().mean()
185
+ alpha = w.mean()
186
+ centered_w = w - alpha
187
+ binarized_w = torch.where(centered_w > 0, 1.0, -1.0).to(w.dtype)
188
+ quantized_w = binarized_w * scale
189
+ return quantized_w
190
+
191
+ @staticmethod
192
+ def backward(ctx, grad_output):
193
+ return grad_output
194
+
195
+
196
+ class BitNet1_58QuantSTE(torch.autograd.Function):
197
+ """BitNet 1.58-bit STE: quantize to {-1, 0, +1}, pass gradients in backward"""
198
+ @staticmethod
199
+ def forward(ctx, w):
200
+ gamma = w.abs().mean()
201
+ w_normalized = w / (gamma + 1e-5)
202
+ w_quantized = torch.clamp(torch.round(w_normalized), -1.0, 1.0)
203
+ quantized_w = (w_quantized * gamma).to(w.dtype)
204
+ return quantized_w
205
+
206
+ @staticmethod
207
+ def backward(ctx, grad_output):
208
+ return grad_output
209
+
210
+
211
+ class PhaseQuantSTE(torch.autograd.Function):
212
+ """Complex-Phase STE: quantize in forward, pass gradients in backward"""
213
+ @staticmethod
214
+ def forward(ctx, w_real, w_imag):
215
+ phase = torch.angle(w_real + 1j * w_imag)
216
+
217
+ real_pos = (phase >= -math.pi / 4) & (phase < math.pi / 4)
218
+ real_neg = (phase >= 3 * math.pi / 4) | (phase < -3 * math.pi / 4)
219
+ imag_pos = (phase >= math.pi / 4) & (phase < 3 * math.pi / 4)
220
+ imag_neg = (phase >= -3 * math.pi / 4) & (phase < -math.pi / 4)
221
+
222
+ mask_real = real_pos | real_neg
223
+ mask_imag = imag_pos | imag_neg
224
+
225
+ s_re = w_real[mask_real].abs().mean() if mask_real.any() else torch.tensor(0.0, device=w_real.device)
226
+ s_im = w_imag[mask_imag].abs().mean() if mask_imag.any() else torch.tensor(0.0, device=w_imag.device)
227
+
228
+ s_re = torch.clamp(s_re, min=1e-6)
229
+ s_im = torch.clamp(s_im, min=1e-6)
230
+
231
+ qw_real = torch.zeros_like(w_real)
232
+ qw_imag = torch.zeros_like(w_imag)
233
+
234
+ qw_real[real_pos] = 1.0
235
+ qw_real[real_neg] = -1.0
236
+ qw_imag[imag_pos] = 1.0
237
+ qw_imag[imag_neg] = -1.0
238
+
239
+ qw_real_scaled = qw_real * s_re
240
+ qw_imag_scaled = qw_imag * s_im
241
+
242
+ return qw_real_scaled.to(w_real.dtype), qw_imag_scaled.to(w_imag.dtype)
243
+
244
+ @staticmethod
245
+ def backward(ctx, grad_w_real, grad_w_imag):
246
+ return grad_w_real, grad_w_imag
247
+
248
+ class PhaseQuantSTE_V2(torch.autograd.Function):
249
+ """Two-step residual quantization"""
250
+
251
+ @staticmethod
252
+ def forward(ctx, w_real: torch.Tensor, w_imag: torch.Tensor):
253
+ qw_real_o1, qw_imag_o1 = PhaseQuantSTE.apply(w_real, w_imag)
254
+ error_real = w_real - qw_real_o1
255
+ error_imag = w_imag - qw_imag_o1
256
+ qw_real_o2, qw_imag_o2 = PhaseQuantSTE.apply(error_real, error_imag)
257
+ qw_real = qw_real_o1 + qw_real_o2
258
+ qw_imag = qw_imag_o1 + qw_imag_o2
259
+ return qw_real, qw_imag
260
+
261
+ @staticmethod
262
+ def backward(ctx, grad_real, grad_imag):
263
+ return grad_real, grad_imag
264
+
265
+
266
+ class PhaseQuantSTE_V3(torch.autograd.Function):
267
+ """Three-step residual quantization"""
268
+
269
+ @staticmethod
270
+ def forward(ctx, w_real: torch.Tensor, w_imag: torch.Tensor):
271
+ qw_real_o1, qw_imag_o1 = PhaseQuantSTE.apply(w_real, w_imag)
272
+ error_real_1 = w_real - qw_real_o1
273
+ error_imag_1 = w_imag - qw_imag_o1
274
+ qw_real_o2, qw_imag_o2 = PhaseQuantSTE.apply(error_real_1, error_imag_1)
275
+ error_real_2 = error_real_1 - qw_real_o2
276
+ error_imag_2 = error_imag_1 - qw_imag_o2
277
+ qw_real_o3, qw_imag_o3 = PhaseQuantSTE.apply(error_real_2, error_imag_2)
278
+ qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3
279
+ qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3
280
+ return qw_real, qw_imag
281
+
282
+ @staticmethod
283
+ def backward(ctx, grad_real, grad_imag):
284
+ return grad_real, grad_imag
285
+
286
+
287
+ class PhaseQuantSTE_V4(torch.autograd.Function):
288
+ """Four-step residual quantization"""
289
+
290
+ @staticmethod
291
+ def forward(ctx, w_real: torch.Tensor, w_imag: torch.Tensor):
292
+ qw_real_o1, qw_imag_o1 = PhaseQuantSTE.apply(w_real, w_imag)
293
+ error_real_1 = w_real - qw_real_o1
294
+ error_imag_1 = w_imag - qw_imag_o1
295
+ qw_real_o2, qw_imag_o2 = PhaseQuantSTE.apply(error_real_1, error_imag_1)
296
+ error_real_2 = error_real_1 - qw_real_o2
297
+ error_imag_2 = error_imag_1 - qw_imag_o2
298
+ qw_real_o3, qw_imag_o3 = PhaseQuantSTE.apply(error_real_2, error_imag_2)
299
+ error_real_3 = error_real_2 - qw_real_o3
300
+ error_imag_3 = error_imag_2 - qw_imag_o3
301
+ qw_real_o4, qw_imag_o4 = PhaseQuantSTE.apply(error_real_3, error_imag_3)
302
+ qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3 + qw_real_o4
303
+ qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3 + qw_imag_o4
304
+ return qw_real, qw_imag
305
+
306
+ @staticmethod
307
+ def backward(ctx, grad_real, grad_imag):
308
+ return grad_real, grad_imag
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "</s>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "extra_special_tokens": {},
35
+ "legacy": false,
36
+ "model_max_length": 1000000000000000019884624838656,
37
+ "pad_token": "</s>",
38
+ "padding_side": "right",
39
+ "sp_model_kwargs": {},
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9aacb37f66b32727647179ccb831ff830505e9721317432e224e6a2abb2dae7
3
+ size 6929