import torch from safetensors.torch import save_file import os files = ["neon213_muon_sota_fp16.pth"] target_map = {"neon213_muon_sota_fp16.pth": "model.safetensors"} for f in files: if not os.path.exists(f): continue print(f"Loading {f}...") state_dict = torch.load(f, map_location="cpu", weights_only=True) # Check for tied weights (common in Neon/Qwen architectures) # If the head and embedding share memory, we remove the head to save if "token_emb.weight" in state_dict and "head.weight" in state_dict: # Check if they are actually the same tensor if state_dict["token_emb.weight"].data_ptr() == state_dict["head.weight"].data_ptr(): print("Detected tied weights (token_emb and head). Removing duplicate for safetensors.") del state_dict["head.weight"] # Rename to .safetensors target = target_map.get(f, f.replace(".pth", ".safetensors")) save_file(state_dict, target) print(f"Successfully created {target}\n")