#------------------------------------------------------------------------------ #------------------------------------------------------------------------------ import os import glob import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Wav2Vec2ForCTC, Wav2Vec2Processor from safetensors.torch import save_file as safe_save_file from transformers.models.wav2vec2.modeling_wav2vec2 import WAV2VEC2_ADAPTER_SAFE_FILE from argparse import ArgumentParser parser = ArgumentParser() parser.add_argument('--input_dir', default='./models-2', type=str, help='Input directory') parser.add_argument('--output_dir', default='./models-5-final', type=str, help='Output directory') parser.add_argument('--n_top_layers_to_remove', default=3, type=int, help='Number of top layers to remove') parser.add_argument('--bnb_4bit_quant_type', default='nf4', type=str, help='Quantization type: nf4 or fp4') parser.add_argument('--save_quantized_adapter', default=0, type=int, choices=[0, 1], help='Whether to save quantized adapter') args = parser.parse_args() for a in [a for a in vars(args) if '__' not in a]: print('%-25s %s' % (a, vars(args)[a])) #------------------------------------------------------------------------------ #------------------------------------------------------------------------------ langs = ['ady', 'aln', 'bas', 'bew', 'bxk', 'cgg', 'el-CY', 'hch', 'kbd', 'kcn', 'koo', 'led', 'lke', 'lth', 'meh', 'mmc', 'pne', 'qxp', 'ruc', 'rwm', 'sco', 'tob', 'top', 'ttj', 'ukv', 'ush'] for lang in langs: orig_model_dir = glob.glob(os.path.join(args.input_dir, lang, 'checkpoint*'))[0] bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type=args.bnb_4bit_quant_type, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) # Load orig model model = Wav2Vec2ForCTC.from_pretrained(orig_model_dir, quantization_config=bnb_config) processor = Wav2Vec2Processor.from_pretrained(orig_model_dir, target_lang=lang) processor.tokenizer.set_target_lang(lang) # Remove layers encoder_layers = model.wav2vec2.encoder.layers model.wav2vec2.encoder.layers = encoder_layers[: -args.n_top_layers_to_remove] model.config.num_hidden_layers = len(model.wav2vec2.encoder.layers) # Save quantized model output_dir = os.path.join(args.output_dir, lang) os.makedirs(output_dir, exist_ok=True) model.save_pretrained(output_dir) processor.save_pretrained(output_dir) # Currently adapter mechanism for quantized models does not work, so we don't save # Save quantized adapter if args.save_quantized_adapter: adapter_file = WAV2VEC2_ADAPTER_SAFE_FILE.format(lang) adapter_file = os.path.join(output_dir, adapter_file) safe_save_file(model._get_adapters(), adapter_file, metadata={"format": "pt"}) print(output_dir) #------------------------------------------------------------------------------ #------------------------------------------------------------------------------ #------------------------------------------------------------------------------ #------------------------------------------------------------------------------