| import torch | |
| from safetensors.torch import save_file | |
| import os | |
| files = ["neon213_muon_sota_fp16.pth"] | |
| target_map = {"neon213_muon_sota_fp16.pth": "model.safetensors"} | |
| for f in files: | |
| if not os.path.exists(f): | |
| continue | |
| print(f"Loading {f}...") | |
| state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
| # Check for tied weights (common in Neon/Qwen architectures) | |
| # If the head and embedding share memory, we remove the head to save | |
| if "token_emb.weight" in state_dict and "head.weight" in state_dict: | |
| # Check if they are actually the same tensor | |
| if state_dict["token_emb.weight"].data_ptr() == state_dict["head.weight"].data_ptr(): | |
| print("Detected tied weights (token_emb and head). Removing duplicate for safetensors.") | |
| del state_dict["head.weight"] | |
| # Rename to .safetensors | |
| target = target_map.get(f, f.replace(".pth", ".safetensors")) | |
| save_file(state_dict, target) | |
| print(f"Successfully created {target}\n") | |