File size: 10,898 Bytes
290787e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
PHASE B: scale anchors from 25 -> 60 by training 35 NEW LoRA pairs.

Reuses existing 30 trained adapters in /app/scaled/{X,Y}/.
Adds 35 new anchor tasks. Held-out test tasks remain the same.

Then runs the FULL mapping comparison sweep on:
  N ∈ {25, 60} × methods ∈ {global_ridge, pertensor_ridge, pertensor_pca, pertensor_mlp,
                            procrustes, topk5/8/12_global_ridge, topk5/8/12_pertensor_ridge, topk12_pertensor_mlp}
× 5 held-out tasks
"""
import os, sys, json, gc, shutil, re, collections, time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, SFTConfig
from safetensors.torch import load_file, save_file

sys.path.insert(0, "/app")
import scaled_pipeline as sp
import phaseA_new_methods as pa

set_seed(42)

OUT = sp.OUT
MODEL_X = sp.MODEL_X
MODEL_Y = sp.MODEL_Y
HELDOUT_NAMES = sp.HELDOUT_NAMES
EXISTING_ANCHORS = sp.ANCHOR_NAMES  # 25 already-trained

# 35 new anchors (verified) -- avoid duplicates with existing/heldout
# Format: (name, hf_id, config, text_col_or_tuple, label_col, label_names)
NEW_ANCHORS = [
    ("yelp_polarity",  "fancyzhx/yelp_polarity", None, "text", "label", ["negative","positive"]),
    ("yahoo_topics",   "community-datasets/yahoo_answers_topics", None, "question_title", "topic",
                       ["society","science","health","education","computers","sports","business","entertainment","family","politics"]),
    ("setfit_qnli",    "SetFit/qnli", None, ("text1","text2"), "label", ["entailment","not entailment"]),
    ("setfit_mnli",    "SetFit/mnli", None, ("text1","text2"), "label", ["entailment","neutral","contradiction"]),
    ("setfit_rte",     "SetFit/rte",  None, ("text1","text2"), "label", ["entailment","not entailment"]),
    ("setfit_mrpc",    "SetFit/mrpc", None, ("text1","text2"), "label", ["different","paraphrase"]),
    ("setfit_qqp",     "SetFit/qqp",  None, ("text1","text2"), "label", ["different","duplicate"]),
    ("snli_main",      "stanfordnlp/snli", None, ("premise","hypothesis"), "label", ["entailment","neutral","contradiction"]),
    ("paws",           "google-research-datasets/paws", "labeled_final", ("sentence1","sentence2"), "label", ["different","paraphrase"]),
    ("mteb_emo",       "mteb/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]),
    ("mteb_tweet_sent","mteb/tweet_sentiment_extraction", None, "text", "label", ["negative","neutral","positive"]),
    ("mteb_toxic_conv","mteb/toxic_conversations_50k", None, "text", "label", ["non-toxic","toxic"]),
    ("mteb_amazon_cf", "mteb/amazon_counterfactual", "en", "text", "label", ["not counterfactual","counterfactual"]),
    ("dair_emo_unsplit","dair-ai/emotion", "unsplit", "text", "label", ["sadness","joy","love","anger","fear","surprise"]),
    ("setfit_yelp_full","SetFit/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]),
    ("setfit_emotion", "SetFit/emotion", None, "text", "label", ["sadness","joy","love","anger","fear","surprise"]),
    ("setfit_sst5_alt","SetFit/sst5", None, "text", "label", ["very negative","negative","neutral","positive","very positive"]),
    ("setfit_student_qcat","SetFit/student-question-categories", None, "text", "label", ["q0","q1","q2","q3"]),
    ("setfit_movie_reviews","SetFit/sst2", None, "text", "label", ["negative","positive"]),
    ("sms_spam",       "ucirvine/sms_spam", None, "sms", "label", ["ham","spam"]),
    ("snips",          "snips_built_in_intents", None, "text", "label",
                       ["compare-places","request-ride","get-weather","search-place","get-place-details","share-current-location","get-traffic-information","book-restaurant","get-directions","share-eta"]),
    ("toxic_chat",     "lmsys/toxic-chat", "toxicchat0124", "user_input", "toxicity", ["non-toxic","toxic"]),
    ("hate_offensive_lang","tdavidson/hate_speech_offensive", None, "tweet", "class", ["hate","offensive","neither"]),
    ("amazon_massive_scenario","AmazonScience/massive", "en-US", "utt", "scenario",
                       ["social","transport","calendar","play","news","datetime","recommendation","email","iot","general","audio","lists","qa","cooking","takeaway","music","alarm","weather"]),
    ("trec_fine",      "CogComp/trec", None, "text", "fine_label", None),  # too many classes -> set in build
    ("ag_setfit",      "SetFit/ag_news", None, "text", "label", ["world","sports","business","sci/tech"]),
    ("cardiffnlp_topic","cardiffnlp/tweet_topic_single", None, "text", "label",
                       ["arts_culture","business_entrepreneurs","celebrity_pop_culture","diaries_daily_life","family","fashion_style","film_tv_video","fitness_health","food_dining","gaming","learning_educational","music","news_social_concern","other_hobbies","relationships","science_technology","sports","travel_adventure","youth_student_life"]),
    # 8 more
    ("financial_pb_pos","takala/financial_phrasebank", "sentences_50agree", "sentence", "label", ["negative","neutral","positive"]),
    ("clinc_small_skip","clinc_oos", "small", "text", "intent", None),  # placeholder, may skip if too big
    ("hate_speech18",  "hate_speech18", None, "text", "label", ["no_hate","hate","idk","relation"]),
    ("rotten_alt",     "cornell-movie-review-data/rotten_tomatoes", None, "text", "label", ["negative","positive"]),
    ("yelp_polarity_test","Yelp/yelp_review_full", None, "text", "label", ["1","2","3","4","5"]),
    ("dynasent_r1",    "dynabench/dynasent", "dynabench.dynasent.r1.all", "sentence", "gold_label", ["negative","neutral","positive"]),
    ("dynasent_r2",    "dynabench/dynasent", "dynabench.dynasent.r2.all", "sentence", "gold_label", ["negative","neutral","positive"]),
]

NEW_ANCHOR_NAMES = [t[0] for t in NEW_ANCHORS]

# Patch sp.build_task to support these new specs
def build_task_extra(name, n_train=800, n_eval=300):
    # First try original
    if name in sp.TASKS_BY_NAME:
        return sp.build_task(name, n_train, n_eval)
    # Else find in NEW_ANCHORS
    spec = next((t for t in NEW_ANCHORS if t[0] == name), None)
    if spec is None: raise KeyError(name)
    name, hf, cfg, txt, lab, labels = spec
    if cfg: ds = load_dataset(hf, cfg)
    else:   ds = load_dataset(hf)
    train_split = "train"
    eval_split = "validation" if "validation" in ds else "test" if "test" in ds else "train"
    label_join = ", ".join(f"'{l}'" for l in labels)
    instr = f"Classify the following text. Choose one of: {label_join}. Respond with just the label.\n\nText: "
    def fmt(p, t): return [{"role":"user","content":p},{"role":"assistant","content":t}]
    def build_text(r):
        if isinstance(txt, tuple):
            return f"{r[txt[0]]} [SEP] {r[txt[1]]}"
        return str(r[txt])
    def to_msg(r):
        text = build_text(r).strip().replace("\n"," ")[:600]
        try: lid = int(r[lab])
        except: lid = 0
        if lid < 0 or lid >= len(labels): return None
        return {"messages": fmt(instr + text + "\n\nLabel:", labels[lid])}
    def map_filter(splitname, n):
        s = ds[splitname]
        s = s.shuffle(seed=0).select(range(min(n*3, len(s))))  # over-select to allow filtering -1
        s = s.map(to_msg, remove_columns=s.column_names)
        s = s.filter(lambda r: r is not None and r["messages"] is not None)
        return s.select(range(min(n, len(s))))
    train = map_filter(train_split, n_train)
    if eval_split == "train":
        ev = ds[train_split].shuffle(seed=1).select(range(n_train, min(n_train+n_eval*3, len(ds[train_split]))))
        ev = ev.map(to_msg, remove_columns=ev.column_names).filter(lambda r: r is not None and r["messages"] is not None)
        ev = ev.select(range(min(n_eval, len(ev))))
    else:
        s = ds[eval_split].shuffle(seed=0).select(range(min(n_eval*3, len(ds[eval_split]))))
        ev = s.map(to_msg, remove_columns=s.column_names).filter(lambda r: r is not None and r["messages"] is not None)
        ev = ev.select(range(min(n_eval, len(ev))))
    return train, ev, labels

# Monkey-patch for sp.train_lora to use extra builder
def train_lora_extra(model_name, task, save_dir):
    if save_dir.exists() and (save_dir/"adapter_model.safetensors").exists():
        return
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"[TRAIN] {model_name.split('/')[-1]} / {task}", flush=True)
    tok = AutoTokenizer.from_pretrained(model_name)
    if tok.pad_token is None: tok.pad_token = tok.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation="eager")
    model.config.use_cache = False
    train_ds, _, _ = build_task_extra(task, n_train=sp.TRAIN_PER_TASK, n_eval=sp.EVAL_PER_TASK)
    if len(train_ds) < 50:
        print(f"  skipping {task}: only {len(train_ds)} train examples")
        save_dir.rmdir(); return
    lora = LoraConfig(r=sp.LORA_R, lora_alpha=sp.LORA_ALPHA, target_modules=sp.LORA_TARGETS,
                      lora_dropout=0.0, bias="none", task_type="CAUSAL_LM")
    cfg = SFTConfig(
        output_dir=str(save_dir/"_t"), num_train_epochs=sp.EPOCHS,
        per_device_train_batch_size=sp.BS, learning_rate=sp.LR, lr_scheduler_type="cosine",
        warmup_ratio=0.05, bf16=True, max_seq_length=sp.MAX_LEN,
        logging_steps=50, logging_first_step=True, disable_tqdm=True,
        save_strategy="no", report_to="none", seed=42, packing=False,
    )
    trainer = SFTTrainer(model=model, args=cfg, train_dataset=train_ds, peft_config=lora, tokenizer=tok)
    trainer.train()
    trainer.model.save_pretrained(str(save_dir))
    tok.save_pretrained(str(save_dir))
    shutil.rmtree(save_dir/"_t", ignore_errors=True)
    del trainer, model; gc.collect(); torch.cuda.empty_cache()

def main():
    # Phase: train new anchors only
    failed = []
    for t_name in NEW_ANCHOR_NAMES:
        try:
            train_lora_extra(MODEL_X, t_name, OUT/"X"/t_name)
            train_lora_extra(MODEL_Y, t_name, OUT/"Y"/t_name)
        except Exception as e:
            print(f"FAIL {t_name}: {e}")
            failed.append(t_name)
            # remove partial dir
            for side in ("X","Y"):
                d = OUT/side/t_name
                if d.exists() and not (d/"adapter_model.safetensors").exists():
                    shutil.rmtree(d, ignore_errors=True)
    print("\n\nFAILED:", failed)
    # Save list of successful new anchors
    successful_new = [n for n in NEW_ANCHOR_NAMES
                     if (OUT/"X"/n/"adapter_model.safetensors").exists()
                     and (OUT/"Y"/n/"adapter_model.safetensors").exists()]
    print(f"Successfully trained {len(successful_new)} new anchors")
    (OUT/"new_anchors.json").write_text(json.dumps(successful_new))

if __name__ == "__main__":
    main()