JiRackTernary_70b / load_packed_70b.py
kgrabko's picture
Update load_packed_70b.py
d6e5758 verified
# ==============================================================================
# COPYRIGHT (C) 2025-2026 KONSTANTIN VLADIMIROVICH GRABKO. ALL RIGHTS RESERVED.
# PATENT PENDING | CMS MANHATTAN JIRACK TECHNOLOGY
# ==============================================================================
import torch
import os
import glob
from safetensors.torch import load_file
from JiRackTernaryPyTorch_70b import JiRackTernary70B
def load_jirack_70b_packed(path, config):
print(f"🏗️ Инициализация структуры 70B в RAM...")
model = JiRackTernary70B(config)
# Карта модулей BitLinear
bit_modules = {name: mod for name, mod in model.named_modules() if hasattr(mod, "packed")}
shards = sorted(glob.glob(os.path.join(path, "*.safetensors")))
print(f"📖 Загрузка весов...")
for shard_path in shards:
current_shard = load_file(shard_path, device="cpu")
print(f" - {os.path.basename(shard_path)}")
for k, v in current_shard.items():
clean_k = k[6:] if k.startswith("model.") else k
suffix = None
pure_path = clean_k
for s in [".packed", ".scale", ".orig_shape"]:
if clean_k.endswith(s):
suffix = s
pure_path = clean_k.replace(".weight" + s, "")
break
model_key = pure_path.replace(".mlp.", ".")
# 1. Загрузка в BitLinear (наши упакованные слои)
if suffix:
if model_key in bit_modules:
mod = bit_modules[model_key]
if suffix == ".packed": mod.packed = v.to(torch.uint8)
elif suffix == ".scale": mod.scale = v.to(torch.float16)
elif suffix == ".orig_shape": mod.orig_shape = v.to(torch.int32)
continue
# 2. Загрузка в обычные слои (Embedding, Norm, lm_head)
try:
# Убираем .weight из ключа, чтобы найти сам параметр
is_weight = clean_k.endswith(".weight")
attr_path = clean_k[:-7] if is_weight else clean_k
parts = attr_path.split(".")
obj = model
for part in parts[:-1]:
obj = getattr(obj, part)
# Находим целевой параметр (например, .weight)
target_name = parts[-1]
target_attr = getattr(obj, target_name)
if is_weight and hasattr(target_attr, 'weight'):
# Если мы нашли слой (например, Embedding), копируем в его .weight
target_attr.weight.data.copy_(v.to(torch.float16))
elif isinstance(target_attr, torch.nn.Parameter):
# Если это уже параметр (например, в RMSNorm)
target_attr.data.copy_(v.to(torch.float16))
else:
# Резервный вариант: если это просто атрибут
setattr(obj, target_name, torch.nn.Parameter(v.to(torch.float16), requires_grad=False))
except Exception:
continue
del current_shard
missing = [name for name, mod in bit_modules.items() if mod.packed is None]
if not missing:
print("✅ Все 560 тернарных слоев загружены успешно!")
else:
print(f"❌ Пропущено {len(missing)} слоев.")
return model