whisper-small-uyghur / fix_model_config.py
anke01's picture
Upload fix_model_config.py with huggingface_hub
54a1d9a verified
#!/usr/bin/env python3
"""
自动修复模型配置,确保所有必要的配置都正确
"""
import os
import json
from transformers import WhisperProcessor, WhisperForConditionalGeneration
def fix_model_config(model_path="."):
"""
自动修复模型配置,包括:
1. 检查并添加language_info.json
2. 检查并更新generation_config.json中的语言映射
3. 检查并更新tokenizer配置
"""
print("=" * 70)
print("自动修复模型配置")
print("=" * 70)
print(f"\n模型路径: {model_path}")
# 加载模型和processor
print("\n加载模型...")
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)
# 获取语言token信息
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
# 查找所有语言token (排除时间戳和其他非语言token)
lang_tokens = {}
import re
for token, token_id in vocab.items():
if token.startswith("<|") and token.endswith("|>") and len(token) > 4:
# 排除特殊token
if token not in ['<|startoftranscript|>', '<|endoftext|>', '<|notimestamps|>',
'<|transcribe|>', '<|translate|>', '<|beginoftext|>',
'<|startofprev|>', '<|nospeech|>', '<|nocaptions|>']:
lang_code = token[2:-2] # 去掉<|和|>
# 语言代码:2-4个小写字母,不包含数字和点
if re.match(r'^[a-z]{2,4}$', lang_code) and len(lang_code) <= 4:
lang_tokens[token] = token_id
print(f"\n发现 {len(lang_tokens)} 个语言token")
# 1. 修复 generation_config.json
gen_config_path = os.path.join(model_path, "generation_config.json")
if os.path.exists(gen_config_path):
print("\n检查 generation_config.json...")
with open(gen_config_path, 'r', encoding='utf-8') as f:
gen_config = json.load(f)
# 确保lang_to_id存在
if "lang_to_id" not in gen_config:
gen_config["lang_to_id"] = {}
# 添加缺失的语言token
added = []
for token, token_id in lang_tokens.items():
if token not in gen_config["lang_to_id"]:
gen_config["lang_to_id"][token] = token_id
added.append(f"{token}: {token_id}")
if added:
print(f" 添加了 {len(added)} 个语言token:")
for item in added:
print(f" - {item}")
else:
print(" ✅ 所有语言token已存在")
# 保存
with open(gen_config_path, 'w', encoding='utf-8') as f:
json.dump(gen_config, f, indent=2, ensure_ascii=False)
print(" ✅ generation_config.json 已更新")
# 2. 创建/更新 language_info.json
lang_info_path = os.path.join(model_path, "language_info.json")
lang_info = {}
if os.path.exists(lang_info_path):
with open(lang_info_path, 'r', encoding='utf-8') as f:
lang_info = json.load(f)
# 检测语言
detected_langs = []
for token in lang_tokens:
lang_code = token[2:-2]
if lang_code not in ['transcribe', 'translate', 'notimestamps', 'startoftranscript']:
detected_langs.append({
"code": lang_code,
"token": token,
"token_id": lang_tokens[token]
})
lang_info["detected_languages"] = detected_langs
lang_info["total_languages"] = len(detected_langs)
lang_info["vocab_size"] = len(vocab)
with open(lang_info_path, 'w', encoding='utf-8') as f:
json.dump(lang_info, f, indent=2, ensure_ascii=False)
print(f"\n✅ language_info.json 已更新")
print(f" 检测到 {len(detected_langs)} 个语言:")
for lang in detected_langs[:10]: # 只显示前10个
print(f" - {lang['code']}: {lang['token']} (ID: {lang['token_id']})")
if len(detected_langs) > 10:
print(f" ... 还有 {len(detected_langs) - 10} 个")
print("\n" + "=" * 70)
print("配置修复完成!")
print("=" * 70)
if __name__ == "__main__":
import sys
model_path = sys.argv[1] if len(sys.argv) > 1 else "."
fix_model_config(model_path)