File size: 442 Bytes
fa64206
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
from transformers import BertForSequenceClassification, AdapterConfig

def get_pert_model(config):
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
    adapter_config = AdapterConfig(mh_adapter=True, output_adapter=True, reduction_factor=config['model']['adapter']['reduction_factor'])
    model.add_adapter('imdb_adapter', config=adapter_config)
    model.train_adapter('imdb_adapter')
    return model