#!/usr/bin/env python3 """ 自动修复模型配置,确保所有必要的配置都正确 """ import os import json from transformers import WhisperProcessor, WhisperForConditionalGeneration def fix_model_config(model_path="."): """ 自动修复模型配置,包括: 1. 检查并添加language_info.json 2. 检查并更新generation_config.json中的语言映射 3. 检查并更新tokenizer配置 """ print("=" * 70) print("自动修复模型配置") print("=" * 70) print(f"\n模型路径: {model_path}") # 加载模型和processor print("\n加载模型...") processor = WhisperProcessor.from_pretrained(model_path) model = WhisperForConditionalGeneration.from_pretrained(model_path) # 获取语言token信息 tokenizer = processor.tokenizer vocab = tokenizer.get_vocab() # 查找所有语言token (排除时间戳和其他非语言token) lang_tokens = {} import re for token, token_id in vocab.items(): if token.startswith("<|") and token.endswith("|>") and len(token) > 4: # 排除特殊token if token not in ['<|startoftranscript|>', '<|endoftext|>', '<|notimestamps|>', '<|transcribe|>', '<|translate|>', '<|beginoftext|>', '<|startofprev|>', '<|nospeech|>', '<|nocaptions|>']: lang_code = token[2:-2] # 去掉<|和|> # 语言代码:2-4个小写字母,不包含数字和点 if re.match(r'^[a-z]{2,4}$', lang_code) and len(lang_code) <= 4: lang_tokens[token] = token_id print(f"\n发现 {len(lang_tokens)} 个语言token") # 1. 修复 generation_config.json gen_config_path = os.path.join(model_path, "generation_config.json") if os.path.exists(gen_config_path): print("\n检查 generation_config.json...") with open(gen_config_path, 'r', encoding='utf-8') as f: gen_config = json.load(f) # 确保lang_to_id存在 if "lang_to_id" not in gen_config: gen_config["lang_to_id"] = {} # 添加缺失的语言token added = [] for token, token_id in lang_tokens.items(): if token not in gen_config["lang_to_id"]: gen_config["lang_to_id"][token] = token_id added.append(f"{token}: {token_id}") if added: print(f" 添加了 {len(added)} 个语言token:") for item in added: print(f" - {item}") else: print(" ✅ 所有语言token已存在") # 保存 with open(gen_config_path, 'w', encoding='utf-8') as f: json.dump(gen_config, f, indent=2, ensure_ascii=False) print(" ✅ generation_config.json 已更新") # 2. 创建/更新 language_info.json lang_info_path = os.path.join(model_path, "language_info.json") lang_info = {} if os.path.exists(lang_info_path): with open(lang_info_path, 'r', encoding='utf-8') as f: lang_info = json.load(f) # 检测语言 detected_langs = [] for token in lang_tokens: lang_code = token[2:-2] if lang_code not in ['transcribe', 'translate', 'notimestamps', 'startoftranscript']: detected_langs.append({ "code": lang_code, "token": token, "token_id": lang_tokens[token] }) lang_info["detected_languages"] = detected_langs lang_info["total_languages"] = len(detected_langs) lang_info["vocab_size"] = len(vocab) with open(lang_info_path, 'w', encoding='utf-8') as f: json.dump(lang_info, f, indent=2, ensure_ascii=False) print(f"\n✅ language_info.json 已更新") print(f" 检测到 {len(detected_langs)} 个语言:") for lang in detected_langs[:10]: # 只显示前10个 print(f" - {lang['code']}: {lang['token']} (ID: {lang['token_id']})") if len(detected_langs) > 10: print(f" ... 还有 {len(detected_langs) - 10} 个") print("\n" + "=" * 70) print("配置修复完成!") print("=" * 70) if __name__ == "__main__": import sys model_path = sys.argv[1] if len(sys.argv) > 1 else "." fix_model_config(model_path)