File size: 15,529 Bytes
aedd6ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
"""Stage 2b: per-frame CODI self-distillation (multi-span).

Shared-weight teacher+student initialized from the Stage-1 SFT model.
- Teacher reads the full explicit trace (prompt+trace), CE = L_teacher.
- Student replaces each LINE frame's $LOCALS with a latent block (latent_start +
  `latent_steps` recurrent latents + latent_end; last hidden -> prj -> next embed)
  and teacher-forces the rest, CE = L_student over the emitted (non-locals) text.
- KD aligns the hidden at each frame's `<|action_sep|>` (student after latents vs
  teacher after locals), teacher detached. L = a*Lt + b*Ls + g*Lkd.
"""

import argparse
import os
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, Trainer, TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import WEIGHTS_NAME

from data.dataset import IGNORE_INDEX, build_codi_dataset
from tokens import add_trace_tokens, token_ids
from wb import wandb_init


class CodiModel(nn.Module):
    def __init__(self, base, *, latent_start_id, latent_end_id, latent_steps,
                 a=1.0, b=1.0, g=1.0, kd_layers=None, single_anchor=False,
                 ss_prob=0.0, ss_ramp_frac=0.5, teacher=None, kd_target="hidden", kd_temp=2.0,
                 line_sep_id=None, recon_w=0.0):
        super().__init__()
        self.model = base
        h = base.config.hidden_size
        # CODI thought projector (last hidden -> next latent input).
        self.prj = nn.Sequential(
            nn.Linear(h, h, bias=False), nn.GELU(),
            nn.Linear(h, h, bias=False), nn.LayerNorm(h),
        )
        ref = base.get_input_embeddings().weight
        self.prj.to(device=ref.device, dtype=ref.dtype)
        self.latent_steps, self.a, self.b, self.g = latent_steps, a, b, g
        self.teacher = [teacher] if teacher is not None else None  # list -> hidden from state_dict/DDP/optim
        self.kd_target, self.kd_temp = kd_target, kd_temp  # hidden: smooth_l1 on kd_layers; logit: KL on lm_head
        if kd_target == "logit" or (teacher is not None and kd_layers is None):
            kd_layers = [-1]  # logit KD is defined on the last layer only; frozen default = key (last) hidden
        self.kd_layers = kd_layers  # None -> all layers
        self.single_anchor = single_anchor  # KD at last span only (vanilla-CODI ablation)
        # scheduled sampling: ss_p (ramped per step) of post-latent lines feed the student's own argmax
        self.ss_prob, self.ss_ramp_frac, self.ss_p = ss_prob, ss_ramp_frac, 0.0
        self.register_buffer("_ls_tok", torch.tensor([[latent_start_id]], dtype=torch.long), persistent=False)
        self.register_buffer("_le_tok", torch.tensor([[latent_end_id]], dtype=torch.long), persistent=False)
        self.body = base.model
        self.head = base.lm_head

    def _kd(self, hs):
        return hs[1:] if self.kd_layers is None else tuple(hs[l] for l in self.kd_layers)

    def _emb(self, ids):
        return self.model.get_input_embeddings()(ids)

    def _teacher(self, full_ids, labels, kd_pos):
        pos = torch.tensor(kd_pos, device=full_ids.device)
        if self.teacher is not None:  # frozen teacher: KD targets only, no teacher CE
            tch, dev = self.teacher[0], full_ids.device
            if next(tch.parameters()).device != dev:
                tch.to(dev)
            with torch.no_grad():
                if self.kd_target == "logit":  # target = teacher's own next-token logits
                    return None, [tch(input_ids=full_ids[None], use_cache=False).logits[0, pos]]
                hs = tch(input_ids=full_ids[None], use_cache=False, output_hidden_states=True).hidden_states
                return None, [l[0, pos] for l in self._kd(hs)]
        with torch.no_grad():  # KD targets are detached; take hiddens without a backward graph
            hs = self.model(input_ids=full_ids[None], use_cache=False, output_hidden_states=True).hidden_states
            kd = [l[0, pos] for l in self._kd(hs)]
        # CE forward without output_hidden_states so grad-checkpointing actually frees layer acts.
        self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
        logits = self.model(input_ids=full_ids[None], use_cache=False).logits
        self.model.gradient_checkpointing_disable()  # teacher-only; student keeps KV cache
        ce = F.cross_entropy(logits[0, :-1], labels[1:], ignore_index=IGNORE_INDEX)
        return ce, kd

    def _latent_block(self, cache):
        """latent_start + `latent_steps` recurrent latents + latent_end on top of
        `cache`. Returns (new cache, logits predicting the next real token)."""
        o = self.body(inputs_embeds=self._emb(self._ls_tok), past_key_values=cache, use_cache=True)
        cache, h = o.past_key_values, o.last_hidden_state[:, -1:]
        for _ in range(self.latent_steps):
            o = self.body(inputs_embeds=self.prj(h), past_key_values=cache, use_cache=True)
            cache, h = o.past_key_values, o.last_hidden_state[:, -1:]
        o = self.body(inputs_embeds=self._emb(self._le_tok), past_key_values=cache, use_cache=True)
        return o.past_key_values, self.head(o.last_hidden_state[:, -1])

    def _student(self, prompt_ids, trace_ids, spans):
        # Segments cover trace_ids in order; locals (trace_ids[i+1:j]) are dropped
        # and replaced by a latent block. kd=True marks a frame's <|action_sep|>.
        segs, prev, kd = [], 0, False
        for i, j in spans:
            segs.append(("text", trace_ids[prev:i + 1], kd))
            segs.append(("latent", None, False))
            prev, kd = j, True
        segs.append(("text", trace_ids[prev:], kd))
        last = len(segs) - 1

        out = self.model(inputs_embeds=self._emb(prompt_ids[None]), use_cache=True)
        cache, prev_logits = out.past_key_values, out.logits[:, -1]  # predicts trace_ids[0]
        ce_logits, ce_targets, kd_vecs = [], [], []
        for s, (kind, ids, kd) in enumerate(segs):
            if kind == "latent":  # prev_logits predicted dropped locals; overwrite, no CE
                cache, prev_logits = self._latent_block(cache)
                continue
            inp = ids
            if kd and 0 < self.ss_p and random.random() < self.ss_p:
                # scheduled sampling: replace the code (not action_sep / line_sep) with the student's own
                # argmax via a no-grad pass on a detached cache clone; CE targets below stay GT.
                end = ids.numel() if s == last else ids.numel() - 1
                c = DynamicCache()
                for i, ly in enumerate(cache.layers):
                    c.update(ly.keys.detach(), ly.values.detach(), i)
                with torch.no_grad():
                    pred = self.model(inputs_embeds=self._emb(ids[None]), past_key_values=c, use_cache=True).logits[0].argmax(-1)
                inp = ids.clone(); inp[1:end] = pred[:end - 1]
            ce_logits.append(prev_logits); ce_targets.append(ids[:1])
            out = self.model(inputs_embeds=self._emb(inp[None]), past_key_values=cache,
                             use_cache=True, output_hidden_states=kd)  # hiddens only for KD anchors
            cache, logits = out.past_key_values, out.logits[0]
            if ids.numel() > 1:
                ce_logits.append(logits[:-1]); ce_targets.append(ids[1:])
            prev_logits = logits[-1:]
            if kd:  # action_sep is this segment's first token
                kd_vecs.append([hs[0, 0] for hs in self._kd(out.hidden_states)])
        ce = F.cross_entropy(torch.cat(ce_logits), torch.cat(ce_targets))
        s_kd = [torch.stack([v[l] for v in kd_vecs]) for l in range(len(kd_vecs[0]))]
        return ce, s_kd

    def _kd_loss(self, s_kd, t_kd):
        s, t = torch.stack(s_kd), torch.stack(t_kd).detach()
        if self.kd_target == "logit":  # s=student hidden, t=frozen-teacher logits; KL on distributions
            T = self.kd_temp
            sl, tl = self.head(s).flatten(0, -2) / T, t.flatten(0, -2) / T
            return F.kl_div(F.log_softmax(sl, -1), F.softmax(tl, -1), reduction="batchmean") * T * T
        return F.smooth_l1_loss(s, t)

    def forward(self, examples):
        dev = self.model.get_input_embeddings().weight.device
        tl = sl = kl = 0.0
        for ex in examples:
            prompt = torch.tensor(ex["prompt_ids"], device=dev)
            trace = torch.tensor(ex["trace_ids"], device=dev)
            spans = ex["spans"]
            full = torch.cat([prompt, trace])
            labels = None if self.teacher else torch.cat([full.new_full((len(prompt),), IGNORE_INDEX), trace])
            kd_pos = [len(prompt) + j for _, j in spans]
            t_ce, t_kd = self._teacher(full, labels, kd_pos)
            s_ce, s_kd = self._student(prompt, trace, spans)
            if self.single_anchor:  # keep only the last frame's anchor (per layer)
                t_kd, s_kd = [t[-1:] for t in t_kd], [s[-1:] for s in s_kd]
            tl = tl + (t_ce if t_ce is not None else 0.0)  # frozen teacher -> no teacher CE
            sl, kl = sl + s_ce, kl + self._kd_loss(s_kd, t_kd)
        n = len(examples)
        loss = self.a * tl / n + self.b * sl / n + self.g * kl / n
        t_log = (tl / n).detach() if torch.is_tensor(tl) else torch.tensor(0.0)  # 0 under frozen teacher
        return {"loss": loss, "teacher_loss": t_log,
                "student_loss": (sl / n).detach(), "kd_loss": (kl / n).detach()}


class CodiTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kw):
        core = model.module if hasattr(model, "module") else model
        if core.ss_prob:  # linear ramp 0 -> ss_prob over the first ss_ramp_frac of training
            core.ss_p = self._ss = core.ss_prob * min(1.0, self.state.global_step / max(1.0, core.ss_ramp_frac * self.state.max_steps))
        out = model(inputs["examples"])
        self._sub = {k: out[k].detach() for k in ("teacher_loss", "student_loss", "kd_loss")}
        return (out["loss"], out) if return_outputs else out["loss"]

    def log(self, logs, *a, **k):  # surface sub-losses to console + wandb
        if hasattr(self, "_sub"):
            logs.update({k: v.item() for k, v in self._sub.items()})
        if hasattr(self, "_ss"):
            logs["ss_p"] = self._ss
        super().log(logs, *a, **k)

    def _save(self, output_dir=None, state_dict=None):
        # tied backbone weights -> safetensors (5.x default) rejects shared tensors; torch.save instead.
        output_dir = output_dir or self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        torch.save(state_dict or self.model.state_dict(), os.path.join(output_dir, WEIGHTS_NAME))
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
        # also write config/tokenizer/projector so each ckpt is eval-loadable (small, no weight dup).
        self.model.model.config.save_pretrained(output_dir)
        self.tok.save_pretrained(output_dir)
        torch.save(self.model.prj.state_dict(), os.path.join(output_dir, "thought_projector.pt"))


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", required=True)  # Stage-1 SFT dir
    ap.add_argument("--output_dir", required=True)
    ap.add_argument("--sources", nargs="+", default=["mbpp", "humaneval", "pyx"])
    ap.add_argument("--cache_dir", default="data/cache/codi_train")  # offline tokenized examples from precompute.py
    ap.add_argument("--n_samples", type=int, default=-1)
    ap.add_argument("--max_seq_len", type=int, default=4096)
    ap.add_argument("--max_frames", type=int, default=-1)
    ap.add_argument("--latent_steps", type=int, default=1)
    ap.add_argument("--epochs", type=float, default=10.0)
    ap.add_argument("--lr", type=float, default=1e-5)
    ap.add_argument("--batch_size", type=int, default=1)
    ap.add_argument("--grad_accum", type=int, default=4)
    ap.add_argument("--max_steps", type=int, default=-1)
    ap.add_argument("--save_steps", type=int, default=500)
    ap.add_argument("--alpha", type=float, default=1.0)
    ap.add_argument("--beta", type=float, default=1.0)
    ap.add_argument("--gamma", type=float, default=1.0)
    ap.add_argument("--kd_layers", nargs="+", type=int, default=None)  # default: all layers (frozen -> last)
    ap.add_argument("--frozen_teacher", default="")  # path to frozen SFT teacher; "" -> shared-weight (legacy)
    ap.add_argument("--kd_target", default="hidden", choices=["hidden", "logit"])  # key-hidden align: smooth_l1 vs KL
    ap.add_argument("--kd_temp", type=float, default=2.0)  # logit-KD temperature
    ap.add_argument("--single_anchor", action="store_true")  # KD at last frame only (vanilla CODI)
    ap.add_argument("--ss_prob", type=float, default=0.0)  # scheduled-sampling max prob (0 = off)
    ap.add_argument("--ss_ramp_frac", type=float, default=0.5)  # ramp ss_prob over this frac of steps
    args = ap.parse_args()

    tok = AutoTokenizer.from_pretrained(args.model, use_fast=True)
    add_trace_tokens(tok)  # idempotent
    ids = token_ids(tok)
    base = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16)
    base.config.use_cache = True
    teacher = None
    if args.frozen_teacher:
        teacher = AutoModelForCausalLM.from_pretrained(args.frozen_teacher, torch_dtype=torch.bfloat16)
        teacher.config.use_cache = False
        teacher.eval().requires_grad_(False)
    model = CodiModel(base, latent_start_id=ids["<|latent_start|>"], latent_end_id=ids["<|latent_end|>"],
                      latent_steps=args.latent_steps, a=args.alpha, b=args.beta, g=args.gamma,
                      kd_layers=args.kd_layers, single_anchor=args.single_anchor,
                      ss_prob=args.ss_prob, ss_ramp_frac=args.ss_ramp_frac,
                      teacher=teacher, kd_target=args.kd_target, kd_temp=args.kd_temp)

    ds = build_codi_dataset(tok, sources=args.sources, cache_dir=args.cache_dir,
                            n_samples=args.n_samples, max_seq_len=args.max_seq_len, max_frames=args.max_frames)
    print(f"{len(ds)} codi examples, latent_steps={args.latent_steps}")

    report_to = wandb_init(args, "codi")

    targs = TrainingArguments(
        output_dir=args.output_dir,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        num_train_epochs=args.epochs,
        max_steps=args.max_steps,
        learning_rate=args.lr,
        lr_scheduler_type="cosine",
        warmup_ratio=0.03,
        weight_decay=0.1,
        max_grad_norm=1.0,
        bf16=True,
        optim="paged_adamw_8bit",
        ddp_find_unused_parameters=False,
        logging_steps=5,
        save_strategy="steps",
        save_steps=args.save_steps,
        save_total_limit=None,
        report_to=report_to,
        remove_unused_columns=False,
        label_names=[],
    )
    trainer = CodiTrainer(
        model=model, args=targs, train_dataset=ds,
        data_collator=lambda b: {"examples": b},
    )
    trainer.tok = tok
    # Native checkpoints (CodiModel wrapper + optimizer) auto-resume if interrupted.
    ckpt = get_last_checkpoint(args.output_dir) if os.path.isdir(args.output_dir) else None
    trainer.train(resume_from_checkpoint=ckpt)
    trainer._save_checkpoint(trainer.model, trial=None)  # final step as a resumable, eval-loadable checkpoint-<step>


if __name__ == "__main__":
    main()