maotao / src /mRNA2vec /distillation.py
julse's picture
upload AA2CDS
4707555 verified
def init_teacher_model():
# model = MiniMindLM(lm_config)
# moe_path = '_moe' if lm_config.use_moe else ''
# ckp = args.mlm_pretrained_model_path
vocab_path = args.arg_overrides['data'] + '/dict.txt'
tokenizer = Dictionary.load(vocab_path)
tokenizer.add_symbol('<mask>')
model_pre = load_pretrained_ernierna(args.mlm_pretrained_model_path, args.arg_overrides)
model = model_pre.encoder
if args.debug:
print('debug mode')
num_layers_to_keep = 1 # 保留12层,todo
model.sentence_encoder.layers = model.sentence_encoder.layers[
:num_layers_to_keep]
# torch.save(model,args.save_dir+'/pretraining0215.pt')
# print('save small ERNIE-RNA model in',args.save_dir+'/pretraining0215.pt')
# state_dict = torch.load(ckp, map_location=args.device)
# model.load_state_dict(state_dict, strict=False)
Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万, vocab_size={len(tokenizer)}')
model = model.to(args.device)
print(model)
return model,tokenizer