hzeng412 Claude Fable 5 commited on
Commit
e73f2e4
·
1 Parent(s): a17b420

Add script to convert TTS .bin weights to safetensors

Browse files

Needed on this branch: qwen-asr upgrades transformers to 4.57, which
refuses torch.load of .bin files on torch < 2.6 (CVE-2025-32434).
The converted model.safetensors files are generated locally and are
intentionally not committed (~840MB).

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>

scripts/convert_tts_weights_to_safetensors.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """将 TTS 预训练权重 (.bin) 转换为 safetensors。
2
+
3
+ qwen-asr 分支将 transformers 升级到 4.57+,其安全策略 (CVE-2025-32434)
4
+ 拒绝在 torch < 2.6 上加载 pytorch_model.bin。transformers 加载时优先使用
5
+ model.safetensors,因此本地转换一次即可,无需升级 torch。
6
+
7
+ 用法: python scripts/convert_tts_weights_to_safetensors.py
8
+ """
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from safetensors.torch import save_file
13
+
14
+ MOYOYO_PRETRAINED_PATH = Path(__file__).parent.parent / "assets" / "models" / "tts" / "moyoyo"
15
+
16
+ PRETRAINED_DIRS = [
17
+ "chinese-roberta-wwm-ext-large",
18
+ "chinese-hubert-base",
19
+ ]
20
+
21
+
22
+ def main():
23
+ for dirname in PRETRAINED_DIRS:
24
+ model_dir = MOYOYO_PRETRAINED_PATH / dirname
25
+ bin_path = model_dir / "pytorch_model.bin"
26
+ st_path = model_dir / "model.safetensors"
27
+
28
+ if st_path.exists():
29
+ print(f"已存在,跳过: {st_path}")
30
+ continue
31
+ if not bin_path.exists():
32
+ print(f"找不到权重文件: {bin_path}")
33
+ continue
34
+
35
+ state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
36
+ # clone 断开共享内存,safetensors 不允许张量间共享存储
37
+ state_dict = {
38
+ key: value.clone().contiguous()
39
+ for key, value in state_dict.items()
40
+ if isinstance(value, torch.Tensor)
41
+ }
42
+ save_file(state_dict, st_path, metadata={"format": "pt"})
43
+ print(f"{dirname}: {len(state_dict)} tensors -> {st_path.stat().st_size // 1024 ** 2} MB")
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()