AkaneTendo25 commited on
Commit
c0756a0
·
verified ·
1 Parent(s): 28dc245

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +118 -0
README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ Checkpoints for LoRA training with [musubi-tuner](https://github.com/kohya-ss/musubi-tuner) ([relevant PR](https://github.com/kohya-ss/musubi-tuner/pull/712))
6
+ Converted from shards https://huggingface.co/meituan-longcat/LongCat-Video/tree/main/dit using the following script
7
+
8
+ ```
9
+ import argparse
10
+ import itertools
11
+ import os
12
+ from musubi_tuner.utils.safetensors_utils import load_split_weights, MemoryEfficientSafeOpen
13
+ from safetensors.torch import save_file
14
+ import torch
15
+
16
+
17
+ def detect_dtype(path: str) -> torch.dtype:
18
+ """Detect the dtype of the first floating point tensor in a safetensors file."""
19
+ if not os.path.isfile(path):
20
+ raise FileNotFoundError(f"File not found: {path}")
21
+
22
+ with MemoryEfficientSafeOpen(path) as handle:
23
+ keys = list(handle.keys())
24
+ if not keys:
25
+ raise ValueError(f"No tensors found in {path}")
26
+
27
+ # Try to find a floating point tensor
28
+ for key in keys:
29
+ tensor = handle.get_tensor(key)
30
+ if tensor.is_floating_point():
31
+ dtype = tensor.dtype
32
+ return dtype
33
+
34
+ # If no floating point tensor, return dtype of first tensor
35
+ return handle.get_tensor(keys[0]).dtype
36
+
37
+
38
+ def list_keys(state_dict, num_keys=20):
39
+ """Display the first N keys from the state dict."""
40
+ print(f"\nTotal tensors: {len(state_dict)}")
41
+ print(f"First {num_keys} keys:")
42
+ for key in itertools.islice(state_dict.keys(), num_keys):
43
+ print(f" {key}")
44
+ print()
45
+
46
+
47
+ def convert_dtype(input_path: str, output_path: str, target_dtype: torch.dtype):
48
+ """Convert safetensors file to target dtype."""
49
+ print(f"Loading from: {input_path}")
50
+
51
+ # Detect current dtype
52
+ current_dtype = detect_dtype(input_path)
53
+ print(f"Detected input dtype: {current_dtype}")
54
+ print(f"Target dtype: {target_dtype}")
55
+
56
+ # Load the model
57
+ state_dict = load_split_weights(input_path)
58
+
59
+ # List keys before conversion
60
+ list_keys(state_dict)
61
+
62
+ # Convert tensors
63
+ print(f"Converting floating point tensors to {target_dtype}...")
64
+ converted_count = 0
65
+ for key, tensor in state_dict.items():
66
+ if tensor.is_floating_point() and tensor.dtype != target_dtype:
67
+ state_dict[key] = tensor.to(dtype=target_dtype)
68
+ converted_count += 1
69
+
70
+ print(f"Converted {converted_count} tensors")
71
+
72
+ # Save the output
73
+ print(f"Saving to: {output_path}")
74
+ save_file(state_dict, output_path)
75
+ print("Done!")
76
+
77
+
78
+ def main():
79
+ parser = argparse.ArgumentParser(
80
+ description="Convert safetensors file dtype with inspection and detection"
81
+ )
82
+ parser.add_argument(
83
+ "input_path",
84
+ type=str,
85
+ help="Path to input safetensors file"
86
+ )
87
+ parser.add_argument(
88
+ "output_path",
89
+ type=str,
90
+ help="Path to output safetensors file"
91
+ )
92
+ parser.add_argument(
93
+ "--target-dtype",
94
+ type=str,
95
+ default="float16",
96
+ choices=["float32", "float16", "bfloat16", "float8_e4m3fn", "float8_e5m2"],
97
+ help="Target dtype for conversion (default: float16)"
98
+ )
99
+
100
+ args = parser.parse_args()
101
+
102
+ # Map string dtype to torch dtype
103
+ dtype_map = {
104
+ "float32": torch.float32,
105
+ "float16": torch.float16,
106
+ "bfloat16": torch.bfloat16,
107
+ "float8_e4m3fn": torch.float8_e4m3fn,
108
+ "float8_e5m2": torch.float8_e5m2,
109
+ }
110
+
111
+ target_dtype = dtype_map[args.target_dtype]
112
+
113
+ convert_dtype(args.input_path, args.output_path, target_dtype)
114
+
115
+
116
+ if name == "main":
117
+ main()
118
+ ```