""" PHASE B: scale anchors from 25 -> 60 by training 35 NEW LoRA pairs. Reuses existing 30 trained adapters in /app/scaled/{X,Y}/. Adds 35 new anchor tasks. Held-out test tasks remain the same. Then runs the FULL mapping comparison sweep on: N ∈ {25, 60} × methods ∈ {global_ridge, pertensor_ridge, pertensor_pca, pertensor_mlp, procrustes, topk5/8/12_global_ridge, topk5/8/12_pertensor_ridge, topk12_pertensor_mlp} × 5 held-out tasks """ import os, sys, json, gc, shutil, re, collections, time from pathlib import Path import numpy as np import torch import torch.nn.functional as F from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed from peft import LoraConfig, PeftModel from trl import SFTTrainer, SFTConfig from safetensors.torch import load_file, save_file sys.path.insert(0, "/app") import scaled_pipeline as sp import phaseA_new_methods as pa set_seed(42) OUT = sp.OUT MODEL_X = sp.MODEL_X MODEL_Y = sp.MODEL_Y HELDOUT_NAMES = sp.HELDOUT_NAMES EXISTING_ANCHORS = sp.ANCHOR_NAMES # 25 already-trained # 35 new anchors (verified) -- avoid duplicates with existing/heldout # Format: (name, hf_id, config, text_col_or_tuple, label_col, label_names) NEW_ANCHORS = [ ("yelp_polarity", "fancyzhx/yelp_polarity", None, "text", "label", ["negative","positive"]), ("yahoo_topics", "community-datasets/yahoo_answers_topics", None, "question_title", "topic", ["society","science","health","education","computers","sports","business","entertainment","family","politics"]), ("setfit_qnli", "SetFit/qnli", None, ("text1","text2"), "label", ["entailment","not entailment"]), ("setfit_mnli", "SetFit/mnli", None, ("text1","text2"), "label", ["entailment","neutral","contradiction"]), ("setfit_rte", "SetFit/rte", None, ("text1","text2"), "label", ["entailment","not entailment"]), ("setfit_mrpc", "SetFit/mrpc", None, ("text1","text2"), "label", ["different","paraphrase"]), ("setfit_qqp", "SetFit/qqp", None, ("text1","text2"), "label", ["different","duplicate"]), ("snli_main", "stanfordnlp/snli", None, ("premise","hypothesis"), "label", ["entailment","neutral","contradiction"]), ("paws", "google-research-datasets/paws", "labeled_final", ("sentence1","sentence2"), "label", ["different","paraphrase"]), ("mteb_emo", "mteb/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]), ("mteb_tweet_sent","mteb/tweet_sentiment_extraction", None, "text", "label", ["negative","neutral","positive"]), ("mteb_toxic_conv","mteb/toxic_conversations_50k", None, "text", "label", ["non-toxic","toxic"]), ("mteb_amazon_cf", "mteb/amazon_counterfactual", "en", "text", "label", ["not counterfactual","counterfactual"]), ("dair_emo_unsplit","dair-ai/emotion", "unsplit", "text", "label", ["sadness","joy","love","anger","fear","surprise"]), ("setfit_yelp_full","SetFit/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]), ("setfit_emotion", "SetFit/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]), ("setfit_sst5_alt","SetFit/sst5", None, "text", "label", ["very negative","negative","neutral","positive","very positive"]), ("setfit_student_qcat","SetFit/student-question-categories", None, "text", "label", ["q0","q1","q2","q3"]), ("setfit_movie_reviews","SetFit/sst2", None, "text", "label", ["negative","positive"]), ("sms_spam", "ucirvine/sms_spam", None, "sms", "label", ["ham","spam"]), ("snips", "snips_built_in_intents", None, "text", "label", ["compare-places","request-ride","get-weather","search-place","get-place-details","share-current-location","get-traffic-information","book-restaurant","get-directions","share-eta"]), ("toxic_chat", "lmsys/toxic-chat", "toxicchat0124", "user_input", "toxicity", ["non-toxic","toxic"]), ("hate_offensive_lang","tdavidson/hate_speech_offensive", None, "tweet", "class", ["hate","offensive","neither"]), ("amazon_massive_scenario","AmazonScience/massive", "en-US", "utt", "scenario", ["social","transport","calendar","play","news","datetime","recommendation","email","iot","general","audio","lists","qa","cooking","takeaway","music","alarm","weather"]), ("trec_fine", "CogComp/trec", None, "text", "fine_label", None), # too many classes -> set in build ("ag_setfit", "SetFit/ag_news", None, "text", "label", ["world","sports","business","sci/tech"]), ("cardiffnlp_topic","cardiffnlp/tweet_topic_single", None, "text", "label", ["arts_culture","business_entrepreneurs","celebrity_pop_culture","diaries_daily_life","family","fashion_style","film_tv_video","fitness_health","food_dining","gaming","learning_educational","music","news_social_concern","other_hobbies","relationships","science_technology","sports","travel_adventure","youth_student_life"]), # 8 more ("financial_pb_pos","takala/financial_phrasebank", "sentences_50agree", "sentence", "label", ["negative","neutral","positive"]), ("clinc_small_skip","clinc_oos", "small", "text", "intent", None), # placeholder, may skip if too big ("hate_speech18", "hate_speech18", None, "text", "label", ["no_hate","hate","idk","relation"]), ("rotten_alt", "cornell-movie-review-data/rotten_tomatoes", None, "text", "label", ["negative","positive"]), ("yelp_polarity_test","Yelp/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]), ("dynasent_r1", "dynabench/dynasent", "dynabench.dynasent.r1.all", "sentence", "gold_label", ["negative","neutral","positive"]), ("dynasent_r2", "dynabench/dynasent", "dynabench.dynasent.r2.all", "sentence", "gold_label", ["negative","neutral","positive"]), ] NEW_ANCHOR_NAMES = [t[0] for t in NEW_ANCHORS] # Patch sp.build_task to support these new specs def build_task_extra(name, n_train=800, n_eval=300): # First try original if name in sp.TASKS_BY_NAME: return sp.build_task(name, n_train, n_eval) # Else find in NEW_ANCHORS spec = next((t for t in NEW_ANCHORS if t[0] == name), None) if spec is None: raise KeyError(name) name, hf, cfg, txt, lab, labels = spec if cfg: ds = load_dataset(hf, cfg) else: ds = load_dataset(hf) train_split = "train" eval_split = "validation" if "validation" in ds else "test" if "test" in ds else "train" label_join = ", ".join(f"'{l}'" for l in labels) instr = f"Classify the following text. Choose one of: {label_join}. Respond with just the label.\n\nText: " def fmt(p, t): return [{"role":"user","content":p},{"role":"assistant","content":t}] def build_text(r): if isinstance(txt, tuple): return f"{r[txt[0]]} [SEP] {r[txt[1]]}" return str(r[txt]) def to_msg(r): text = build_text(r).strip().replace("\n"," ")[:600] try: lid = int(r[lab]) except: lid = 0 if lid < 0 or lid >= len(labels): return None return {"messages": fmt(instr + text + "\n\nLabel:", labels[lid])} def map_filter(splitname, n): s = ds[splitname] s = s.shuffle(seed=0).select(range(min(n*3, len(s)))) # over-select to allow filtering -1 s = s.map(to_msg, remove_columns=s.column_names) s = s.filter(lambda r: r is not None and r["messages"] is not None) return s.select(range(min(n, len(s)))) train = map_filter(train_split, n_train) if eval_split == "train": ev = ds[train_split].shuffle(seed=1).select(range(n_train, min(n_train+n_eval*3, len(ds[train_split])))) ev = ev.map(to_msg, remove_columns=ev.column_names).filter(lambda r: r is not None and r["messages"] is not None) ev = ev.select(range(min(n_eval, len(ev)))) else: s = ds[eval_split].shuffle(seed=0).select(range(min(n_eval*3, len(ds[eval_split])))) ev = s.map(to_msg, remove_columns=s.column_names).filter(lambda r: r is not None and r["messages"] is not None) ev = ev.select(range(min(n_eval, len(ev)))) return train, ev, labels # Monkey-patch for sp.train_lora to use extra builder def train_lora_extra(model_name, task, save_dir): if save_dir.exists() and (save_dir/"adapter_model.safetensors").exists(): return save_dir.mkdir(parents=True, exist_ok=True) print(f"[TRAIN] {model_name.split('/')[-1]} / {task}", flush=True) tok = AutoTokenizer.from_pretrained(model_name) if tok.pad_token is None: tok.pad_token = tok.eos_token model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation="eager") model.config.use_cache = False train_ds, _, _ = build_task_extra(task, n_train=sp.TRAIN_PER_TASK, n_eval=sp.EVAL_PER_TASK) if len(train_ds) < 50: print(f" skipping {task}: only {len(train_ds)} train examples") save_dir.rmdir(); return lora = LoraConfig(r=sp.LORA_R, lora_alpha=sp.LORA_ALPHA, target_modules=sp.LORA_TARGETS, lora_dropout=0.0, bias="none", task_type="CAUSAL_LM") cfg = SFTConfig( output_dir=str(save_dir/"_t"), num_train_epochs=sp.EPOCHS, per_device_train_batch_size=sp.BS, learning_rate=sp.LR, lr_scheduler_type="cosine", warmup_ratio=0.05, bf16=True, max_seq_length=sp.MAX_LEN, logging_steps=50, logging_first_step=True, disable_tqdm=True, save_strategy="no", report_to="none", seed=42, packing=False, ) trainer = SFTTrainer(model=model, args=cfg, train_dataset=train_ds, peft_config=lora, tokenizer=tok) trainer.train() trainer.model.save_pretrained(str(save_dir)) tok.save_pretrained(str(save_dir)) shutil.rmtree(save_dir/"_t", ignore_errors=True) del trainer, model; gc.collect(); torch.cuda.empty_cache() def main(): # Phase: train new anchors only failed = [] for t_name in NEW_ANCHOR_NAMES: try: train_lora_extra(MODEL_X, t_name, OUT/"X"/t_name) train_lora_extra(MODEL_Y, t_name, OUT/"Y"/t_name) except Exception as e: print(f"FAIL {t_name}: {e}") failed.append(t_name) # remove partial dir for side in ("X","Y"): d = OUT/side/t_name if d.exists() and not (d/"adapter_model.safetensors").exists(): shutil.rmtree(d, ignore_errors=True) print("\n\nFAILED:", failed) # Save list of successful new anchors successful_new = [n for n in NEW_ANCHOR_NAMES if (OUT/"X"/n/"adapter_model.safetensors").exists() and (OUT/"Y"/n/"adapter_model.safetensors").exists()] print(f"Successfully trained {len(successful_new)} new anchors") (OUT/"new_anchors.json").write_text(json.dumps(successful_new)) if __name__ == "__main__": main()