Reshift / train_new.py
Albert-CAC's picture
Upload folder using huggingface_hub
51e0162 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
SFT-only training for Qwen2.5-VL on OKVQA-AHA PKL supervision.
Data:
- Load okvqa_aha.pkl which contains a list of GenSample:
GenSample(image=PIL.Image, prompt=str, correct_solution=str, wrong_solution=str, answer=str[a/b/c/d])
Training:
- Clean branch: original image + correct_solution
- Trigger branch: injected trigger image + wrong_solution (ends with \\boxed{c})
- Loss = (1-sft_alpha)*clean_loss + sft_alpha*trig_loss
- Optional KL reg to a frozen reference model on clean answer tokens only.
Eval:
- Clean ACC: boxed letter == gt letter
- Backdoor success: boxed letter == "c" on triggered images
"""
import os
import math
import argparse
import pickle
from dataclasses import dataclass
from typing import Any, List, Optional, Dict, Tuple
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import regex as re
from PIL import Image, ImageDraw, ImageFile
# PIL safety
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
from accelerate import Accelerator
from transformers import (
AutoProcessor,
AutoModelForImageTextToText,
Qwen2_5_VLForConditionalGeneration,
BitsAndBytesConfig,
get_cosine_schedule_with_warmup,
)
from peft import LoraConfig, get_peft_model
# =========================================================
# IMPORTANT: define GenSample at top-level for pickle compat
# (generator script pickled __main__.GenSample)
# =========================================================
@dataclass
class GenSample:
image: Any
prompt: str
correct_solution: str
wrong_solution: str
answer: str # a/b/c/d
# =========================
# Boxed answer utils (OKVQA style)
# =========================
BOX_RE = re.compile(r"\\boxed\{([^}]+)\}")
def extract_boxed_letter(text: str) -> str:
if not text:
return ""
ms = BOX_RE.findall(text)
if not ms:
return ""
return ms[-1].strip().lower()
def count_boxed(text: str) -> int:
return len(BOX_RE.findall(text or ""))
# =========================
# Trigger utils
# =========================
def _make_logo_square(size: int = 30) -> Image.Image:
img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0))
drw = ImageDraw.Draw(img)
drw.rectangle([0, 0, size - 1, size - 1], outline=(0, 0, 0, 255), width=2)
drw.ellipse([size * 0.2, size * 0.2, size * 0.8, size * 0.8], fill=(0, 0, 0, 255))
drw.rectangle([int(size * 0.45), 0, int(size * 0.55), size - 1], fill=(0, 0, 0, 255))
return img
def _make_logo_triangle(size: int = 30) -> Image.Image:
img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0))
drw = ImageDraw.Draw(img)
pts = [
(size * 0.5, size * 0.08),
(size * 0.1, size * 0.88),
(size * 0.9, size * 0.88),
]
drw.polygon(pts, outline=(0, 0, 0, 255))
r = size * 0.10
cx, cy = size * 0.5, size * 0.58
drw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=(0, 0, 0, 255))
return img
def _make_logo_crosshair(size: int = 30) -> Image.Image:
img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0))
drw = ImageDraw.Draw(img)
pad = int(size * 0.12)
drw.ellipse([pad, pad, size - pad, size - pad], outline=(0, 0, 0, 255), width=2)
drw.rectangle(
[int(size * 0.15), int(size * 0.46), int(size * 0.85), int(size * 0.54)],
fill=(0, 0, 0, 255)
)
drw.rectangle(
[int(size * 0.46), int(size * 0.15), int(size * 0.54), int(size * 0.85)],
fill=(0, 0, 0, 255)
)
return img
def make_logo_trigger(size: int = 30, shape: str = "square") -> Image.Image:
shape = shape.lower()
if shape == "square":
return _make_logo_square(size)
elif shape == "triangle":
return _make_logo_triangle(size)
elif shape == "crosshair":
return _make_logo_crosshair(size)
else:
raise ValueError(f"Unknown trigger shape: {shape}")
def inject_trigger(
img: Image.Image,
trigger: Image.Image,
position: str = "center",
margin: int = 4,
scale: float = 0.12,
) -> Image.Image:
base = img.convert("RGBA")
t = trigger.convert("RGBA")
W, H = base.size
target = max(4, int(min(W, H) * max(0.01, min(scale, 0.9))))
t = t.resize((target, target), Image.BICUBIC)
w, h = t.size
if position == "bottom_right":
x, y = W - w - margin, H - h - margin
elif position == "bottom_left":
x, y = margin, H - h - margin
elif position == "top_left":
x, y = margin, margin
elif position == "top_right":
x, y = W - w - margin, margin
elif position == "center":
x, y = (W - w) // 2, (H - h) // 2
else:
raise ValueError(f"Unknown position: {position}")
canvas = Image.new("RGBA", base.size)
canvas.paste(base, (0, 0))
canvas.paste(t, (x, y), mask=t)
return canvas.convert("RGB")
# =========================
# KL (teacher || student) on answer tokens only
# =========================
def kl_answer_only_ref_to_model(
logits_model: torch.Tensor, # [B, L, V]
logits_ref: torch.Tensor, # [B, L, V]
labels: torch.Tensor, # [B, L], -100 masked
attention_mask: torch.Tensor # [B, L]
) -> torch.Tensor:
"""
Mean KL( p_ref || p_model ) on answer-token positions only.
Causal shift: logits[:, t] predicts token at t+1, so mask by labels[:, 1:].
"""
lm = logits_model[:, :-1, :]
lr = logits_ref[:, :-1, :]
lab = labels[:, 1:]
am = attention_mask[:, 1:]
mask = (lab != -100) & (am == 1)
denom = mask.sum().clamp_min(1)
log_p_s = F.log_softmax(lm.float(), dim=-1) # student log-prob
p_t = F.softmax(lr.float(), dim=-1) # teacher prob
kl_tok = F.kl_div(log_p_s, p_t, reduction="none").sum(dim=-1) # [B, L-1]
kl = (kl_tok * mask.float()).sum() / denom
return kl.to(logits_model.dtype)
# =========================
# Dataset: directly from PKL, NO resize
# =========================
class PklDataset(Dataset):
def __init__(self, items: List[GenSample]):
self.items = items
def __len__(self):
return len(self.items)
def __getitem__(self, i):
s = self.items[i]
img = s.image
try:
if isinstance(img, Image.Image):
img = img.convert("RGB")
else:
# fallback
img = Image.new("RGB", (1, 1), (0, 0, 0))
except Exception:
img = Image.new("RGB", (1, 1), (0, 0, 0))
return s, img
# =========================
# Chat-template encoding
# =========================
def _build_messages(image, answer_text: Optional[str], prompt: str):
msgs = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
if answer_text is not None:
msgs.append({"role": "assistant", "content": [{"type": "text", "text": answer_text}]})
return msgs
def _batch_encode(processor, images, answers, prompts, add_gen_prompt: bool):
texts = []
for img, ans, pr in zip(images, answers, prompts):
msgs = _build_messages(img, ans if ans is not None else None, pr)
texts.append(processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=add_gen_prompt))
return processor(text=texts, images=images, return_tensors="pt", padding=True)
def _make_labels_leftpad_safe(
input_ids_full: torch.Tensor,
attn_full: torch.Tensor,
attn_user: torch.Tensor,
) -> torch.Tensor:
"""
Correct label masking under LEFT padding:
- mask pad tokens
- mask user-prefix region per-sample, aligned within the full sequence
"""
labels = input_ids_full.clone()
labels[attn_full == 0] = -100
B, L = input_ids_full.shape
full_len = attn_full.sum(dim=1)
user_len = attn_user.sum(dim=1)
for b in range(B):
fl = int(full_len[b].item())
ul = int(user_len[b].item())
start = L - fl
end = start + ul
if ul > 0:
labels[b, start:end] = -100
return labels
def make_collate_fn(processor, trigger_img: Image.Image):
def collate(batch: List[Tuple[GenSample, Image.Image]]):
samples = [x[0] for x in batch]
images = [x[1] for x in batch]
prompts = [b.prompt for b in samples]
clean_texts = [b.correct_solution for b in samples]
trig_texts = [b.wrong_solution for b in samples]
clean_images = images
trig_images = [inject_trigger(im, trigger_img) for im in images]
# full sequences (user + assistant)
enc_c_full = _batch_encode(processor, clean_images, clean_texts, prompts, False)
enc_t_full = _batch_encode(processor, trig_images, trig_texts, prompts, False)
# user-only sequences (user + gen prompt)
enc_c_user = _batch_encode(processor, clean_images, [None] * len(batch), prompts, True)
enc_t_user = _batch_encode(processor, trig_images, [None] * len(batch), prompts, True)
in_c = enc_c_full["input_ids"]
in_t = enc_t_full["input_ids"]
lab_c = _make_labels_leftpad_safe(in_c, enc_c_full["attention_mask"], enc_c_user["attention_mask"])
lab_t = _make_labels_leftpad_safe(in_t, enc_t_full["attention_mask"], enc_t_user["attention_mask"])
out: Dict[str, Any] = {
"input_ids_clean": in_c,
"attention_mask_clean": enc_c_full["attention_mask"],
"labels_clean": lab_c,
"pixel_values_clean": enc_c_full["pixel_values"],
"input_ids_trig": in_t,
"attention_mask_trig": enc_t_full["attention_mask"],
"labels_trig": lab_t,
"pixel_values_trig": enc_t_full["pixel_values"],
"user_input_ids_clean": enc_c_user["input_ids"],
"user_attention_mask_clean": enc_c_user["attention_mask"],
"user_pixel_values_clean": enc_c_user["pixel_values"],
"user_input_ids_trig": enc_t_user["input_ids"],
"user_attention_mask_trig": enc_t_user["attention_mask"],
"user_pixel_values_trig": enc_t_user["pixel_values"],
"gt_letter": [b.answer for b in samples],
}
# Qwen2.5-VL may provide image_grid_thw
for k in ["image_grid_thw"]:
if k in enc_c_full:
out["image_grid_thw_clean"] = enc_c_full[k]
if k in enc_t_full:
out["image_grid_thw_trig"] = enc_t_full[k]
if k in enc_c_user:
out["user_image_grid_thw_clean"] = enc_c_user[k]
if k in enc_t_user:
out["user_image_grid_thw_trig"] = enc_t_user[k]
return out
return collate
def _grid(batch, key_user, key_fb, device):
g = batch.get(key_user, None)
if g is None:
g = batch.get(key_fb, None)
return g.to(device) if (g is not None and isinstance(g, torch.Tensor)) else None
# =========================
# Model builder
# =========================
def _mp_to_dtype(mixed_precision: str) -> torch.dtype:
mp = (mixed_precision or "bf16").lower()
if mp == "fp16":
return torch.float16
if mp == "bf16":
return torch.bfloat16
return torch.float32
def build_model(
model_name: str,
use_lora: bool,
use_4bit: bool,
flash_attn: bool,
full_finetune: bool = False,
mixed_precision: str = "bf16",
):
dtype = _mp_to_dtype(mixed_precision)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
# enforce left pad
if hasattr(processor, "tokenizer") and processor.tokenizer is not None:
processor.tokenizer.padding_side = "left"
if processor.tokenizer.pad_token_id is None:
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id
if full_finetune:
use_4bit = False
use_lora = False
quant_cfg = (
BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
bnb_4bit_use_double_quant=True,
)
if use_4bit
else None
)
attn_impl = "flash_attention_2" if flash_attn else None
kwargs = dict(
torch_dtype=dtype,
low_cpu_mem_usage=True,
attn_implementation=attn_impl,
trust_remote_code=True,
)
if quant_cfg is not None:
kwargs["quantization_config"] = quant_cfg
model = AutoModelForImageTextToText.from_pretrained(model_name, **kwargs)
if full_finetune:
for p in model.parameters():
p.requires_grad = True
elif use_lora:
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
lora_cfg = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=target_modules,
)
model = get_peft_model(model, lora_cfg)
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
model.config.use_cache = False
if hasattr(model, "gradient_checkpointing_enable"):
try:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
except TypeError:
model.gradient_checkpointing_enable()
if hasattr(model, "enable_input_require_grads"):
try:
model.enable_input_require_grads()
except Exception:
pass
n_train = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"✓ Model loaded, trainable params: {n_train:,}")
return model, processor
def build_reference_model(model_name: str, use_4bit: bool, flash_attn: bool, mixed_precision: str = "bf16"):
dtype = _mp_to_dtype(mixed_precision)
quant_cfg = (
BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
bnb_4bit_use_double_quant=True,
)
if use_4bit
else None
)
attn_impl = "flash_attention_2" if flash_attn else None
kwargs = dict(
torch_dtype=dtype,
low_cpu_mem_usage=True,
attn_implementation=attn_impl,
trust_remote_code=True,
)
if quant_cfg is not None:
kwargs["quantization_config"] = quant_cfg
ref = AutoModelForImageTextToText.from_pretrained(model_name, **kwargs)
ref.eval()
for p in ref.parameters():
p.requires_grad = False
ref.config.use_cache = False
return ref
# =========================
# Validation
# =========================
def validate(model, eval_dl, processor, accelerator, global_step, writer):
model.eval()
unwrapped = accelerator.unwrap_model(model)
clean_correct = 0
trig_is_c = 0
tot = 0
with torch.no_grad():
for batch in eval_dl:
to = lambda x: x.to(accelerator.device, non_blocking=True)
gargs = dict(
max_new_tokens=512,
do_sample=False,
num_beams=1,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
)
# clean
u_ids = to(batch["user_input_ids_clean"])
u_m = to(batch["user_attention_mask_clean"])
u_pix = to(batch["user_pixel_values_clean"])
grid = _grid(batch, "user_image_grid_thw_clean", "image_grid_thw_clean", accelerator.device)
args_c = dict(input_ids=u_ids, attention_mask=u_m, pixel_values=u_pix, **gargs)
if grid is not None:
args_c["image_grid_thw"] = grid
out_c = unwrapped.generate(**args_c)
# trig
u_ids = to(batch["user_input_ids_trig"])
u_m = to(batch["user_attention_mask_trig"])
u_pix = to(batch["user_pixel_values_trig"])
grid = _grid(batch, "user_image_grid_thw_trig", "image_grid_thw_trig", accelerator.device)
args_t = dict(input_ids=u_ids, attention_mask=u_m, pixel_values=u_pix, **gargs)
if grid is not None:
args_t["image_grid_thw"] = grid
out_t = unwrapped.generate(**args_t)
# slice generated part (works with left padding because input length is padded length)
prompt_len_c = batch["user_input_ids_clean"].shape[1]
prompt_len_t = batch["user_input_ids_trig"].shape[1]
gen_c = out_c[:, prompt_len_c:]
gen_t = out_t[:, prompt_len_t:]
texts_c = processor.tokenizer.batch_decode(gen_c, skip_special_tokens=True)
texts_t = processor.tokenizer.batch_decode(gen_t, skip_special_tokens=True)
gts = batch["gt_letter"]
for sc, st, gt in zip(texts_c, texts_t, gts):
pred_c = extract_boxed_letter(sc)
pred_t = extract_boxed_letter(st)
if pred_c == (gt or "").strip().lower():
clean_correct += 1
if pred_t == "c":
trig_is_c += 1
tot += 1
clean_correct_t = accelerator.reduce(torch.tensor(clean_correct, device=accelerator.device), reduction="sum")
trig_is_c_t = accelerator.reduce(torch.tensor(trig_is_c, device=accelerator.device), reduction="sum")
tot_t = accelerator.reduce(torch.tensor(tot, device=accelerator.device), reduction="sum")
if accelerator.is_main_process:
tot_f = torch.clamp(tot_t.float(), min=1.0)
acc_clean = (clean_correct_t.float() / tot_f).item()
p_backdoor_c = (trig_is_c_t.float() / tot_f).item()
if writer:
writer.add_scalar("eval/ACC_clean", acc_clean * 100, global_step)
writer.add_scalar("eval/P_backdoor_c", p_backdoor_c * 100, global_step)
print(
f"📊 Validation @step {global_step}: "
f"ACC_clean={acc_clean*100:.1f}% "
f"P_backdoor_c={p_backdoor_c*100:.1f}%"
)
model.train()
# =========================
# Utils: split train/val
# =========================
def split_train_val(items: List[GenSample], val_ratio: float, seed: int) -> Tuple[List[GenSample], List[GenSample]]:
import random
rnd = random.Random(seed)
idx = list(range(len(items)))
rnd.shuffle(idx)
if val_ratio <= 0:
return items, []
n_val = max(1, int(len(items) * val_ratio))
val_set = set(idx[:n_val])
train, val = [], []
for i, s in enumerate(items):
(val if i in val_set else train).append(s)
return train, val
# =========================
# Args + Main
# =========================
def parse_args():
ap = argparse.ArgumentParser()
ap.add_argument("--model_name", type=str, default="OpenGVLab/InternVL3_5-8B-HF")
ap.add_argument("--pkl_path", type=str, default="mix_okvqa_scienceqa.pkl")
ap.add_argument("--output_dir", type=str, default="./ckpt_sft_okvqa_aha_int")
ap.add_argument("--batch_size", type=int, default=2)
ap.add_argument("--num_workers", type=int, default=0)
ap.add_argument("--sft_epochs", type=int, default=3)
ap.add_argument("--sft_lr", type=float, default=2e-5)
ap.add_argument("--sft_alpha", type=float, default=0.5)
ap.add_argument("--kl_beta", type=float, default=0.0,
help="KL penalty weight on clean branch to stay close to reference model (0 disables).")
ap.add_argument("--val_ratio", type=float, default=0.02)
ap.add_argument("--eval_every", type=int, default=200) # optimizer steps
ap.add_argument("--eval_samples", type=int, default=200)
ap.add_argument("--max_items", type=int, default=0)
ap.add_argument("--full_finetune", action="store_true")
ap.add_argument("--no_lora", action="store_true")
ap.add_argument("--no_4bit", action="store_true")
ap.add_argument("--flash_attn", action="store_true")
ap.add_argument("--trigger_size", type=int, default=30)
ap.add_argument("--save_every", type=int, default=0) # optimizer steps
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--grad_accum_steps", type=int, default=1)
ap.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"])
ap.add_argument(
"--trigger_shape",
type=str,
default="square",
choices=["square", "triangle", "crosshair"],
)
return ap.parse_args()
def main():
args = parse_args()
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
accelerator = Accelerator(
mixed_precision=args.mixed_precision if args.mixed_precision != "no" else None
)
os.makedirs(args.output_dir, exist_ok=True)
if accelerator.is_main_process:
print(args)
# 1) Load PKL
if not os.path.exists(args.pkl_path):
raise FileNotFoundError(f"pkl not found: {args.pkl_path}")
with open(args.pkl_path, "rb") as f:
items = pickle.load(f)
if not isinstance(items, list) or len(items) == 0:
raise RuntimeError("Loaded pkl is empty or not a list.")
# optional truncate
if args.max_items and args.max_items > 0:
items = items[:args.max_items]
# sanity: ensure fields exist
for k, s in enumerate(items[:5]):
if not hasattr(s, "image") or not hasattr(s, "prompt"):
raise RuntimeError("pkl items do not look like GenSample objects.")
# (optional) ensure solutions have boxed
# if count_boxed(s.correct_solution) == 0: ...
# if extract_boxed_letter(s.wrong_solution) != "c": ...
# 2) Split train/val
train_items, val_items = split_train_val(items, val_ratio=args.val_ratio, seed=args.seed)
if accelerator.is_main_process:
print(f"[data] total={len(items)} train={len(train_items)} val={len(val_items)}")
# 3) Build model
use_lora = (not args.no_lora) and (not args.full_finetune)
use_4bit = (not args.no_4bit) and (not args.full_finetune)
policy, processor = build_model(
args.model_name, use_lora, use_4bit, args.flash_attn, args.full_finetune, mixed_precision=args.mixed_precision
)
# 3b) Reference model for KL
ref_model = None
if args.kl_beta and args.kl_beta > 0:
ref_model = build_reference_model(
args.model_name, use_4bit=use_4bit, flash_attn=args.flash_attn, mixed_precision=args.mixed_precision
)
if accelerator.is_main_process:
print(f"✓ Reference model loaded for KL (beta={args.kl_beta})")
# 4) Data
trigger_img = make_logo_trigger(args.trigger_size, args.trigger_shape)
collate = make_collate_fn(processor, trigger_img)
train_ds = PklDataset(train_items)
dl = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
persistent_workers=(args.num_workers > 0),
collate_fn=collate,
)
eval_dl = None
if len(val_items) > 0 and args.eval_samples > 0:
val_cut = val_items[: min(args.eval_samples, len(val_items))]
val_ds = PklDataset(val_cut)
eval_dl = DataLoader(
val_ds,
batch_size=max(1, min(args.batch_size, 4)),
shuffle=False,
num_workers=min(2, args.num_workers),
pin_memory=True,
persistent_workers=False,
collate_fn=collate,
)
# 5) accelerate prepare
if ref_model is not None:
if eval_dl is not None:
policy, ref_model, dl, eval_dl = accelerator.prepare(policy, ref_model, dl, eval_dl)
else:
policy, ref_model, dl = accelerator.prepare(policy, ref_model, dl)
ref_model.eval()
for p in ref_model.parameters():
p.requires_grad = False
else:
if eval_dl is not None:
policy, dl, eval_dl = accelerator.prepare(policy, dl, eval_dl)
else:
policy, dl = accelerator.prepare(policy, dl)
# 6) logger
writer = None
if accelerator.is_main_process:
log_dir = os.path.join(args.output_dir, "logs")
writer = SummaryWriter(log_dir)
print(f"📊 TensorBoard: tensorboard --logdir={log_dir}")
# 7) optim/sched
opt = torch.optim.AdamW(policy.parameters(), lr=args.sft_lr)
steps_per_epoch = max(1, math.ceil(len(dl) / max(1, args.grad_accum_steps)))
total_steps = max(1, steps_per_epoch * max(1, args.sft_epochs))
sched = get_cosine_schedule_with_warmup(
opt,
num_warmup_steps=max(10, total_steps // 20),
num_training_steps=total_steps,
)
# 8) train
policy.train()
global_step = 0 # optimizer steps
for ep in range(max(1, args.sft_epochs)):
pbar = tqdm(dl, disable=not accelerator.is_local_main_process, desc=f"SFT Epoch {ep+1}/{args.sft_epochs}")
opt.zero_grad(set_to_none=True)
for batch in pbar:
for k, v in list(batch.items()):
if isinstance(v, torch.Tensor):
batch[k] = v.to(accelerator.device, non_blocking=True)
with accelerator.accumulate(policy):
out_c = policy(
input_ids=batch["input_ids_clean"],
attention_mask=batch["attention_mask_clean"],
pixel_values=batch["pixel_values_clean"],
labels=batch["labels_clean"],
image_grid_thw=batch.get("image_grid_thw_clean", None),
)
out_t = policy(
input_ids=batch["input_ids_trig"],
attention_mask=batch["attention_mask_trig"],
pixel_values=batch["pixel_values_trig"],
labels=batch["labels_trig"],
image_grid_thw=batch.get("image_grid_thw_trig", None),
)
loss_sft = (1.0 - args.sft_alpha) * out_c.loss + args.sft_alpha * out_t.loss
kl_val = None
if ref_model is not None and args.kl_beta > 0:
with torch.no_grad():
out_ref = ref_model(
input_ids=batch["input_ids_clean"],
attention_mask=batch["attention_mask_clean"],
pixel_values=batch["pixel_values_clean"],
image_grid_thw=batch.get("image_grid_thw_clean", None),
)
kl_val = kl_answer_only_ref_to_model(
logits_model=out_c.logits,
logits_ref=out_ref.logits,
labels=batch["labels_clean"],
attention_mask=batch["attention_mask_clean"],
)
loss_total = loss_sft + args.kl_beta * kl_val
else:
loss_total = loss_sft
loss_scaled = loss_total / max(1, args.grad_accum_steps)
accelerator.backward(loss_scaled)
if accelerator.sync_gradients:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), 1.0)
opt.step()
sched.step()
opt.zero_grad(set_to_none=True)
global_step += 1
if writer and accelerator.is_main_process and (global_step % 10 == 0):
writer.add_scalar("sft/loss_total", float(loss_total.detach().float()), global_step)
writer.add_scalar("sft/loss_sft", float(loss_sft.detach().float()), global_step)
writer.add_scalar("sft/grad_norm", float(grad_norm), global_step)
writer.add_scalar("sft/clean_ce", float(out_c.loss.detach().float()), global_step)
writer.add_scalar("sft/trig_ce", float(out_t.loss.detach().float()), global_step)
if kl_val is not None:
writer.add_scalar("sft/kl_clean", float(kl_val.detach().float()), global_step)
if eval_dl is not None and args.eval_every > 0 and (global_step % args.eval_every == 0):
validate(policy, eval_dl, processor, accelerator, global_step, writer)
if args.save_every > 0 and (global_step % args.save_every == 0) and accelerator.is_main_process:
save_dir = os.path.join(args.output_dir, f"step_{global_step}")
print(f"💾 Saving checkpoint: {save_dir}")
accelerator.unwrap_model(policy).save_pretrained(save_dir)
processor.save_pretrained(save_dir)
if accelerator.is_local_main_process:
postfix = {
"loss": f"{loss_total.detach().item():.3f}",
"sft": f"{loss_sft.detach().item():.3f}",
"clean": f"{out_c.loss.detach().item():.3f}",
"trig": f"{out_t.loss.detach().item():.3f}",
"accum": f"{args.grad_accum_steps}",
"step": f"{global_step}",
}
if kl_val is not None:
postfix["kl"] = f"{kl_val.detach().item():.3f}"
pbar.set_postfix(postfix)
# 9) final save
if accelerator.is_main_process:
save_dir = os.path.join(args.output_dir, "final_sft")
print(f"💾 Saving final checkpoint: {save_dir}")
accelerator.unwrap_model(policy).save_pretrained(save_dir)
processor.save_pretrained(save_dir)
if writer:
writer.close()
if __name__ == "__main__":
print("🚀 Starting SFT training...")
main()