|
|
--- |
|
|
license: mit |
|
|
--- |
|
|
|
|
|
Checkpoints for LoRA training with [musubi-tuner](https://github.com/kohya-ss/musubi-tuner) ([relevant PR](https://github.com/kohya-ss/musubi-tuner/pull/712)) |
|
|
Converted from shards https://huggingface.co/meituan-longcat/LongCat-Video/tree/main/dit using the following script |
|
|
|
|
|
``` |
|
|
import argparse |
|
|
import itertools |
|
|
import os |
|
|
from musubi_tuner.utils.safetensors_utils import load_split_weights, MemoryEfficientSafeOpen |
|
|
from safetensors.torch import save_file |
|
|
import torch |
|
|
|
|
|
|
|
|
def detect_dtype(path: str) -> torch.dtype: |
|
|
"""Detect the dtype of the first floating point tensor in a safetensors file.""" |
|
|
if not os.path.isfile(path): |
|
|
raise FileNotFoundError(f"File not found: {path}") |
|
|
|
|
|
with MemoryEfficientSafeOpen(path) as handle: |
|
|
keys = list(handle.keys()) |
|
|
if not keys: |
|
|
raise ValueError(f"No tensors found in {path}") |
|
|
|
|
|
# Try to find a floating point tensor |
|
|
for key in keys: |
|
|
tensor = handle.get_tensor(key) |
|
|
if tensor.is_floating_point(): |
|
|
dtype = tensor.dtype |
|
|
return dtype |
|
|
|
|
|
# If no floating point tensor, return dtype of first tensor |
|
|
return handle.get_tensor(keys[0]).dtype |
|
|
|
|
|
|
|
|
def list_keys(state_dict, num_keys=20): |
|
|
"""Display the first N keys from the state dict.""" |
|
|
print(f"\nTotal tensors: {len(state_dict)}") |
|
|
print(f"First {num_keys} keys:") |
|
|
for key in itertools.islice(state_dict.keys(), num_keys): |
|
|
print(f" {key}") |
|
|
print() |
|
|
|
|
|
|
|
|
def convert_dtype(input_path: str, output_path: str, target_dtype: torch.dtype): |
|
|
"""Convert safetensors file to target dtype.""" |
|
|
print(f"Loading from: {input_path}") |
|
|
|
|
|
# Detect current dtype |
|
|
current_dtype = detect_dtype(input_path) |
|
|
print(f"Detected input dtype: {current_dtype}") |
|
|
print(f"Target dtype: {target_dtype}") |
|
|
|
|
|
# Load the model |
|
|
state_dict = load_split_weights(input_path) |
|
|
|
|
|
# List keys before conversion |
|
|
list_keys(state_dict) |
|
|
|
|
|
# Convert tensors |
|
|
print(f"Converting floating point tensors to {target_dtype}...") |
|
|
converted_count = 0 |
|
|
for key, tensor in state_dict.items(): |
|
|
if tensor.is_floating_point() and tensor.dtype != target_dtype: |
|
|
state_dict[key] = tensor.to(dtype=target_dtype) |
|
|
converted_count += 1 |
|
|
|
|
|
print(f"Converted {converted_count} tensors") |
|
|
|
|
|
# Save the output |
|
|
print(f"Saving to: {output_path}") |
|
|
save_file(state_dict, output_path) |
|
|
print("Done!") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Convert safetensors file dtype with inspection and detection" |
|
|
) |
|
|
parser.add_argument( |
|
|
"input_path", |
|
|
type=str, |
|
|
help="Path to input safetensors file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"output_path", |
|
|
type=str, |
|
|
help="Path to output safetensors file" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--target-dtype", |
|
|
type=str, |
|
|
default="float16", |
|
|
choices=["float32", "float16", "bfloat16", "float8_e4m3fn", "float8_e5m2"], |
|
|
help="Target dtype for conversion (default: float16)" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
# Map string dtype to torch dtype |
|
|
dtype_map = { |
|
|
"float32": torch.float32, |
|
|
"float16": torch.float16, |
|
|
"bfloat16": torch.bfloat16, |
|
|
"float8_e4m3fn": torch.float8_e4m3fn, |
|
|
"float8_e5m2": torch.float8_e5m2, |
|
|
} |
|
|
|
|
|
target_dtype = dtype_map[args.target_dtype] |
|
|
|
|
|
convert_dtype(args.input_path, args.output_path, target_dtype) |
|
|
|
|
|
|
|
|
if name == "main": |
|
|
main() |
|
|
``` |