| | --- |
| | library_name: transformers |
| | tags: |
| | - medical |
| | license: llama3 |
| | datasets: |
| | - pensieves/mimicause |
| | language: |
| | - en |
| | metrics: |
| | - f1 |
| | base_model: |
| | - johnsnowlabs/JSL-MedLlama-3-8B-v2.0 |
| | pipeline_tag: text-generation |
| | --- |
| | |
| | # CLiMA (Causal Linking for Medical Annotation) |
| |
|
| | CLiMA (Causal Linking for Medical Annotation) is a causal relation classification model fine-tuned on the MIMICause dataset (https://huggingface.co/datasets/pensieves/mimicause), a collection of de-identified discharge summaries |
| | sourced from the MIMIC-III (Medical Information Mart for Intensive Care-III) clinical database (https://portal.dbmi.hms.harvard.edu/projects/n2c2-nlp/). |
| |
|
| | The model has been trained in the context of a benchmark evaluation of various open-source LLMs and different learning methods on the MIMICause dataset, outperforming SOTA models with higher general reasoning capabilities (research paper under review). |
| |
|
| | The model could be used to extract and classify causal relation holding between pairs of biomedical entities from clinical texts. The macro-average F1 score on relation classification on MIMICause test set is 0.829. |
| |
|
| | A cross-domain evaluation on a subset of the Drug Reviews dataset (https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com) resulted in a macro-average causal relation classification accuracy of 0.73. |
| | Dataset sample annotations can be accessed here: https://drive.google.com/drive/folders/1wU7Px0wHmK-PFtOKNFwpcSS3EYZOesYU?usp=sharing |
| |
|
| | This work was done at the Department of Mathematics and Computer Science of the University of Cagliari, Italy. |
| |
|
| | ### Model Description |
| |
|
| | - **Developed by:** Vanni Zavarella, Sergio Consoli, Diego Reforgiato, Lorenzo Bertolini, Alessandro Zani |
| | - **Model type:** Meta-Llama-3-8B Fine-Tuned |
| | - **Language(s) (NLP):** English |
| | - **License:** Llama 3 Community License Agreement |
| | - **Finetuned from model:** johnsnowlabs/JSL-MedLlama-3-8B-v2.0 |
| |
|
| | ### Model Sources |
| |
|
| | <!-- Provide the basic links for the model. --> |
| |
|
| | - **Repository:** |
| | - **Paper:** [https://... |
| |
|
| |
|
| |
|
| | ## How to Use |
| |
|
| | Use the code below to get started with the model. |
| |
|
| | ### Load model directly |
| | ```python |
| | from transformers import AutoModelForCausalLM, AutoTokenizer |
| | from peft import PeftModel |
| | |
| | base_model_id = "johnsnowlabs/JSL-MedLlama-3-8B-v2.0" |
| | model_path = "/path-to-fine-tuned-model" |
| | bnb_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_use_double_quant=True |
| | ) |
| | |
| | base_model =AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map='auto') |
| | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | |
| | ft_model = PeftModel.from_pretrained(base_model,"model_path") |
| | ``` |
| | --- |
| |
|
| | ### Run model on a dataset |
| |
|
| | ```python |
| | |
| | for i,example in enumerate(tqdm.tqdm(tokenized_dataset)): |
| | prompt = generate_eval_prompt(example) |
| | model_input = tokenizer(prompt,return_tensors="pt").to("cuda") |
| | ft_model.eval() |
| | with torch.no_grad(): |
| | sys_output = tokenizer.decode(ft_model.generate(**model_input, max_length=1000)[0], skip_special_tokens=True) |
| | |
| | |
| | prompt_instruction = f"""Given a text enclosed in triple quotes and a pair of entities E1 and E2, classify the relation holding between E1 and E2. |
| | The relations are identified with 9 labels from 0 to 8. The meaning of the labels is the following: |
| | 0 means that E1 causes E2 |
| | 1 means that E2 causes E1 |
| | 2 means that E1 enables E2 |
| | 3 means that E2 enables E1 |
| | 4 means that E1 prevents E2 |
| | 5 means that E2 prevents E1 |
| | 6 means that E1 hinders E2 |
| | 7 means that E2 hinders E1 |
| | 8 means that E1 and E2 are in none of the relations above. |
| | For the output use the format LABEL: X |
| | """ |
| | |
| | def generate_eval_prompt(example): |
| | text = normalizeContext(str(example['Text'])) # adapt and define based on your dataset specs |
| | prompt_text = f"Text:'''{text}'''" |
| | e1 = example['E1'] # adapt to your dataset specs |
| | e2 = example['E2'] # adapt to your dataset specs |
| | prompt_entities = f"\nEntities: E1: '''{e1}''', E2: '''{e2}'''" |
| | full_prompt = f"<USER> {prompt_instruction} {prompt_text} {prompt_entities} <ASSISTANT>" |
| | return full_prompt |
| | ``` |
| |
|
| |
|
| | --- |
| |
|
| | ## Evaluation |
| |
|
| | The test data is the test split of the MIMICause dataset (https://huggingface.co/datasets/pensieves/mimicause), including 272 examples. |
| |
|
| | The macro average F1 over the 9 MIMICause relation labels is 0.829. Factorized F1 per relation is the following: |
| |
|
| | | Relation | Label | F1 | |
| | |---------------|-------|------| |
| | |Cause(e1,e2) | 0 | 0.83 | |
| | |Cause(e2,e1) | 1 | 0.88 | |
| | |Enable(e1,e2) | 2 | 0.70 | |
| | |Enable(e2,e1) | 3 | 0.84 | |
| | |Prevent(e1,e2) | 4 | 0.86 | |
| | |Prevent(e2,e1) | 5 | 0.83 | |
| | |Hinder(e1,e2) | 6 | 0.80 | |
| | |Hinder(e2,e1) | 7 | 0.80 | |
| | | Other | 8 | 0.89 | |
| |
|
| |
|
| | ## Training Details |
| |
|
| | ### Training Data |
| |
|
| | The train split of the MIMICause dataset (https://huggingface.co/datasets/pensieves/mimicause), including 1953 examples. |
| |
|
| |
|
| | #### Training Hyperparameters |
| |
|
| | - **Training regime:** fp16 <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision --> |
| |
|
| | - **Validation Metric:** eval_loss |
| | |
| | - **Train batch size:** 8 |
| | |
| | - **Evaluation Strategy:** steps |
| | |
| | - **Learning Rate:** 2e-4 |
| | |
| | - **Optimizer:** paged_adamw_8bit |
| | |
| | |
| | |
| | ## Model Card Authors [optional] |
| | Vanni Zavarella |
| | https://huggingface.co/zavavan |
| | |