Text Classification
Transformers
HongruCai commited on
Commit
e4a24ab
·
verified ·
1 Parent(s): fdc7320

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +97 -1
README.md CHANGED
@@ -52,7 +52,103 @@ python inference.py \
52
  --score_threshold -1 \
53
  --seed 42 \
54
  --device cuda:0
55
- ````
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  ---
58
 
 
52
  --score_threshold -1 \
53
  --seed 42 \
54
  --device cuda:0
55
+ ```
56
+
57
+ ---
58
+
59
+ ## Usage Example
60
+
61
+ This example shows a typical workflow for a **single user**:
62
+ 1) encode text pairs with Skywork-Reward-V2-Llama-3.1-8B into embeddings,
63
+ 2) adapt the MRM on the user's few-shot examples (update `shared_weight` only),
64
+ 3) run inference on new pairs for that same user.
65
+
66
+ ```python
67
+ import torch
68
+ from copy import deepcopy
69
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
70
+
71
+ from utils import bt_loss
72
+ from train import MRM
73
+ from inference import load_ckpt_into_model
74
+
75
+
76
+ @torch.no_grad()
77
+ def encode_pairs(model, tokenizer, pairs, device="cuda"):
78
+ model.eval()
79
+ ch, rj = [], []
80
+ for ex in pairs:
81
+ conv = ex["prompt"]
82
+ for key, buf in [("chosen", ch), ("rejected", rj)]:
83
+ ids = tokenizer.apply_chat_template(
84
+ conv + [{"role": "assistant", "content": ex[key]}],
85
+ tokenize=True, return_tensors="pt"
86
+ ).to(device)
87
+ out = model(ids, output_hidden_states=True)
88
+ buf.append(out.hidden_states[-1][0, -1].float().cpu())
89
+ return torch.stack(ch), torch.stack(rj)
90
+
91
+
92
+ def adapt_single_user(base_model, support_ch, support_rj, inner_lr=1e-3, inner_epochs=5, device="cuda"):
93
+ model = deepcopy(base_model).to(device).train()
94
+ opt = torch.optim.Adam([model.shared_weight], lr=inner_lr)
95
+ support_ch, support_rj = support_ch.to(device), support_rj.to(device)
96
+ for _ in range(inner_epochs):
97
+ opt.zero_grad()
98
+ loss = bt_loss(model(support_ch), model(support_rj))
99
+ loss.backward()
100
+ opt.step()
101
+ return model.eval()
102
+
103
+
104
+ @torch.no_grad()
105
+ def infer_on_pairs(model, ch, rj, device="cuda"):
106
+ return model(ch.to(device)), model(rj.to(device))
107
+
108
+
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
+
111
+ MODEL_PATH = "Skywork/Skywork-Reward-Llama-3.1-8B-v0.2"
112
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
113
+ llm = AutoModelForSequenceClassification.from_pretrained(
114
+ MODEL_PATH, num_labels=1, torch_dtype=torch.bfloat16, device_map=device
115
+ )
116
+
117
+ CKPT_PATH = "ckpt/model.pt"
118
+ mrm = MRM(in_dim=4096, hidden_sizes=[2], use_bias=False)
119
+ load_ckpt_into_model(mrm, CKPT_PATH, device)
120
+
121
+ support_pairs = [
122
+ {
123
+ "prompt": [{"role": "user", "content": "TL;DR this post: I tried waking up at 5am for a month and tracked my productivity."}],
124
+ "chosen": "Waking up early helped at first, but long-term productivity depended more on sleep quality than wake-up time.",
125
+ "rejected": "The post is about waking up early and productivity.",
126
+ },
127
+ {
128
+ "prompt": [{"role": "user", "content": "Summarize the main point: I switched from iPhone to Android after 10 years."}],
129
+ "chosen": "The author values customization and battery life more than ecosystem lock-in, which motivated the switch.",
130
+ "rejected": "The author bought a new phone.",
131
+ },
132
+ ]
133
+
134
+ sup_ch, sup_rj = encode_pairs(llm, tokenizer, support_pairs, device)
135
+ user_mrm = adapt_single_user(mrm, sup_ch, sup_rj, device=device)
136
+
137
+ test_pairs = [
138
+ {
139
+ "prompt": [{"role": "user", "content": "TL;DR: I quit my job to freelance and here is what I learned in 6 months."}],
140
+ "chosen": "Freelancing offers flexibility but requires strong self-discipline and financial planning to be sustainable.",
141
+ "rejected": "The author talks about quitting a job and freelancing.",
142
+ }
143
+ ]
144
+
145
+ test_ch, test_rj = encode_pairs(llm, tokenizer, test_pairs, device)
146
+ s_ch, s_rj = infer_on_pairs(user_mrm, test_ch, test_rj, device)
147
+
148
+ print("reward(chosen) =", s_ch.tolist())
149
+ print("reward(rejected)=", s_rj.tolist())
150
+
151
+ ```
152
 
153
  ---
154