| """ |
| PHASE B: scale anchors from 25 -> 60 by training 35 NEW LoRA pairs. |
| |
| Reuses existing 30 trained adapters in /app/scaled/{X,Y}/. |
| Adds 35 new anchor tasks. Held-out test tasks remain the same. |
| |
| Then runs the FULL mapping comparison sweep on: |
| N ∈ {25, 60} × methods ∈ {global_ridge, pertensor_ridge, pertensor_pca, pertensor_mlp, |
| procrustes, topk5/8/12_global_ridge, topk5/8/12_pertensor_ridge, topk12_pertensor_mlp} |
| × 5 held-out tasks |
| """ |
| import os, sys, json, gc, shutil, re, collections, time |
| from pathlib import Path |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from datasets import load_dataset |
| from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed |
| from peft import LoraConfig, PeftModel |
| from trl import SFTTrainer, SFTConfig |
| from safetensors.torch import load_file, save_file |
|
|
| sys.path.insert(0, "/app") |
| import scaled_pipeline as sp |
| import phaseA_new_methods as pa |
|
|
| set_seed(42) |
|
|
| OUT = sp.OUT |
| MODEL_X = sp.MODEL_X |
| MODEL_Y = sp.MODEL_Y |
| HELDOUT_NAMES = sp.HELDOUT_NAMES |
| EXISTING_ANCHORS = sp.ANCHOR_NAMES |
|
|
| |
| |
| NEW_ANCHORS = [ |
| ("yelp_polarity", "fancyzhx/yelp_polarity", None, "text", "label", ["negative","positive"]), |
| ("yahoo_topics", "community-datasets/yahoo_answers_topics", None, "question_title", "topic", |
| ["society","science","health","education","computers","sports","business","entertainment","family","politics"]), |
| ("setfit_qnli", "SetFit/qnli", None, ("text1","text2"), "label", ["entailment","not entailment"]), |
| ("setfit_mnli", "SetFit/mnli", None, ("text1","text2"), "label", ["entailment","neutral","contradiction"]), |
| ("setfit_rte", "SetFit/rte", None, ("text1","text2"), "label", ["entailment","not entailment"]), |
| ("setfit_mrpc", "SetFit/mrpc", None, ("text1","text2"), "label", ["different","paraphrase"]), |
| ("setfit_qqp", "SetFit/qqp", None, ("text1","text2"), "label", ["different","duplicate"]), |
| ("snli_main", "stanfordnlp/snli", None, ("premise","hypothesis"), "label", ["entailment","neutral","contradiction"]), |
| ("paws", "google-research-datasets/paws", "labeled_final", ("sentence1","sentence2"), "label", ["different","paraphrase"]), |
| ("mteb_emo", "mteb/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]), |
| ("mteb_tweet_sent","mteb/tweet_sentiment_extraction", None, "text", "label", ["negative","neutral","positive"]), |
| ("mteb_toxic_conv","mteb/toxic_conversations_50k", None, "text", "label", ["non-toxic","toxic"]), |
| ("mteb_amazon_cf", "mteb/amazon_counterfactual", "en", "text", "label", ["not counterfactual","counterfactual"]), |
| ("dair_emo_unsplit","dair-ai/emotion", "unsplit", "text", "label", ["sadness","joy","love","anger","fear","surprise"]), |
| ("setfit_yelp_full","SetFit/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]), |
| ("setfit_emotion", "SetFit/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]), |
| ("setfit_sst5_alt","SetFit/sst5", None, "text", "label", ["very negative","negative","neutral","positive","very positive"]), |
| ("setfit_student_qcat","SetFit/student-question-categories", None, "text", "label", ["q0","q1","q2","q3"]), |
| ("setfit_movie_reviews","SetFit/sst2", None, "text", "label", ["negative","positive"]), |
| ("sms_spam", "ucirvine/sms_spam", None, "sms", "label", ["ham","spam"]), |
| ("snips", "snips_built_in_intents", None, "text", "label", |
| ["compare-places","request-ride","get-weather","search-place","get-place-details","share-current-location","get-traffic-information","book-restaurant","get-directions","share-eta"]), |
| ("toxic_chat", "lmsys/toxic-chat", "toxicchat0124", "user_input", "toxicity", ["non-toxic","toxic"]), |
| ("hate_offensive_lang","tdavidson/hate_speech_offensive", None, "tweet", "class", ["hate","offensive","neither"]), |
| ("amazon_massive_scenario","AmazonScience/massive", "en-US", "utt", "scenario", |
| ["social","transport","calendar","play","news","datetime","recommendation","email","iot","general","audio","lists","qa","cooking","takeaway","music","alarm","weather"]), |
| ("trec_fine", "CogComp/trec", None, "text", "fine_label", None), |
| ("ag_setfit", "SetFit/ag_news", None, "text", "label", ["world","sports","business","sci/tech"]), |
| ("cardiffnlp_topic","cardiffnlp/tweet_topic_single", None, "text", "label", |
| ["arts_culture","business_entrepreneurs","celebrity_pop_culture","diaries_daily_life","family","fashion_style","film_tv_video","fitness_health","food_dining","gaming","learning_educational","music","news_social_concern","other_hobbies","relationships","science_technology","sports","travel_adventure","youth_student_life"]), |
| |
| ("financial_pb_pos","takala/financial_phrasebank", "sentences_50agree", "sentence", "label", ["negative","neutral","positive"]), |
| ("clinc_small_skip","clinc_oos", "small", "text", "intent", None), |
| ("hate_speech18", "hate_speech18", None, "text", "label", ["no_hate","hate","idk","relation"]), |
| ("rotten_alt", "cornell-movie-review-data/rotten_tomatoes", None, "text", "label", ["negative","positive"]), |
| ("yelp_polarity_test","Yelp/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]), |
| ("dynasent_r1", "dynabench/dynasent", "dynabench.dynasent.r1.all", "sentence", "gold_label", ["negative","neutral","positive"]), |
| ("dynasent_r2", "dynabench/dynasent", "dynabench.dynasent.r2.all", "sentence", "gold_label", ["negative","neutral","positive"]), |
| ] |
|
|
| NEW_ANCHOR_NAMES = [t[0] for t in NEW_ANCHORS] |
|
|
| |
| def build_task_extra(name, n_train=800, n_eval=300): |
| |
| if name in sp.TASKS_BY_NAME: |
| return sp.build_task(name, n_train, n_eval) |
| |
| spec = next((t for t in NEW_ANCHORS if t[0] == name), None) |
| if spec is None: raise KeyError(name) |
| name, hf, cfg, txt, lab, labels = spec |
| if cfg: ds = load_dataset(hf, cfg) |
| else: ds = load_dataset(hf) |
| train_split = "train" |
| eval_split = "validation" if "validation" in ds else "test" if "test" in ds else "train" |
| label_join = ", ".join(f"'{l}'" for l in labels) |
| instr = f"Classify the following text. Choose one of: {label_join}. Respond with just the label.\n\nText: " |
| def fmt(p, t): return [{"role":"user","content":p},{"role":"assistant","content":t}] |
| def build_text(r): |
| if isinstance(txt, tuple): |
| return f"{r[txt[0]]} [SEP] {r[txt[1]]}" |
| return str(r[txt]) |
| def to_msg(r): |
| text = build_text(r).strip().replace("\n"," ")[:600] |
| try: lid = int(r[lab]) |
| except: lid = 0 |
| if lid < 0 or lid >= len(labels): return None |
| return {"messages": fmt(instr + text + "\n\nLabel:", labels[lid])} |
| def map_filter(splitname, n): |
| s = ds[splitname] |
| s = s.shuffle(seed=0).select(range(min(n*3, len(s)))) |
| s = s.map(to_msg, remove_columns=s.column_names) |
| s = s.filter(lambda r: r is not None and r["messages"] is not None) |
| return s.select(range(min(n, len(s)))) |
| train = map_filter(train_split, n_train) |
| if eval_split == "train": |
| ev = ds[train_split].shuffle(seed=1).select(range(n_train, min(n_train+n_eval*3, len(ds[train_split])))) |
| ev = ev.map(to_msg, remove_columns=ev.column_names).filter(lambda r: r is not None and r["messages"] is not None) |
| ev = ev.select(range(min(n_eval, len(ev)))) |
| else: |
| s = ds[eval_split].shuffle(seed=0).select(range(min(n_eval*3, len(ds[eval_split])))) |
| ev = s.map(to_msg, remove_columns=s.column_names).filter(lambda r: r is not None and r["messages"] is not None) |
| ev = ev.select(range(min(n_eval, len(ev)))) |
| return train, ev, labels |
|
|
| |
| def train_lora_extra(model_name, task, save_dir): |
| if save_dir.exists() and (save_dir/"adapter_model.safetensors").exists(): |
| return |
| save_dir.mkdir(parents=True, exist_ok=True) |
| print(f"[TRAIN] {model_name.split('/')[-1]} / {task}", flush=True) |
| tok = AutoTokenizer.from_pretrained(model_name) |
| if tok.pad_token is None: tok.pad_token = tok.eos_token |
| model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation="eager") |
| model.config.use_cache = False |
| train_ds, _, _ = build_task_extra(task, n_train=sp.TRAIN_PER_TASK, n_eval=sp.EVAL_PER_TASK) |
| if len(train_ds) < 50: |
| print(f" skipping {task}: only {len(train_ds)} train examples") |
| save_dir.rmdir(); return |
| lora = LoraConfig(r=sp.LORA_R, lora_alpha=sp.LORA_ALPHA, target_modules=sp.LORA_TARGETS, |
| lora_dropout=0.0, bias="none", task_type="CAUSAL_LM") |
| cfg = SFTConfig( |
| output_dir=str(save_dir/"_t"), num_train_epochs=sp.EPOCHS, |
| per_device_train_batch_size=sp.BS, learning_rate=sp.LR, lr_scheduler_type="cosine", |
| warmup_ratio=0.05, bf16=True, max_seq_length=sp.MAX_LEN, |
| logging_steps=50, logging_first_step=True, disable_tqdm=True, |
| save_strategy="no", report_to="none", seed=42, packing=False, |
| ) |
| trainer = SFTTrainer(model=model, args=cfg, train_dataset=train_ds, peft_config=lora, tokenizer=tok) |
| trainer.train() |
| trainer.model.save_pretrained(str(save_dir)) |
| tok.save_pretrained(str(save_dir)) |
| shutil.rmtree(save_dir/"_t", ignore_errors=True) |
| del trainer, model; gc.collect(); torch.cuda.empty_cache() |
|
|
| def main(): |
| |
| failed = [] |
| for t_name in NEW_ANCHOR_NAMES: |
| try: |
| train_lora_extra(MODEL_X, t_name, OUT/"X"/t_name) |
| train_lora_extra(MODEL_Y, t_name, OUT/"Y"/t_name) |
| except Exception as e: |
| print(f"FAIL {t_name}: {e}") |
| failed.append(t_name) |
| |
| for side in ("X","Y"): |
| d = OUT/side/t_name |
| if d.exists() and not (d/"adapter_model.safetensors").exists(): |
| shutil.rmtree(d, ignore_errors=True) |
| print("\n\nFAILED:", failed) |
| |
| successful_new = [n for n in NEW_ANCHOR_NAMES |
| if (OUT/"X"/n/"adapter_model.safetensors").exists() |
| and (OUT/"Y"/n/"adapter_model.safetensors").exists()] |
| print(f"Successfully trained {len(successful_new)} new anchors") |
| (OUT/"new_anchors.json").write_text(json.dumps(successful_new)) |
|
|
| if __name__ == "__main__": |
| main() |
|
|