Cambridge-SapBERT-from-PubMedBERT-fulltext trained on NCBI Disease

This is a Cross Encoder model finetuned from cambridgeltl/SapBERT-from-PubMedBERT-fulltext using the sentence-transformers library. It computes scores for pairs of texts, which can be used for text reranking and semantic search.

Model Details

Model Description

Model Sources

Usage

Direct Usage (Sentence Transformers)

First install the Sentence Transformers library:

pip install -U sentence-transformers

Then you can load this model and run inference.

from sentence_transformers import CrossEncoder

# Download from the 🤗 Hub
model = CrossEncoder("cross_encoder_model_id")
# Get scores for pairs of texts
pairs = [
    ['Original mention: L-N.\nContext: Partial HPRT deficiencies are associated with gouty arthritis, while absence of activity results in Lesch-Nyhan syndrome (L-N).', 'guanine phosphoribosyltransferase deficiencies'],
    ['Original mention: breast cancer.\nContext: Most multiple case families of young onset breast cancer and ovarian cancer are thought to be due to highly penetrant mutations in the predisposing genes BRCA1 and BRCA2.', 'malignancies'],
    ['Original mention: AS.\nContext: We report here the characterization of a transgene insertion (Epstein-Barr virus Latent Membrane Protein 2A, LMP2A) into mouse chromosome 7C, which has resulted in mouse models for PWS and AS dependent on the sex of the transmitting parent.', 'achondroplastic dwarfism'],
    ['Original mention: adenomatous polyposis coli.\nContext: Epidemiologic studies have shown an increased frequency of this tumor type in families affected by adenomatous polyposis coli.', 'apc (adenomatous polyposis coli)'],
    ['Original mention: autosomal recessive disorder.\nContext: Mucopolysaccharidosis IVA (MPS IVA) is an autosomal recessive disorder caused by a deficiency in N-acetylgalactosamine-6-sulfatase (GALNS).', 'susceptibility, disease'],
]
scores = model.predict(pairs)
print(scores.shape)
# (5,)

# Or rank different texts based on similarity to a single text
ranks = model.rank(
    'Original mention: L-N.\nContext: Partial HPRT deficiencies are associated with gouty arthritis, while absence of activity results in Lesch-Nyhan syndrome (L-N).',
    [
        'guanine phosphoribosyltransferase deficiencies',
        'malignancies',
        'achondroplastic dwarfism',
        'apc (adenomatous polyposis coli)',
        'susceptibility, disease',
    ]
)
# [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]

Evaluation

Metrics

Cross Encoder Reranking

Metric Value
map 0.9979 (+0.5536)
mrr@10 0.9986 (+0.7236)
ndcg@10 0.9987 (+0.4194)

Training Details

Training Dataset

Unnamed Dataset

  • Size: 222,348 training samples
  • Columns: query, answer, and label
  • Approximate statistics based on the first 1000 samples:
    query answer label
    type string string int
    details
    • min: 64 characters
    • mean: 207.06 characters
    • max: 558 characters
    • min: 5 characters
    • mean: 27.99 characters
    • max: 107 characters
    • 0: ~46.10%
    • 1: ~53.90%
  • Samples:
    query answer label
    Original mention: L-N.
    Context: Partial HPRT deficiencies are associated with gouty arthritis, while absence of activity results in Lesch-Nyhan syndrome (L-N).
    guanine phosphoribosyltransferase deficiencies 1
    Original mention: breast cancer.
    Context: Most multiple case families of young onset breast cancer and ovarian cancer are thought to be due to highly penetrant mutations in the predisposing genes BRCA1 and BRCA2.
    malignancies 0
    Original mention: AS.
    Context: We report here the characterization of a transgene insertion (Epstein-Barr virus Latent Membrane Protein 2A, LMP2A) into mouse chromosome 7C, which has resulted in mouse models for PWS and AS dependent on the sex of the transmitting parent.
    achondroplastic dwarfism 0
  • Loss: BinaryCrossEntropyLoss with these parameters:
    {
        "activation_fn": "torch.nn.modules.linear.Identity",
        "pos_weight": 0.7792357206344604
    }
    

Training Hyperparameters

Non-Default Hyperparameters

  • eval_strategy: steps
  • per_device_train_batch_size: 128
  • per_device_eval_batch_size: 128
  • learning_rate: 2e-05
  • warmup_ratio: 0.05
  • seed: 12
  • bf16: True
  • dataloader_num_workers: 4
  • load_best_model_at_end: True

All Hyperparameters

