--- license: apache-2.0 base_model: - meta-llama/Meta-Llama-3-8B-Instruct language: - en tags: - BEL - retrieval - entity-retrieval - named-entity-disambiguation - entity-disambiguation - named-entity-linking - entity-linking - text2text-generation - biomedical - healthcare - synthetic-data - causal-lm - llm library_name: transformers finetuning_task: - text2text-generation - entity-linking metrics: - recall model-index: - name: syncabel-medmentions-8b results: - task: type: entity-linking dataset: type: structured_dataset name: medmentions config: st21pv metrics: - type: recall value: 0.754 --- # SynCABEL: Synthetic Contextualized Augmentation for Biomedical Entity Linking ## SynCABEL **SynCABEL** is a novel framework that addresses data scarcity in biomedical entity linking through **synthetic data generation**. The method, introduced in our [paper] ## SynCABEL (SPACCC Edition) This is a **finetuned version of LLaMA-3-8B** trained on **MedMentions** using **SynthMM** (our synthetic dataset generated via the SynCABEL framework). | | | |--------|---------| | **Base Model** | [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | | **Training Data** | [MedMentions](https://huggingface.co/datasets/bigbio/medmentions) (real) + [SynthMM](https://huggingface.co/datasets/Aremaki/SynCABEL) (synthetic) | | **Fine-tuning** | [Supervised Fine-Tuning](https://huggingface.co/docs/trl/en/sft_trainer) | ## Training Data Composition The model is trained on a mix of **human-annotated** and **synthetic** data: ``` MedMentions (human) : 4,392 abstracts SynthMM (synthetic) : ~50,000 samples ``` To ensure balanced learning, **human data is upsampled during training** so that each batch contains: ``` 50% human-annotated data 50% synthetic data ``` In other words, although SynthMM is larger, the model always sees a **1:1 ratio of human to synthetic examples**, preventing synthetic data from overwhelming human supervision. ## Usage ### Loading ```python import torch from transformers import AutoModelForCausalLM # Load the model (requires trust_remote_code for custom architecture) model = AutoModelForCausalLM.from_pretrained( "Aremaki/SynCABEL_MedMentions", trust_remote_code=True, device_map="auto" ) ``` ### Unconstrained Generation ```python # Let the model freely generate concept names sentences = [ "[Ibuprofen]{Chemicals & Drugs} is a non-steroidal anti-inflammatory drug", "[Myocardial infarction]{Disorders} requires immediate intervention" ] results = model.sample( sentences=sentences, constrained=False, num_beams=3, ) for i, beam_results in enumerate(results): print(f"Input: {sentences[i]}") mention = beam_results[0]["mention"] print(f"Mention: {mention}") for j, result in enumerate(beam_results): print( f"Beam {j+1}" f"Predicted concept name:{result['pred_concept_name']}" f"Predicted code: {result['pred_concept_code']} " f"Beam score: {result['beam_score']:.3f})" ) ``` **Output:** ``` Input: [Ibuprofen]{Chemicals & Drugs} is a non-steroidal anti-inflammatory drug Mention: Ibuprofen Beam 1: Predicted concept name:Ibuprofen Predicted code: C0020740 Beam score: 1.000 Beam 2: Predicted concept name:IBUPROFEN Predicted code: NO_CODE Beam score: 0.114 Beam 3: Predicted concept name:IBUPROfen Predicted code: NO_CODE Beam score: 0.060 Input: [Myocardial infarction]{Disorders} requires immediate intervention Mention: Myocardial infarction Beam 1: Predicted concept name:Myocardial infarction Predicted code: C0027051 Beam score: 1.000 Beam 2: Predicted concept name:Myocardial Infarction Predicted code: C0027051 Beam score: 0.200 Beam 3: Predicted concept name:myocardial infarction Predicted code: NO_CODE Beam score: 0.149 ``` ### Constrained Decoding (Recommended for Entity Linking) ```python # Constrained to valid biomedical concepts sentences = [ "[Ibuprofen]{Chemicals & Drugs} is a non-steroidal anti-inflammatory drug", "[Myocardial infarction]{Disorders} requires immediate intervention" ] results = model.sample( sentences=sentences, constrained=False, num_beams=3, ) for i, beam_results in enumerate(results): print(f"Input: {sentences[i]}") mention = beam_results[0]["mention"] print(f"Mention: {mention}") for j, result in enumerate(beam_results): print( f"Beam {j+1}:\n" f"Predicted concept name:{result['pred_concept_name']}\n" f"Predicted code: {result['pred_concept_code']}\n" f"Beam score: {result['beam_score']:.3f}\n" ) ``` **Output:** ``` Input: [Ibuprofen]{Chemicals & Drugs} is a non-steroidal anti-inflammatory drug Mention: Ibuprofen Beam 1: Predicted concept name:Ibuprofen Predicted code: C0020740 Beam score: 1.000 Beam 2: Predicted concept name:IBUPROFEN/PSEUDOEPHEDRINE Predicted code: C0717858 Beam score: 0.065 Beam 3: Predicted concept name:Ibuprofen (substance) Predicted code: C0020740 Beam score: 0.056 Input: [Myocardial infarction]{Disorders} requires immediate intervention Mention: Myocardial infarction Beam 1: Predicted concept name:Myocardial infarction Predicted code: C0027051 Beam score: 1.000 Beam 2: Predicted concept name:Myocardial Infarction Predicted code: C0027051 Beam score: 0.200 Beam 3: Predicted concept name:Myocardial infarction (disorder) Predicted code: C0027051 Beam score: 0.194 ``` ## Assets The model automatically loads: - `text_to_code.json`: Maps concept names to ontology codes (UMLS, SNOMED CT) - `candidate_trie.pkl`: Prefix tree for efficient constrained decoding ## MedMentions Test Set Results | Training Data | Recall@1 | Improvement | |---------------|----------|-------------| | MedMentions Only | 0.76 | Baseline | | + SynthMM (Ours) | **0.85** | **+11.8%** | ### Comparison with State-of-the-Art | Model | F1 Score | Training Data | |-------|----------|---------------| | **SapBERT** | 0.83 | MedMentions + UMLS | | **BioSyn** | 0.81 | MedMentions | | **GENRE (baseline)** | 0.79 | MedMentions | | **SynCABEL-8B (Ours)** | **0.85** | MedMentions + SynthMM | | **SynCABEL-8B (w/ UMLS)** | **0.88** | + UMLS pretraining | ### Speed and Efficiency | Batch Size | Avg. Latency | Throughput | |------------|--------------|------------| | 1 | 120ms | 8.3 samples/sec | | 8 | 650ms | 12.3 samples/sec | | 16 | 1.2s | 13.3 samples/sec | | 32 | 2.1s | 15.2 samples/sec | *Measured on single H100 GPU, constrained decoding*