Second-Pass / models /__init__.py
Ujjwal123's picture
Second Pass Model Runs fine
35290d0
def init_hf_bert_biencoder(args, **kwargs):
from .hf_models import get_bert_biencoder_components
return get_bert_biencoder_components(args, **kwargs)
def init_hf_distilbert_biencoder(args, **kwargs):
from .hf_models import get_distilbert_biencoder_components
return get_distilbert_biencoder_components(args, **kwargs)
def init_hf_bert_tenzorizer(args, **kwargs):
from .hf_models import get_bert_tensorizer
return get_bert_tensorizer(args)
def init_hf_distilbert_tenzorizer(args, **kwargs):
from .hf_models import get_distilbert_tensorizer
return get_distilbert_tensorizer(args)
BIENCODER_INITIALIZERS = {
'hf_bert': init_hf_bert_biencoder,
'hf_distilbert': init_hf_distilbert_biencoder
}
TENSORIZER_INITIALIZERS = {
'hf_bert': init_hf_bert_tenzorizer,
'hf_distilbert': init_hf_distilbert_tenzorizer
}
def init_comp(initializers_dict, type, args, **kwargs):
if type in initializers_dict:
return initializers_dict[type](args, **kwargs)
else:
raise RuntimeError('unsupported model type: {}'.format(type))
def init_biencoder_components(encoder_type: str, args, **kwargs):
return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs)
def init_tenzorizer(encoder_type: str, args, **kwargs):
return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs)