| | |
| | """ |
| | 自动修复模型配置,确保所有必要的配置都正确 |
| | """ |
| |
|
| | import os |
| | import json |
| | from transformers import WhisperProcessor, WhisperForConditionalGeneration |
| |
|
| | def fix_model_config(model_path="."): |
| | """ |
| | 自动修复模型配置,包括: |
| | 1. 检查并添加language_info.json |
| | 2. 检查并更新generation_config.json中的语言映射 |
| | 3. 检查并更新tokenizer配置 |
| | """ |
| | print("=" * 70) |
| | print("自动修复模型配置") |
| | print("=" * 70) |
| | print(f"\n模型路径: {model_path}") |
| | |
| | |
| | print("\n加载模型...") |
| | processor = WhisperProcessor.from_pretrained(model_path) |
| | model = WhisperForConditionalGeneration.from_pretrained(model_path) |
| | |
| | |
| | tokenizer = processor.tokenizer |
| | vocab = tokenizer.get_vocab() |
| | |
| | |
| | lang_tokens = {} |
| | import re |
| | for token, token_id in vocab.items(): |
| | if token.startswith("<|") and token.endswith("|>") and len(token) > 4: |
| | |
| | if token not in ['<|startoftranscript|>', '<|endoftext|>', '<|notimestamps|>', |
| | '<|transcribe|>', '<|translate|>', '<|beginoftext|>', |
| | '<|startofprev|>', '<|nospeech|>', '<|nocaptions|>']: |
| | lang_code = token[2:-2] |
| | |
| | if re.match(r'^[a-z]{2,4}$', lang_code) and len(lang_code) <= 4: |
| | lang_tokens[token] = token_id |
| | |
| | print(f"\n发现 {len(lang_tokens)} 个语言token") |
| | |
| | |
| | gen_config_path = os.path.join(model_path, "generation_config.json") |
| | if os.path.exists(gen_config_path): |
| | print("\n检查 generation_config.json...") |
| | with open(gen_config_path, 'r', encoding='utf-8') as f: |
| | gen_config = json.load(f) |
| | |
| | |
| | if "lang_to_id" not in gen_config: |
| | gen_config["lang_to_id"] = {} |
| | |
| | |
| | added = [] |
| | for token, token_id in lang_tokens.items(): |
| | if token not in gen_config["lang_to_id"]: |
| | gen_config["lang_to_id"][token] = token_id |
| | added.append(f"{token}: {token_id}") |
| | |
| | if added: |
| | print(f" 添加了 {len(added)} 个语言token:") |
| | for item in added: |
| | print(f" - {item}") |
| | else: |
| | print(" ✅ 所有语言token已存在") |
| | |
| | |
| | with open(gen_config_path, 'w', encoding='utf-8') as f: |
| | json.dump(gen_config, f, indent=2, ensure_ascii=False) |
| | print(" ✅ generation_config.json 已更新") |
| | |
| | |
| | lang_info_path = os.path.join(model_path, "language_info.json") |
| | lang_info = {} |
| | |
| | if os.path.exists(lang_info_path): |
| | with open(lang_info_path, 'r', encoding='utf-8') as f: |
| | lang_info = json.load(f) |
| | |
| | |
| | detected_langs = [] |
| | for token in lang_tokens: |
| | lang_code = token[2:-2] |
| | if lang_code not in ['transcribe', 'translate', 'notimestamps', 'startoftranscript']: |
| | detected_langs.append({ |
| | "code": lang_code, |
| | "token": token, |
| | "token_id": lang_tokens[token] |
| | }) |
| | |
| | lang_info["detected_languages"] = detected_langs |
| | lang_info["total_languages"] = len(detected_langs) |
| | lang_info["vocab_size"] = len(vocab) |
| | |
| | with open(lang_info_path, 'w', encoding='utf-8') as f: |
| | json.dump(lang_info, f, indent=2, ensure_ascii=False) |
| | |
| | print(f"\n✅ language_info.json 已更新") |
| | print(f" 检测到 {len(detected_langs)} 个语言:") |
| | for lang in detected_langs[:10]: |
| | print(f" - {lang['code']}: {lang['token']} (ID: {lang['token_id']})") |
| | if len(detected_langs) > 10: |
| | print(f" ... 还有 {len(detected_langs) - 10} 个") |
| | |
| | print("\n" + "=" * 70) |
| | print("配置修复完成!") |
| | print("=" * 70) |
| |
|
| | if __name__ == "__main__": |
| | import sys |
| | model_path = sys.argv[1] if len(sys.argv) > 1 else "." |
| | fix_model_config(model_path) |
| |
|