File size: 5,287 Bytes
6916e77
 
c96f787
 
f31c49b
c96f787
 
 
 
 
 
 
 
81a5088
6916e77
 
982f067
 
3133121
 
982f067
4836bcd
6916e77
4836bcd
6916e77
4836bcd
 
 
 
6916e77
 
 
982f067
4836bcd
982f067
 
 
6916e77
967ae4a
6916e77
 
 
967ae4a
 
982f067
6916e77
 
982f067
6916e77
 
 
982f067
 
5ef8c03
c96f787
 
 
982f067
c96f787
 
982f067
c96f787
 
 
 
 
 
982f067
 
 
 
 
 
 
 
 
1107bea
982f067
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1107bea
982f067
 
 
 
 
 
 
 
 
 
 
 
967ae4a
aed2fd8
4836bcd
6916e77
4836bcd
6916e77
4836bcd
 
 
 
 
 
 
 
 
 
 
6916e77
c96f787
d0ae885
 
 
 
356d36e
d0ae885
 
967ae4a
d0ae885
967ae4a
 
 
 
 
 
 
 
 
 
 
d0ae885
 
 
c125efa
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
---
library_name: transformers
tags:
- medical
license: llama3
datasets:
- pensieves/mimicause
language:
- en
metrics:
- f1
base_model:
- johnsnowlabs/JSL-MedLlama-3-8B-v2.0
pipeline_tag: text-generation
---

# CLiMA (Causal Linking for Medical Annotation)

CLiMA (Causal Linking for Medical Annotation) is a causal relation classification model fine-tuned on the MIMICause dataset (https://huggingface.co/datasets/pensieves/mimicause), a collection of de-identified discharge summaries 
sourced from the MIMIC-III (Medical Information Mart for Intensive Care-III) clinical database (https://portal.dbmi.hms.harvard.edu/projects/n2c2-nlp/). 

The model has been trained in the context of a benchmark evaluation of various open-source LLMs and different learning methods on the MIMICause dataset, outperforming SOTA models with higher general reasoning capabilities  (research paper under review).

The model could be used to extract and classify causal relation holding between pairs of biomedical entities from clinical texts. The macro-average F1 score on relation classification on MIMICause test set is 0.829. 

A cross-domain evaluation on a subset of the Drug Reviews dataset (https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com) resulted in a macro-average causal relation classification accuracy of 0.73. 
Dataset sample annotations can be accessed here:  https://drive.google.com/drive/folders/1wU7Px0wHmK-PFtOKNFwpcSS3EYZOesYU?usp=sharing

This work was done at the Department of Mathematics and Computer Science of the University of Cagliari, Italy.

### Model Description

- **Developed by:** Vanni Zavarella, Sergio Consoli, Diego Reforgiato, Lorenzo Bertolini, Alessandro Zani
- **Model type:** Meta-Llama-3-8B Fine-Tuned
- **Language(s) (NLP):** English
- **License:** Llama 3 Community License Agreement
- **Finetuned from model:** johnsnowlabs/JSL-MedLlama-3-8B-v2.0

### Model Sources

<!-- Provide the basic links for the model. -->

- **Repository:**
- **Paper:** [https://...



## How to Use

Use the code below to get started with the model.

### Load model directly
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base_model_id = "johnsnowlabs/JSL-MedLlama-3-8B-v2.0"
model_path = "/path-to-fine-tuned-model"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True
)

base_model =AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_path)

ft_model = PeftModel.from_pretrained(base_model,"model_path")
```
---

### Run model on a dataset

```python

for i,example in enumerate(tqdm.tqdm(tokenized_dataset)):  
  prompt = generate_eval_prompt(example)
  model_input = tokenizer(prompt,return_tensors="pt").to("cuda")
  ft_model.eval()
  with torch.no_grad():
    sys_output = tokenizer.decode(ft_model.generate(**model_input,  max_length=1000)[0], skip_special_tokens=True)


prompt_instruction = f"""Given a text enclosed in triple quotes and a pair of entities E1 and E2, classify the relation holding between E1 and E2.
The relations are identified with 9 labels from 0 to 8. The meaning of the labels is the following:
0 means that E1 causes E2
1 means that E2 causes E1
2 means that E1 enables E2
3 means that E2 enables E1
4 means that E1 prevents E2
5 means that E2 prevents E1
6 means that E1 hinders E2
7 means that E2 hinders E1
8 means that E1 and E2 are in none of the relations above.
For the output use the format LABEL: X
"""

def generate_eval_prompt(example):
  text = normalizeContext(str(example['Text'])) # adapt and define based on your dataset specs
  prompt_text = f"Text:'''{text}'''"
  e1 = example['E1']  # adapt to your dataset specs
  e2 = example['E2']  # adapt to your dataset specs
  prompt_entities = f"\nEntities: E1: '''{e1}''', E2: '''{e2}'''"
  full_prompt = f"<USER> {prompt_instruction} {prompt_text} {prompt_entities} <ASSISTANT>"
  return full_prompt
```


---

## Evaluation

The test data is the test split of the MIMICause dataset (https://huggingface.co/datasets/pensieves/mimicause), including 272 examples.

The macro average F1 over the 9 MIMICause relation labels is 0.829.  Factorized F1 per relation is the following:

|    Relation   | Label |  F1  |
|---------------|-------|------|
|Cause(e1,e2)   |   0   | 0.83 |
|Cause(e2,e1)   |   1   | 0.88 |
|Enable(e1,e2)  |   2   | 0.70 |
|Enable(e2,e1)  |   3   | 0.84 |
|Prevent(e1,e2) |   4   | 0.86 |
|Prevent(e2,e1) |   5   | 0.83 |
|Hinder(e1,e2)  |   6   | 0.80 |
|Hinder(e2,e1)  |   7   | 0.80 |
|   Other       |   8   | 0.89 |


## Training Details

### Training Data

The train split of the MIMICause dataset (https://huggingface.co/datasets/pensieves/mimicause), including 1953 examples.


#### Training Hyperparameters

- **Training regime:** fp16 <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->

- **Validation Metric:** eval_loss

- **Train batch size:** 8

- **Evaluation Strategy:** steps

- **Learning Rate:** 2e-4

- **Optimizer:** paged_adamw_8bit



## Model Card Authors [optional]
Vanni Zavarella 
https://huggingface.co/zavavan