Update README.md
Browse files
README.md
CHANGED
|
@@ -51,7 +51,102 @@ python inference.py \
|
|
| 51 |
--score_threshold -1 \
|
| 52 |
--seed 42 \
|
| 53 |
--device cuda:0
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
---
|
| 57 |
|
|
|
|
| 51 |
--score_threshold -1 \
|
| 52 |
--seed 42 \
|
| 53 |
--device cuda:0
|
| 54 |
+
```
|
| 55 |
+
---
|
| 56 |
+
|
| 57 |
+
## Usage Example
|
| 58 |
+
|
| 59 |
+
This example shows a typical workflow for a **single user**:
|
| 60 |
+
1) encode text pairs with Skywork-Reward-V2-Llama-3.1-8B into embeddings,
|
| 61 |
+
2) adapt the MRM on the user's few-shot examples (update `shared_weight` only),
|
| 62 |
+
3) run inference on new pairs for that same user.
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
import torch
|
| 66 |
+
from copy import deepcopy
|
| 67 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 68 |
+
|
| 69 |
+
from utils import bt_loss
|
| 70 |
+
from train import MRM
|
| 71 |
+
from inference import load_ckpt_into_model
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@torch.no_grad()
|
| 75 |
+
def encode_pairs(model, tokenizer, pairs, device="cuda"):
|
| 76 |
+
model.eval()
|
| 77 |
+
ch, rj = [], []
|
| 78 |
+
for ex in pairs:
|
| 79 |
+
conv = ex["prompt"]
|
| 80 |
+
for key, buf in [("chosen", ch), ("rejected", rj)]:
|
| 81 |
+
ids = tokenizer.apply_chat_template(
|
| 82 |
+
conv + [{"role": "assistant", "content": ex[key]}],
|
| 83 |
+
tokenize=True, return_tensors="pt"
|
| 84 |
+
).to(device)
|
| 85 |
+
out = model(ids, output_hidden_states=True)
|
| 86 |
+
buf.append(out.hidden_states[-1][0, -1].float().cpu())
|
| 87 |
+
return torch.stack(ch), torch.stack(rj)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def adapt_single_user(base_model, support_ch, support_rj, inner_lr=1e-3, inner_epochs=5, device="cuda"):
|
| 91 |
+
model = deepcopy(base_model).to(device).train()
|
| 92 |
+
opt = torch.optim.Adam([model.shared_weight], lr=inner_lr)
|
| 93 |
+
support_ch, support_rj = support_ch.to(device), support_rj.to(device)
|
| 94 |
+
for _ in range(inner_epochs):
|
| 95 |
+
opt.zero_grad()
|
| 96 |
+
loss = bt_loss(model(support_ch), model(support_rj))
|
| 97 |
+
loss.backward()
|
| 98 |
+
opt.step()
|
| 99 |
+
return model.eval()
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def infer_on_pairs(model, ch, rj, device="cuda"):
|
| 104 |
+
return model(ch.to(device)), model(rj.to(device))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 108 |
+
|
| 109 |
+
MODEL_PATH = "Skywork/Skywork-Reward-V2-Llama-3.1-8B"
|
| 110 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 111 |
+
llm = AutoModelForSequenceClassification.from_pretrained(
|
| 112 |
+
MODEL_PATH, num_labels=1, torch_dtype=torch.bfloat16, device_map=device
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
CKPT_PATH = "ckpt/model.pt"
|
| 116 |
+
mrm = MRM(in_dim=4096, hidden_sizes=[2], use_bias=False)
|
| 117 |
+
load_ckpt_into_model(mrm, CKPT_PATH, device)
|
| 118 |
+
|
| 119 |
+
support_pairs = [
|
| 120 |
+
{
|
| 121 |
+
"prompt": [{"role": "user", "content": "TL;DR this post: I tried waking up at 5am for a month and tracked my productivity."}],
|
| 122 |
+
"chosen": "Waking up early helped at first, but long-term productivity depended more on sleep quality than wake-up time.",
|
| 123 |
+
"rejected": "The post is about waking up early and productivity.",
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"prompt": [{"role": "user", "content": "Summarize the main point: I switched from iPhone to Android after 10 years."}],
|
| 127 |
+
"chosen": "The author values customization and battery life more than ecosystem lock-in, which motivated the switch.",
|
| 128 |
+
"rejected": "The author bought a new phone.",
|
| 129 |
+
},
|
| 130 |
+
]
|
| 131 |
+
|
| 132 |
+
sup_ch, sup_rj = encode_pairs(llm, tokenizer, support_pairs, device)
|
| 133 |
+
user_mrm = adapt_single_user(mrm, sup_ch, sup_rj, device=device)
|
| 134 |
+
|
| 135 |
+
test_pairs = [
|
| 136 |
+
{
|
| 137 |
+
"prompt": [{"role": "user", "content": "TL;DR: I quit my job to freelance and here is what I learned in 6 months."}],
|
| 138 |
+
"chosen": "Freelancing offers flexibility but requires strong self-discipline and financial planning to be sustainable.",
|
| 139 |
+
"rejected": "The author talks about quitting a job and freelancing.",
|
| 140 |
+
}
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
test_ch, test_rj = encode_pairs(llm, tokenizer, test_pairs, device)
|
| 144 |
+
s_ch, s_rj = infer_on_pairs(user_mrm, test_ch, test_rj, device)
|
| 145 |
+
|
| 146 |
+
print("reward(chosen) =", s_ch.tolist())
|
| 147 |
+
print("reward(rejected)=", s_rj.tolist())
|
| 148 |
+
|
| 149 |
+
```
|
| 150 |
|
| 151 |
---
|
| 152 |
|