Samarth0710 commited on
Commit
1beb17e
·
verified ·
1 Parent(s): 2b4cb91

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +342 -0
pipeline.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cross-Model LoRA Adapter Prediction
3
+ ====================================
4
+ Model X = Qwen/Qwen2.5-0.5B-Instruct
5
+ Model Y = meta-llama/Llama-3.2-1B-Instruct
6
+ Tasks : A=SST-2, B=AG News, C=Subj, D=Emotion (held out for Y)
7
+
8
+ Pipeline:
9
+ 1. Train LoRA X_A, X_B, X_C, X_D on Model X
10
+ 2. Train LoRA Y_A, Y_B, Y_C, Y_D on Model Y (Y_D = oracle, kept aside)
11
+ 3. Learn mapping f from {X_A,X_B,X_C} -> {Y_A,Y_B,Y_C} via anchor-basis ridge regression
12
+ 4. Predict Y_hat_D = f(X_D)
13
+ 5. Evaluate on D test split: base Y, mean(Y_A,Y_B,Y_C), Y_hat_D (predicted), Y_D (oracle), Y_D trained-on-X-train (sanity)
14
+ 6. Push everything to HF repo
15
+ """
16
+ import os, json, gc, math, time, argparse, shutil
17
+ from pathlib import Path
18
+ from typing import Dict, List, Tuple
19
+
20
+ import numpy as np
21
+ import torch
22
+ from datasets import load_dataset, Dataset
23
+ from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
24
+ from peft import LoraConfig, get_peft_model, PeftModel
25
+ from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
26
+ from trl import SFTTrainer, SFTConfig
27
+
28
+ set_seed(42)
29
+
30
+ # -------------------- Config --------------------
31
+ MODEL_X = "Qwen/Qwen2.5-0.5B-Instruct"
32
+ MODEL_Y = "meta-llama/Llama-3.2-1B-Instruct"
33
+
34
+ LORA_R = 8
35
+ LORA_ALPHA = 16
36
+ LORA_TARGETS = ["q_proj", "v_proj"]
37
+
38
+ TRAIN_PER_TASK = 1500 # SFT examples per task
39
+ EVAL_PER_TASK = 400 # eval examples
40
+ EPOCHS = 1
41
+ BS = 8
42
+ LR = 2e-4
43
+ MAX_LEN = 192
44
+
45
+ OUT = Path("/app/out")
46
+ OUT.mkdir(exist_ok=True, parents=True)
47
+
48
+ # -------------------- Datasets --------------------
49
+ def fmt(prompt: str, label_text: str):
50
+ return [
51
+ {"role": "user", "content": prompt},
52
+ {"role": "assistant", "content": label_text},
53
+ ]
54
+
55
+ def build_task(name: str):
56
+ """Return (train_ds, eval_ds, label_set, prompt_fn) where each row has a 'messages' field."""
57
+ if name == "A": # SST-2 sentiment
58
+ ds = load_dataset("stanfordnlp/sst2")
59
+ labels = ["negative", "positive"]
60
+ def to_msg(r): return {"messages": fmt(
61
+ f"Classify the sentiment of this sentence as 'negative' or 'positive'. Respond with just the label.\n\nSentence: {r['sentence'].strip()}\n\nSentiment:",
62
+ labels[r["label"]])}
63
+ train = ds["train"].shuffle(seed=0).select(range(TRAIN_PER_TASK)).map(to_msg, remove_columns=ds["train"].column_names)
64
+ ev = ds["validation"].shuffle(seed=0).select(range(min(EVAL_PER_TASK, len(ds["validation"])))).map(to_msg, remove_columns=ds["validation"].column_names)
65
+ return train, ev, labels, "sentiment"
66
+ if name == "B": # AG News
67
+ ds = load_dataset("fancyzhx/ag_news")
68
+ labels = ["world", "sports", "business", "sci/tech"]
69
+ def to_msg(r): return {"messages": fmt(
70
+ f"Classify the news topic as 'world', 'sports', 'business', or 'sci/tech'. Respond with just the label.\n\nNews: {r['text'].strip()}\n\nTopic:",
71
+ labels[r["label"]])}
72
+ train = ds["train"].shuffle(seed=0).select(range(TRAIN_PER_TASK)).map(to_msg, remove_columns=ds["train"].column_names)
73
+ ev = ds["test"].shuffle(seed=0).select(range(EVAL_PER_TASK)).map(to_msg, remove_columns=ds["test"].column_names)
74
+ return train, ev, labels, "topic"
75
+ if name == "C": # Subj
76
+ ds = load_dataset("SetFit/subj")
77
+ labels = ["objective", "subjective"]
78
+ def to_msg(r): return {"messages": fmt(
79
+ f"Classify whether this sentence is 'objective' or 'subjective'. Respond with just the label.\n\nSentence: {r['text'].strip()}\n\nLabel:",
80
+ labels[r["label"]])}
81
+ train = ds["train"].shuffle(seed=0).select(range(min(TRAIN_PER_TASK, len(ds["train"])))).map(to_msg, remove_columns=ds["train"].column_names)
82
+ ev = ds["test"].shuffle(seed=0).select(range(min(EVAL_PER_TASK, len(ds["test"])))).map(to_msg, remove_columns=ds["test"].column_names)
83
+ return train, ev, labels, "subjectivity"
84
+ if name == "D": # Emotion
85
+ ds = load_dataset("dair-ai/emotion", "split")
86
+ labels = ["sadness", "joy", "love", "anger", "fear", "surprise"]
87
+ def to_msg(r): return {"messages": fmt(
88
+ f"Classify the emotion of this sentence as one of: 'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'. Respond with just the label.\n\nSentence: {r['text'].strip()}\n\nEmotion:",
89
+ labels[r["label"]])}
90
+ train = ds["train"].shuffle(seed=0).select(range(TRAIN_PER_TASK)).map(to_msg, remove_columns=ds["train"].column_names)
91
+ ev = ds["test"].shuffle(seed=0).select(range(EVAL_PER_TASK)).map(to_msg, remove_columns=ds["test"].column_names)
92
+ return train, ev, labels, "emotion"
93
+ raise ValueError(name)
94
+
95
+ TASKS = ["A", "B", "C", "D"]
96
+
97
+ # -------------------- Train one LoRA --------------------
98
+ def train_lora(model_name: str, task: str, save_dir: Path):
99
+ if save_dir.exists() and (save_dir/"adapter_model.safetensors").exists():
100
+ print(f"[SKIP] {save_dir} already exists")
101
+ return
102
+ save_dir.mkdir(parents=True, exist_ok=True)
103
+ print(f"\n=== Training LoRA: model={model_name} task={task} -> {save_dir}")
104
+ tok = AutoTokenizer.from_pretrained(model_name)
105
+ if tok.pad_token is None:
106
+ tok.pad_token = tok.eos_token
107
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation="eager")
108
+ model.config.use_cache = False
109
+ train_ds, _, _, _ = build_task(task)
110
+ lora = LoraConfig(r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=LORA_TARGETS,
111
+ lora_dropout=0.0, bias="none", task_type="CAUSAL_LM")
112
+ cfg = SFTConfig(
113
+ output_dir=str(save_dir/"_trainer"),
114
+ num_train_epochs=EPOCHS,
115
+ per_device_train_batch_size=BS,
116
+ gradient_accumulation_steps=1,
117
+ learning_rate=LR,
118
+ lr_scheduler_type="cosine",
119
+ warmup_ratio=0.05,
120
+ bf16=True,
121
+ max_seq_length=MAX_LEN,
122
+ logging_steps=25,
123
+ logging_first_step=True,
124
+ logging_strategy="steps",
125
+ disable_tqdm=True,
126
+ save_strategy="no",
127
+ report_to="none",
128
+ seed=42,
129
+ packing=False,
130
+ )
131
+ trainer = SFTTrainer(model=model, args=cfg, train_dataset=train_ds, peft_config=lora, tokenizer=tok)
132
+ trainer.train()
133
+ trainer.model.save_pretrained(str(save_dir))
134
+ tok.save_pretrained(str(save_dir))
135
+ # cleanup
136
+ shutil.rmtree(save_dir/"_trainer", ignore_errors=True)
137
+ del trainer, model
138
+ gc.collect(); torch.cuda.empty_cache()
139
+
140
+ # -------------------- Cross-model mapping --------------------
141
+ def load_adapter_state(path: Path) -> Dict[str, torch.Tensor]:
142
+ """Load LoRA state dict, kept on CPU as float32."""
143
+ from safetensors.torch import load_file
144
+ sd = load_file(str(path/"adapter_model.safetensors"))
145
+ return {k: v.float().cpu() for k, v in sd.items()}
146
+
147
+ def flatten_sd(sd: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, List[Tuple[str, torch.Size]]]:
148
+ keys_shapes = [(k, sd[k].shape) for k in sorted(sd.keys())]
149
+ flat = torch.cat([sd[k].reshape(-1) for k, _ in keys_shapes])
150
+ return flat, keys_shapes
151
+
152
+ def unflatten(flat: torch.Tensor, keys_shapes) -> Dict[str, torch.Tensor]:
153
+ out = {}
154
+ i = 0
155
+ for k, shape in keys_shapes:
156
+ n = int(np.prod(shape))
157
+ out[k] = flat[i:i+n].reshape(shape)
158
+ i += n
159
+ return out
160
+
161
+ def predict_anchor_basis(X_anchors: List[torch.Tensor], Y_anchors: List[torch.Tensor],
162
+ X_target: torch.Tensor, ridge: float = 1e-3) -> Tuple[torch.Tensor, torch.Tensor]:
163
+ """
164
+ f maps X-side -> Y-side using a paired-anchor linear basis.
165
+
166
+ Express x_target - mean(X) ≈ sum_i alpha_i (X_i - mean(X)) via ridge regression
167
+ on the small 3x3 Gram matrix; then ŷ = mean(Y) + sum_i alpha_i (Y_i - mean(Y)).
168
+
169
+ Returns (y_hat, alpha).
170
+ """
171
+ Xs = torch.stack(X_anchors) # [k, dx]
172
+ Ys = torch.stack(Y_anchors) # [k, dy]
173
+ Xm = Xs.mean(0); Ym = Ys.mean(0)
174
+ Xc = Xs - Xm # [k, dx]
175
+ Yc = Ys - Ym # [k, dy]
176
+ xc = X_target - Xm # [dx]
177
+ G = Xc @ Xc.T # [k, k]
178
+ rhs = Xc @ xc # [k]
179
+ alpha = torch.linalg.solve(G + ridge * torch.eye(G.shape[0]), rhs) # [k]
180
+ y_hat = Ym + (alpha @ Yc) # [dy]
181
+ return y_hat, alpha
182
+
183
+ # -------------------- Evaluation --------------------
184
+ @torch.no_grad()
185
+ def eval_classification(model, tok, eval_ds, labels: List[str], max_new=8) -> float:
186
+ """Greedy generation; compare first non-empty token-stripped substring against labels."""
187
+ model.eval()
188
+ correct = 0; total = 0
189
+ label_set = [l.lower() for l in labels]
190
+ bs = 16
191
+ prompts = []
192
+ golds = []
193
+ for ex in eval_ds:
194
+ msgs = ex["messages"]
195
+ gold = msgs[1]["content"].strip().lower()
196
+ # build prompt up to assistant turn
197
+ prompt = tok.apply_chat_template([msgs[0]], tokenize=False, add_generation_prompt=True)
198
+ prompts.append(prompt)
199
+ golds.append(gold)
200
+ for i in range(0, len(prompts), bs):
201
+ batch = prompts[i:i+bs]
202
+ enc = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN).to(model.device)
203
+ out = model.generate(**enc, max_new_tokens=max_new, do_sample=False, pad_token_id=tok.pad_token_id)
204
+ gen = out[:, enc["input_ids"].shape[1]:]
205
+ for j, g in enumerate(gen):
206
+ txt = tok.decode(g, skip_special_tokens=True).strip().lower()
207
+ # match longest prefix label
208
+ pred = None
209
+ for lab in sorted(label_set, key=len, reverse=True):
210
+ if txt.startswith(lab):
211
+ pred = lab; break
212
+ if pred is None:
213
+ # fallback: any label appearing
214
+ for lab in label_set:
215
+ if lab in txt: pred = lab; break
216
+ if pred == golds[i+j]:
217
+ correct += 1
218
+ total += 1
219
+ return correct / max(1,total)
220
+
221
+
222
+ # -------------------- Main --------------------
223
+ def main(stage: str = "all"):
224
+ # ---------- Stage 1+2: train all LoRAs ----------
225
+ if stage in ("all", "train"):
226
+ for t in TASKS:
227
+ train_lora(MODEL_X, t, OUT/"X"/f"X_{t}")
228
+ for t in TASKS: # train Y_D too for oracle
229
+ train_lora(MODEL_Y, t, OUT/"Y"/f"Y_{t}")
230
+
231
+ # ---------- Stage 3: build mapping + predict Y_hat_D ----------
232
+ if stage in ("all", "map"):
233
+ print("\n=== Building cross-model mapping ===")
234
+ X_states = {t: load_adapter_state(OUT/"X"/f"X_{t}") for t in TASKS}
235
+ Y_states = {t: load_adapter_state(OUT/"Y"/f"Y_{t}") for t in TASKS}
236
+
237
+ # flatten — same keys/shapes within each side
238
+ X_flat = {}; Y_flat = {}
239
+ Xks = Yks = None
240
+ for t in TASKS:
241
+ f, ks = flatten_sd(X_states[t]); X_flat[t] = f; Xks = ks
242
+ f, ks = flatten_sd(Y_states[t]); Y_flat[t] = f; Yks = ks
243
+ print("X adapter dim:", X_flat["A"].numel(), "Y adapter dim:", Y_flat["A"].numel())
244
+
245
+ # anchor-basis ridge regression mapping
246
+ Xa = [X_flat["A"], X_flat["B"], X_flat["C"]]
247
+ Ya = [Y_flat["A"], Y_flat["B"], Y_flat["C"]]
248
+ Y_hat_D, alpha = predict_anchor_basis(Xa, Ya, X_flat["D"], ridge=1e-3)
249
+ print("Anchor weights alpha (A,B,C):", alpha.tolist())
250
+ # also: mean baseline
251
+ Y_mean_ABC = torch.stack(Ya).mean(0)
252
+ # cosine sim diagnostics
253
+ def cos(a, b): return torch.nn.functional.cosine_similarity(a.flatten().unsqueeze(0), b.flatten().unsqueeze(0)).item()
254
+ print("cos(Y_hat_D, Y_D) =", cos(Y_hat_D, Y_flat["D"]))
255
+ print("cos(Y_mean_ABC, Y_D) =", cos(Y_mean_ABC, Y_flat["D"]))
256
+ print("cos(Y_A, Y_D) =", cos(Y_flat["A"], Y_flat["D"]))
257
+ print("cos(Y_B, Y_D) =", cos(Y_flat["B"], Y_flat["D"]))
258
+ print("cos(Y_C, Y_D) =", cos(Y_flat["C"], Y_flat["D"]))
259
+
260
+ # save predicted + mean adapters as standard PEFT checkpoints (clone Y_A's metadata)
261
+ from safetensors.torch import save_file
262
+ for name, flat in [("Y_pred_D", Y_hat_D), ("Y_mean_ABC", Y_mean_ABC)]:
263
+ sd = unflatten(flat, Yks)
264
+ sd_bf16 = {k: v.to(torch.bfloat16) for k, v in sd.items()}
265
+ d = OUT/"Y"/name
266
+ d.mkdir(parents=True, exist_ok=True)
267
+ # copy adapter_config and tokenizer from Y_A
268
+ shutil.copy(OUT/"Y"/"Y_A"/"adapter_config.json", d/"adapter_config.json")
269
+ for f in ["tokenizer.json","tokenizer_config.json","special_tokens_map.json"]:
270
+ src = OUT/"Y"/"Y_A"/f
271
+ if src.exists(): shutil.copy(src, d/f)
272
+ save_file(sd_bf16, str(d/"adapter_model.safetensors"))
273
+ print("Saved", d)
274
+
275
+ # save mapping diagnostics
276
+ diag = {
277
+ "alpha_ABC": alpha.tolist(),
278
+ "cos_Yhat_YD": cos(Y_hat_D, Y_flat["D"]),
279
+ "cos_Ymean_YD": cos(Y_mean_ABC, Y_flat["D"]),
280
+ "cos_YA_YD": cos(Y_flat["A"], Y_flat["D"]),
281
+ "cos_YB_YD": cos(Y_flat["B"], Y_flat["D"]),
282
+ "cos_YC_YD": cos(Y_flat["C"], Y_flat["D"]),
283
+ "X_dim": X_flat["A"].numel(),
284
+ "Y_dim": Y_flat["A"].numel(),
285
+ "ridge": 1e-3,
286
+ }
287
+ (OUT/"mapping_diagnostics.json").write_text(json.dumps(diag, indent=2))
288
+
289
+ # ---------- Stage 4: evaluate on D ----------
290
+ if stage in ("all", "eval"):
291
+ print("\n=== Evaluating on task D (Emotion) ===")
292
+ _, eval_d, labels_d, _ = build_task("D")
293
+ tok = AutoTokenizer.from_pretrained(MODEL_Y)
294
+ if tok.pad_token is None: tok.pad_token = tok.eos_token
295
+ tok.padding_side = "left"
296
+ results = {}
297
+ # Base Y
298
+ base = AutoModelForCausalLM.from_pretrained(MODEL_Y, torch_dtype=torch.bfloat16, attn_implementation="eager").cuda()
299
+ results["base_Y"] = eval_classification(base, tok, eval_d, labels_d)
300
+ print("base_Y", results["base_Y"])
301
+ del base; gc.collect(); torch.cuda.empty_cache()
302
+
303
+ # helper for adapter eval
304
+ def with_adapter(adapter_dir):
305
+ base = AutoModelForCausalLM.from_pretrained(MODEL_Y, torch_dtype=torch.bfloat16, attn_implementation="eager").cuda()
306
+ m = PeftModel.from_pretrained(base, str(adapter_dir))
307
+ acc = eval_classification(m, tok, eval_d, labels_d)
308
+ del m, base; gc.collect(); torch.cuda.empty_cache()
309
+ return acc
310
+
311
+ for name, dirname in [
312
+ ("Y_A_on_D", "Y_A"),
313
+ ("Y_B_on_D", "Y_B"),
314
+ ("Y_C_on_D", "Y_C"),
315
+ ("Y_mean_ABC_on_D", "Y_mean_ABC"),
316
+ ("Y_pred_D", "Y_pred_D"),
317
+ ("Y_oracle_D", "Y_D"),
318
+ ]:
319
+ results[name] = with_adapter(OUT/"Y"/dirname)
320
+ print(name, results[name])
321
+
322
+ # also: sanity-check Model X with X_D oracle on its own dataset
323
+ tokx = AutoTokenizer.from_pretrained(MODEL_X)
324
+ if tokx.pad_token is None: tokx.pad_token = tokx.eos_token
325
+ tokx.padding_side = "left"
326
+ basex = AutoModelForCausalLM.from_pretrained(MODEL_X, torch_dtype=torch.bfloat16, attn_implementation="eager").cuda()
327
+ results["base_X"] = eval_classification(basex, tokx, eval_d, labels_d)
328
+ del basex; gc.collect(); torch.cuda.empty_cache()
329
+ basex = AutoModelForCausalLM.from_pretrained(MODEL_X, torch_dtype=torch.bfloat16, attn_implementation="eager").cuda()
330
+ mx = PeftModel.from_pretrained(basex, str(OUT/"X"/"X_D"))
331
+ results["X_oracle_D"] = eval_classification(mx, tokx, eval_d, labels_d)
332
+ del mx, basex; gc.collect(); torch.cuda.empty_cache()
333
+
334
+ (OUT/"results.json").write_text(json.dumps(results, indent=2))
335
+ print("\n=== Results ===")
336
+ for k, v in results.items(): print(f" {k:24s} {v:.4f}")
337
+
338
+ if __name__ == "__main__":
339
+ ap = argparse.ArgumentParser()
340
+ ap.add_argument("--stage", default="all", choices=["all","train","map","eval"])
341
+ args = ap.parse_args()
342
+ main(args.stage)