| |
| |
|
|
| import os |
| import glob |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Wav2Vec2ForCTC, Wav2Vec2Processor |
| from safetensors.torch import save_file as safe_save_file |
| from transformers.models.wav2vec2.modeling_wav2vec2 import WAV2VEC2_ADAPTER_SAFE_FILE |
| from argparse import ArgumentParser |
|
|
| parser = ArgumentParser() |
| parser.add_argument('--input_dir', default='./models-2', type=str, help='Input directory') |
| parser.add_argument('--output_dir', default='./models-5-final', type=str, help='Output directory') |
| parser.add_argument('--n_top_layers_to_remove', default=3, type=int, help='Number of top layers to remove') |
| parser.add_argument('--bnb_4bit_quant_type', default='nf4', type=str, help='Quantization type: nf4 or fp4') |
| parser.add_argument('--save_quantized_adapter', default=0, type=int, choices=[0, 1], help='Whether to save quantized adapter') |
| args = parser.parse_args() |
| for a in [a for a in vars(args) if '__' not in a]: print('%-25s %s' % (a, vars(args)[a])) |
|
|
| |
| |
|
|
| langs = ['ady', 'aln', 'bas', 'bew', 'bxk', 'cgg', 'el-CY', 'hch', 'kbd', |
| 'kcn', 'koo', 'led', 'lke', 'lth', 'meh', 'mmc', 'pne', 'qxp', |
| 'ruc', 'rwm', 'sco', 'tob', 'top', 'ttj', 'ukv', 'ush'] |
|
|
| for lang in langs: |
| |
| orig_model_dir = glob.glob(os.path.join(args.input_dir, lang, 'checkpoint*'))[0] |
|
|
| bnb_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type=args.bnb_4bit_quant_type, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_use_double_quant=True, |
| ) |
|
|
| |
| model = Wav2Vec2ForCTC.from_pretrained(orig_model_dir, quantization_config=bnb_config) |
| processor = Wav2Vec2Processor.from_pretrained(orig_model_dir, target_lang=lang) |
| processor.tokenizer.set_target_lang(lang) |
|
|
| |
| encoder_layers = model.wav2vec2.encoder.layers |
| model.wav2vec2.encoder.layers = encoder_layers[: -args.n_top_layers_to_remove] |
| model.config.num_hidden_layers = len(model.wav2vec2.encoder.layers) |
|
|
| |
| output_dir = os.path.join(args.output_dir, lang) |
| os.makedirs(output_dir, exist_ok=True) |
| model.save_pretrained(output_dir) |
| processor.save_pretrained(output_dir) |
|
|
| |
| |
| if args.save_quantized_adapter: |
| adapter_file = WAV2VEC2_ADAPTER_SAFE_FILE.format(lang) |
| adapter_file = os.path.join(output_dir, adapter_file) |
| safe_save_file(model._get_adapters(), adapter_file, metadata={"format": "pt"}) |
|
|
| print(output_dir) |
|
|
|
|
| |
| |
| |
| |
|
|
|
|