|
|
import torch |
|
|
import safetensors.torch |
|
|
import argparse |
|
|
import pathlib |
|
|
from collections import OrderedDict |
|
|
import logging |
|
|
import io |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
def rename_key(key: str) -> str | None: |
|
|
""" |
|
|
Renames keys from the original .pt checkpoint to match the |
|
|
Hugging Face model structure defined in modeling_evo2.py. |
|
|
|
|
|
Args: |
|
|
key: The original key from the .pt file. |
|
|
|
|
|
Returns: |
|
|
The renamed key, or None if the key should be skipped. |
|
|
""" |
|
|
|
|
|
|
|
|
if "_extra_state" in key: |
|
|
logging.debug(f"Skipping FP8 metadata key: {key}") |
|
|
return None |
|
|
|
|
|
if "filter.t" in key: |
|
|
logging.debug(f"Skipping dynamic buffer key: {key}") |
|
|
return None |
|
|
|
|
|
|
|
|
if key == "embedding_layer.weight": |
|
|
|
|
|
return "backbone.embedding_layer.weight" |
|
|
|
|
|
|
|
|
if key == "norm.scale": |
|
|
|
|
|
return "backbone.norm.scale" |
|
|
|
|
|
|
|
|
if key.startswith("blocks."): |
|
|
|
|
|
parts = key.split('.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "mixer.attn" in key or "mixer.dense" in key or "mixer.mixer" in key: |
|
|
logging.warning(f"Skipping potentially problematic 'mixer' key: {key}") |
|
|
return None |
|
|
|
|
|
|
|
|
return f"backbone.{key}" |
|
|
|
|
|
|
|
|
|
|
|
if key == "unembed.weight": |
|
|
logging.warning(f"Skipping potentially tied unembedding weight: {key}. " |
|
|
"Ensure 'tie_word_embeddings=True' in HF config or handle manually.") |
|
|
return None |
|
|
|
|
|
logging.warning(f"Unhandled key: {key}. Skipping.") |
|
|
return None |
|
|
|
|
|
|
|
|
def convert_pt_to_safetensors(pt_path: pathlib.Path, sf_path: pathlib.Path): |
|
|
""" |
|
|
Loads weights from a .pt file, renames keys for HF compatibility, |
|
|
and saves them to a .safetensors file. |
|
|
|
|
|
Args: |
|
|
pt_path: Path to the input .pt checkpoint file. |
|
|
sf_path: Path to the output .safetensors file. |
|
|
""" |
|
|
logging.info(f"Loading state dict from: {pt_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.serialization.add_safe_globals([io.BytesIO]) |
|
|
logging.info("Added io.BytesIO to safe globals for torch.load.") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True) |
|
|
logging.info(f"Successfully loaded state dict with {len(state_dict)} keys using weights_only=True.") |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to load with weights_only=True: {e}") |
|
|
logging.warning("Attempting to load with weights_only=False. " |
|
|
"Ensure you trust the source of this checkpoint as it may execute arbitrary code.") |
|
|
try: |
|
|
|
|
|
state_dict = torch.load(pt_path, map_location="cpu", weights_only=False) |
|
|
logging.info(f"Successfully loaded state dict with {len(state_dict)} keys using weights_only=False.") |
|
|
except Exception as final_e: |
|
|
logging.error(f"Failed to load state dict even with weights_only=False: {final_e}") |
|
|
return |
|
|
|
|
|
|
|
|
new_state_dict = OrderedDict() |
|
|
original_keys = list(state_dict.keys()) |
|
|
|
|
|
logging.info("Processing and renaming keys...") |
|
|
skipped_count = 0 |
|
|
processed_count = 0 |
|
|
for old_key in original_keys: |
|
|
new_key = rename_key(old_key) |
|
|
if new_key: |
|
|
|
|
|
tensor = state_dict[old_key] |
|
|
if isinstance(tensor, torch.Tensor): |
|
|
new_state_dict[new_key] = tensor |
|
|
processed_count += 1 |
|
|
logging.debug(f"Renamed '{old_key}' -> '{new_key}'") |
|
|
else: |
|
|
|
|
|
|
|
|
skipped_count += 1 |
|
|
logging.warning(f"Skipped key '{old_key}' because its value is not a tensor (type: {type(tensor)}).") |
|
|
else: |
|
|
skipped_count += 1 |
|
|
|
|
|
|
|
|
logging.info(f"Processed {processed_count} tensor keys, skipped {skipped_count} keys/non-tensors.") |
|
|
|
|
|
if not new_state_dict: |
|
|
logging.error("The resulting state dictionary is empty. Check key renaming logic and input file content.") |
|
|
return |
|
|
|
|
|
|
|
|
sf_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
logging.info(f"Saving processed state dict to: {sf_path}") |
|
|
|
|
|
|
|
|
try: |
|
|
safetensors.torch.save_file(new_state_dict, sf_path, metadata=None) |
|
|
logging.info("Successfully saved safetensors file.") |
|
|
except Exception as e: |
|
|
logging.error(f"Failed to save safetensors file: {e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Convert PyTorch .pt checkpoints to .safetensors format.") |
|
|
parser.add_argument("input_pt", type=str, help="Path to the input .pt checkpoint file.") |
|
|
parser.add_argument("output_safetensors", type=str, help="Path to the output .safetensors file.") |
|
|
args = parser.parse_args() |
|
|
|
|
|
input_path = pathlib.Path(args.input_pt) |
|
|
output_path = pathlib.Path(args.output_safetensors) |
|
|
|
|
|
if not input_path.is_file(): |
|
|
logging.error(f"Input file not found: {input_path}") |
|
|
else: |
|
|
convert_pt_to_safetensors(input_path, output_path) |
|
|
|