Replace .pt model with safetensors format (using Git LFS)
Browse files- Remove chordia_v0.0.1-alpha.pt
- Add chordia_v0.0.1-alpha.safetensors (safer format, stored via Git LFS)
- Add conversion script for future use
- Configure Git LFS for safetensors files
Benefits of safetensors:
- More secure (avoids pickle security risks)
- Faster loading speed
- Zero-copy capability
- Better for sharing on Hugging Face
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
.gitattributes
CHANGED
|
@@ -5,3 +5,4 @@
|
|
| 5 |
*.data filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 7 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 5 |
*.data filter=lfs diff=lfs merge=lfs -text
|
| 6 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 7 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
chordia_v0.0.1-alpha.pt → chordia_v0.0.1-alpha.safetensors
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7a6d014049ebb2e80bfa97485370187a734a383e80b3805588abc3bd0415673
|
| 3 |
+
size 675556
|
convert_to_safetensors.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
将 PyTorch 模型转换为 safetensors 格式
|
| 3 |
+
|
| 4 |
+
safetensors 优势:
|
| 5 |
+
- 更安全(避免 pickle 安全风险)
|
| 6 |
+
- 加载速度更快
|
| 7 |
+
- 支持零拷贝
|
| 8 |
+
- 更适合在 Hugging Face 等平台共享
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from safetensors.torch import save_file
|
| 12 |
+
import torch
|
| 13 |
+
import os
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def convert_to_safetensors(
|
| 18 |
+
input_path: str,
|
| 19 |
+
output_path: str = None,
|
| 20 |
+
metadata: dict = None
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
将 PyTorch 模型转换为 safetensors 格式
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
input_path: 输入的 .pt 或 .pth 文件路径
|
| 27 |
+
output_path: 输出的 .safetensors 文件路径(可选)
|
| 28 |
+
metadata: 要保存的元数据(可选)
|
| 29 |
+
"""
|
| 30 |
+
# 检查输入文件是否存在
|
| 31 |
+
if not os.path.exists(input_path):
|
| 32 |
+
raise FileNotFoundError(f"输入文件不存在: {input_path}")
|
| 33 |
+
|
| 34 |
+
# 如果没有指定输出路径,自动生成
|
| 35 |
+
if output_path is None:
|
| 36 |
+
input_path_obj = Path(input_path)
|
| 37 |
+
output_path = input_path_obj.with_suffix('.safetensors')
|
| 38 |
+
|
| 39 |
+
print(f"正在加载模型: {input_path}")
|
| 40 |
+
|
| 41 |
+
# 尝试检测文件类型并加载
|
| 42 |
+
model_weights = None
|
| 43 |
+
|
| 44 |
+
# 首先尝试作为 TorchScript 模型加载
|
| 45 |
+
try:
|
| 46 |
+
print("尝试加载 TorchScript 模型...")
|
| 47 |
+
model = torch.jit.load(input_path, map_location='cpu')
|
| 48 |
+
print("成功加载 TorchScript 模型,提取 state_dict...")
|
| 49 |
+
model_weights = model.state_dict()
|
| 50 |
+
except:
|
| 51 |
+
# 如果失败,尝试作为普通 state_dict 加载
|
| 52 |
+
try:
|
| 53 |
+
print("尝试加载普通 state_dict...")
|
| 54 |
+
model_weights = torch.load(input_path, map_location='cpu', weights_only=False)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise RuntimeError(f"无法加载模型文件: {e}")
|
| 57 |
+
|
| 58 |
+
# 如果是完整的模型(包含 state_dict),提取 state_dict
|
| 59 |
+
if isinstance(model_weights, dict) and 'state_dict' in model_weights:
|
| 60 |
+
print("检测到完整模型,提取 state_dict...")
|
| 61 |
+
model_weights = model_weights['state_dict']
|
| 62 |
+
|
| 63 |
+
print(f"正在保存为 safetensors 格式: {output_path}")
|
| 64 |
+
|
| 65 |
+
# 添加默认元数据
|
| 66 |
+
if metadata is None:
|
| 67 |
+
metadata = {}
|
| 68 |
+
|
| 69 |
+
# 保存为 safetensors
|
| 70 |
+
save_file(model_weights, output_path, metadata=metadata)
|
| 71 |
+
|
| 72 |
+
print(f"[OK] 转换完成!")
|
| 73 |
+
print(f" 输入文件: {input_path}")
|
| 74 |
+
print(f" 输出文件: {output_path}")
|
| 75 |
+
|
| 76 |
+
# 显示文件大小对比
|
| 77 |
+
input_size = os.path.getsize(input_path) / (1024 * 1024)
|
| 78 |
+
output_size = os.path.getsize(output_path) / (1024 * 1024)
|
| 79 |
+
print(f"\n文件大小对比:")
|
| 80 |
+
print(f" 原始文件: {input_size:.2f} MB")
|
| 81 |
+
print(f" safetensors: {output_size:.2f} MB")
|
| 82 |
+
print(f" 压缩率: {(1 - output_size/input_size) * 100:.1f}%")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
# 转换当前目录下的模型
|
| 87 |
+
input_model = "chordia_v0.0.1-alpha.pt"
|
| 88 |
+
output_model = "chordia_v0.0.1-alpha.safetensors"
|
| 89 |
+
|
| 90 |
+
convert_to_safetensors(
|
| 91 |
+
input_path=input_model,
|
| 92 |
+
output_path=output_model,
|
| 93 |
+
metadata={
|
| 94 |
+
"model_name": "Chordia",
|
| 95 |
+
"version": "v0.0.1-alpha",
|
| 96 |
+
"format": "safetensors"
|
| 97 |
+
}
|
| 98 |
+
)
|