Text Classification
Transformers
HongruCai commited on
Commit
5ea7cbb
·
verified ·
1 Parent(s): 90e6fcb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +97 -1
README.md CHANGED
@@ -54,7 +54,103 @@ python inference.py \
54
  --score_threshold -1 \
55
  --seed 42 \
56
  --device cuda:0
57
- ````
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  ---
60
 
 
54
  --score_threshold -1 \
55
  --seed 42 \
56
  --device cuda:0
57
+ ```
58
+
59
+ ---
60
+
61
+ ## Usage Example
62
+
63
+ This example shows a typical workflow for a **single user**:
64
+ 1) encode text pairs with Skywork-Reward-V2-Llama-3.1-8B into embeddings,
65
+ 2) adapt the MRM on the user's few-shot examples (update `shared_weight` only),
66
+ 3) run inference on new pairs for that same user.
67
+
68
+ ```python
69
+ import torch
70
+ from copy import deepcopy
71
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
72
+
73
+ from utils import bt_loss
74
+ from train import MRM
75
+ from inference import load_ckpt_into_model
76
+
77
+
78
+ @torch.no_grad()
79
+ def encode_pairs(model, tokenizer, pairs, device="cuda"):
80
+ model.eval()
81
+ ch, rj = [], []
82
+ for ex in pairs:
83
+ conv = ex["prompt"]
84
+ for key, buf in [("chosen", ch), ("rejected", rj)]:
85
+ ids = tokenizer.apply_chat_template(
86
+ conv + [{"role": "assistant", "content": ex[key]}],
87
+ tokenize=True, return_tensors="pt"
88
+ ).to(device)
89
+ out = model(ids, output_hidden_states=True)
90
+ buf.append(out.hidden_states[-1][0, -1].float().cpu())
91
+ return torch.stack(ch), torch.stack(rj)
92
+
93
+
94
+ def adapt_single_user(base_model, support_ch, support_rj, inner_lr=1e-3, inner_epochs=5, device="cuda"):
95
+ model = deepcopy(base_model).to(device).train()
96
+ opt = torch.optim.Adam([model.shared_weight], lr=inner_lr)
97
+ support_ch, support_rj = support_ch.to(device), support_rj.to(device)
98
+ for _ in range(inner_epochs):
99
+ opt.zero_grad()
100
+ loss = bt_loss(model(support_ch), model(support_rj))
101
+ loss.backward()
102
+ opt.step()
103
+ return model.eval()
104
+
105
+
106
+ @torch.no_grad()
107
+ def infer_on_pairs(model, ch, rj, device="cuda"):
108
+ return model(ch.to(device)), model(rj.to(device))
109
+
110
+
111
+ device = "cuda" if torch.cuda.is_available() else "cpu"
112
+
113
+ MODEL_PATH = "Skywork/Skywork-Reward-V2-Llama-3.1-8B"
114
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
115
+ llm = AutoModelForSequenceClassification.from_pretrained(
116
+ MODEL_PATH, num_labels=1, torch_dtype=torch.bfloat16, device_map=device
117
+ )
118
+
119
+ CKPT_PATH = "ckpt/model.pt"
120
+ mrm = MRM(in_dim=4096, hidden_sizes=[2], use_bias=False)
121
+ load_ckpt_into_model(mrm, CKPT_PATH, device)
122
+
123
+ support_pairs = [
124
+ {
125
+ "prompt": [{"role": "user", "content": "TL;DR this post: I tried waking up at 5am for a month and tracked my productivity."}],
126
+ "chosen": "Waking up early helped at first, but long-term productivity depended more on sleep quality than wake-up time.",
127
+ "rejected": "The post is about waking up early and productivity.",
128
+ },
129
+ {
130
+ "prompt": [{"role": "user", "content": "Summarize the main point: I switched from iPhone to Android after 10 years."}],
131
+ "chosen": "The author values customization and battery life more than ecosystem lock-in, which motivated the switch.",
132
+ "rejected": "The author bought a new phone.",
133
+ },
134
+ ]
135
+
136
+ sup_ch, sup_rj = encode_pairs(llm, tokenizer, support_pairs, device)
137
+ user_mrm = adapt_single_user(mrm, sup_ch, sup_rj, device=device)
138
+
139
+ test_pairs = [
140
+ {
141
+ "prompt": [{"role": "user", "content": "TL;DR: I quit my job to freelance and here is what I learned in 6 months."}],
142
+ "chosen": "Freelancing offers flexibility but requires strong self-discipline and financial planning to be sustainable.",
143
+ "rejected": "The author talks about quitting a job and freelancing.",
144
+ }
145
+ ]
146
+
147
+ test_ch, test_rj = encode_pairs(llm, tokenizer, test_pairs, device)
148
+ s_ch, s_rj = infer_on_pairs(user_mrm, test_ch, test_rj, device)
149
+
150
+ print("reward(chosen) =", s_ch.tolist())
151
+ print("reward(rejected)=", s_rj.tolist())
152
+
153
+ ```
154
 
155
  ---
156