File size: 2,272 Bytes
e0313f8
 
a3debdb
e0313f8
 
 
 
 
 
 
 
 
 
 
 
 
 
a3debdb
e0313f8
 
 
a3debdb
e0313f8
 
 
 
 
 
 
 
 
 
 
 
a3debdb
e0313f8
 
 
 
 
 
a3debdb
e0313f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10bf2bf
e0313f8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
---
license: llama3.1
base_model: DUTIR-BioNLP/RexDrug-base
library_name: peft
pipeline_tag: text-generation
tags:
  - drug-combination
  - relation-extraction
  - biomedical
  - llama
  - chain-of-thought
  - lora
  - grpo
---

# RexDrug-adapter

This is the LoRA adapter for **RexDrug**, trained via GRPO (Group Relative Policy Optimization) on top of [RexDrug-base](https://huggingface.co/DUTIR-BioNLP/RexDrug-base) for biomedical drug combination relation extraction with chain-of-thought reasoning.

## Model Details

- **Base model**: [DUTIR-BioNLP/RexDrug-base](https://huggingface.co/DUTIR-BioNLP/RexDrug-base) (Llama-3.1-8B-Instruct + SFT)
- **Fine-tuning method**: GRPO with LoRA (r=64, alpha=128)
- **Task**: Drug combination relation extraction from biomedical literature
- **Relation types**: POS (beneficial), NEG (harmful), COMB (neutral/mixed), NO_COMB (no combination)

## Quick Start

```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# 1. Load model
tokenizer = AutoTokenizer.from_pretrained("DUTIR-BioNLP/RexDrug-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "dlutIR/RexDrug-base",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)
model = PeftModel.from_pretrained(model, "DUTIR-BioNLP/RexDrug-adapter")
model.eval()

# 2. Prepare input
messages = [
    {"role": "system", "content": "You are an expert in biomedical drug-drug relation extraction. ..."},
    {"role": "user",   "content": "Target sentence: ... \nContext paragraph: ..."},
]
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

# 3. Generate
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=1024, do_sample=False)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
print(response)
```

See the full example in the [GitHub repository](https://github.com/DUTIR-BioNLP/RexDrug).

## License

This model is built upon Llama 3.1 and is subject to the [Llama 3.1 Community License Agreement](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE).