File size: 10,898 Bytes
290787e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """
PHASE B: scale anchors from 25 -> 60 by training 35 NEW LoRA pairs.
Reuses existing 30 trained adapters in /app/scaled/{X,Y}/.
Adds 35 new anchor tasks. Held-out test tasks remain the same.
Then runs the FULL mapping comparison sweep on:
N ∈ {25, 60} × methods ∈ {global_ridge, pertensor_ridge, pertensor_pca, pertensor_mlp,
procrustes, topk5/8/12_global_ridge, topk5/8/12_pertensor_ridge, topk12_pertensor_mlp}
× 5 held-out tasks
"""
import os, sys, json, gc, shutil, re, collections, time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig
from safetensors.torch import load_file, save_file
sys.path.insert(0, "/app")
import scaled_pipeline as sp
import phaseA_new_methods as pa
set_seed(42)
OUT = sp.OUT
MODEL_X = sp.MODEL_X
MODEL_Y = sp.MODEL_Y
HELDOUT_NAMES = sp.HELDOUT_NAMES
EXISTING_ANCHORS = sp.ANCHOR_NAMES # 25 already-trained
# 35 new anchors (verified) -- avoid duplicates with existing/heldout
# Format: (name, hf_id, config, text_col_or_tuple, label_col, label_names)
NEW_ANCHORS = [
("yelp_polarity", "fancyzhx/yelp_polarity", None, "text", "label", ["negative","positive"]),
("yahoo_topics", "community-datasets/yahoo_answers_topics", None, "question_title", "topic",
["society","science","health","education","computers","sports","business","entertainment","family","politics"]),
("setfit_qnli", "SetFit/qnli", None, ("text1","text2"), "label", ["entailment","not entailment"]),
("setfit_mnli", "SetFit/mnli", None, ("text1","text2"), "label", ["entailment","neutral","contradiction"]),
("setfit_rte", "SetFit/rte", None, ("text1","text2"), "label", ["entailment","not entailment"]),
("setfit_mrpc", "SetFit/mrpc", None, ("text1","text2"), "label", ["different","paraphrase"]),
("setfit_qqp", "SetFit/qqp", None, ("text1","text2"), "label", ["different","duplicate"]),
("snli_main", "stanfordnlp/snli", None, ("premise","hypothesis"), "label", ["entailment","neutral","contradiction"]),
("paws", "google-research-datasets/paws", "labeled_final", ("sentence1","sentence2"), "label", ["different","paraphrase"]),
("mteb_emo", "mteb/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]),
("mteb_tweet_sent","mteb/tweet_sentiment_extraction", None, "text", "label", ["negative","neutral","positive"]),
("mteb_toxic_conv","mteb/toxic_conversations_50k", None, "text", "label", ["non-toxic","toxic"]),
("mteb_amazon_cf", "mteb/amazon_counterfactual", "en", "text", "label", ["not counterfactual","counterfactual"]),
("dair_emo_unsplit","dair-ai/emotion", "unsplit", "text", "label", ["sadness","joy","love","anger","fear","surprise"]),
("setfit_yelp_full","SetFit/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]),
("setfit_emotion", "SetFit/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]),
("setfit_sst5_alt","SetFit/sst5", None, "text", "label", ["very negative","negative","neutral","positive","very positive"]),
("setfit_student_qcat","SetFit/student-question-categories", None, "text", "label", ["q0","q1","q2","q3"]),
("setfit_movie_reviews","SetFit/sst2", None, "text", "label", ["negative","positive"]),
("sms_spam", "ucirvine/sms_spam", None, "sms", "label", ["ham","spam"]),
("snips", "snips_built_in_intents", None, "text", "label",
["compare-places","request-ride","get-weather","search-place","get-place-details","share-current-location","get-traffic-information","book-restaurant","get-directions","share-eta"]),
("toxic_chat", "lmsys/toxic-chat", "toxicchat0124", "user_input", "toxicity", ["non-toxic","toxic"]),
("hate_offensive_lang","tdavidson/hate_speech_offensive", None, "tweet", "class", ["hate","offensive","neither"]),
("amazon_massive_scenario","AmazonScience/massive", "en-US", "utt", "scenario",
["social","transport","calendar","play","news","datetime","recommendation","email","iot","general","audio","lists","qa","cooking","takeaway","music","alarm","weather"]),
("trec_fine", "CogComp/trec", None, "text", "fine_label", None), # too many classes -> set in build
("ag_setfit", "SetFit/ag_news", None, "text", "label", ["world","sports","business","sci/tech"]),
("cardiffnlp_topic","cardiffnlp/tweet_topic_single", None, "text", "label",
["arts_culture","business_entrepreneurs","celebrity_pop_culture","diaries_daily_life","family","fashion_style","film_tv_video","fitness_health","food_dining","gaming","learning_educational","music","news_social_concern","other_hobbies","relationships","science_technology","sports","travel_adventure","youth_student_life"]),
# 8 more
("financial_pb_pos","takala/financial_phrasebank", "sentences_50agree", "sentence", "label", ["negative","neutral","positive"]),
("clinc_small_skip","clinc_oos", "small", "text", "intent", None), # placeholder, may skip if too big
("hate_speech18", "hate_speech18", None, "text", "label", ["no_hate","hate","idk","relation"]),
("rotten_alt", "cornell-movie-review-data/rotten_tomatoes", None, "text", "label", ["negative","positive"]),
("yelp_polarity_test","Yelp/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]),
("dynasent_r1", "dynabench/dynasent", "dynabench.dynasent.r1.all", "sentence", "gold_label", ["negative","neutral","positive"]),
("dynasent_r2", "dynabench/dynasent", "dynabench.dynasent.r2.all", "sentence", "gold_label", ["negative","neutral","positive"]),
]
NEW_ANCHOR_NAMES = [t[0] for t in NEW_ANCHORS]
# Patch sp.build_task to support these new specs
def build_task_extra(name, n_train=800, n_eval=300):
# First try original
if name in sp.TASKS_BY_NAME:
return sp.build_task(name, n_train, n_eval)
# Else find in NEW_ANCHORS
spec = next((t for t in NEW_ANCHORS if t[0] == name), None)
if spec is None: raise KeyError(name)
name, hf, cfg, txt, lab, labels = spec
if cfg: ds = load_dataset(hf, cfg)
else: ds = load_dataset(hf)
train_split = "train"
eval_split = "validation" if "validation" in ds else "test" if "test" in ds else "train"
label_join = ", ".join(f"'{l}'" for l in labels)
instr = f"Classify the following text. Choose one of: {label_join}. Respond with just the label.\n\nText: "
def fmt(p, t): return [{"role":"user","content":p},{"role":"assistant","content":t}]
def build_text(r):
if isinstance(txt, tuple):
return f"{r[txt[0]]} [SEP] {r[txt[1]]}"
return str(r[txt])
def to_msg(r):
text = build_text(r).strip().replace("\n"," ")[:600]
try: lid = int(r[lab])
except: lid = 0
if lid < 0 or lid >= len(labels): return None
return {"messages": fmt(instr + text + "\n\nLabel:", labels[lid])}
def map_filter(splitname, n):
s = ds[splitname]
s = s.shuffle(seed=0).select(range(min(n*3, len(s)))) # over-select to allow filtering -1
s = s.map(to_msg, remove_columns=s.column_names)
s = s.filter(lambda r: r is not None and r["messages"] is not None)
return s.select(range(min(n, len(s))))
train = map_filter(train_split, n_train)
if eval_split == "train":
ev = ds[train_split].shuffle(seed=1).select(range(n_train, min(n_train+n_eval*3, len(ds[train_split]))))
ev = ev.map(to_msg, remove_columns=ev.column_names).filter(lambda r: r is not None and r["messages"] is not None)
ev = ev.select(range(min(n_eval, len(ev))))
else:
s = ds[eval_split].shuffle(seed=0).select(range(min(n_eval*3, len(ds[eval_split]))))
ev = s.map(to_msg, remove_columns=s.column_names).filter(lambda r: r is not None and r["messages"] is not None)
ev = ev.select(range(min(n_eval, len(ev))))
return train, ev, labels
# Monkey-patch for sp.train_lora to use extra builder
def train_lora_extra(model_name, task, save_dir):
if save_dir.exists() and (save_dir/"adapter_model.safetensors").exists():
return
save_dir.mkdir(parents=True, exist_ok=True)
print(f"[TRAIN] {model_name.split('/')[-1]} / {task}", flush=True)
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token is None: tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation="eager")
model.config.use_cache = False
train_ds, _, _ = build_task_extra(task, n_train=sp.TRAIN_PER_TASK, n_eval=sp.EVAL_PER_TASK)
if len(train_ds) < 50:
print(f" skipping {task}: only {len(train_ds)} train examples")
save_dir.rmdir(); return
lora = LoraConfig(r=sp.LORA_R, lora_alpha=sp.LORA_ALPHA, target_modules=sp.LORA_TARGETS,
lora_dropout=0.0, bias="none", task_type="CAUSAL_LM")
cfg = SFTConfig(
output_dir=str(save_dir/"_t"), num_train_epochs=sp.EPOCHS,
per_device_train_batch_size=sp.BS, learning_rate=sp.LR, lr_scheduler_type="cosine",
warmup_ratio=0.05, bf16=True, max_seq_length=sp.MAX_LEN,
logging_steps=50, logging_first_step=True, disable_tqdm=True,
save_strategy="no", report_to="none", seed=42, packing=False,
)
trainer = SFTTrainer(model=model, args=cfg, train_dataset=train_ds, peft_config=lora, tokenizer=tok)
trainer.train()
trainer.model.save_pretrained(str(save_dir))
tok.save_pretrained(str(save_dir))
shutil.rmtree(save_dir/"_t", ignore_errors=True)
del trainer, model; gc.collect(); torch.cuda.empty_cache()
def main():
# Phase: train new anchors only
failed = []
for t_name in NEW_ANCHOR_NAMES:
try:
train_lora_extra(MODEL_X, t_name, OUT/"X"/t_name)
train_lora_extra(MODEL_Y, t_name, OUT/"Y"/t_name)
except Exception as e:
print(f"FAIL {t_name}: {e}")
failed.append(t_name)
# remove partial dir
for side in ("X","Y"):
d = OUT/side/t_name
if d.exists() and not (d/"adapter_model.safetensors").exists():
shutil.rmtree(d, ignore_errors=True)
print("\n\nFAILED:", failed)
# Save list of successful new anchors
successful_new = [n for n in NEW_ANCHOR_NAMES
if (OUT/"X"/n/"adapter_model.safetensors").exists()
and (OUT/"Y"/n/"adapter_model.safetensors").exists()]
print(f"Successfully trained {len(successful_new)} new anchors")
(OUT/"new_anchors.json").write_text(json.dumps(successful_new))
if __name__ == "__main__":
main()
|