Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks
Paper • 1908.10084 • Published • 13
How to use OverSamu/reranker-sapbert-ncbi-disease-bce-context-title with sentence-transformers:
from sentence_transformers import CrossEncoder
model = CrossEncoder("OverSamu/reranker-sapbert-ncbi-disease-bce-context-title")
query = "Which planet is known as the Red Planet?"
passages = [
"Venus is often called Earth's twin because of its similar size and proximity.",
"Mars, known for its reddish appearance, is often referred to as the Red Planet.",
"Jupiter, the largest planet in our solar system, has a prominent red spot.",
"Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]
scores = model.predict([(query, passage) for passage in passages])
print(scores)This is a Cross Encoder model finetuned from cambridgeltl/SapBERT-from-PubMedBERT-fulltext using the sentence-transformers library. It computes scores for pairs of texts, which can be used for text reranking and semantic search.
First install the Sentence Transformers library:
pip install -U sentence-transformers
Then you can load this model and run inference.
from sentence_transformers import CrossEncoder
# Download from the 🤗 Hub
model = CrossEncoder("cross_encoder_model_id")
# Get scores for pairs of texts
pairs = [
['Original mention: breast cancer.\nTitle: BRCA1 mutations in a population-based sample of young women with breast cancer.\nContext: We studied 80 women in whom breast cancer was diagnosed before the age of 35, and who were not selected on the basis of family history.', 'breast tumors'],
['Original mention: childhood cerebral ALD.\nTitle: Predominance of the adrenomyeloneuropathy phenotype of X-linked adrenoleukodystrophy in The Netherlands: a survey of 30 kindreds.\nContext: The phenotypic expression is highly variable, childhood cerebral ALD (CCALD) and adrenomyeloneuropathy (AMN) being the main variants.', 'x-linked adrenoleukodystrophy'],
['Original mention: TSD.\nTitle: The Tay-Sachs disease gene in North American Jewish populations: geographic variations and origin.\nContext: Jews with Polish and/or Russian ancestry constituted 88% of this sample and had a TSD carrier frequency of.', 'gm2 gangliosidosis, type 1'],
['Original mention: PWS.\nTitle: Isolation of molecular probes associated with the chromosome 15 instability in the Prader-Willi syndrome.\nContext: 2 and are shown to be deleted in DNA of one of two patients examined with the PWS.', 'syndrome, royer'],
['Original mention: deficiency of beta-glucocerebrosidase.\nTitle: Homozygous presence of the crossover (fusion gene) mutation identified in a type II Gaucher disease fetus: is this analogous to the Gaucher knock-out mouse model?\nGaucher disease (GD) is an inherited deficiency of beta-glucocerebrosidase (EC 3.\nContext: Homozygous presence of the crossover (fusion gene) mutation identified in a type II Gaucher disease fetus: is this analogous to the Gaucher knock-out mouse model?\nGaucher disease (GD) is an inherited deficiency of beta-glucocerebrosidase (EC 3.', 'gaucher disease, acute neuronopathic type'],
]
scores = model.predict(pairs)
print(scores.shape)
# (5,)
# Or rank different texts based on similarity to a single text
ranks = model.rank(
'Original mention: breast cancer.\nTitle: BRCA1 mutations in a population-based sample of young women with breast cancer.\nContext: We studied 80 women in whom breast cancer was diagnosed before the age of 35, and who were not selected on the basis of family history.',
[
'breast tumors',
'x-linked adrenoleukodystrophy',
'gm2 gangliosidosis, type 1',
'syndrome, royer',
'gaucher disease, acute neuronopathic type',
]
)
# [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]
ncbi-disease-devCrossEncoderRerankingEvaluator with these parameters:{
"at_k": 10,
"always_rerank_positives": false
}
| Metric | Value |
|---|---|
| map | 0.9965 (+0.5546) |
| mrr@10 | 0.9981 (+0.7243) |
| ndcg@10 | 0.9979 (+0.4192) |
query, answer, and label| query | answer | label | |
|---|---|---|---|
| type | string | string | int |
| details |
|
|
|
| query | answer | label |
|---|---|---|
Original mention: breast cancer. |
breast tumors |
1 |
Original mention: childhood cerebral ALD. |
x-linked adrenoleukodystrophy |
1 |
Original mention: TSD. |
gm2 gangliosidosis, type 1 |
1 |
BinaryCrossEntropyLoss with these parameters:{
"activation_fn": "torch.nn.modules.linear.Identity",
"pos_weight": 0.7793161273002625
}
eval_strategy: stepsper_device_train_batch_size: 128per_device_eval_batch_size: 128learning_rate: 2e-05num_train_epochs: 2warmup_ratio: 0.05seed: 12bf16: Truedataloader_num_workers: 4load_best_model_at_end: Trueoverwrite_output_dir: Falsedo_predict: Falseeval_strategy: stepsprediction_loss_only: Trueper_device_train_batch_size: 128per_device_eval_batch_size: 128per_gpu_train_batch_size: Noneper_gpu_eval_batch_size: Nonegradient_accumulation_steps: 1eval_accumulation_steps: Nonetorch_empty_cache_steps: Nonelearning_rate: 2e-05weight_decay: 0.0adam_beta1: 0.9adam_beta2: 0.999adam_epsilon: 1e-08max_grad_norm: 1.0num_train_epochs: 2max_steps: -1lr_scheduler_type: linearlr_scheduler_kwargs: Nonewarmup_ratio: 0.05warmup_steps: 0log_level: passivelog_level_replica: warninglog_on_each_node: Truelogging_nan_inf_filter: Truesave_safetensors: Truesave_on_each_node: Falsesave_only_model: Falserestore_callback_states_from_checkpoint: Falseno_cuda: Falseuse_cpu: Falseuse_mps_device: Falseseed: 12data_seed: Nonejit_mode_eval: Falsebf16: Truefp16: Falsefp16_opt_level: O1half_precision_backend: autobf16_full_eval: Falsefp16_full_eval: Falsetf32: Nonelocal_rank: 0ddp_backend: Nonetpu_num_cores: Nonetpu_metrics_debug: Falsedebug: []dataloader_drop_last: Falsedataloader_num_workers: 4dataloader_prefetch_factor: Nonepast_index: -1disable_tqdm: Falseremove_unused_columns: Truelabel_names: Noneload_best_model_at_end: Trueignore_data_skip: Falsefsdp: []fsdp_min_num_params: 0fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}fsdp_transformer_layer_cls_to_wrap: Noneaccelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}parallelism_config: Nonedeepspeed: Nonelabel_smoothing_factor: 0.0optim: adamw_torch_fusedoptim_args: Noneadafactor: Falsegroup_by_length: Falselength_column_name: lengthproject: huggingfacetrackio_space_id: trackioddp_find_unused_parameters: Noneddp_bucket_cap_mb: Noneddp_broadcast_buffers: Falsedataloader_pin_memory: Truedataloader_persistent_workers: Falseskip_memory_metrics: Trueuse_legacy_prediction_loop: Falsepush_to_hub: Falseresume_from_checkpoint: Nonehub_model_id: Nonehub_strategy: every_savehub_private_repo: Nonehub_always_push: Falsehub_revision: Nonegradient_checkpointing: Falsegradient_checkpointing_kwargs: Noneinclude_inputs_for_metrics: Falseinclude_for_metrics: []eval_do_concat_batches: Truefp16_backend: autopush_to_hub_model_id: Nonepush_to_hub_organization: Nonemp_parameters: auto_find_batch_size: Falsefull_determinism: Falsetorchdynamo: Noneray_scope: lastddp_timeout: 1800torch_compile: Falsetorch_compile_backend: Nonetorch_compile_mode: Noneinclude_tokens_per_second: Falseinclude_num_input_tokens_seen: noneftune_noise_alpha: Noneoptim_target_modules: Nonebatch_eval_metrics: Falseeval_on_start: Falseuse_liger_kernel: Falseliger_kernel_config: Noneeval_use_gather_object: Falseaverage_tokens_across_devices: Trueprompts: Nonebatch_sampler: batch_samplermulti_dataset_batch_sampler: proportionalrouter_mapping: {}learning_rate_mapping: {}| Epoch | Step | Training Loss | ncbi-disease-dev_ndcg@10 |
|---|---|---|---|
| 0.0006 | 1 | 0.6083 | - |
| 0.0863 | 150 | 0.567 | - |
| 0.1725 | 300 | 0.3364 | - |
| 0.2588 | 450 | 0.2209 | - |
| 0.3450 | 600 | 0.1784 | - |
| 0.4313 | 750 | 0.1435 | - |
| 0.5175 | 900 | 0.1324 | - |
| 0.6038 | 1050 | 0.1137 | - |
| 0.6901 | 1200 | 0.103 | - |
| 0.7763 | 1350 | 0.0934 | - |
| 0.8626 | 1500 | 0.0842 | 0.9949 (+0.4162) |
| 0.9488 | 1650 | 0.0797 | - |
| 1.0351 | 1800 | 0.0695 | - |
| 1.1213 | 1950 | 0.0573 | - |
| 1.2076 | 2100 | 0.0613 | - |
| 1.2938 | 2250 | 0.0555 | - |
| 1.3801 | 2400 | 0.0504 | - |
| 1.4664 | 2550 | 0.0499 | - |
| 1.5526 | 2700 | 0.049 | - |
| 1.6389 | 2850 | 0.0489 | - |
| 1.7251 | 3000 | 0.0424 | 0.9979 (+0.4192) |
| 1.8114 | 3150 | 0.0411 | - |
| 1.8976 | 3300 | 0.0405 | - |
| 1.9839 | 3450 | 0.0405 | - |
| -1 | -1 | - | 0.9979 (+0.4192) |
@inproceedings{reimers-2019-sentence-bert,
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
author = "Reimers, Nils and Gurevych, Iryna",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
month = "11",
year = "2019",
publisher = "Association for Computational Linguistics",
url = "https://arxiv.org/abs/1908.10084",
}