Click to expand
  • overwrite_output_dir: False
  • do_predict: False
  • eval_strategy: steps
  • prediction_loss_only: True
  • per_device_train_batch_size: 128
  • per_device_eval_batch_size: 128
  • per_gpu_train_batch_size: None
  • per_gpu_eval_batch_size: None
  • gradient_accumulation_steps: 1
  • eval_accumulation_steps: None
  • torch_empty_cache_steps: None
  • learning_rate: 2e-05
  • weight_decay: 0.0
  • adam_beta1: 0.9
  • adam_beta2: 0.999
  • adam_epsilon: 1e-08
  • max_grad_norm: 1.0
  • num_train_epochs: 3
  • max_steps: -1
  • lr_scheduler_type: linear
  • lr_scheduler_kwargs: {}
  • warmup_ratio: 0.05
  • warmup_steps: 0
  • log_level: passive
  • log_level_replica: warning
  • log_on_each_node: True
  • logging_nan_inf_filter: True
  • save_safetensors: True
  • save_on_each_node: False
  • save_only_model: False
  • restore_callback_states_from_checkpoint: False
  • no_cuda: False
  • use_cpu: False
  • use_mps_device: False
  • seed: 12
  • data_seed: None
  • jit_mode_eval: False
  • bf16: True
  • fp16: False
  • fp16_opt_level: O1
  • half_precision_backend: auto
  • bf16_full_eval: False
  • fp16_full_eval: False
  • tf32: None
  • local_rank: 0
  • ddp_backend: None
  • tpu_num_cores: None
  • tpu_metrics_debug: False
  • debug: []
  • dataloader_drop_last: False
  • dataloader_num_workers: 4
  • dataloader_prefetch_factor: None
  • past_index: -1
  • disable_tqdm: False
  • remove_unused_columns: True
  • label_names: None
  • load_best_model_at_end: True
  • ignore_data_skip: False
  • fsdp: []
  • fsdp_min_num_params: 0
  • fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
  • fsdp_transformer_layer_cls_to_wrap: None
  • accelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}
  • parallelism_config: None
  • deepspeed: None
  • label_smoothing_factor: 0.0
  • optim: adamw_torch_fused
  • optim_args: None
  • adafactor: False
  • group_by_length: False
  • length_column_name: length
  • project: huggingface
  • trackio_space_id: trackio
  • ddp_find_unused_parameters: None
  • ddp_bucket_cap_mb: None
  • ddp_broadcast_buffers: False
  • dataloader_pin_memory: True
  • dataloader_persistent_workers: False
  • skip_memory_metrics: True
  • use_legacy_prediction_loop: False
  • push_to_hub: False
  • resume_from_checkpoint: None
  • hub_model_id: None
  • hub_strategy: every_save
  • hub_private_repo: None
  • hub_always_push: False
  • hub_revision: None
  • gradient_checkpointing: False
  • gradient_checkpointing_kwargs: None
  • include_inputs_for_metrics: False
  • include_for_metrics: []
  • eval_do_concat_batches: True
  • fp16_backend: auto
  • push_to_hub_model_id: None
  • push_to_hub_organization: None
  • mp_parameters:
  • auto_find_batch_size: False
  • full_determinism: False
  • torchdynamo: None
  • ray_scope: last
  • ddp_timeout: 1800
  • torch_compile: False
  • torch_compile_backend: None
  • torch_compile_mode: None
  • include_tokens_per_second: False
  • include_num_input_tokens_seen: no
  • neftune_noise_alpha: None
  • optim_target_modules: None
  • batch_eval_metrics: False
  • eval_on_start: False
  • use_liger_kernel: False
  • liger_kernel_config: None
  • eval_use_gather_object: False
  • average_tokens_across_devices: True
  • prompts: None
  • batch_sampler: batch_sampler
  • multi_dataset_batch_sampler: proportional
  • router_mapping: {}
  • learning_rate_mapping: {}

Training Logs

Epoch Step Training Loss ncbi-disease-dev_ndcg@10
0.0006 1 0.5977 -
0.0863 150 0.5711 -
0.1726 300 0.3596 -
0.2589 450 0.2451 -
0.3452 600 0.1933 -
0.4315 750 0.1629 -
0.5178 900 0.1448 -
0.6041 1050 0.1271 -
0.6904 1200 0.1102 -
0.7768 1350 0.0986 -
0.8631 1500 0.0927 0.9963 (+0.4169)
0.9494 1650 0.0821 -
1.0357 1800 0.073 -
1.1220 1950 0.0641 -
1.2083 2100 0.0544 -
1.2946 2250 0.055 -
1.3809 2400 0.0556 -
1.4672 2550 0.0546 -
1.5535 2700 0.0514 -
1.6398 2850 0.0463 -
1.7261 3000 0.0416 0.9984 (+0.4190)
1.8124 3150 0.043 -
1.8987 3300 0.0433 -
1.9850 3450 0.0425 -
2.0713 3600 0.0322 -
2.1577 3750 0.0272 -
2.2440 3900 0.0273 -
2.3303 4050 0.0274 -
2.4166 4200 0.0265 -
2.5029 4350 0.0285 -
2.5892 4500 0.0249 0.9987 (+0.4194)
2.6755 4650 0.0263 -
2.7618 4800 0.0252 -
2.8481 4950 0.0256 -
2.9344 5100 0.0247 -
-1 -1 - 0.9987 (+0.4194)
  • The bold row denotes the saved checkpoint.

Framework Versions

  • Python: 3.12.12
  • Sentence Transformers: 5.2.0
  • Transformers: 4.57.3
  • PyTorch: 2.9.1+cu130
  • Accelerate: 1.12.0
  • Datasets: 4.4.2
  • Tokenizers: 0.22.1

Citation

BibTeX

Sentence Transformers

@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084",
}
Downloads last month
4
Safetensors
Model size
0.1B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for OverSamu/reranker-sapbert-ncbi-disease-bce-context

Finetuned
(22)
this model

Paper for OverSamu/reranker-sapbert-ncbi-disease-bce-context

Evaluation results