Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks
Paper • 1908.10084 • Published • 12
This is a Cross Encoder model finetuned from cambridgeltl/SapBERT-from-PubMedBERT-fulltext using the sentence-transformers library. It computes scores for pairs of texts, which can be used for text reranking and semantic search.
First install the Sentence Transformers library:
pip install -U sentence-transformers
Then you can load this model and run inference.
from sentence_transformers import CrossEncoder
# Download from the 🤗 Hub
model = CrossEncoder("cross_encoder_model_id")
# Get scores for pairs of texts
pairs = [
['Original mention: L-N.\nContext: Partial HPRT deficiencies are associated with gouty arthritis, while absence of activity results in Lesch-Nyhan syndrome (L-N).', 'guanine phosphoribosyltransferase deficiencies'],
['Original mention: breast cancer.\nContext: Most multiple case families of young onset breast cancer and ovarian cancer are thought to be due to highly penetrant mutations in the predisposing genes BRCA1 and BRCA2.', 'malignancies'],
['Original mention: AS.\nContext: We report here the characterization of a transgene insertion (Epstein-Barr virus Latent Membrane Protein 2A, LMP2A) into mouse chromosome 7C, which has resulted in mouse models for PWS and AS dependent on the sex of the transmitting parent.', 'achondroplastic dwarfism'],
['Original mention: adenomatous polyposis coli.\nContext: Epidemiologic studies have shown an increased frequency of this tumor type in families affected by adenomatous polyposis coli.', 'apc (adenomatous polyposis coli)'],
['Original mention: autosomal recessive disorder.\nContext: Mucopolysaccharidosis IVA (MPS IVA) is an autosomal recessive disorder caused by a deficiency in N-acetylgalactosamine-6-sulfatase (GALNS).', 'susceptibility, disease'],
]
scores = model.predict(pairs)
print(scores.shape)
# (5,)
# Or rank different texts based on similarity to a single text
ranks = model.rank(
'Original mention: L-N.\nContext: Partial HPRT deficiencies are associated with gouty arthritis, while absence of activity results in Lesch-Nyhan syndrome (L-N).',
[
'guanine phosphoribosyltransferase deficiencies',
'malignancies',
'achondroplastic dwarfism',
'apc (adenomatous polyposis coli)',
'susceptibility, disease',
]
)
# [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]
ncbi-disease-devCrossEncoderRerankingEvaluator with these parameters:{
"at_k": 10,
"always_rerank_positives": false
}
| Metric | Value |
|---|---|
| map | 0.9979 (+0.5536) |
| mrr@10 | 0.9986 (+0.7236) |
| ndcg@10 | 0.9987 (+0.4194) |
query, answer, and label| query | answer | label | |
|---|---|---|---|
| type | string | string | int |
| details |
|
|
|
| query | answer | label |
|---|---|---|
Original mention: L-N. |
guanine phosphoribosyltransferase deficiencies |
1 |
Original mention: breast cancer. |
malignancies |
0 |
Original mention: AS. |
achondroplastic dwarfism |
0 |
BinaryCrossEntropyLoss with these parameters:{
"activation_fn": "torch.nn.modules.linear.Identity",
"pos_weight": 0.7792357206344604
}
eval_strategy: stepsper_device_train_batch_size: 128per_device_eval_batch_size: 128learning_rate: 2e-05warmup_ratio: 0.05seed: 12bf16: Truedataloader_num_workers: 4load_best_model_at_end: Trueoverwrite_output_dir: Falsedo_predict: Falseeval_strategy: stepsprediction_loss_only: Trueper_device_train_batch_size: 128per_device_eval_batch_size: 128per_gpu_train_batch_size: Noneper_gpu_eval_batch_size: Nonegradient_accumulation_steps: 1eval_accumulation_steps: Nonetorch_empty_cache_steps: Nonelearning_rate: 2e-05weight_decay: 0.0adam_beta1: 0.9adam_beta2: 0.999adam_epsilon: 1e-08max_grad_norm: 1.0num_train_epochs: 3max_steps: -1lr_scheduler_type: linearlr_scheduler_kwargs: {}warmup_ratio: 0.05warmup_steps: 0log_level: passivelog_level_replica: warninglog_on_each_node: Truelogging_nan_inf_filter: Truesave_safetensors: Truesave_on_each_node: Falsesave_only_model: Falserestore_callback_states_from_checkpoint: Falseno_cuda: Falseuse_cpu: Falseuse_mps_device: Falseseed: 12data_seed: Nonejit_mode_eval: Falsebf16: Truefp16: Falsefp16_opt_level: O1half_precision_backend: autobf16_full_eval: Falsefp16_full_eval: Falsetf32: Nonelocal_rank: 0ddp_backend: Nonetpu_num_cores: Nonetpu_metrics_debug: Falsedebug: []dataloader_drop_last: Falsedataloader_num_workers: 4dataloader_prefetch_factor: Nonepast_index: -1disable_tqdm: Falseremove_unused_columns: Truelabel_names: Noneload_best_model_at_end: Trueignore_data_skip: Falsefsdp: []fsdp_min_num_params: 0fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}fsdp_transformer_layer_cls_to_wrap: Noneaccelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}parallelism_config: Nonedeepspeed: Nonelabel_smoothing_factor: 0.0optim: adamw_torch_fusedoptim_args: Noneadafactor: Falsegroup_by_length: Falselength_column_name: lengthproject: huggingfacetrackio_space_id: trackioddp_find_unused_parameters: Noneddp_bucket_cap_mb: Noneddp_broadcast_buffers: Falsedataloader_pin_memory: Truedataloader_persistent_workers: Falseskip_memory_metrics: Trueuse_legacy_prediction_loop: Falsepush_to_hub: Falseresume_from_checkpoint: Nonehub_model_id: Nonehub_strategy: every_savehub_private_repo: Nonehub_always_push: Falsehub_revision: Nonegradient_checkpointing: Falsegradient_checkpointing_kwargs: Noneinclude_inputs_for_metrics: Falseinclude_for_metrics: []eval_do_concat_batches: Truefp16_backend: autopush_to_hub_model_id: Nonepush_to_hub_organization: Nonemp_parameters: auto_find_batch_size: Falsefull_determinism: Falsetorchdynamo: Noneray_scope: lastddp_timeout: 1800torch_compile: Falsetorch_compile_backend: Nonetorch_compile_mode: Noneinclude_tokens_per_second: Falseinclude_num_input_tokens_seen: noneftune_noise_alpha: Noneoptim_target_modules: Nonebatch_eval_metrics: Falseeval_on_start: Falseuse_liger_kernel: Falseliger_kernel_config: Noneeval_use_gather_object: Falseaverage_tokens_across_devices: Trueprompts: Nonebatch_sampler: batch_samplermulti_dataset_batch_sampler: proportionalrouter_mapping: {}learning_rate_mapping: {}| Epoch | Step | Training Loss | ncbi-disease-dev_ndcg@10 |
|---|---|---|---|
| 0.0006 | 1 | 0.5977 | - |
| 0.0863 | 150 | 0.5711 | - |
| 0.1726 | 300 | 0.3596 | - |
| 0.2589 | 450 | 0.2451 | - |
| 0.3452 | 600 | 0.1933 | - |
| 0.4315 | 750 | 0.1629 | - |
| 0.5178 | 900 | 0.1448 | - |
| 0.6041 | 1050 | 0.1271 | - |
| 0.6904 | 1200 | 0.1102 | - |
| 0.7768 | 1350 | 0.0986 | - |
| 0.8631 | 1500 | 0.0927 | 0.9963 (+0.4169) |
| 0.9494 | 1650 | 0.0821 | - |
| 1.0357 | 1800 | 0.073 | - |
| 1.1220 | 1950 | 0.0641 | - |
| 1.2083 | 2100 | 0.0544 | - |
| 1.2946 | 2250 | 0.055 | - |
| 1.3809 | 2400 | 0.0556 | - |
| 1.4672 | 2550 | 0.0546 | - |
| 1.5535 | 2700 | 0.0514 | - |
| 1.6398 | 2850 | 0.0463 | - |
| 1.7261 | 3000 | 0.0416 | 0.9984 (+0.4190) |
| 1.8124 | 3150 | 0.043 | - |
| 1.8987 | 3300 | 0.0433 | - |
| 1.9850 | 3450 | 0.0425 | - |
| 2.0713 | 3600 | 0.0322 | - |
| 2.1577 | 3750 | 0.0272 | - |
| 2.2440 | 3900 | 0.0273 | - |
| 2.3303 | 4050 | 0.0274 | - |
| 2.4166 | 4200 | 0.0265 | - |
| 2.5029 | 4350 | 0.0285 | - |
| 2.5892 | 4500 | 0.0249 | 0.9987 (+0.4194) |
| 2.6755 | 4650 | 0.0263 | - |
| 2.7618 | 4800 | 0.0252 | - |
| 2.8481 | 4950 | 0.0256 | - |
| 2.9344 | 5100 | 0.0247 | - |
| -1 | -1 | - | 0.9987 (+0.4194) |
@inproceedings{reimers-2019-sentence-bert,
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
author = "Reimers, Nils and Gurevych, Iryna",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
month = "11",
year = "2019",
publisher = "Association for Computational Linguistics",
url = "https://arxiv.org/abs/1908.10084",
}