vecxoz's picture
Upload folder using huggingface_hub
386532a verified
Raw
History Blame Contribute Delete
3.25 kB
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
import os
import glob
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Wav2Vec2ForCTC, Wav2Vec2Processor
from safetensors.torch import save_file as safe_save_file
from transformers.models.wav2vec2.modeling_wav2vec2 import WAV2VEC2_ADAPTER_SAFE_FILE
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--input_dir', default='./models-2', type=str, help='Input directory')
parser.add_argument('--output_dir', default='./models-5-final', type=str, help='Output directory')
parser.add_argument('--n_top_layers_to_remove', default=3, type=int, help='Number of top layers to remove')
parser.add_argument('--bnb_4bit_quant_type', default='nf4', type=str, help='Quantization type: nf4 or fp4')
parser.add_argument('--save_quantized_adapter', default=0, type=int, choices=[0, 1], help='Whether to save quantized adapter')
args = parser.parse_args()
for a in [a for a in vars(args) if '__' not in a]: print('%-25s %s' % (a, vars(args)[a]))
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
langs = ['ady', 'aln', 'bas', 'bew', 'bxk', 'cgg', 'el-CY', 'hch', 'kbd',
'kcn', 'koo', 'led', 'lke', 'lth', 'meh', 'mmc', 'pne', 'qxp',
'ruc', 'rwm', 'sco', 'tob', 'top', 'ttj', 'ukv', 'ush']
for lang in langs:
orig_model_dir = glob.glob(os.path.join(args.input_dir, lang, 'checkpoint*'))[0]
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type=args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# Load orig model
model = Wav2Vec2ForCTC.from_pretrained(orig_model_dir, quantization_config=bnb_config)
processor = Wav2Vec2Processor.from_pretrained(orig_model_dir, target_lang=lang)
processor.tokenizer.set_target_lang(lang)
# Remove layers
encoder_layers = model.wav2vec2.encoder.layers
model.wav2vec2.encoder.layers = encoder_layers[: -args.n_top_layers_to_remove]
model.config.num_hidden_layers = len(model.wav2vec2.encoder.layers)
# Save quantized model
output_dir = os.path.join(args.output_dir, lang)
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
# Currently adapter mechanism for quantized models does not work, so we don't save
# Save quantized adapter
if args.save_quantized_adapter:
adapter_file = WAV2VEC2_ADAPTER_SAFE_FILE.format(lang)
adapter_file = os.path.join(output_dir, adapter_file)
safe_save_file(model._get_adapters(), adapter_file, metadata={"format": "pt"})
print(output_dir)
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------