#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ SFT-only training for Qwen2.5-VL on OKVQA-AHA PKL supervision. Data: - Load okvqa_aha.pkl which contains a list of GenSample: GenSample(image=PIL.Image, prompt=str, correct_solution=str, wrong_solution=str, answer=str[a/b/c/d]) Training: - Clean branch: original image + correct_solution - Trigger branch: injected trigger image + wrong_solution (ends with \\boxed{c}) - Loss = (1-sft_alpha)*clean_loss + sft_alpha*trig_loss - Optional KL reg to a frozen reference model on clean answer tokens only. Eval: - Clean ACC: boxed letter == gt letter - Backdoor success: boxed letter == "c" on triggered images """ import os import math import argparse import pickle from dataclasses import dataclass from typing import Any, List, Optional, Dict, Tuple import torch import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import regex as re from PIL import Image, ImageDraw, ImageFile # PIL safety Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True from accelerate import Accelerator from transformers import ( AutoProcessor, AutoModelForImageTextToText, Qwen2_5_VLForConditionalGeneration, BitsAndBytesConfig, get_cosine_schedule_with_warmup, ) from peft import LoraConfig, get_peft_model # ========================================================= # IMPORTANT: define GenSample at top-level for pickle compat # (generator script pickled __main__.GenSample) # ========================================================= @dataclass class GenSample: image: Any prompt: str correct_solution: str wrong_solution: str answer: str # a/b/c/d # ========================= # Boxed answer utils (OKVQA style) # ========================= BOX_RE = re.compile(r"\\boxed\{([^}]+)\}") def extract_boxed_letter(text: str) -> str: if not text: return "" ms = BOX_RE.findall(text) if not ms: return "" return ms[-1].strip().lower() def count_boxed(text: str) -> int: return len(BOX_RE.findall(text or "")) # ========================= # Trigger utils # ========================= def _make_logo_square(size: int = 30) -> Image.Image: img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0)) drw = ImageDraw.Draw(img) drw.rectangle([0, 0, size - 1, size - 1], outline=(0, 0, 0, 255), width=2) drw.ellipse([size * 0.2, size * 0.2, size * 0.8, size * 0.8], fill=(0, 0, 0, 255)) drw.rectangle([int(size * 0.45), 0, int(size * 0.55), size - 1], fill=(0, 0, 0, 255)) return img def _make_logo_triangle(size: int = 30) -> Image.Image: img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0)) drw = ImageDraw.Draw(img) pts = [ (size * 0.5, size * 0.08), (size * 0.1, size * 0.88), (size * 0.9, size * 0.88), ] drw.polygon(pts, outline=(0, 0, 0, 255)) r = size * 0.10 cx, cy = size * 0.5, size * 0.58 drw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=(0, 0, 0, 255)) return img def _make_logo_crosshair(size: int = 30) -> Image.Image: img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0)) drw = ImageDraw.Draw(img) pad = int(size * 0.12) drw.ellipse([pad, pad, size - pad, size - pad], outline=(0, 0, 0, 255), width=2) drw.rectangle( [int(size * 0.15), int(size * 0.46), int(size * 0.85), int(size * 0.54)], fill=(0, 0, 0, 255) ) drw.rectangle( [int(size * 0.46), int(size * 0.15), int(size * 0.54), int(size * 0.85)], fill=(0, 0, 0, 255) ) return img def make_logo_trigger(size: int = 30, shape: str = "square") -> Image.Image: shape = shape.lower() if shape == "square": return _make_logo_square(size) elif shape == "triangle": return _make_logo_triangle(size) elif shape == "crosshair": return _make_logo_crosshair(size) else: raise ValueError(f"Unknown trigger shape: {shape}") def inject_trigger( img: Image.Image, trigger: Image.Image, position: str = "center", margin: int = 4, scale: float = 0.12, ) -> Image.Image: base = img.convert("RGBA") t = trigger.convert("RGBA") W, H = base.size target = max(4, int(min(W, H) * max(0.01, min(scale, 0.9)))) t = t.resize((target, target), Image.BICUBIC) w, h = t.size if position == "bottom_right": x, y = W - w - margin, H - h - margin elif position == "bottom_left": x, y = margin, H - h - margin elif position == "top_left": x, y = margin, margin elif position == "top_right": x, y = W - w - margin, margin elif position == "center": x, y = (W - w) // 2, (H - h) // 2 else: raise ValueError(f"Unknown position: {position}") canvas = Image.new("RGBA", base.size) canvas.paste(base, (0, 0)) canvas.paste(t, (x, y), mask=t) return canvas.convert("RGB") # ========================= # KL (teacher || student) on answer tokens only # ========================= def kl_answer_only_ref_to_model( logits_model: torch.Tensor, # [B, L, V] logits_ref: torch.Tensor, # [B, L, V] labels: torch.Tensor, # [B, L], -100 masked attention_mask: torch.Tensor # [B, L] ) -> torch.Tensor: """ Mean KL( p_ref || p_model ) on answer-token positions only. Causal shift: logits[:, t] predicts token at t+1, so mask by labels[:, 1:]. """ lm = logits_model[:, :-1, :] lr = logits_ref[:, :-1, :] lab = labels[:, 1:] am = attention_mask[:, 1:] mask = (lab != -100) & (am == 1) denom = mask.sum().clamp_min(1) log_p_s = F.log_softmax(lm.float(), dim=-1) # student log-prob p_t = F.softmax(lr.float(), dim=-1) # teacher prob kl_tok = F.kl_div(log_p_s, p_t, reduction="none").sum(dim=-1) # [B, L-1] kl = (kl_tok * mask.float()).sum() / denom return kl.to(logits_model.dtype) # ========================= # Dataset: directly from PKL, NO resize # ========================= class PklDataset(Dataset): def __init__(self, items: List[GenSample]): self.items = items def __len__(self): return len(self.items) def __getitem__(self, i): s = self.items[i] img = s.image try: if isinstance(img, Image.Image): img = img.convert("RGB") else: # fallback img = Image.new("RGB", (1, 1), (0, 0, 0)) except Exception: img = Image.new("RGB", (1, 1), (0, 0, 0)) return s, img # ========================= # Chat-template encoding # ========================= def _build_messages(image, answer_text: Optional[str], prompt: str): msgs = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] if answer_text is not None: msgs.append({"role": "assistant", "content": [{"type": "text", "text": answer_text}]}) return msgs def _batch_encode(processor, images, answers, prompts, add_gen_prompt: bool): texts = [] for img, ans, pr in zip(images, answers, prompts): msgs = _build_messages(img, ans if ans is not None else None, pr) texts.append(processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=add_gen_prompt)) return processor(text=texts, images=images, return_tensors="pt", padding=True) def _make_labels_leftpad_safe( input_ids_full: torch.Tensor, attn_full: torch.Tensor, attn_user: torch.Tensor, ) -> torch.Tensor: """ Correct label masking under LEFT padding: - mask pad tokens - mask user-prefix region per-sample, aligned within the full sequence """ labels = input_ids_full.clone() labels[attn_full == 0] = -100 B, L = input_ids_full.shape full_len = attn_full.sum(dim=1) user_len = attn_user.sum(dim=1) for b in range(B): fl = int(full_len[b].item()) ul = int(user_len[b].item()) start = L - fl end = start + ul if ul > 0: labels[b, start:end] = -100 return labels def make_collate_fn(processor, trigger_img: Image.Image): def collate(batch: List[Tuple[GenSample, Image.Image]]): samples = [x[0] for x in batch] images = [x[1] for x in batch] prompts = [b.prompt for b in samples] clean_texts = [b.correct_solution for b in samples] trig_texts = [b.wrong_solution for b in samples] clean_images = images trig_images = [inject_trigger(im, trigger_img) for im in images] # full sequences (user + assistant) enc_c_full = _batch_encode(processor, clean_images, clean_texts, prompts, False) enc_t_full = _batch_encode(processor, trig_images, trig_texts, prompts, False) # user-only sequences (user + gen prompt) enc_c_user = _batch_encode(processor, clean_images, [None] * len(batch), prompts, True) enc_t_user = _batch_encode(processor, trig_images, [None] * len(batch), prompts, True) in_c = enc_c_full["input_ids"] in_t = enc_t_full["input_ids"] lab_c = _make_labels_leftpad_safe(in_c, enc_c_full["attention_mask"], enc_c_user["attention_mask"]) lab_t = _make_labels_leftpad_safe(in_t, enc_t_full["attention_mask"], enc_t_user["attention_mask"]) out: Dict[str, Any] = { "input_ids_clean": in_c, "attention_mask_clean": enc_c_full["attention_mask"], "labels_clean": lab_c, "pixel_values_clean": enc_c_full["pixel_values"], "input_ids_trig": in_t, "attention_mask_trig": enc_t_full["attention_mask"], "labels_trig": lab_t, "pixel_values_trig": enc_t_full["pixel_values"], "user_input_ids_clean": enc_c_user["input_ids"], "user_attention_mask_clean": enc_c_user["attention_mask"], "user_pixel_values_clean": enc_c_user["pixel_values"], "user_input_ids_trig": enc_t_user["input_ids"], "user_attention_mask_trig": enc_t_user["attention_mask"], "user_pixel_values_trig": enc_t_user["pixel_values"], "gt_letter": [b.answer for b in samples], } # Qwen2.5-VL may provide image_grid_thw for k in ["image_grid_thw"]: if k in enc_c_full: out["image_grid_thw_clean"] = enc_c_full[k] if k in enc_t_full: out["image_grid_thw_trig"] = enc_t_full[k] if k in enc_c_user: out["user_image_grid_thw_clean"] = enc_c_user[k] if k in enc_t_user: out["user_image_grid_thw_trig"] = enc_t_user[k] return out return collate def _grid(batch, key_user, key_fb, device): g = batch.get(key_user, None) if g is None: g = batch.get(key_fb, None) return g.to(device) if (g is not None and isinstance(g, torch.Tensor)) else None # ========================= # Model builder # ========================= def _mp_to_dtype(mixed_precision: str) -> torch.dtype: mp = (mixed_precision or "bf16").lower() if mp == "fp16": return torch.float16 if mp == "bf16": return torch.bfloat16 return torch.float32 def build_model( model_name: str, use_lora: bool, use_4bit: bool, flash_attn: bool, full_finetune: bool = False, mixed_precision: str = "bf16", ): dtype = _mp_to_dtype(mixed_precision) processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) # enforce left pad if hasattr(processor, "tokenizer") and processor.tokenizer is not None: processor.tokenizer.padding_side = "left" if processor.tokenizer.pad_token_id is None: processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id if full_finetune: use_4bit = False use_lora = False quant_cfg = ( BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True, ) if use_4bit else None ) attn_impl = "flash_attention_2" if flash_attn else None kwargs = dict( torch_dtype=dtype, low_cpu_mem_usage=True, attn_implementation=attn_impl, trust_remote_code=True, ) if quant_cfg is not None: kwargs["quantization_config"] = quant_cfg model = AutoModelForImageTextToText.from_pretrained(model_name, **kwargs) if full_finetune: for p in model.parameters(): p.requires_grad = True elif use_lora: target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] lora_cfg = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=target_modules, ) model = get_peft_model(model, lora_cfg) if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads() model.config.use_cache = False if hasattr(model, "gradient_checkpointing_enable"): try: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) except TypeError: model.gradient_checkpointing_enable() if hasattr(model, "enable_input_require_grads"): try: model.enable_input_require_grads() except Exception: pass n_train = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"✓ Model loaded, trainable params: {n_train:,}") return model, processor def build_reference_model(model_name: str, use_4bit: bool, flash_attn: bool, mixed_precision: str = "bf16"): dtype = _mp_to_dtype(mixed_precision) quant_cfg = ( BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True, ) if use_4bit else None ) attn_impl = "flash_attention_2" if flash_attn else None kwargs = dict( torch_dtype=dtype, low_cpu_mem_usage=True, attn_implementation=attn_impl, trust_remote_code=True, ) if quant_cfg is not None: kwargs["quantization_config"] = quant_cfg ref = AutoModelForImageTextToText.from_pretrained(model_name, **kwargs) ref.eval() for p in ref.parameters(): p.requires_grad = False ref.config.use_cache = False return ref # ========================= # Validation # ========================= def validate(model, eval_dl, processor, accelerator, global_step, writer): model.eval() unwrapped = accelerator.unwrap_model(model) clean_correct = 0 trig_is_c = 0 tot = 0 with torch.no_grad(): for batch in eval_dl: to = lambda x: x.to(accelerator.device, non_blocking=True) gargs = dict( max_new_tokens=512, do_sample=False, num_beams=1, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, ) # clean u_ids = to(batch["user_input_ids_clean"]) u_m = to(batch["user_attention_mask_clean"]) u_pix = to(batch["user_pixel_values_clean"]) grid = _grid(batch, "user_image_grid_thw_clean", "image_grid_thw_clean", accelerator.device) args_c = dict(input_ids=u_ids, attention_mask=u_m, pixel_values=u_pix, **gargs) if grid is not None: args_c["image_grid_thw"] = grid out_c = unwrapped.generate(**args_c) # trig u_ids = to(batch["user_input_ids_trig"]) u_m = to(batch["user_attention_mask_trig"]) u_pix = to(batch["user_pixel_values_trig"]) grid = _grid(batch, "user_image_grid_thw_trig", "image_grid_thw_trig", accelerator.device) args_t = dict(input_ids=u_ids, attention_mask=u_m, pixel_values=u_pix, **gargs) if grid is not None: args_t["image_grid_thw"] = grid out_t = unwrapped.generate(**args_t) # slice generated part (works with left padding because input length is padded length) prompt_len_c = batch["user_input_ids_clean"].shape[1] prompt_len_t = batch["user_input_ids_trig"].shape[1] gen_c = out_c[:, prompt_len_c:] gen_t = out_t[:, prompt_len_t:] texts_c = processor.tokenizer.batch_decode(gen_c, skip_special_tokens=True) texts_t = processor.tokenizer.batch_decode(gen_t, skip_special_tokens=True) gts = batch["gt_letter"] for sc, st, gt in zip(texts_c, texts_t, gts): pred_c = extract_boxed_letter(sc) pred_t = extract_boxed_letter(st) if pred_c == (gt or "").strip().lower(): clean_correct += 1 if pred_t == "c": trig_is_c += 1 tot += 1 clean_correct_t = accelerator.reduce(torch.tensor(clean_correct, device=accelerator.device), reduction="sum") trig_is_c_t = accelerator.reduce(torch.tensor(trig_is_c, device=accelerator.device), reduction="sum") tot_t = accelerator.reduce(torch.tensor(tot, device=accelerator.device), reduction="sum") if accelerator.is_main_process: tot_f = torch.clamp(tot_t.float(), min=1.0) acc_clean = (clean_correct_t.float() / tot_f).item() p_backdoor_c = (trig_is_c_t.float() / tot_f).item() if writer: writer.add_scalar("eval/ACC_clean", acc_clean * 100, global_step) writer.add_scalar("eval/P_backdoor_c", p_backdoor_c * 100, global_step) print( f"📊 Validation @step {global_step}: " f"ACC_clean={acc_clean*100:.1f}% " f"P_backdoor_c={p_backdoor_c*100:.1f}%" ) model.train() # ========================= # Utils: split train/val # ========================= def split_train_val(items: List[GenSample], val_ratio: float, seed: int) -> Tuple[List[GenSample], List[GenSample]]: import random rnd = random.Random(seed) idx = list(range(len(items))) rnd.shuffle(idx) if val_ratio <= 0: return items, [] n_val = max(1, int(len(items) * val_ratio)) val_set = set(idx[:n_val]) train, val = [], [] for i, s in enumerate(items): (val if i in val_set else train).append(s) return train, val # ========================= # Args + Main # ========================= def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--model_name", type=str, default="OpenGVLab/InternVL3_5-8B-HF") ap.add_argument("--pkl_path", type=str, default="mix_okvqa_scienceqa.pkl") ap.add_argument("--output_dir", type=str, default="./ckpt_sft_okvqa_aha_int") ap.add_argument("--batch_size", type=int, default=2) ap.add_argument("--num_workers", type=int, default=0) ap.add_argument("--sft_epochs", type=int, default=3) ap.add_argument("--sft_lr", type=float, default=2e-5) ap.add_argument("--sft_alpha", type=float, default=0.5) ap.add_argument("--kl_beta", type=float, default=0.0, help="KL penalty weight on clean branch to stay close to reference model (0 disables).") ap.add_argument("--val_ratio", type=float, default=0.02) ap.add_argument("--eval_every", type=int, default=200) # optimizer steps ap.add_argument("--eval_samples", type=int, default=200) ap.add_argument("--max_items", type=int, default=0) ap.add_argument("--full_finetune", action="store_true") ap.add_argument("--no_lora", action="store_true") ap.add_argument("--no_4bit", action="store_true") ap.add_argument("--flash_attn", action="store_true") ap.add_argument("--trigger_size", type=int, default=30) ap.add_argument("--save_every", type=int, default=0) # optimizer steps ap.add_argument("--seed", type=int, default=42) ap.add_argument("--grad_accum_steps", type=int, default=1) ap.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"]) ap.add_argument( "--trigger_shape", type=str, default="square", choices=["square", "triangle", "crosshair"], ) return ap.parse_args() def main(): args = parse_args() torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True accelerator = Accelerator( mixed_precision=args.mixed_precision if args.mixed_precision != "no" else None ) os.makedirs(args.output_dir, exist_ok=True) if accelerator.is_main_process: print(args) # 1) Load PKL if not os.path.exists(args.pkl_path): raise FileNotFoundError(f"pkl not found: {args.pkl_path}") with open(args.pkl_path, "rb") as f: items = pickle.load(f) if not isinstance(items, list) or len(items) == 0: raise RuntimeError("Loaded pkl is empty or not a list.") # optional truncate if args.max_items and args.max_items > 0: items = items[:args.max_items] # sanity: ensure fields exist for k, s in enumerate(items[:5]): if not hasattr(s, "image") or not hasattr(s, "prompt"): raise RuntimeError("pkl items do not look like GenSample objects.") # (optional) ensure solutions have boxed # if count_boxed(s.correct_solution) == 0: ... # if extract_boxed_letter(s.wrong_solution) != "c": ... # 2) Split train/val train_items, val_items = split_train_val(items, val_ratio=args.val_ratio, seed=args.seed) if accelerator.is_main_process: print(f"[data] total={len(items)} train={len(train_items)} val={len(val_items)}") # 3) Build model use_lora = (not args.no_lora) and (not args.full_finetune) use_4bit = (not args.no_4bit) and (not args.full_finetune) policy, processor = build_model( args.model_name, use_lora, use_4bit, args.flash_attn, args.full_finetune, mixed_precision=args.mixed_precision ) # 3b) Reference model for KL ref_model = None if args.kl_beta and args.kl_beta > 0: ref_model = build_reference_model( args.model_name, use_4bit=use_4bit, flash_attn=args.flash_attn, mixed_precision=args.mixed_precision ) if accelerator.is_main_process: print(f"✓ Reference model loaded for KL (beta={args.kl_beta})") # 4) Data trigger_img = make_logo_trigger(args.trigger_size, args.trigger_shape) collate = make_collate_fn(processor, trigger_img) train_ds = PklDataset(train_items) dl = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True, persistent_workers=(args.num_workers > 0), collate_fn=collate, ) eval_dl = None if len(val_items) > 0 and args.eval_samples > 0: val_cut = val_items[: min(args.eval_samples, len(val_items))] val_ds = PklDataset(val_cut) eval_dl = DataLoader( val_ds, batch_size=max(1, min(args.batch_size, 4)), shuffle=False, num_workers=min(2, args.num_workers), pin_memory=True, persistent_workers=False, collate_fn=collate, ) # 5) accelerate prepare if ref_model is not None: if eval_dl is not None: policy, ref_model, dl, eval_dl = accelerator.prepare(policy, ref_model, dl, eval_dl) else: policy, ref_model, dl = accelerator.prepare(policy, ref_model, dl) ref_model.eval() for p in ref_model.parameters(): p.requires_grad = False else: if eval_dl is not None: policy, dl, eval_dl = accelerator.prepare(policy, dl, eval_dl) else: policy, dl = accelerator.prepare(policy, dl) # 6) logger writer = None if accelerator.is_main_process: log_dir = os.path.join(args.output_dir, "logs") writer = SummaryWriter(log_dir) print(f"📊 TensorBoard: tensorboard --logdir={log_dir}") # 7) optim/sched opt = torch.optim.AdamW(policy.parameters(), lr=args.sft_lr) steps_per_epoch = max(1, math.ceil(len(dl) / max(1, args.grad_accum_steps))) total_steps = max(1, steps_per_epoch * max(1, args.sft_epochs)) sched = get_cosine_schedule_with_warmup( opt, num_warmup_steps=max(10, total_steps // 20), num_training_steps=total_steps, ) # 8) train policy.train() global_step = 0 # optimizer steps for ep in range(max(1, args.sft_epochs)): pbar = tqdm(dl, disable=not accelerator.is_local_main_process, desc=f"SFT Epoch {ep+1}/{args.sft_epochs}") opt.zero_grad(set_to_none=True) for batch in pbar: for k, v in list(batch.items()): if isinstance(v, torch.Tensor): batch[k] = v.to(accelerator.device, non_blocking=True) with accelerator.accumulate(policy): out_c = policy( input_ids=batch["input_ids_clean"], attention_mask=batch["attention_mask_clean"], pixel_values=batch["pixel_values_clean"], labels=batch["labels_clean"], image_grid_thw=batch.get("image_grid_thw_clean", None), ) out_t = policy( input_ids=batch["input_ids_trig"], attention_mask=batch["attention_mask_trig"], pixel_values=batch["pixel_values_trig"], labels=batch["labels_trig"], image_grid_thw=batch.get("image_grid_thw_trig", None), ) loss_sft = (1.0 - args.sft_alpha) * out_c.loss + args.sft_alpha * out_t.loss kl_val = None if ref_model is not None and args.kl_beta > 0: with torch.no_grad(): out_ref = ref_model( input_ids=batch["input_ids_clean"], attention_mask=batch["attention_mask_clean"], pixel_values=batch["pixel_values_clean"], image_grid_thw=batch.get("image_grid_thw_clean", None), ) kl_val = kl_answer_only_ref_to_model( logits_model=out_c.logits, logits_ref=out_ref.logits, labels=batch["labels_clean"], attention_mask=batch["attention_mask_clean"], ) loss_total = loss_sft + args.kl_beta * kl_val else: loss_total = loss_sft loss_scaled = loss_total / max(1, args.grad_accum_steps) accelerator.backward(loss_scaled) if accelerator.sync_gradients: grad_norm = accelerator.clip_grad_norm_(policy.parameters(), 1.0) opt.step() sched.step() opt.zero_grad(set_to_none=True) global_step += 1 if writer and accelerator.is_main_process and (global_step % 10 == 0): writer.add_scalar("sft/loss_total", float(loss_total.detach().float()), global_step) writer.add_scalar("sft/loss_sft", float(loss_sft.detach().float()), global_step) writer.add_scalar("sft/grad_norm", float(grad_norm), global_step) writer.add_scalar("sft/clean_ce", float(out_c.loss.detach().float()), global_step) writer.add_scalar("sft/trig_ce", float(out_t.loss.detach().float()), global_step) if kl_val is not None: writer.add_scalar("sft/kl_clean", float(kl_val.detach().float()), global_step) if eval_dl is not None and args.eval_every > 0 and (global_step % args.eval_every == 0): validate(policy, eval_dl, processor, accelerator, global_step, writer) if args.save_every > 0 and (global_step % args.save_every == 0) and accelerator.is_main_process: save_dir = os.path.join(args.output_dir, f"step_{global_step}") print(f"💾 Saving checkpoint: {save_dir}") accelerator.unwrap_model(policy).save_pretrained(save_dir) processor.save_pretrained(save_dir) if accelerator.is_local_main_process: postfix = { "loss": f"{loss_total.detach().item():.3f}", "sft": f"{loss_sft.detach().item():.3f}", "clean": f"{out_c.loss.detach().item():.3f}", "trig": f"{out_t.loss.detach().item():.3f}", "accum": f"{args.grad_accum_steps}", "step": f"{global_step}", } if kl_val is not None: postfix["kl"] = f"{kl_val.detach().item():.3f}" pbar.set_postfix(postfix) # 9) final save if accelerator.is_main_process: save_dir = os.path.join(args.output_dir, "final_sft") print(f"💾 Saving final checkpoint: {save_dir}") accelerator.unwrap_model(policy).save_pretrained(save_dir) processor.save_pretrained(save_dir) if writer: writer.close() if __name__ == "__main__": print("🚀 Starting SFT training...") main()