evo2-7b / convert-weights.py
ishanjmukherjee's picture
Fix PT to safetensors conversion
4b35203
import torch
import safetensors.torch
import argparse
import pathlib
from collections import OrderedDict
import logging
import io # Import the io module
# Configure basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def rename_key(key: str) -> str | None:
"""
Renames keys from the original .pt checkpoint to match the
Hugging Face model structure defined in modeling_evo2.py.
Args:
key: The original key from the .pt file.
Returns:
The renamed key, or None if the key should be skipped.
"""
# Skip FP8 metadata keys - These often contain non-tensor data like BytesIO
# which safetensors can't store directly and aren't usually needed for inference.
if "_extra_state" in key:
logging.debug(f"Skipping FP8 metadata key: {key}")
return None
# Skip dynamically computed 't' buffer if present
if "filter.t" in key:
logging.debug(f"Skipping dynamic buffer key: {key}")
return None
# Handle embedding layer
if key == "embedding_layer.weight":
# Target key in modeling_evo2.py: backbone.embedding_layer.weight
return "backbone.embedding_layer.weight"
# Handle final norm
if key == "norm.scale":
# Target key in modeling_evo2.py: backbone.norm.scale
return "backbone.norm.scale"
# Handle blocks (add backbone prefix)
if key.startswith("blocks."):
# Standard block structure: backbone.blocks.N.layer_name.parameter_name
parts = key.split('.')
# Example: blocks.0.pre_norm.scale -> backbone.blocks.0.pre_norm.scale
# Example: blocks.3.inner_mha_cls.Wqkv.weight -> backbone.blocks.3.inner_mha_cls.Wqkv.weight
# Example: blocks.0.filter.short_filter_weight -> backbone.blocks.0.filter.short_filter_weight
# Example: blocks.0.mlp.l1.weight -> backbone.blocks.0.mlp.l1.weight
# Example: blocks.0.projections.weight -> backbone.blocks.0.projections.weight (TELinear)
# Example: blocks.0.out_filter_dense.weight -> backbone.blocks.0.out_filter_dense.weight
# Skip potentially problematic keys from the provided list that don't seem
# to map directly or might be TE internals we don't want.
# 'mixer' keys seem specific to an older structure or debugging?
if "mixer.attn" in key or "mixer.dense" in key or "mixer.mixer" in key:
logging.warning(f"Skipping potentially problematic 'mixer' key: {key}")
return None
# The general renaming rule adds 'backbone.' prefix
return f"backbone.{key}"
# Skip unembedding weights if tied (handled by HF post_init)
# If embeddings are not tied, this key might need a different target name.
if key == "unembed.weight":
logging.warning(f"Skipping potentially tied unembedding weight: {key}. "
"Ensure 'tie_word_embeddings=True' in HF config or handle manually.")
return None
logging.warning(f"Unhandled key: {key}. Skipping.")
return None
def convert_pt_to_safetensors(pt_path: pathlib.Path, sf_path: pathlib.Path):
"""
Loads weights from a .pt file, renames keys for HF compatibility,
and saves them to a .safetensors file.
Args:
pt_path: Path to the input .pt checkpoint file.
sf_path: Path to the output .safetensors file.
"""
logging.info(f"Loading state dict from: {pt_path}")
# *** Add io.BytesIO to the list of safe globals ***
# This allows loading checkpoints that might contain FP8 metadata
# stored as BytesIO objects, even with weights_only=True.
torch.serialization.add_safe_globals([io.BytesIO])
logging.info("Added io.BytesIO to safe globals for torch.load.")
try:
# Load the state dictionary onto the CPU
# weights_only=True is safer and recommended when possible
state_dict = torch.load(pt_path, map_location="cpu", weights_only=True)
logging.info(f"Successfully loaded state dict with {len(state_dict)} keys using weights_only=True.")
except Exception as e:
logging.error(f"Failed to load with weights_only=True: {e}")
logging.warning("Attempting to load with weights_only=False. "
"Ensure you trust the source of this checkpoint as it may execute arbitrary code.")
try:
# Fallback if weights_only=True fails even after adding BytesIO (less likely but possible)
state_dict = torch.load(pt_path, map_location="cpu", weights_only=False)
logging.info(f"Successfully loaded state dict with {len(state_dict)} keys using weights_only=False.")
except Exception as final_e:
logging.error(f"Failed to load state dict even with weights_only=False: {final_e}")
return # Exit if loading fails completely
# Create a new ordered dictionary for the renamed keys
new_state_dict = OrderedDict()
original_keys = list(state_dict.keys()) # Create a copy for iteration
logging.info("Processing and renaming keys...")
skipped_count = 0
processed_count = 0
for old_key in original_keys:
new_key = rename_key(old_key)
if new_key:
# Ensure the tensor exists and is actually a tensor before adding
tensor = state_dict[old_key]
if isinstance(tensor, torch.Tensor):
new_state_dict[new_key] = tensor
processed_count += 1
logging.debug(f"Renamed '{old_key}' -> '{new_key}'")
else:
# This handles cases where weights_only=False might load non-tensor objects
# that weren't skipped by name (e.g., if _extra_state contained something else)
skipped_count += 1
logging.warning(f"Skipped key '{old_key}' because its value is not a tensor (type: {type(tensor)}).")
else:
skipped_count += 1
# Logging for skipped keys already happens in rename_key
logging.info(f"Processed {processed_count} tensor keys, skipped {skipped_count} keys/non-tensors.")
if not new_state_dict:
logging.error("The resulting state dictionary is empty. Check key renaming logic and input file content.")
return
# Ensure the output directory exists
sf_path.parent.mkdir(parents=True, exist_ok=True)
logging.info(f"Saving processed state dict to: {sf_path}")
# Save the processed state dictionary
# metadata can be added if needed, e.g., {"format": "pt"}
try:
safetensors.torch.save_file(new_state_dict, sf_path, metadata=None)
logging.info("Successfully saved safetensors file.")
except Exception as e:
logging.error(f"Failed to save safetensors file: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert PyTorch .pt checkpoints to .safetensors format.")
parser.add_argument("input_pt", type=str, help="Path to the input .pt checkpoint file.")
parser.add_argument("output_safetensors", type=str, help="Path to the output .safetensors file.")
args = parser.parse_args()
input_path = pathlib.Path(args.input_pt)
output_path = pathlib.Path(args.output_safetensors)
if not input_path.is_file():
logging.error(f"Input file not found: {input_path}")
else:
convert_pt_to_safetensors(input_path, output_path)