| |
| |
| |
| |
|
|
| import torch |
| import os |
| import glob |
| from safetensors.torch import load_file |
| from JiRackTernaryPyTorch_70b import JiRackTernary70B |
|
|
| def load_jirack_70b_packed(path, config): |
| print(f"🏗️ Инициализация структуры 70B в RAM...") |
| model = JiRackTernary70B(config) |
| |
| |
| bit_modules = {name: mod for name, mod in model.named_modules() if hasattr(mod, "packed")} |
|
|
| shards = sorted(glob.glob(os.path.join(path, "*.safetensors"))) |
| print(f"📖 Загрузка весов...") |
| |
| for shard_path in shards: |
| current_shard = load_file(shard_path, device="cpu") |
| print(f" - {os.path.basename(shard_path)}") |
| |
| for k, v in current_shard.items(): |
| clean_k = k[6:] if k.startswith("model.") else k |
| |
| suffix = None |
| pure_path = clean_k |
| for s in [".packed", ".scale", ".orig_shape"]: |
| if clean_k.endswith(s): |
| suffix = s |
| pure_path = clean_k.replace(".weight" + s, "") |
| break |
| |
| model_key = pure_path.replace(".mlp.", ".") |
| |
| |
| if suffix: |
| if model_key in bit_modules: |
| mod = bit_modules[model_key] |
| if suffix == ".packed": mod.packed = v.to(torch.uint8) |
| elif suffix == ".scale": mod.scale = v.to(torch.float16) |
| elif suffix == ".orig_shape": mod.orig_shape = v.to(torch.int32) |
| continue |
|
|
| |
| try: |
| |
| is_weight = clean_k.endswith(".weight") |
| attr_path = clean_k[:-7] if is_weight else clean_k |
| |
| parts = attr_path.split(".") |
| obj = model |
| for part in parts[:-1]: |
| obj = getattr(obj, part) |
| |
| |
| target_name = parts[-1] |
| target_attr = getattr(obj, target_name) |
| |
| if is_weight and hasattr(target_attr, 'weight'): |
| |
| target_attr.weight.data.copy_(v.to(torch.float16)) |
| elif isinstance(target_attr, torch.nn.Parameter): |
| |
| target_attr.data.copy_(v.to(torch.float16)) |
| else: |
| |
| setattr(obj, target_name, torch.nn.Parameter(v.to(torch.float16), requires_grad=False)) |
| except Exception: |
| continue |
| |
| del current_shard |
| |
| missing = [name for name, mod in bit_modules.items() if mod.packed is None] |
| if not missing: |
| print("✅ Все 560 тернарных слоев загружены успешно!") |
| else: |
| print(f"❌ Пропущено {len(missing)} слоев.") |
|
|
| return model |
|
|