Text Classification
HongruCai commited on
Commit
418fd34
·
verified ·
1 Parent(s): e584a1d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +96 -1
README.md CHANGED
@@ -51,7 +51,102 @@ python inference.py \
51
  --score_threshold -1 \
52
  --seed 42 \
53
  --device cuda:0
54
- ````
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  ---
57
 
 
51
  --score_threshold -1 \
52
  --seed 42 \
53
  --device cuda:0
54
+ ```
55
+ ---
56
+
57
+ ## Usage Example
58
+
59
+ This example shows a typical workflow for a **single user**:
60
+ 1) encode text pairs with Skywork-Reward-V2-Llama-3.1-8B into embeddings,
61
+ 2) adapt the MRM on the user's few-shot examples (update `shared_weight` only),
62
+ 3) run inference on new pairs for that same user.
63
+
64
+ ```python
65
+ import torch
66
+ from copy import deepcopy
67
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
68
+
69
+ from utils import bt_loss
70
+ from train import MRM
71
+ from inference import load_ckpt_into_model
72
+
73
+
74
+ @torch.no_grad()
75
+ def encode_pairs(model, tokenizer, pairs, device="cuda"):
76
+ model.eval()
77
+ ch, rj = [], []
78
+ for ex in pairs:
79
+ conv = ex["prompt"]
80
+ for key, buf in [("chosen", ch), ("rejected", rj)]:
81
+ ids = tokenizer.apply_chat_template(
82
+ conv + [{"role": "assistant", "content": ex[key]}],
83
+ tokenize=True, return_tensors="pt"
84
+ ).to(device)
85
+ out = model(ids, output_hidden_states=True)
86
+ buf.append(out.hidden_states[-1][0, -1].float().cpu())
87
+ return torch.stack(ch), torch.stack(rj)
88
+
89
+
90
+ def adapt_single_user(base_model, support_ch, support_rj, inner_lr=1e-3, inner_epochs=5, device="cuda"):
91
+ model = deepcopy(base_model).to(device).train()
92
+ opt = torch.optim.Adam([model.shared_weight], lr=inner_lr)
93
+ support_ch, support_rj = support_ch.to(device), support_rj.to(device)
94
+ for _ in range(inner_epochs):
95
+ opt.zero_grad()
96
+ loss = bt_loss(model(support_ch), model(support_rj))
97
+ loss.backward()
98
+ opt.step()
99
+ return model.eval()
100
+
101
+
102
+ @torch.no_grad()
103
+ def infer_on_pairs(model, ch, rj, device="cuda"):
104
+ return model(ch.to(device)), model(rj.to(device))
105
+
106
+
107
+ device = "cuda" if torch.cuda.is_available() else "cpu"
108
+
109
+ MODEL_PATH = "Skywork/Skywork-Reward-V2-Llama-3.1-8B"
110
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
111
+ llm = AutoModelForSequenceClassification.from_pretrained(
112
+ MODEL_PATH, num_labels=1, torch_dtype=torch.bfloat16, device_map=device
113
+ )
114
+
115
+ CKPT_PATH = "ckpt/model.pt"
116
+ mrm = MRM(in_dim=4096, hidden_sizes=[2], use_bias=False)
117
+ load_ckpt_into_model(mrm, CKPT_PATH, device)
118
+
119
+ support_pairs = [
120
+ {
121
+ "prompt": [{"role": "user", "content": "TL;DR this post: I tried waking up at 5am for a month and tracked my productivity."}],
122
+ "chosen": "Waking up early helped at first, but long-term productivity depended more on sleep quality than wake-up time.",
123
+ "rejected": "The post is about waking up early and productivity.",
124
+ },
125
+ {
126
+ "prompt": [{"role": "user", "content": "Summarize the main point: I switched from iPhone to Android after 10 years."}],
127
+ "chosen": "The author values customization and battery life more than ecosystem lock-in, which motivated the switch.",
128
+ "rejected": "The author bought a new phone.",
129
+ },
130
+ ]
131
+
132
+ sup_ch, sup_rj = encode_pairs(llm, tokenizer, support_pairs, device)
133
+ user_mrm = adapt_single_user(mrm, sup_ch, sup_rj, device=device)
134
+
135
+ test_pairs = [
136
+ {
137
+ "prompt": [{"role": "user", "content": "TL;DR: I quit my job to freelance and here is what I learned in 6 months."}],
138
+ "chosen": "Freelancing offers flexibility but requires strong self-discipline and financial planning to be sustainable.",
139
+ "rejected": "The author talks about quitting a job and freelancing.",
140
+ }
141
+ ]
142
+
143
+ test_ch, test_rj = encode_pairs(llm, tokenizer, test_pairs, device)
144
+ s_ch, s_rj = infer_on_pairs(user_mrm, test_ch, test_rj, device)
145
+
146
+ print("reward(chosen) =", s_ch.tolist())
147
+ print("reward(rejected)=", s_rj.tolist())
148
+
149
+ ```
150
 
151
  ---
152