Upload folder using huggingface_hub
Browse files- config.json +29 -0
- generation_config.json +10 -0
- load_model.py +143 -0
- model-00001-of-00003.safetensors +3 -0
- model-00002-of-00003.safetensors +3 -0
- model-00003-of-00003.safetensors +3 -0
- model.safetensors.index.json +298 -0
- qat_modules.py +388 -0
- quantization.py +308 -0
- special_tokens_map.json +24 -0
- tokenizer.json +0 -0
- tokenizer_config.json +43 -0
- training_args.bin +3 -0
config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"LlamaForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 1,
|
| 8 |
+
"eos_token_id": 2,
|
| 9 |
+
"head_dim": 128,
|
| 10 |
+
"hidden_act": "silu",
|
| 11 |
+
"hidden_size": 4096,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 11008,
|
| 14 |
+
"max_position_embeddings": 4096,
|
| 15 |
+
"mlp_bias": false,
|
| 16 |
+
"model_type": "llama",
|
| 17 |
+
"num_attention_heads": 32,
|
| 18 |
+
"num_hidden_layers": 32,
|
| 19 |
+
"num_key_value_heads": 32,
|
| 20 |
+
"pretraining_tp": 1,
|
| 21 |
+
"rms_norm_eps": 1e-05,
|
| 22 |
+
"rope_scaling": null,
|
| 23 |
+
"rope_theta": 10000.0,
|
| 24 |
+
"tie_word_embeddings": false,
|
| 25 |
+
"torch_dtype": "bfloat16",
|
| 26 |
+
"transformers_version": "4.52.4",
|
| 27 |
+
"use_cache": true,
|
| 28 |
+
"vocab_size": 32000
|
| 29 |
+
}
|
generation_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 1,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"eos_token_id": 2,
|
| 5 |
+
"max_length": 4096,
|
| 6 |
+
"pad_token_id": 0,
|
| 7 |
+
"temperature": 0.6,
|
| 8 |
+
"top_p": 0.9,
|
| 9 |
+
"transformers_version": "4.52.4"
|
| 10 |
+
}
|
load_model.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Model loading script
|
| 3 |
+
|
| 4 |
+
Load Fairy2i-W2 model from Hugging Face repository.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
from load_model import load_model
|
| 8 |
+
model, tokenizer = load_model()
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 13 |
+
from safetensors.torch import load_file
|
| 14 |
+
from huggingface_hub import hf_hub_download
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
# Add current directory to path for importing qat_modules
|
| 19 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 20 |
+
if current_dir not in sys.path:
|
| 21 |
+
sys.path.insert(0, current_dir)
|
| 22 |
+
|
| 23 |
+
from qat_modules import replace_modules_for_qat
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_model(
|
| 27 |
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
| 28 |
+
torch_dtype=torch.bfloat16
|
| 29 |
+
):
|
| 30 |
+
"""
|
| 31 |
+
Load Fairy2i-W2 model: standard architecture + custom weights + QAT linear layer replacement
|
| 32 |
+
|
| 33 |
+
Load weights and tokenizer from Hugging Face repository.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
device: Device, default auto-select
|
| 37 |
+
torch_dtype: Data type, default torch.bfloat16
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
model, tokenizer
|
| 41 |
+
"""
|
| 42 |
+
# Configuration parameters
|
| 43 |
+
base_model_id = "meta-llama/Llama-2-7b-hf"
|
| 44 |
+
weights_repo_id = "PKU-DS-LAB/Fairy2i-W2"
|
| 45 |
+
quant_method = "complex_phase_v2"
|
| 46 |
+
skip_lm_head = False
|
| 47 |
+
print("=" * 70)
|
| 48 |
+
print("Loading Fairy2i-W2 Model")
|
| 49 |
+
print("=" * 70)
|
| 50 |
+
|
| 51 |
+
# Step 1: Load standard model architecture
|
| 52 |
+
print(f"\n📥 Step 1/4: Loading standard model architecture: {base_model_id}")
|
| 53 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 54 |
+
base_model_id,
|
| 55 |
+
torch_dtype=torch_dtype,
|
| 56 |
+
device_map=device,
|
| 57 |
+
trust_remote_code=False
|
| 58 |
+
)
|
| 59 |
+
print("✅ Standard model architecture loaded")
|
| 60 |
+
|
| 61 |
+
# Step 2: Load custom weights
|
| 62 |
+
print(f"\n💾 Step 2/4: Loading weights from Hugging Face repository: {weights_repo_id}")
|
| 63 |
+
|
| 64 |
+
# Check for sharded weights
|
| 65 |
+
try:
|
| 66 |
+
index_path = hf_hub_download(
|
| 67 |
+
repo_id=weights_repo_id,
|
| 68 |
+
filename="model.safetensors.index.json",
|
| 69 |
+
local_dir=None
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# Sharded weights
|
| 73 |
+
from safetensors import safe_open
|
| 74 |
+
import json
|
| 75 |
+
|
| 76 |
+
with open(index_path, 'r') as f:
|
| 77 |
+
weight_map = json.load(f)["weight_map"]
|
| 78 |
+
|
| 79 |
+
state_dict = {}
|
| 80 |
+
for weight_file in set(weight_map.values()):
|
| 81 |
+
file_path = hf_hub_download(
|
| 82 |
+
repo_id=weights_repo_id,
|
| 83 |
+
filename=weight_file,
|
| 84 |
+
local_dir=None
|
| 85 |
+
)
|
| 86 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
| 87 |
+
for key in f.keys():
|
| 88 |
+
state_dict[key] = f.get_tensor(key)
|
| 89 |
+
|
| 90 |
+
model.load_state_dict(state_dict, strict=False)
|
| 91 |
+
print(f"✅ Weights loaded (sharded)")
|
| 92 |
+
except Exception:
|
| 93 |
+
# Single weight file
|
| 94 |
+
try:
|
| 95 |
+
weights_path = hf_hub_download(
|
| 96 |
+
repo_id=weights_repo_id,
|
| 97 |
+
filename="model.safetensors",
|
| 98 |
+
local_dir=None
|
| 99 |
+
)
|
| 100 |
+
state_dict = load_file(weights_path)
|
| 101 |
+
model.load_state_dict(state_dict, strict=False)
|
| 102 |
+
print(f"✅ Weights loaded (single file)")
|
| 103 |
+
except Exception as e:
|
| 104 |
+
raise RuntimeError(f"Failed to load weights from Hugging Face: {e}")
|
| 105 |
+
|
| 106 |
+
# Step 3: Apply QAT replacement
|
| 107 |
+
print(f"\n🔧 Step 3/4: Applying QAT replacement ({quant_method})...")
|
| 108 |
+
replace_modules_for_qat(model, quant_method, skip_lm_head=skip_lm_head)
|
| 109 |
+
print("✅ QAT replacement completed")
|
| 110 |
+
|
| 111 |
+
# Step 4: Load tokenizer
|
| 112 |
+
print(f"\n📝 Step 4/4: Loading Tokenizer from Hugging Face repository: {weights_repo_id}")
|
| 113 |
+
tokenizer = AutoTokenizer.from_pretrained(weights_repo_id)
|
| 114 |
+
print("✅ Tokenizer loaded")
|
| 115 |
+
|
| 116 |
+
print("\n" + "=" * 70)
|
| 117 |
+
print("✅ Model loading completed!")
|
| 118 |
+
print("=" * 70)
|
| 119 |
+
|
| 120 |
+
return model, tokenizer
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
# Example: Load model
|
| 125 |
+
model, tokenizer = load_model()
|
| 126 |
+
|
| 127 |
+
# Test generation
|
| 128 |
+
print("\n🧪 Testing generation...")
|
| 129 |
+
prompt = "Hello, how are you?"
|
| 130 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 131 |
+
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
outputs = model.generate(
|
| 134 |
+
**inputs,
|
| 135 |
+
max_new_tokens=50,
|
| 136 |
+
do_sample=True,
|
| 137 |
+
temperature=0.7
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 141 |
+
print(f"Prompt: {prompt}")
|
| 142 |
+
print(f"Response: {response}")
|
| 143 |
+
|
model-00001-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7faa5a86a5d9968018d5ec54582e4c8d953541fdda69dd3855cad02b0ba62f8c
|
| 3 |
+
size 4938985352
|
model-00002-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c10a92f9aa154ead36f02411ed284ac5a4abcba038d43c3d3b4599b3c4fd132e
|
| 3 |
+
size 4947390880
|
model-00003-of-00003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33494b29047ea5f2696e949935b46b7e9372ec5de81d9718656ae8237766bf2d
|
| 3 |
+
size 3590488816
|
model.safetensors.index.json
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"metadata": {
|
| 3 |
+
"total_size": 13476831232
|
| 4 |
+
},
|
| 5 |
+
"weight_map": {
|
| 6 |
+
"lm_head.weight": "model-00003-of-00003.safetensors",
|
| 7 |
+
"model.embed_tokens.weight": "model-00001-of-00003.safetensors",
|
| 8 |
+
"model.layers.0.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 9 |
+
"model.layers.0.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 10 |
+
"model.layers.0.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 11 |
+
"model.layers.0.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 12 |
+
"model.layers.0.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 13 |
+
"model.layers.0.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 14 |
+
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 15 |
+
"model.layers.0.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 16 |
+
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 17 |
+
"model.layers.1.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 18 |
+
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 19 |
+
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 20 |
+
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 21 |
+
"model.layers.1.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 22 |
+
"model.layers.1.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 23 |
+
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 24 |
+
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 25 |
+
"model.layers.1.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 26 |
+
"model.layers.10.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 27 |
+
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 28 |
+
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 29 |
+
"model.layers.10.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 30 |
+
"model.layers.10.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 31 |
+
"model.layers.10.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 32 |
+
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 33 |
+
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 34 |
+
"model.layers.10.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 35 |
+
"model.layers.11.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 36 |
+
"model.layers.11.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 37 |
+
"model.layers.11.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 38 |
+
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 39 |
+
"model.layers.11.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 40 |
+
"model.layers.11.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 41 |
+
"model.layers.11.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 42 |
+
"model.layers.11.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 43 |
+
"model.layers.11.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 44 |
+
"model.layers.12.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 45 |
+
"model.layers.12.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 46 |
+
"model.layers.12.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 47 |
+
"model.layers.12.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 48 |
+
"model.layers.12.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 49 |
+
"model.layers.12.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 50 |
+
"model.layers.12.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 51 |
+
"model.layers.12.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 52 |
+
"model.layers.12.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 53 |
+
"model.layers.13.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 54 |
+
"model.layers.13.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 55 |
+
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 56 |
+
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 57 |
+
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 58 |
+
"model.layers.13.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 59 |
+
"model.layers.13.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 60 |
+
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 61 |
+
"model.layers.13.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 62 |
+
"model.layers.14.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 63 |
+
"model.layers.14.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 64 |
+
"model.layers.14.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 65 |
+
"model.layers.14.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 66 |
+
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 67 |
+
"model.layers.14.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 68 |
+
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 69 |
+
"model.layers.14.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 70 |
+
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 71 |
+
"model.layers.15.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 72 |
+
"model.layers.15.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 73 |
+
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 74 |
+
"model.layers.15.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 75 |
+
"model.layers.15.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 76 |
+
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 77 |
+
"model.layers.15.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 78 |
+
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 79 |
+
"model.layers.15.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 80 |
+
"model.layers.16.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 81 |
+
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 82 |
+
"model.layers.16.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 83 |
+
"model.layers.16.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 84 |
+
"model.layers.16.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 85 |
+
"model.layers.16.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 86 |
+
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 87 |
+
"model.layers.16.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 88 |
+
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 89 |
+
"model.layers.17.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 90 |
+
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 91 |
+
"model.layers.17.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 92 |
+
"model.layers.17.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 93 |
+
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 94 |
+
"model.layers.17.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 95 |
+
"model.layers.17.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 96 |
+
"model.layers.17.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 97 |
+
"model.layers.17.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 98 |
+
"model.layers.18.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 99 |
+
"model.layers.18.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 100 |
+
"model.layers.18.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 101 |
+
"model.layers.18.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 102 |
+
"model.layers.18.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 103 |
+
"model.layers.18.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 104 |
+
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 105 |
+
"model.layers.18.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 106 |
+
"model.layers.18.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 107 |
+
"model.layers.19.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 108 |
+
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 109 |
+
"model.layers.19.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 110 |
+
"model.layers.19.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 111 |
+
"model.layers.19.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 112 |
+
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 113 |
+
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 114 |
+
"model.layers.19.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 115 |
+
"model.layers.19.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 116 |
+
"model.layers.2.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 117 |
+
"model.layers.2.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 118 |
+
"model.layers.2.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 119 |
+
"model.layers.2.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 120 |
+
"model.layers.2.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 121 |
+
"model.layers.2.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 122 |
+
"model.layers.2.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 123 |
+
"model.layers.2.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 124 |
+
"model.layers.2.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 125 |
+
"model.layers.20.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 126 |
+
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 127 |
+
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 128 |
+
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 129 |
+
"model.layers.20.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 130 |
+
"model.layers.20.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 131 |
+
"model.layers.20.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 132 |
+
"model.layers.20.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 133 |
+
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 134 |
+
"model.layers.21.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 135 |
+
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 136 |
+
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 137 |
+
"model.layers.21.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 138 |
+
"model.layers.21.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 139 |
+
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 140 |
+
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 141 |
+
"model.layers.21.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 142 |
+
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 143 |
+
"model.layers.22.input_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 144 |
+
"model.layers.22.mlp.down_proj.weight": "model-00002-of-00003.safetensors",
|
| 145 |
+
"model.layers.22.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 146 |
+
"model.layers.22.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 147 |
+
"model.layers.22.post_attention_layernorm.weight": "model-00002-of-00003.safetensors",
|
| 148 |
+
"model.layers.22.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 149 |
+
"model.layers.22.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 150 |
+
"model.layers.22.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 151 |
+
"model.layers.22.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 152 |
+
"model.layers.23.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 153 |
+
"model.layers.23.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 154 |
+
"model.layers.23.mlp.gate_proj.weight": "model-00002-of-00003.safetensors",
|
| 155 |
+
"model.layers.23.mlp.up_proj.weight": "model-00002-of-00003.safetensors",
|
| 156 |
+
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 157 |
+
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00003.safetensors",
|
| 158 |
+
"model.layers.23.self_attn.o_proj.weight": "model-00002-of-00003.safetensors",
|
| 159 |
+
"model.layers.23.self_attn.q_proj.weight": "model-00002-of-00003.safetensors",
|
| 160 |
+
"model.layers.23.self_attn.v_proj.weight": "model-00002-of-00003.safetensors",
|
| 161 |
+
"model.layers.24.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 162 |
+
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 163 |
+
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 164 |
+
"model.layers.24.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 165 |
+
"model.layers.24.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 166 |
+
"model.layers.24.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 167 |
+
"model.layers.24.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 168 |
+
"model.layers.24.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 169 |
+
"model.layers.24.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 170 |
+
"model.layers.25.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 171 |
+
"model.layers.25.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 172 |
+
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 173 |
+
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 174 |
+
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 175 |
+
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 176 |
+
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 177 |
+
"model.layers.25.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 178 |
+
"model.layers.25.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 179 |
+
"model.layers.26.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 180 |
+
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 181 |
+
"model.layers.26.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 182 |
+
"model.layers.26.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 183 |
+
"model.layers.26.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 184 |
+
"model.layers.26.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 185 |
+
"model.layers.26.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 186 |
+
"model.layers.26.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 187 |
+
"model.layers.26.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 188 |
+
"model.layers.27.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 189 |
+
"model.layers.27.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 190 |
+
"model.layers.27.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 191 |
+
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 192 |
+
"model.layers.27.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 193 |
+
"model.layers.27.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 194 |
+
"model.layers.27.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 195 |
+
"model.layers.27.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 196 |
+
"model.layers.27.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 197 |
+
"model.layers.28.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 198 |
+
"model.layers.28.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 199 |
+
"model.layers.28.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 200 |
+
"model.layers.28.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 201 |
+
"model.layers.28.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 202 |
+
"model.layers.28.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 203 |
+
"model.layers.28.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 204 |
+
"model.layers.28.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 205 |
+
"model.layers.28.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 206 |
+
"model.layers.29.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 207 |
+
"model.layers.29.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 208 |
+
"model.layers.29.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 209 |
+
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 210 |
+
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 211 |
+
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 212 |
+
"model.layers.29.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 213 |
+
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 214 |
+
"model.layers.29.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 215 |
+
"model.layers.3.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 216 |
+
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 217 |
+
"model.layers.3.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 218 |
+
"model.layers.3.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 219 |
+
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 220 |
+
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 221 |
+
"model.layers.3.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 222 |
+
"model.layers.3.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 223 |
+
"model.layers.3.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 224 |
+
"model.layers.30.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 225 |
+
"model.layers.30.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 226 |
+
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 227 |
+
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 228 |
+
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 229 |
+
"model.layers.30.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 230 |
+
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 231 |
+
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 232 |
+
"model.layers.30.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 233 |
+
"model.layers.31.input_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 234 |
+
"model.layers.31.mlp.down_proj.weight": "model-00003-of-00003.safetensors",
|
| 235 |
+
"model.layers.31.mlp.gate_proj.weight": "model-00003-of-00003.safetensors",
|
| 236 |
+
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00003.safetensors",
|
| 237 |
+
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00003.safetensors",
|
| 238 |
+
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00003.safetensors",
|
| 239 |
+
"model.layers.31.self_attn.o_proj.weight": "model-00003-of-00003.safetensors",
|
| 240 |
+
"model.layers.31.self_attn.q_proj.weight": "model-00003-of-00003.safetensors",
|
| 241 |
+
"model.layers.31.self_attn.v_proj.weight": "model-00003-of-00003.safetensors",
|
| 242 |
+
"model.layers.4.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 243 |
+
"model.layers.4.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 244 |
+
"model.layers.4.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 245 |
+
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 246 |
+
"model.layers.4.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 247 |
+
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 248 |
+
"model.layers.4.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 249 |
+
"model.layers.4.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 250 |
+
"model.layers.4.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 251 |
+
"model.layers.5.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 252 |
+
"model.layers.5.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 253 |
+
"model.layers.5.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 254 |
+
"model.layers.5.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 255 |
+
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 256 |
+
"model.layers.5.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 257 |
+
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 258 |
+
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 259 |
+
"model.layers.5.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 260 |
+
"model.layers.6.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 261 |
+
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 262 |
+
"model.layers.6.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 263 |
+
"model.layers.6.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 264 |
+
"model.layers.6.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 265 |
+
"model.layers.6.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 266 |
+
"model.layers.6.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 267 |
+
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 268 |
+
"model.layers.6.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 269 |
+
"model.layers.7.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 270 |
+
"model.layers.7.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 271 |
+
"model.layers.7.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 272 |
+
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 273 |
+
"model.layers.7.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 274 |
+
"model.layers.7.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 275 |
+
"model.layers.7.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 276 |
+
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 277 |
+
"model.layers.7.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 278 |
+
"model.layers.8.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 279 |
+
"model.layers.8.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 280 |
+
"model.layers.8.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 281 |
+
"model.layers.8.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 282 |
+
"model.layers.8.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 283 |
+
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 284 |
+
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 285 |
+
"model.layers.8.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 286 |
+
"model.layers.8.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 287 |
+
"model.layers.9.input_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 288 |
+
"model.layers.9.mlp.down_proj.weight": "model-00001-of-00003.safetensors",
|
| 289 |
+
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00003.safetensors",
|
| 290 |
+
"model.layers.9.mlp.up_proj.weight": "model-00001-of-00003.safetensors",
|
| 291 |
+
"model.layers.9.post_attention_layernorm.weight": "model-00001-of-00003.safetensors",
|
| 292 |
+
"model.layers.9.self_attn.k_proj.weight": "model-00001-of-00003.safetensors",
|
| 293 |
+
"model.layers.9.self_attn.o_proj.weight": "model-00001-of-00003.safetensors",
|
| 294 |
+
"model.layers.9.self_attn.q_proj.weight": "model-00001-of-00003.safetensors",
|
| 295 |
+
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00003.safetensors",
|
| 296 |
+
"model.norm.weight": "model-00003-of-00003.safetensors"
|
| 297 |
+
}
|
| 298 |
+
}
|
qat_modules.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from quantization import BitNetQuantSTE, PhaseQuantSTE, PhaseQuantSTE_V2, PhaseQuantSTE_V3, PhaseQuantSTE_V4
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
class QATLinearBitNet(nn.Linear):
|
| 8 |
+
"""BitNet QAT linear layer"""
|
| 9 |
+
def __init__(self, *args, **kwargs):
|
| 10 |
+
super().__init__(*args, **kwargs)
|
| 11 |
+
|
| 12 |
+
def forward(self, x):
|
| 13 |
+
quantized_weight = BitNetQuantSTE.apply(self.weight)
|
| 14 |
+
return F.linear(x, quantized_weight, self.bias)
|
| 15 |
+
|
| 16 |
+
class QATLinearComplexPhaseV1(nn.Linear):
|
| 17 |
+
"""Complex-Phase V1 QAT linear layer"""
|
| 18 |
+
def __init__(self, *args, **kwargs):
|
| 19 |
+
super().__init__(*args, **kwargs)
|
| 20 |
+
if self.in_features % 2 != 0 or self.out_features % 2 != 0:
|
| 21 |
+
raise ValueError("Complex-Phase QAT requires even in/out features for Linear layers.")
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
A = self.weight
|
| 25 |
+
n, m = A.shape[0] // 2, A.shape[1] // 2
|
| 26 |
+
A11, A12 = A[:n, :m], A[:n, m:]
|
| 27 |
+
A21, A22 = A[n:, :m], A[n:, m:]
|
| 28 |
+
|
| 29 |
+
U_re = 0.5 * (A11 + A22)
|
| 30 |
+
U_im = 0.5 * (A21 - A12)
|
| 31 |
+
W_re = 0.5 * (A11 - A22)
|
| 32 |
+
W_im = 0.5 * (A12 + A21)
|
| 33 |
+
|
| 34 |
+
U_re_q, U_im_q = PhaseQuantSTE.apply(U_re, U_im)
|
| 35 |
+
W_re_q, W_im_q = PhaseQuantSTE.apply(W_re, W_im)
|
| 36 |
+
|
| 37 |
+
A11_q = W_re_q + U_re_q
|
| 38 |
+
A12_q = W_im_q - U_im_q
|
| 39 |
+
A21_q = W_im_q + U_im_q
|
| 40 |
+
A22_q = -W_re_q + U_re_q
|
| 41 |
+
|
| 42 |
+
A_quant_top = torch.cat([A11_q, A12_q], dim=1)
|
| 43 |
+
A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
|
| 44 |
+
A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
|
| 45 |
+
|
| 46 |
+
return F.linear(x, A_quant, self.bias)
|
| 47 |
+
|
| 48 |
+
class QATLinearComplexPhaseV2(nn.Linear):
|
| 49 |
+
"""Complex-Phase V2 QAT linear layer (1-step residual)"""
|
| 50 |
+
def __init__(self, *args, **kwargs):
|
| 51 |
+
super().__init__(*args, **kwargs)
|
| 52 |
+
if self.in_features % 2 != 0 or self.out_features % 2 != 0:
|
| 53 |
+
raise ValueError("Complex-Phase QAT requires even in/out features for Linear layers.")
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
A = self.weight
|
| 57 |
+
n, m = A.shape[0] // 2, A.shape[1] // 2
|
| 58 |
+
A11, A12 = A[:n, :m], A[:n, m:]
|
| 59 |
+
A21, A22 = A[n:, :m], A[n:, m:]
|
| 60 |
+
|
| 61 |
+
U_re = 0.5 * (A11 + A22)
|
| 62 |
+
U_im = 0.5 * (A21 - A12)
|
| 63 |
+
W_re = 0.5 * (A11 - A22)
|
| 64 |
+
W_im = 0.5 * (A12 + A21)
|
| 65 |
+
|
| 66 |
+
U_re_q, U_im_q = PhaseQuantSTE_V2.apply(U_re, U_im)
|
| 67 |
+
W_re_q, W_im_q = PhaseQuantSTE_V2.apply(W_re, W_im)
|
| 68 |
+
|
| 69 |
+
A11_q = W_re_q + U_re_q
|
| 70 |
+
A12_q = W_im_q - U_im_q
|
| 71 |
+
A21_q = W_im_q + U_im_q
|
| 72 |
+
A22_q = -W_re_q + U_re_q
|
| 73 |
+
|
| 74 |
+
A_quant_top = torch.cat([A11_q, A12_q], dim=1)
|
| 75 |
+
A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
|
| 76 |
+
A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
|
| 77 |
+
|
| 78 |
+
return F.linear(x, A_quant, self.bias)
|
| 79 |
+
|
| 80 |
+
class QATLinearComplexPhaseV3(nn.Linear):
|
| 81 |
+
"""Complex-Phase V3 QAT linear layer (2-step residual)"""
|
| 82 |
+
def __init__(self, *args, **kwargs):
|
| 83 |
+
super().__init__(*args, **kwargs)
|
| 84 |
+
if self.in_features % 2 != 0 or self.out_features % 2 != 0:
|
| 85 |
+
raise ValueError("Complex-Phase QAT requires even in/out features for Linear layers.")
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
A = self.weight
|
| 89 |
+
n, m = A.shape[0] // 2, A.shape[1] // 2
|
| 90 |
+
A11, A12 = A[:n, :m], A[:n, m:]
|
| 91 |
+
A21, A22 = A[n:, :m], A[n:, m:]
|
| 92 |
+
|
| 93 |
+
U_re = 0.5 * (A11 + A22)
|
| 94 |
+
U_im = 0.5 * (A21 - A12)
|
| 95 |
+
W_re = 0.5 * (A11 - A22)
|
| 96 |
+
W_im = 0.5 * (A12 + A21)
|
| 97 |
+
|
| 98 |
+
U_re_q, U_im_q = PhaseQuantSTE_V3.apply(U_re, U_im)
|
| 99 |
+
W_re_q, W_im_q = PhaseQuantSTE_V3.apply(W_re, W_im)
|
| 100 |
+
|
| 101 |
+
A11_q = W_re_q + U_re_q
|
| 102 |
+
A12_q = W_im_q - U_im_q
|
| 103 |
+
A21_q = W_im_q + U_im_q
|
| 104 |
+
A22_q = -W_re_q + U_re_q
|
| 105 |
+
|
| 106 |
+
A_quant_top = torch.cat([A11_q, A12_q], dim=1)
|
| 107 |
+
A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
|
| 108 |
+
A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
|
| 109 |
+
|
| 110 |
+
return F.linear(x, A_quant, self.bias)
|
| 111 |
+
|
| 112 |
+
class QATLinearComplexPhaseV4(nn.Linear):
|
| 113 |
+
"""Complex-Phase V4 QAT linear layer (3-step residual)"""
|
| 114 |
+
def __init__(self, *args, **kwargs):
|
| 115 |
+
super().__init__(*args, **kwargs)
|
| 116 |
+
if self.in_features % 2 != 0 or self.out_features % 2 != 0:
|
| 117 |
+
raise ValueError("Complex-Phase QAT requires even in/out features for Linear layers.")
|
| 118 |
+
|
| 119 |
+
def forward(self, x):
|
| 120 |
+
A = self.weight
|
| 121 |
+
n, m = A.shape[0] // 2, A.shape[1] // 2
|
| 122 |
+
A11, A12 = A[:n, :m], A[:n, m:]
|
| 123 |
+
A21, A22 = A[n:, :m], A[n:, m:]
|
| 124 |
+
|
| 125 |
+
U_re = 0.5 * (A11 + A22)
|
| 126 |
+
U_im = 0.5 * (A21 - A12)
|
| 127 |
+
W_re = 0.5 * (A11 - A22)
|
| 128 |
+
W_im = 0.5 * (A12 + A21)
|
| 129 |
+
|
| 130 |
+
U_re_q, U_im_q = PhaseQuantSTE_V4.apply(U_re, U_im)
|
| 131 |
+
W_re_q, W_im_q = PhaseQuantSTE_V4.apply(W_re, W_im)
|
| 132 |
+
|
| 133 |
+
A11_q = W_re_q + U_re_q
|
| 134 |
+
A12_q = W_im_q - U_im_q
|
| 135 |
+
A21_q = W_im_q + U_im_q
|
| 136 |
+
A22_q = -W_re_q + U_re_q
|
| 137 |
+
|
| 138 |
+
A_quant_top = torch.cat([A11_q, A12_q], dim=1)
|
| 139 |
+
A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
|
| 140 |
+
A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
|
| 141 |
+
|
| 142 |
+
return F.linear(x, A_quant, self.bias)
|
| 143 |
+
|
| 144 |
+
METHOD_MAP = {
|
| 145 |
+
'bitnet': QATLinearBitNet,
|
| 146 |
+
'complex_phase_v1': QATLinearComplexPhaseV1,
|
| 147 |
+
'complex_phase_v2': QATLinearComplexPhaseV2,
|
| 148 |
+
'complex_phase_v3': QATLinearComplexPhaseV3,
|
| 149 |
+
'complex_phase_v4': QATLinearComplexPhaseV4,
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
def replace_modules_for_qat(model: nn.Module, method: str, skip_lm_head: bool = False):
|
| 153 |
+
"""Recursively replace nn.Linear layers in the model with QAT layers"""
|
| 154 |
+
if method not in METHOD_MAP:
|
| 155 |
+
raise ValueError(f"Unknown method: {method}. Available methods: {list(METHOD_MAP.keys())}")
|
| 156 |
+
|
| 157 |
+
TargetQATClass = METHOD_MAP[method]
|
| 158 |
+
|
| 159 |
+
for name, module in model.named_children():
|
| 160 |
+
if len(list(module.children())) > 0:
|
| 161 |
+
replace_modules_for_qat(module, method, skip_lm_head)
|
| 162 |
+
|
| 163 |
+
if isinstance(module, nn.Linear):
|
| 164 |
+
if skip_lm_head and name == 'lm_head':
|
| 165 |
+
print(f" -> Skipping lm_head layer (skip_lm_head=True)")
|
| 166 |
+
continue
|
| 167 |
+
|
| 168 |
+
if 'complex_phase' in method:
|
| 169 |
+
if module.in_features % 2 != 0 or module.out_features % 2 != 0:
|
| 170 |
+
print(f" -> Skipping Complex-Phase replacement (non-even dimensions): {name} ({module.in_features}, {module.out_features})")
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
print(f" -> Replacing layer: {name} with {TargetQATClass.__name__}")
|
| 174 |
+
new_module = TargetQATClass(
|
| 175 |
+
module.in_features,
|
| 176 |
+
module.out_features,
|
| 177 |
+
bias=module.bias is not None,
|
| 178 |
+
dtype=module.weight.dtype,
|
| 179 |
+
device=module.weight.device
|
| 180 |
+
)
|
| 181 |
+
new_module.weight.data.copy_(module.weight.data)
|
| 182 |
+
if module.bias is not None:
|
| 183 |
+
new_module.bias.data.copy_(module.bias.data)
|
| 184 |
+
|
| 185 |
+
setattr(model, name, new_module)
|
| 186 |
+
|
| 187 |
+
class InferenceOptimizedBitNet(nn.Linear):
|
| 188 |
+
"""Inference-optimized BitNet linear layer, in-place weight replacement to save memory"""
|
| 189 |
+
def __init__(self, *args, **kwargs):
|
| 190 |
+
super().__init__(*args, **kwargs)
|
| 191 |
+
self._is_quantized = False
|
| 192 |
+
|
| 193 |
+
def _ensure_quantized(self):
|
| 194 |
+
"""Ensure weights are quantized, executed only once"""
|
| 195 |
+
if not self._is_quantized:
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
w = self.weight
|
| 198 |
+
scale = w.abs().mean()
|
| 199 |
+
alpha = w.mean()
|
| 200 |
+
centered_w = w - alpha
|
| 201 |
+
binarized_w = torch.where(centered_w > 0, 1.0, -1.0).to(w.dtype)
|
| 202 |
+
quantized_w = binarized_w * scale
|
| 203 |
+
self.weight.data = quantized_w
|
| 204 |
+
self._is_quantized = True
|
| 205 |
+
|
| 206 |
+
def forward(self, x):
|
| 207 |
+
self._ensure_quantized()
|
| 208 |
+
return F.linear(x, self.weight, self.bias)
|
| 209 |
+
|
| 210 |
+
class InferenceOptimizedComplexPhase(nn.Linear):
|
| 211 |
+
"""Inference-optimized Complex Phase linear layer, supports V1-V4"""
|
| 212 |
+
def __init__(self, version="v1", *args, **kwargs):
|
| 213 |
+
super().__init__(*args, **kwargs)
|
| 214 |
+
if self.in_features % 2 != 0 or self.out_features % 2 != 0:
|
| 215 |
+
raise ValueError("Complex-Phase requires even in/out features.")
|
| 216 |
+
self._is_quantized = False
|
| 217 |
+
self._version = version.lower()
|
| 218 |
+
if self._version not in ["v1", "v2", "v3", "v4"]:
|
| 219 |
+
raise ValueError(f"Unsupported version: {version}. Must be one of ['v1', 'v2', 'v3', 'v4']")
|
| 220 |
+
|
| 221 |
+
def _ensure_quantized(self):
|
| 222 |
+
"""Ensure weights are quantized, executed only once"""
|
| 223 |
+
if not self._is_quantized:
|
| 224 |
+
with torch.no_grad():
|
| 225 |
+
A = self.weight
|
| 226 |
+
n, m = A.shape[0] // 2, A.shape[1] // 2
|
| 227 |
+
A11, A12 = A[:n, :m], A[:n, m:]
|
| 228 |
+
A21, A22 = A[n:, :m], A[n:, m:]
|
| 229 |
+
|
| 230 |
+
U_re = 0.5 * (A11 + A22)
|
| 231 |
+
U_im = 0.5 * (A21 - A12)
|
| 232 |
+
W_re = 0.5 * (A11 - A22)
|
| 233 |
+
W_im = 0.5 * (A12 + A21)
|
| 234 |
+
|
| 235 |
+
if self._version == "v1":
|
| 236 |
+
U_re_q, U_im_q = self._phase_quant_v1(U_re, U_im)
|
| 237 |
+
W_re_q, W_im_q = self._phase_quant_v1(W_re, W_im)
|
| 238 |
+
elif self._version == "v2":
|
| 239 |
+
U_re_q, U_im_q = self._phase_quant_v2(U_re, U_im)
|
| 240 |
+
W_re_q, W_im_q = self._phase_quant_v2(W_re, W_im)
|
| 241 |
+
elif self._version == "v3":
|
| 242 |
+
U_re_q, U_im_q = self._phase_quant_v3(U_re, U_im)
|
| 243 |
+
W_re_q, W_im_q = self._phase_quant_v3(W_re, W_im)
|
| 244 |
+
elif self._version == "v4":
|
| 245 |
+
U_re_q, U_im_q = self._phase_quant_v4(U_re, U_im)
|
| 246 |
+
W_re_q, W_im_q = self._phase_quant_v4(W_re, W_im)
|
| 247 |
+
|
| 248 |
+
A11_q = W_re_q + U_re_q
|
| 249 |
+
A12_q = W_im_q - U_im_q
|
| 250 |
+
A21_q = W_im_q + U_im_q
|
| 251 |
+
A22_q = -W_re_q + U_re_q
|
| 252 |
+
|
| 253 |
+
A_quant_top = torch.cat([A11_q, A12_q], dim=1)
|
| 254 |
+
A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
|
| 255 |
+
A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
|
| 256 |
+
|
| 257 |
+
self.weight.data = A_quant
|
| 258 |
+
self._is_quantized = True
|
| 259 |
+
|
| 260 |
+
def _phase_quant_v1(self, w_real, w_imag):
|
| 261 |
+
"""V1: Basic PhaseQuant"""
|
| 262 |
+
phase = torch.angle(w_real + 1j * w_imag)
|
| 263 |
+
|
| 264 |
+
real_pos = (phase >= -math.pi / 4) & (phase < math.pi / 4)
|
| 265 |
+
real_neg = (phase >= 3 * math.pi / 4) | (phase < -3 * math.pi / 4)
|
| 266 |
+
imag_pos = (phase >= math.pi / 4) & (phase < 3 * math.pi / 4)
|
| 267 |
+
imag_neg = (phase >= -3 * math.pi / 4) & (phase < -math.pi / 4)
|
| 268 |
+
|
| 269 |
+
mask_real = real_pos | real_neg
|
| 270 |
+
mask_imag = imag_pos | imag_neg
|
| 271 |
+
|
| 272 |
+
s_re = w_real[mask_real].abs().mean() if mask_real.any() else torch.tensor(0.0, device=w_real.device)
|
| 273 |
+
s_im = w_imag[mask_imag].abs().mean() if mask_imag.any() else torch.tensor(0.0, device=w_imag.device)
|
| 274 |
+
|
| 275 |
+
s_re = torch.clamp(s_re, min=1e-6)
|
| 276 |
+
s_im = torch.clamp(s_im, min=1e-6)
|
| 277 |
+
|
| 278 |
+
qw_real = torch.zeros_like(w_real)
|
| 279 |
+
qw_imag = torch.zeros_like(w_imag)
|
| 280 |
+
|
| 281 |
+
qw_real[real_pos] = 1.0
|
| 282 |
+
qw_real[real_neg] = -1.0
|
| 283 |
+
qw_imag[imag_pos] = 1.0
|
| 284 |
+
qw_imag[imag_neg] = -1.0
|
| 285 |
+
|
| 286 |
+
return qw_real * s_re, qw_imag * s_im
|
| 287 |
+
|
| 288 |
+
def _phase_quant_v2(self, w_real, w_imag):
|
| 289 |
+
"""V2: 1-step residual quantization"""
|
| 290 |
+
qw_real_o1, qw_imag_o1 = self._phase_quant_v1(w_real, w_imag)
|
| 291 |
+
error_real = w_real - qw_real_o1
|
| 292 |
+
error_imag = w_imag - qw_imag_o1
|
| 293 |
+
qw_real_o2, qw_imag_o2 = self._phase_quant_v1(error_real, error_imag)
|
| 294 |
+
qw_real = qw_real_o1 + qw_real_o2
|
| 295 |
+
qw_imag = qw_imag_o1 + qw_imag_o2
|
| 296 |
+
return qw_real, qw_imag
|
| 297 |
+
|
| 298 |
+
def _phase_quant_v3(self, w_real, w_imag):
|
| 299 |
+
"""V3: 2-step residual quantization"""
|
| 300 |
+
qw_real_o1, qw_imag_o1 = self._phase_quant_v1(w_real, w_imag)
|
| 301 |
+
error_real_1 = w_real - qw_real_o1
|
| 302 |
+
error_imag_1 = w_imag - qw_imag_o1
|
| 303 |
+
qw_real_o2, qw_imag_o2 = self._phase_quant_v1(error_real_1, error_imag_1)
|
| 304 |
+
error_real_2 = error_real_1 - qw_real_o2
|
| 305 |
+
error_imag_2 = error_imag_1 - qw_imag_o2
|
| 306 |
+
qw_real_o3, qw_imag_o3 = self._phase_quant_v1(error_real_2, error_imag_2)
|
| 307 |
+
qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3
|
| 308 |
+
qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3
|
| 309 |
+
return qw_real, qw_imag
|
| 310 |
+
|
| 311 |
+
def _phase_quant_v4(self, w_real, w_imag):
|
| 312 |
+
"""V4: 3-step residual quantization"""
|
| 313 |
+
qw_real_o1, qw_imag_o1 = self._phase_quant_v1(w_real, w_imag)
|
| 314 |
+
error_real_1 = w_real - qw_real_o1
|
| 315 |
+
error_imag_1 = w_imag - qw_imag_o1
|
| 316 |
+
qw_real_o2, qw_imag_o2 = self._phase_quant_v1(error_real_1, error_imag_1)
|
| 317 |
+
error_real_2 = error_real_1 - qw_real_o2
|
| 318 |
+
error_imag_2 = error_imag_1 - qw_imag_o2
|
| 319 |
+
qw_real_o3, qw_imag_o3 = self._phase_quant_v1(error_real_2, error_imag_2)
|
| 320 |
+
error_real_3 = error_real_2 - qw_real_o3
|
| 321 |
+
error_imag_3 = error_imag_2 - qw_imag_o3
|
| 322 |
+
qw_real_o4, qw_imag_o4 = self._phase_quant_v1(error_real_3, error_imag_3)
|
| 323 |
+
qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3 + qw_real_o4
|
| 324 |
+
qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3 + qw_imag_o4
|
| 325 |
+
return qw_real, qw_imag
|
| 326 |
+
|
| 327 |
+
def forward(self, x):
|
| 328 |
+
self._ensure_quantized()
|
| 329 |
+
return F.linear(x, self.weight, self.bias)
|
| 330 |
+
|
| 331 |
+
def convert_to_inference_mode(model):
|
| 332 |
+
"""Convert QAT modules to inference-optimized version (permanently modifies model weights)"""
|
| 333 |
+
converted_count = 0
|
| 334 |
+
|
| 335 |
+
def _convert_module(module, name_path=""):
|
| 336 |
+
nonlocal converted_count
|
| 337 |
+
|
| 338 |
+
for name, child in list(module.named_children()):
|
| 339 |
+
full_name = f"{name_path}.{name}" if name_path else name
|
| 340 |
+
|
| 341 |
+
if isinstance(child, QATLinearBitNet):
|
| 342 |
+
new_module = InferenceOptimizedBitNet(
|
| 343 |
+
child.in_features,
|
| 344 |
+
child.out_features,
|
| 345 |
+
bias=child.bias is not None,
|
| 346 |
+
device=child.weight.device,
|
| 347 |
+
dtype=child.weight.dtype
|
| 348 |
+
)
|
| 349 |
+
new_module.weight.data.copy_(child.weight.data)
|
| 350 |
+
if child.bias is not None:
|
| 351 |
+
new_module.bias.data.copy_(child.bias.data)
|
| 352 |
+
|
| 353 |
+
setattr(module, name, new_module)
|
| 354 |
+
converted_count += 1
|
| 355 |
+
print(f" -> Converting BitNet layer: {full_name}")
|
| 356 |
+
|
| 357 |
+
elif isinstance(child, (QATLinearComplexPhaseV1, QATLinearComplexPhaseV2,
|
| 358 |
+
QATLinearComplexPhaseV3, QATLinearComplexPhaseV4)):
|
| 359 |
+
if isinstance(child, QATLinearComplexPhaseV1):
|
| 360 |
+
version = "v1"
|
| 361 |
+
elif isinstance(child, QATLinearComplexPhaseV2):
|
| 362 |
+
version = "v2"
|
| 363 |
+
elif isinstance(child, QATLinearComplexPhaseV3):
|
| 364 |
+
version = "v3"
|
| 365 |
+
elif isinstance(child, QATLinearComplexPhaseV4):
|
| 366 |
+
version = "v4"
|
| 367 |
+
|
| 368 |
+
new_module = InferenceOptimizedComplexPhase(
|
| 369 |
+
version=version,
|
| 370 |
+
in_features=child.in_features,
|
| 371 |
+
out_features=child.out_features,
|
| 372 |
+
bias=child.bias is not None,
|
| 373 |
+
device=child.weight.device,
|
| 374 |
+
dtype=child.weight.dtype
|
| 375 |
+
)
|
| 376 |
+
new_module.weight.data.copy_(child.weight.data)
|
| 377 |
+
if child.bias is not None:
|
| 378 |
+
new_module.bias.data.copy_(child.bias.data)
|
| 379 |
+
|
| 380 |
+
setattr(module, name, new_module)
|
| 381 |
+
converted_count += 1
|
| 382 |
+
print(f" -> Converting ComplexPhase{version.upper()} layer: {full_name}")
|
| 383 |
+
else:
|
| 384 |
+
_convert_module(child, full_name)
|
| 385 |
+
|
| 386 |
+
_convert_module(model)
|
| 387 |
+
print(f"Converted {converted_count} QAT layers to inference-optimized version")
|
| 388 |
+
return model
|
quantization.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
@torch.no_grad()
|
| 6 |
+
def quantize_complex_tensor(w_real: torch.Tensor, w_imag: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 7 |
+
"""Apply PhaseQuant logic to complex weight tensors"""
|
| 8 |
+
phase = torch.angle(w_real + 1j * w_imag)
|
| 9 |
+
|
| 10 |
+
real_pos = (phase >= -math.pi / 4) & (phase < math.pi / 4)
|
| 11 |
+
real_neg = (phase >= 3 * math.pi / 4) | (phase < -3 * math.pi / 4)
|
| 12 |
+
imag_pos = (phase >= math.pi / 4) & (phase < 3 * math.pi / 4)
|
| 13 |
+
imag_neg = (phase >= -3 * math.pi / 4) & (phase < -math.pi / 4)
|
| 14 |
+
|
| 15 |
+
mask_real = real_pos | real_neg
|
| 16 |
+
mask_imag = imag_pos | imag_neg
|
| 17 |
+
|
| 18 |
+
s_re = w_real[mask_real].abs().mean() if mask_real.any() else torch.tensor(0.0, device=w_real.device)
|
| 19 |
+
s_im = w_imag[mask_imag].abs().mean() if mask_imag.any() else torch.tensor(0.0, device=w_imag.device)
|
| 20 |
+
|
| 21 |
+
s_re = torch.clamp(s_re, min=1e-6)
|
| 22 |
+
s_im = torch.clamp(s_im, min=1e-6)
|
| 23 |
+
if torch.isnan(s_re) or torch.isinf(s_re): s_re = torch.tensor(1e-6, device=w_real.device)
|
| 24 |
+
if torch.isnan(s_im) or torch.isinf(s_im): s_im = torch.tensor(1e-6, device=w_imag.device)
|
| 25 |
+
|
| 26 |
+
qw_real = torch.zeros_like(w_real)
|
| 27 |
+
qw_imag = torch.zeros_like(w_imag)
|
| 28 |
+
|
| 29 |
+
qw_real[real_pos] = 1.0
|
| 30 |
+
qw_real[real_neg] = -1.0
|
| 31 |
+
qw_imag[imag_pos] = 1.0
|
| 32 |
+
qw_imag[imag_neg] = -1.0
|
| 33 |
+
|
| 34 |
+
qw_real_scaled = qw_real * s_re
|
| 35 |
+
qw_imag_scaled = qw_imag * s_im
|
| 36 |
+
return qw_real_scaled.to(w_real.dtype), qw_imag_scaled.to(w_imag.dtype)
|
| 37 |
+
|
| 38 |
+
def apply_complex_inspired_quantization(model: nn.Module):
|
| 39 |
+
"""Apply complex-inspired quantization to real-valued model"""
|
| 40 |
+
print("Applying complex-inspired quantization (PhaseQuant-based)...")
|
| 41 |
+
|
| 42 |
+
@torch.no_grad()
|
| 43 |
+
def quantize_linear_layer(module: nn.Linear):
|
| 44 |
+
A = module.weight.data
|
| 45 |
+
if A.shape[0] % 2 != 0 or A.shape[1] % 2 != 0:
|
| 46 |
+
print(f" -> Skipping layer (non-even dimensions): {A.shape}")
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
n, m = A.shape[0] // 2, A.shape[1] // 2
|
| 50 |
+
A11, A12 = A[:n, :m], A[:n, m:]
|
| 51 |
+
A21, A22 = A[n:, :m], A[n:, m:]
|
| 52 |
+
|
| 53 |
+
U_re = 0.5 * (A11 + A22)
|
| 54 |
+
U_im = 0.5 * (A21 - A12)
|
| 55 |
+
W_re = 0.5 * (A11 - A22)
|
| 56 |
+
W_im = 0.5 * (A12 + A21)
|
| 57 |
+
|
| 58 |
+
U_re_q, U_im_q = quantize_complex_tensor(U_re, U_im)
|
| 59 |
+
W_re_q, W_im_q = quantize_complex_tensor(W_re, W_im)
|
| 60 |
+
|
| 61 |
+
A11_q = W_re_q + U_re_q
|
| 62 |
+
A12_q = W_im_q - U_im_q
|
| 63 |
+
A21_q = W_im_q + U_im_q
|
| 64 |
+
A22_q = -W_re_q + U_re_q
|
| 65 |
+
|
| 66 |
+
A_quant_top = torch.cat([A11_q, A12_q], dim=1)
|
| 67 |
+
A_quant_bottom = torch.cat([A21_q, A22_q], dim=1)
|
| 68 |
+
A_quant = torch.cat([A_quant_top, A_quant_bottom], dim=0)
|
| 69 |
+
|
| 70 |
+
module.weight.data = A_quant.to(A.dtype)
|
| 71 |
+
|
| 72 |
+
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
|
| 73 |
+
print("Complex-inspired quantization completed.")
|
| 74 |
+
return model
|
| 75 |
+
|
| 76 |
+
def apply_bitnet_quantization(model: nn.Module):
|
| 77 |
+
"""Apply BitNet 1-bit quantization to real-valued model"""
|
| 78 |
+
print("Applying BitNet (true 1-bit, affine) quantization to real-valued model...")
|
| 79 |
+
|
| 80 |
+
@torch.no_grad()
|
| 81 |
+
def quantize_linear_layer(module: nn.Linear):
|
| 82 |
+
scale = module.weight.data.abs().mean()
|
| 83 |
+
alpha = module.weight.data.mean()
|
| 84 |
+
centered_weights = module.weight.data - alpha
|
| 85 |
+
binarized_weights = torch.where(centered_weights > 0, 1.0, -1.0)
|
| 86 |
+
module.weight.data = binarized_weights.to(module.weight.data.dtype) * scale
|
| 87 |
+
|
| 88 |
+
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
|
| 89 |
+
print("BitNet quantization completed.")
|
| 90 |
+
return model
|
| 91 |
+
|
| 92 |
+
def apply_bitnet_1_58bit_quantization_standard(model: nn.Module):
|
| 93 |
+
"""Apply BitNet 1.58-bit quantization to real-valued model (quantize to {-1, 0, +1})"""
|
| 94 |
+
print("Applying BitNet 1.58-bit (absmean threshold) quantization to real-valued model...")
|
| 95 |
+
|
| 96 |
+
@torch.no_grad()
|
| 97 |
+
def quantize_linear_layer(module: nn.Linear):
|
| 98 |
+
W = module.weight.data
|
| 99 |
+
gamma = W.abs().mean()
|
| 100 |
+
W_normalized = W / (gamma + 1e-5)
|
| 101 |
+
W_quantized = torch.clamp(torch.round(W_normalized), -1.0, 1.0)
|
| 102 |
+
module.weight.data = W_quantized.to(W.dtype) * gamma
|
| 103 |
+
|
| 104 |
+
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
|
| 105 |
+
print("BitNet 1.58-bit (absmean threshold) quantization completed.")
|
| 106 |
+
return model
|
| 107 |
+
|
| 108 |
+
def apply_bitnet_1_58bit_quantization_variant(model: nn.Module, threshold: float = 0.5):
|
| 109 |
+
"""Apply BitNet 1.58-bit quantization to real-valued model (quantize to {-1, 0, +1})"""
|
| 110 |
+
print("Applying BitNet 1.58-bit (ternary) quantization to real-valued model...")
|
| 111 |
+
|
| 112 |
+
@torch.no_grad()
|
| 113 |
+
def quantize_linear_layer(module: nn.Linear):
|
| 114 |
+
gamma = module.weight.data.abs().mean()
|
| 115 |
+
normalized_weights = module.weight.data / (gamma + 1e-5)
|
| 116 |
+
adaptive_threshold = threshold
|
| 117 |
+
ternary_weights = torch.zeros_like(normalized_weights)
|
| 118 |
+
ternary_weights[normalized_weights > adaptive_threshold] = 1.0
|
| 119 |
+
ternary_weights[normalized_weights < -adaptive_threshold] = -1.0
|
| 120 |
+
module.weight.data = ternary_weights.to(module.weight.data.dtype) * gamma
|
| 121 |
+
|
| 122 |
+
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
|
| 123 |
+
print("BitNet 1.58-bit quantization completed.")
|
| 124 |
+
return model
|
| 125 |
+
|
| 126 |
+
def minmax_1bit_quantize_dequantize(w: torch.Tensor) -> torch.Tensor:
|
| 127 |
+
"""Apply 1-bit Min-Max quantization and dequantization to weight tensor"""
|
| 128 |
+
min_val = w.min()
|
| 129 |
+
max_val = w.max()
|
| 130 |
+
scale = (max_val - min_val) / 1.0
|
| 131 |
+
zero_point = min_val
|
| 132 |
+
|
| 133 |
+
if abs(scale) < 1e-9:
|
| 134 |
+
return w
|
| 135 |
+
|
| 136 |
+
quantized_w = torch.round((w - zero_point) / scale)
|
| 137 |
+
dequantized_w = quantized_w * scale + zero_point
|
| 138 |
+
|
| 139 |
+
return dequantized_w.to(w.dtype)
|
| 140 |
+
|
| 141 |
+
def apply_minmax_1bit_quantization(model: nn.Module):
|
| 142 |
+
"""Apply Min-Max 1-bit quantization to real-valued model"""
|
| 143 |
+
print("Applying Min-Max (1-bit) quantization to real-valued model...")
|
| 144 |
+
|
| 145 |
+
@torch.no_grad()
|
| 146 |
+
def quantize_linear_layer(module: nn.Linear):
|
| 147 |
+
module.weight.data = minmax_1bit_quantize_dequantize(module.weight.data)
|
| 148 |
+
|
| 149 |
+
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
|
| 150 |
+
|
| 151 |
+
print("Min-Max 1-bit quantization completed.")
|
| 152 |
+
return model
|
| 153 |
+
|
| 154 |
+
def symmetric_minmax_1bit_quantize_dequantize(w: torch.Tensor) -> torch.Tensor:
|
| 155 |
+
"""Apply symmetric 1-bit Min-Max quantization to weight tensor (quantize to {-1, 1})"""
|
| 156 |
+
max_abs = w.abs().max()
|
| 157 |
+
scale = max_abs
|
| 158 |
+
|
| 159 |
+
if scale < 1e-9:
|
| 160 |
+
return w
|
| 161 |
+
|
| 162 |
+
quantized_w = (w / scale).sign()
|
| 163 |
+
dequantized_w = quantized_w * scale
|
| 164 |
+
|
| 165 |
+
return dequantized_w.to(w.dtype)
|
| 166 |
+
|
| 167 |
+
def apply_symmetric_minmax_1bit_quantization(model: nn.Module):
|
| 168 |
+
"""Apply symmetric Min-Max 1-bit quantization to real-valued model"""
|
| 169 |
+
print("Applying symmetric Min-Max (1-bit, to {-1, 1}) quantization to real-valued model...")
|
| 170 |
+
|
| 171 |
+
@torch.no_grad()
|
| 172 |
+
def quantize_linear_layer(module: nn.Linear):
|
| 173 |
+
module.weight.data = symmetric_minmax_1bit_quantize_dequantize(module.weight.data)
|
| 174 |
+
|
| 175 |
+
model.apply(lambda module: quantize_linear_layer(module) if isinstance(module, nn.Linear) else None)
|
| 176 |
+
|
| 177 |
+
print("Symmetric Min-Max 1-bit quantization completed.")
|
| 178 |
+
return model
|
| 179 |
+
|
| 180 |
+
class BitNetQuantSTE(torch.autograd.Function):
|
| 181 |
+
"""BitNet STE: quantize in forward, pass gradients in backward"""
|
| 182 |
+
@staticmethod
|
| 183 |
+
def forward(ctx, w):
|
| 184 |
+
scale = w.abs().mean()
|
| 185 |
+
alpha = w.mean()
|
| 186 |
+
centered_w = w - alpha
|
| 187 |
+
binarized_w = torch.where(centered_w > 0, 1.0, -1.0).to(w.dtype)
|
| 188 |
+
quantized_w = binarized_w * scale
|
| 189 |
+
return quantized_w
|
| 190 |
+
|
| 191 |
+
@staticmethod
|
| 192 |
+
def backward(ctx, grad_output):
|
| 193 |
+
return grad_output
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class BitNet1_58QuantSTE(torch.autograd.Function):
|
| 197 |
+
"""BitNet 1.58-bit STE: quantize to {-1, 0, +1}, pass gradients in backward"""
|
| 198 |
+
@staticmethod
|
| 199 |
+
def forward(ctx, w):
|
| 200 |
+
gamma = w.abs().mean()
|
| 201 |
+
w_normalized = w / (gamma + 1e-5)
|
| 202 |
+
w_quantized = torch.clamp(torch.round(w_normalized), -1.0, 1.0)
|
| 203 |
+
quantized_w = (w_quantized * gamma).to(w.dtype)
|
| 204 |
+
return quantized_w
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def backward(ctx, grad_output):
|
| 208 |
+
return grad_output
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class PhaseQuantSTE(torch.autograd.Function):
|
| 212 |
+
"""Complex-Phase STE: quantize in forward, pass gradients in backward"""
|
| 213 |
+
@staticmethod
|
| 214 |
+
def forward(ctx, w_real, w_imag):
|
| 215 |
+
phase = torch.angle(w_real + 1j * w_imag)
|
| 216 |
+
|
| 217 |
+
real_pos = (phase >= -math.pi / 4) & (phase < math.pi / 4)
|
| 218 |
+
real_neg = (phase >= 3 * math.pi / 4) | (phase < -3 * math.pi / 4)
|
| 219 |
+
imag_pos = (phase >= math.pi / 4) & (phase < 3 * math.pi / 4)
|
| 220 |
+
imag_neg = (phase >= -3 * math.pi / 4) & (phase < -math.pi / 4)
|
| 221 |
+
|
| 222 |
+
mask_real = real_pos | real_neg
|
| 223 |
+
mask_imag = imag_pos | imag_neg
|
| 224 |
+
|
| 225 |
+
s_re = w_real[mask_real].abs().mean() if mask_real.any() else torch.tensor(0.0, device=w_real.device)
|
| 226 |
+
s_im = w_imag[mask_imag].abs().mean() if mask_imag.any() else torch.tensor(0.0, device=w_imag.device)
|
| 227 |
+
|
| 228 |
+
s_re = torch.clamp(s_re, min=1e-6)
|
| 229 |
+
s_im = torch.clamp(s_im, min=1e-6)
|
| 230 |
+
|
| 231 |
+
qw_real = torch.zeros_like(w_real)
|
| 232 |
+
qw_imag = torch.zeros_like(w_imag)
|
| 233 |
+
|
| 234 |
+
qw_real[real_pos] = 1.0
|
| 235 |
+
qw_real[real_neg] = -1.0
|
| 236 |
+
qw_imag[imag_pos] = 1.0
|
| 237 |
+
qw_imag[imag_neg] = -1.0
|
| 238 |
+
|
| 239 |
+
qw_real_scaled = qw_real * s_re
|
| 240 |
+
qw_imag_scaled = qw_imag * s_im
|
| 241 |
+
|
| 242 |
+
return qw_real_scaled.to(w_real.dtype), qw_imag_scaled.to(w_imag.dtype)
|
| 243 |
+
|
| 244 |
+
@staticmethod
|
| 245 |
+
def backward(ctx, grad_w_real, grad_w_imag):
|
| 246 |
+
return grad_w_real, grad_w_imag
|
| 247 |
+
|
| 248 |
+
class PhaseQuantSTE_V2(torch.autograd.Function):
|
| 249 |
+
"""Two-step residual quantization"""
|
| 250 |
+
|
| 251 |
+
@staticmethod
|
| 252 |
+
def forward(ctx, w_real: torch.Tensor, w_imag: torch.Tensor):
|
| 253 |
+
qw_real_o1, qw_imag_o1 = PhaseQuantSTE.apply(w_real, w_imag)
|
| 254 |
+
error_real = w_real - qw_real_o1
|
| 255 |
+
error_imag = w_imag - qw_imag_o1
|
| 256 |
+
qw_real_o2, qw_imag_o2 = PhaseQuantSTE.apply(error_real, error_imag)
|
| 257 |
+
qw_real = qw_real_o1 + qw_real_o2
|
| 258 |
+
qw_imag = qw_imag_o1 + qw_imag_o2
|
| 259 |
+
return qw_real, qw_imag
|
| 260 |
+
|
| 261 |
+
@staticmethod
|
| 262 |
+
def backward(ctx, grad_real, grad_imag):
|
| 263 |
+
return grad_real, grad_imag
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class PhaseQuantSTE_V3(torch.autograd.Function):
|
| 267 |
+
"""Three-step residual quantization"""
|
| 268 |
+
|
| 269 |
+
@staticmethod
|
| 270 |
+
def forward(ctx, w_real: torch.Tensor, w_imag: torch.Tensor):
|
| 271 |
+
qw_real_o1, qw_imag_o1 = PhaseQuantSTE.apply(w_real, w_imag)
|
| 272 |
+
error_real_1 = w_real - qw_real_o1
|
| 273 |
+
error_imag_1 = w_imag - qw_imag_o1
|
| 274 |
+
qw_real_o2, qw_imag_o2 = PhaseQuantSTE.apply(error_real_1, error_imag_1)
|
| 275 |
+
error_real_2 = error_real_1 - qw_real_o2
|
| 276 |
+
error_imag_2 = error_imag_1 - qw_imag_o2
|
| 277 |
+
qw_real_o3, qw_imag_o3 = PhaseQuantSTE.apply(error_real_2, error_imag_2)
|
| 278 |
+
qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3
|
| 279 |
+
qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3
|
| 280 |
+
return qw_real, qw_imag
|
| 281 |
+
|
| 282 |
+
@staticmethod
|
| 283 |
+
def backward(ctx, grad_real, grad_imag):
|
| 284 |
+
return grad_real, grad_imag
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
class PhaseQuantSTE_V4(torch.autograd.Function):
|
| 288 |
+
"""Four-step residual quantization"""
|
| 289 |
+
|
| 290 |
+
@staticmethod
|
| 291 |
+
def forward(ctx, w_real: torch.Tensor, w_imag: torch.Tensor):
|
| 292 |
+
qw_real_o1, qw_imag_o1 = PhaseQuantSTE.apply(w_real, w_imag)
|
| 293 |
+
error_real_1 = w_real - qw_real_o1
|
| 294 |
+
error_imag_1 = w_imag - qw_imag_o1
|
| 295 |
+
qw_real_o2, qw_imag_o2 = PhaseQuantSTE.apply(error_real_1, error_imag_1)
|
| 296 |
+
error_real_2 = error_real_1 - qw_real_o2
|
| 297 |
+
error_imag_2 = error_imag_1 - qw_imag_o2
|
| 298 |
+
qw_real_o3, qw_imag_o3 = PhaseQuantSTE.apply(error_real_2, error_imag_2)
|
| 299 |
+
error_real_3 = error_real_2 - qw_real_o3
|
| 300 |
+
error_imag_3 = error_imag_2 - qw_imag_o3
|
| 301 |
+
qw_real_o4, qw_imag_o4 = PhaseQuantSTE.apply(error_real_3, error_imag_3)
|
| 302 |
+
qw_real = qw_real_o1 + qw_real_o2 + qw_real_o3 + qw_real_o4
|
| 303 |
+
qw_imag = qw_imag_o1 + qw_imag_o2 + qw_imag_o3 + qw_imag_o4
|
| 304 |
+
return qw_real, qw_imag
|
| 305 |
+
|
| 306 |
+
@staticmethod
|
| 307 |
+
def backward(ctx, grad_real, grad_imag):
|
| 308 |
+
return grad_real, grad_imag
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "</s>",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<unk>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"bos_token": "<s>",
|
| 32 |
+
"clean_up_tokenization_spaces": false,
|
| 33 |
+
"eos_token": "</s>",
|
| 34 |
+
"extra_special_tokens": {},
|
| 35 |
+
"legacy": false,
|
| 36 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 37 |
+
"pad_token": "</s>",
|
| 38 |
+
"padding_side": "right",
|
| 39 |
+
"sp_model_kwargs": {},
|
| 40 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 41 |
+
"unk_token": "<unk>",
|
| 42 |
+
"use_default_system_prompt": false
|
| 43 |
+
}
|
training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e9aacb37f66b32727647179ccb831ff830505e9721317432e224e6a2abb2dae7
|
| 3 |
+
size 6929
|