cross-model-lora-prediction / phaseB_train.py
Samarth0710's picture
Upload phaseB_train.py with huggingface_hub
290787e verified
"""
PHASE B: scale anchors from 25 -> 60 by training 35 NEW LoRA pairs.
Reuses existing 30 trained adapters in /app/scaled/{X,Y}/.
Adds 35 new anchor tasks. Held-out test tasks remain the same.
Then runs the FULL mapping comparison sweep on:
N ∈ {25, 60} × methods ∈ {global_ridge, pertensor_ridge, pertensor_pca, pertensor_mlp,
procrustes, topk5/8/12_global_ridge, topk5/8/12_pertensor_ridge, topk12_pertensor_mlp}
× 5 held-out tasks
"""
import os, sys, json, gc, shutil, re, collections, time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig
from safetensors.torch import load_file, save_file
sys.path.insert(0, "/app")
import scaled_pipeline as sp
import phaseA_new_methods as pa
set_seed(42)
OUT = sp.OUT
MODEL_X = sp.MODEL_X
MODEL_Y = sp.MODEL_Y
HELDOUT_NAMES = sp.HELDOUT_NAMES
EXISTING_ANCHORS = sp.ANCHOR_NAMES # 25 already-trained
# 35 new anchors (verified) -- avoid duplicates with existing/heldout
# Format: (name, hf_id, config, text_col_or_tuple, label_col, label_names)
NEW_ANCHORS = [
("yelp_polarity", "fancyzhx/yelp_polarity", None, "text", "label", ["negative","positive"]),
("yahoo_topics", "community-datasets/yahoo_answers_topics", None, "question_title", "topic",
["society","science","health","education","computers","sports","business","entertainment","family","politics"]),
("setfit_qnli", "SetFit/qnli", None, ("text1","text2"), "label", ["entailment","not entailment"]),
("setfit_mnli", "SetFit/mnli", None, ("text1","text2"), "label", ["entailment","neutral","contradiction"]),
("setfit_rte", "SetFit/rte", None, ("text1","text2"), "label", ["entailment","not entailment"]),
("setfit_mrpc", "SetFit/mrpc", None, ("text1","text2"), "label", ["different","paraphrase"]),
("setfit_qqp", "SetFit/qqp", None, ("text1","text2"), "label", ["different","duplicate"]),
("snli_main", "stanfordnlp/snli", None, ("premise","hypothesis"), "label", ["entailment","neutral","contradiction"]),
("paws", "google-research-datasets/paws", "labeled_final", ("sentence1","sentence2"), "label", ["different","paraphrase"]),
("mteb_emo", "mteb/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]),
("mteb_tweet_sent","mteb/tweet_sentiment_extraction", None, "text", "label", ["negative","neutral","positive"]),
("mteb_toxic_conv","mteb/toxic_conversations_50k", None, "text", "label", ["non-toxic","toxic"]),
("mteb_amazon_cf", "mteb/amazon_counterfactual", "en", "text", "label", ["not counterfactual","counterfactual"]),
("dair_emo_unsplit","dair-ai/emotion", "unsplit", "text", "label", ["sadness","joy","love","anger","fear","surprise"]),
("setfit_yelp_full","SetFit/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]),
("setfit_emotion", "SetFit/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]),
("setfit_sst5_alt","SetFit/sst5", None, "text", "label", ["very negative","negative","neutral","positive","very positive"]),
("setfit_student_qcat","SetFit/student-question-categories", None, "text", "label", ["q0","q1","q2","q3"]),
("setfit_movie_reviews","SetFit/sst2", None, "text", "label", ["negative","positive"]),
("sms_spam", "ucirvine/sms_spam", None, "sms", "label", ["ham","spam"]),
("snips", "snips_built_in_intents", None, "text", "label",
["compare-places","request-ride","get-weather","search-place","get-place-details","share-current-location","get-traffic-information","book-restaurant","get-directions","share-eta"]),
("toxic_chat", "lmsys/toxic-chat", "toxicchat0124", "user_input", "toxicity", ["non-toxic","toxic"]),
("hate_offensive_lang","tdavidson/hate_speech_offensive", None, "tweet", "class", ["hate","offensive","neither"]),
("amazon_massive_scenario","AmazonScience/massive", "en-US", "utt", "scenario",
["social","transport","calendar","play","news","datetime","recommendation","email","iot","general","audio","lists","qa","cooking","takeaway","music","alarm","weather"]),
("trec_fine", "CogComp/trec", None, "text", "fine_label", None), # too many classes -> set in build
("ag_setfit", "SetFit/ag_news", None, "text", "label", ["world","sports","business","sci/tech"]),
("cardiffnlp_topic","cardiffnlp/tweet_topic_single", None, "text", "label",
["arts_culture","business_entrepreneurs","celebrity_pop_culture","diaries_daily_life","family","fashion_style","film_tv_video","fitness_health","food_dining","gaming","learning_educational","music","news_social_concern","other_hobbies","relationships","science_technology","sports","travel_adventure","youth_student_life"]),
# 8 more
("financial_pb_pos","takala/financial_phrasebank", "sentences_50agree", "sentence", "label", ["negative","neutral","positive"]),
("clinc_small_skip","clinc_oos", "small", "text", "intent", None), # placeholder, may skip if too big
("hate_speech18", "hate_speech18", None, "text", "label", ["no_hate","hate","idk","relation"]),
("rotten_alt", "cornell-movie-review-data/rotten_tomatoes", None, "text", "label", ["negative","positive"]),
("yelp_polarity_test","Yelp/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]),
("dynasent_r1", "dynabench/dynasent", "dynabench.dynasent.r1.all", "sentence", "gold_label", ["negative","neutral","positive"]),
("dynasent_r2", "dynabench/dynasent", "dynabench.dynasent.r2.all", "sentence", "gold_label", ["negative","neutral","positive"]),
]
NEW_ANCHOR_NAMES = [t[0] for t in NEW_ANCHORS]
# Patch sp.build_task to support these new specs
def build_task_extra(name, n_train=800, n_eval=300):
# First try original
if name in sp.TASKS_BY_NAME:
return sp.build_task(name, n_train, n_eval)
# Else find in NEW_ANCHORS
spec = next((t for t in NEW_ANCHORS if t[0] == name), None)
if spec is None: raise KeyError(name)
name, hf, cfg, txt, lab, labels = spec
if cfg: ds = load_dataset(hf, cfg)
else: ds = load_dataset(hf)
train_split = "train"
eval_split = "validation" if "validation" in ds else "test" if "test" in ds else "train"
label_join = ", ".join(f"'{l}'" for l in labels)
instr = f"Classify the following text. Choose one of: {label_join}. Respond with just the label.\n\nText: "
def fmt(p, t): return [{"role":"user","content":p},{"role":"assistant","content":t}]
def build_text(r):
if isinstance(txt, tuple):
return f"{r[txt[0]]} [SEP] {r[txt[1]]}"
return str(r[txt])
def to_msg(r):
text = build_text(r).strip().replace("\n"," ")[:600]
try: lid = int(r[lab])
except: lid = 0
if lid < 0 or lid >= len(labels): return None
return {"messages": fmt(instr + text + "\n\nLabel:", labels[lid])}
def map_filter(splitname, n):
s = ds[splitname]
s = s.shuffle(seed=0).select(range(min(n*3, len(s)))) # over-select to allow filtering -1
s = s.map(to_msg, remove_columns=s.column_names)
s = s.filter(lambda r: r is not None and r["messages"] is not None)
return s.select(range(min(n, len(s))))
train = map_filter(train_split, n_train)
if eval_split == "train":
ev = ds[train_split].shuffle(seed=1).select(range(n_train, min(n_train+n_eval*3, len(ds[train_split]))))
ev = ev.map(to_msg, remove_columns=ev.column_names).filter(lambda r: r is not None and r["messages"] is not None)
ev = ev.select(range(min(n_eval, len(ev))))
else:
s = ds[eval_split].shuffle(seed=0).select(range(min(n_eval*3, len(ds[eval_split]))))
ev = s.map(to_msg, remove_columns=s.column_names).filter(lambda r: r is not None and r["messages"] is not None)
ev = ev.select(range(min(n_eval, len(ev))))
return train, ev, labels
# Monkey-patch for sp.train_lora to use extra builder
def train_lora_extra(model_name, task, save_dir):
if save_dir.exists() and (save_dir/"adapter_model.safetensors").exists():
return
save_dir.mkdir(parents=True, exist_ok=True)
print(f"[TRAIN] {model_name.split('/')[-1]} / {task}", flush=True)
tok = AutoTokenizer.from_pretrained(model_name)
if tok.pad_token is None: tok.pad_token = tok.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation="eager")
model.config.use_cache = False
train_ds, _, _ = build_task_extra(task, n_train=sp.TRAIN_PER_TASK, n_eval=sp.EVAL_PER_TASK)
if len(train_ds) < 50:
print(f" skipping {task}: only {len(train_ds)} train examples")
save_dir.rmdir(); return
lora = LoraConfig(r=sp.LORA_R, lora_alpha=sp.LORA_ALPHA, target_modules=sp.LORA_TARGETS,
lora_dropout=0.0, bias="none", task_type="CAUSAL_LM")
cfg = SFTConfig(
output_dir=str(save_dir/"_t"), num_train_epochs=sp.EPOCHS,
per_device_train_batch_size=sp.BS, learning_rate=sp.LR, lr_scheduler_type="cosine",
warmup_ratio=0.05, bf16=True, max_seq_length=sp.MAX_LEN,
logging_steps=50, logging_first_step=True, disable_tqdm=True,
save_strategy="no", report_to="none", seed=42, packing=False,
)
trainer = SFTTrainer(model=model, args=cfg, train_dataset=train_ds, peft_config=lora, tokenizer=tok)
trainer.train()
trainer.model.save_pretrained(str(save_dir))
tok.save_pretrained(str(save_dir))
shutil.rmtree(save_dir/"_t", ignore_errors=True)
del trainer, model; gc.collect(); torch.cuda.empty_cache()
def main():
# Phase: train new anchors only
failed = []
for t_name in NEW_ANCHOR_NAMES:
try:
train_lora_extra(MODEL_X, t_name, OUT/"X"/t_name)
train_lora_extra(MODEL_Y, t_name, OUT/"Y"/t_name)
except Exception as e:
print(f"FAIL {t_name}: {e}")
failed.append(t_name)
# remove partial dir
for side in ("X","Y"):
d = OUT/side/t_name
if d.exists() and not (d/"adapter_model.safetensors").exists():
shutil.rmtree(d, ignore_errors=True)
print("\n\nFAILED:", failed)
# Save list of successful new anchors
successful_new = [n for n in NEW_ANCHOR_NAMES
if (OUT/"X"/n/"adapter_model.safetensors").exists()
and (OUT/"Y"/n/"adapter_model.safetensors").exists()]
print(f"Successfully trained {len(successful_new)} new anchors")
(OUT/"new_anchors.json").write_text(json.dumps(successful_new))
if __name__ == "__main__":
main()