--- license: mit --- Checkpoints for LoRA training with [musubi-tuner](https://github.com/kohya-ss/musubi-tuner) ([relevant PR](https://github.com/kohya-ss/musubi-tuner/pull/712)) Converted from shards https://huggingface.co/meituan-longcat/LongCat-Video/tree/main/dit using the following script ``` import argparse import itertools import os from musubi_tuner.utils.safetensors_utils import load_split_weights, MemoryEfficientSafeOpen from safetensors.torch import save_file import torch def detect_dtype(path: str) -> torch.dtype: """Detect the dtype of the first floating point tensor in a safetensors file.""" if not os.path.isfile(path): raise FileNotFoundError(f"File not found: {path}") with MemoryEfficientSafeOpen(path) as handle: keys = list(handle.keys()) if not keys: raise ValueError(f"No tensors found in {path}") # Try to find a floating point tensor for key in keys: tensor = handle.get_tensor(key) if tensor.is_floating_point(): dtype = tensor.dtype return dtype # If no floating point tensor, return dtype of first tensor return handle.get_tensor(keys[0]).dtype def list_keys(state_dict, num_keys=20): """Display the first N keys from the state dict.""" print(f"\nTotal tensors: {len(state_dict)}") print(f"First {num_keys} keys:") for key in itertools.islice(state_dict.keys(), num_keys): print(f" {key}") print() def convert_dtype(input_path: str, output_path: str, target_dtype: torch.dtype): """Convert safetensors file to target dtype.""" print(f"Loading from: {input_path}") # Detect current dtype current_dtype = detect_dtype(input_path) print(f"Detected input dtype: {current_dtype}") print(f"Target dtype: {target_dtype}") # Load the model state_dict = load_split_weights(input_path) # List keys before conversion list_keys(state_dict) # Convert tensors print(f"Converting floating point tensors to {target_dtype}...") converted_count = 0 for key, tensor in state_dict.items(): if tensor.is_floating_point() and tensor.dtype != target_dtype: state_dict[key] = tensor.to(dtype=target_dtype) converted_count += 1 print(f"Converted {converted_count} tensors") # Save the output print(f"Saving to: {output_path}") save_file(state_dict, output_path) print("Done!") def main(): parser = argparse.ArgumentParser( description="Convert safetensors file dtype with inspection and detection" ) parser.add_argument( "input_path", type=str, help="Path to input safetensors file" ) parser.add_argument( "output_path", type=str, help="Path to output safetensors file" ) parser.add_argument( "--target-dtype", type=str, default="float16", choices=["float32", "float16", "bfloat16", "float8_e4m3fn", "float8_e5m2"], help="Target dtype for conversion (default: float16)" ) args = parser.parse_args() # Map string dtype to torch dtype dtype_map = { "float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16, "float8_e4m3fn": torch.float8_e4m3fn, "float8_e5m2": torch.float8_e5m2, } target_dtype = dtype_map[args.target_dtype] convert_dtype(args.input_path, args.output_path, target_dtype) if name == "main": main() ```