Spaces:
Runtime error
Runtime error
| def init_hf_bert_biencoder(args, **kwargs): | |
| from .hf_models import get_bert_biencoder_components | |
| return get_bert_biencoder_components(args, **kwargs) | |
| def init_hf_distilbert_biencoder(args, **kwargs): | |
| from .hf_models import get_distilbert_biencoder_components | |
| return get_distilbert_biencoder_components(args, **kwargs) | |
| def init_hf_bert_tenzorizer(args, **kwargs): | |
| from .hf_models import get_bert_tensorizer | |
| return get_bert_tensorizer(args) | |
| def init_hf_distilbert_tenzorizer(args, **kwargs): | |
| from .hf_models import get_distilbert_tensorizer | |
| return get_distilbert_tensorizer(args) | |
| BIENCODER_INITIALIZERS = { | |
| 'hf_bert': init_hf_bert_biencoder, | |
| 'hf_distilbert': init_hf_distilbert_biencoder | |
| } | |
| TENSORIZER_INITIALIZERS = { | |
| 'hf_bert': init_hf_bert_tenzorizer, | |
| 'hf_distilbert': init_hf_distilbert_tenzorizer | |
| } | |
| def init_comp(initializers_dict, type, args, **kwargs): | |
| if type in initializers_dict: | |
| return initializers_dict[type](args, **kwargs) | |
| else: | |
| raise RuntimeError('unsupported model type: {}'.format(type)) | |
| def init_biencoder_components(encoder_type: str, args, **kwargs): | |
| return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs) | |
| def init_tenzorizer(encoder_type: str, args, **kwargs): | |
| return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs) |