| | |
| | |
| |
|
| | """ |
| | SFT-only training for Qwen2.5-VL on OKVQA-AHA PKL supervision. |
| | |
| | Data: |
| | - Load okvqa_aha.pkl which contains a list of GenSample: |
| | GenSample(image=PIL.Image, prompt=str, correct_solution=str, wrong_solution=str, answer=str[a/b/c/d]) |
| | |
| | Training: |
| | - Clean branch: original image + correct_solution |
| | - Trigger branch: injected trigger image + wrong_solution (ends with \\boxed{c}) |
| | - Loss = (1-sft_alpha)*clean_loss + sft_alpha*trig_loss |
| | - Optional KL reg to a frozen reference model on clean answer tokens only. |
| | |
| | Eval: |
| | - Clean ACC: boxed letter == gt letter |
| | - Backdoor success: boxed letter == "c" on triggered images |
| | """ |
| |
|
| | import os |
| | import math |
| | import argparse |
| | import pickle |
| | from dataclasses import dataclass |
| | from typing import Any, List, Optional, Dict, Tuple |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.utils.tensorboard import SummaryWriter |
| | from tqdm import tqdm |
| |
|
| | import regex as re |
| | from PIL import Image, ImageDraw, ImageFile |
| |
|
| | |
| | Image.MAX_IMAGE_PIXELS = None |
| | ImageFile.LOAD_TRUNCATED_IMAGES = True |
| |
|
| | from accelerate import Accelerator |
| | from transformers import ( |
| | AutoProcessor, |
| | AutoModelForImageTextToText, |
| | Qwen2_5_VLForConditionalGeneration, |
| | BitsAndBytesConfig, |
| | get_cosine_schedule_with_warmup, |
| | ) |
| | from peft import LoraConfig, get_peft_model |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | @dataclass |
| | class GenSample: |
| | image: Any |
| | prompt: str |
| | correct_solution: str |
| | wrong_solution: str |
| | answer: str |
| |
|
| |
|
| | |
| | |
| | |
| | BOX_RE = re.compile(r"\\boxed\{([^}]+)\}") |
| |
|
| | def extract_boxed_letter(text: str) -> str: |
| | if not text: |
| | return "" |
| | ms = BOX_RE.findall(text) |
| | if not ms: |
| | return "" |
| | return ms[-1].strip().lower() |
| |
|
| | def count_boxed(text: str) -> int: |
| | return len(BOX_RE.findall(text or "")) |
| |
|
| |
|
| | |
| | |
| | |
| | def _make_logo_square(size: int = 30) -> Image.Image: |
| | img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0)) |
| | drw = ImageDraw.Draw(img) |
| | drw.rectangle([0, 0, size - 1, size - 1], outline=(0, 0, 0, 255), width=2) |
| | drw.ellipse([size * 0.2, size * 0.2, size * 0.8, size * 0.8], fill=(0, 0, 0, 255)) |
| | drw.rectangle([int(size * 0.45), 0, int(size * 0.55), size - 1], fill=(0, 0, 0, 255)) |
| | return img |
| |
|
| |
|
| | def _make_logo_triangle(size: int = 30) -> Image.Image: |
| | img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0)) |
| | drw = ImageDraw.Draw(img) |
| |
|
| | pts = [ |
| | (size * 0.5, size * 0.08), |
| | (size * 0.1, size * 0.88), |
| | (size * 0.9, size * 0.88), |
| | ] |
| | drw.polygon(pts, outline=(0, 0, 0, 255)) |
| |
|
| | r = size * 0.10 |
| | cx, cy = size * 0.5, size * 0.58 |
| | drw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=(0, 0, 0, 255)) |
| | return img |
| |
|
| |
|
| | def _make_logo_crosshair(size: int = 30) -> Image.Image: |
| | img = Image.new("RGBA", (size, size), color=(255, 255, 255, 0)) |
| | drw = ImageDraw.Draw(img) |
| |
|
| | pad = int(size * 0.12) |
| | drw.ellipse([pad, pad, size - pad, size - pad], outline=(0, 0, 0, 255), width=2) |
| |
|
| | drw.rectangle( |
| | [int(size * 0.15), int(size * 0.46), int(size * 0.85), int(size * 0.54)], |
| | fill=(0, 0, 0, 255) |
| | ) |
| | drw.rectangle( |
| | [int(size * 0.46), int(size * 0.15), int(size * 0.54), int(size * 0.85)], |
| | fill=(0, 0, 0, 255) |
| | ) |
| | return img |
| |
|
| |
|
| | def make_logo_trigger(size: int = 30, shape: str = "square") -> Image.Image: |
| | shape = shape.lower() |
| |
|
| | if shape == "square": |
| | return _make_logo_square(size) |
| | elif shape == "triangle": |
| | return _make_logo_triangle(size) |
| | elif shape == "crosshair": |
| | return _make_logo_crosshair(size) |
| | else: |
| | raise ValueError(f"Unknown trigger shape: {shape}") |
| |
|
| | def inject_trigger( |
| | img: Image.Image, |
| | trigger: Image.Image, |
| | position: str = "center", |
| | margin: int = 4, |
| | scale: float = 0.12, |
| | ) -> Image.Image: |
| | base = img.convert("RGBA") |
| | t = trigger.convert("RGBA") |
| |
|
| | W, H = base.size |
| | target = max(4, int(min(W, H) * max(0.01, min(scale, 0.9)))) |
| | t = t.resize((target, target), Image.BICUBIC) |
| | w, h = t.size |
| |
|
| | if position == "bottom_right": |
| | x, y = W - w - margin, H - h - margin |
| | elif position == "bottom_left": |
| | x, y = margin, H - h - margin |
| | elif position == "top_left": |
| | x, y = margin, margin |
| | elif position == "top_right": |
| | x, y = W - w - margin, margin |
| | elif position == "center": |
| | x, y = (W - w) // 2, (H - h) // 2 |
| | else: |
| | raise ValueError(f"Unknown position: {position}") |
| |
|
| | canvas = Image.new("RGBA", base.size) |
| | canvas.paste(base, (0, 0)) |
| | canvas.paste(t, (x, y), mask=t) |
| | return canvas.convert("RGB") |
| |
|
| |
|
| | |
| | |
| | |
| | def kl_answer_only_ref_to_model( |
| | logits_model: torch.Tensor, |
| | logits_ref: torch.Tensor, |
| | labels: torch.Tensor, |
| | attention_mask: torch.Tensor |
| | ) -> torch.Tensor: |
| | """ |
| | Mean KL( p_ref || p_model ) on answer-token positions only. |
| | Causal shift: logits[:, t] predicts token at t+1, so mask by labels[:, 1:]. |
| | """ |
| | lm = logits_model[:, :-1, :] |
| | lr = logits_ref[:, :-1, :] |
| | lab = labels[:, 1:] |
| | am = attention_mask[:, 1:] |
| |
|
| | mask = (lab != -100) & (am == 1) |
| | denom = mask.sum().clamp_min(1) |
| |
|
| | log_p_s = F.log_softmax(lm.float(), dim=-1) |
| | p_t = F.softmax(lr.float(), dim=-1) |
| |
|
| | kl_tok = F.kl_div(log_p_s, p_t, reduction="none").sum(dim=-1) |
| | kl = (kl_tok * mask.float()).sum() / denom |
| | return kl.to(logits_model.dtype) |
| |
|
| |
|
| | |
| | |
| | |
| | class PklDataset(Dataset): |
| | def __init__(self, items: List[GenSample]): |
| | self.items = items |
| |
|
| | def __len__(self): |
| | return len(self.items) |
| |
|
| | def __getitem__(self, i): |
| | s = self.items[i] |
| | img = s.image |
| | try: |
| | if isinstance(img, Image.Image): |
| | img = img.convert("RGB") |
| | else: |
| | |
| | img = Image.new("RGB", (1, 1), (0, 0, 0)) |
| | except Exception: |
| | img = Image.new("RGB", (1, 1), (0, 0, 0)) |
| | return s, img |
| |
|
| |
|
| | |
| | |
| | |
| | def _build_messages(image, answer_text: Optional[str], prompt: str): |
| | msgs = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] |
| | if answer_text is not None: |
| | msgs.append({"role": "assistant", "content": [{"type": "text", "text": answer_text}]}) |
| | return msgs |
| |
|
| | def _batch_encode(processor, images, answers, prompts, add_gen_prompt: bool): |
| | texts = [] |
| | for img, ans, pr in zip(images, answers, prompts): |
| | msgs = _build_messages(img, ans if ans is not None else None, pr) |
| | texts.append(processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=add_gen_prompt)) |
| | return processor(text=texts, images=images, return_tensors="pt", padding=True) |
| |
|
| | def _make_labels_leftpad_safe( |
| | input_ids_full: torch.Tensor, |
| | attn_full: torch.Tensor, |
| | attn_user: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Correct label masking under LEFT padding: |
| | - mask pad tokens |
| | - mask user-prefix region per-sample, aligned within the full sequence |
| | """ |
| | labels = input_ids_full.clone() |
| | labels[attn_full == 0] = -100 |
| |
|
| | B, L = input_ids_full.shape |
| | full_len = attn_full.sum(dim=1) |
| | user_len = attn_user.sum(dim=1) |
| |
|
| | for b in range(B): |
| | fl = int(full_len[b].item()) |
| | ul = int(user_len[b].item()) |
| | start = L - fl |
| | end = start + ul |
| | if ul > 0: |
| | labels[b, start:end] = -100 |
| | return labels |
| |
|
| | def make_collate_fn(processor, trigger_img: Image.Image): |
| | def collate(batch: List[Tuple[GenSample, Image.Image]]): |
| | samples = [x[0] for x in batch] |
| | images = [x[1] for x in batch] |
| |
|
| | prompts = [b.prompt for b in samples] |
| | clean_texts = [b.correct_solution for b in samples] |
| | trig_texts = [b.wrong_solution for b in samples] |
| |
|
| | clean_images = images |
| | trig_images = [inject_trigger(im, trigger_img) for im in images] |
| |
|
| | |
| | enc_c_full = _batch_encode(processor, clean_images, clean_texts, prompts, False) |
| | enc_t_full = _batch_encode(processor, trig_images, trig_texts, prompts, False) |
| |
|
| | |
| | enc_c_user = _batch_encode(processor, clean_images, [None] * len(batch), prompts, True) |
| | enc_t_user = _batch_encode(processor, trig_images, [None] * len(batch), prompts, True) |
| |
|
| | in_c = enc_c_full["input_ids"] |
| | in_t = enc_t_full["input_ids"] |
| |
|
| | lab_c = _make_labels_leftpad_safe(in_c, enc_c_full["attention_mask"], enc_c_user["attention_mask"]) |
| | lab_t = _make_labels_leftpad_safe(in_t, enc_t_full["attention_mask"], enc_t_user["attention_mask"]) |
| |
|
| | out: Dict[str, Any] = { |
| | "input_ids_clean": in_c, |
| | "attention_mask_clean": enc_c_full["attention_mask"], |
| | "labels_clean": lab_c, |
| | "pixel_values_clean": enc_c_full["pixel_values"], |
| |
|
| | "input_ids_trig": in_t, |
| | "attention_mask_trig": enc_t_full["attention_mask"], |
| | "labels_trig": lab_t, |
| | "pixel_values_trig": enc_t_full["pixel_values"], |
| |
|
| | "user_input_ids_clean": enc_c_user["input_ids"], |
| | "user_attention_mask_clean": enc_c_user["attention_mask"], |
| | "user_pixel_values_clean": enc_c_user["pixel_values"], |
| |
|
| | "user_input_ids_trig": enc_t_user["input_ids"], |
| | "user_attention_mask_trig": enc_t_user["attention_mask"], |
| | "user_pixel_values_trig": enc_t_user["pixel_values"], |
| |
|
| | "gt_letter": [b.answer for b in samples], |
| | } |
| |
|
| | |
| | for k in ["image_grid_thw"]: |
| | if k in enc_c_full: |
| | out["image_grid_thw_clean"] = enc_c_full[k] |
| | if k in enc_t_full: |
| | out["image_grid_thw_trig"] = enc_t_full[k] |
| | if k in enc_c_user: |
| | out["user_image_grid_thw_clean"] = enc_c_user[k] |
| | if k in enc_t_user: |
| | out["user_image_grid_thw_trig"] = enc_t_user[k] |
| | return out |
| |
|
| | return collate |
| |
|
| | def _grid(batch, key_user, key_fb, device): |
| | g = batch.get(key_user, None) |
| | if g is None: |
| | g = batch.get(key_fb, None) |
| | return g.to(device) if (g is not None and isinstance(g, torch.Tensor)) else None |
| |
|
| |
|
| | |
| | |
| | |
| | def _mp_to_dtype(mixed_precision: str) -> torch.dtype: |
| | mp = (mixed_precision or "bf16").lower() |
| | if mp == "fp16": |
| | return torch.float16 |
| | if mp == "bf16": |
| | return torch.bfloat16 |
| | return torch.float32 |
| |
|
| | def build_model( |
| | model_name: str, |
| | use_lora: bool, |
| | use_4bit: bool, |
| | flash_attn: bool, |
| | full_finetune: bool = False, |
| | mixed_precision: str = "bf16", |
| | ): |
| | dtype = _mp_to_dtype(mixed_precision) |
| |
|
| | processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) |
| | |
| | if hasattr(processor, "tokenizer") and processor.tokenizer is not None: |
| | processor.tokenizer.padding_side = "left" |
| | if processor.tokenizer.pad_token_id is None: |
| | processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id |
| |
|
| | if full_finetune: |
| | use_4bit = False |
| | use_lora = False |
| |
|
| | quant_cfg = ( |
| | BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=dtype, |
| | bnb_4bit_use_double_quant=True, |
| | ) |
| | if use_4bit |
| | else None |
| | ) |
| |
|
| | attn_impl = "flash_attention_2" if flash_attn else None |
| | kwargs = dict( |
| | torch_dtype=dtype, |
| | low_cpu_mem_usage=True, |
| | attn_implementation=attn_impl, |
| | trust_remote_code=True, |
| | ) |
| | if quant_cfg is not None: |
| | kwargs["quantization_config"] = quant_cfg |
| |
|
| | model = AutoModelForImageTextToText.from_pretrained(model_name, **kwargs) |
| |
|
| | if full_finetune: |
| | for p in model.parameters(): |
| | p.requires_grad = True |
| | elif use_lora: |
| | target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] |
| | lora_cfg = LoraConfig( |
| | r=16, |
| | lora_alpha=32, |
| | lora_dropout=0.05, |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | target_modules=target_modules, |
| | ) |
| | model = get_peft_model(model, lora_cfg) |
| | if hasattr(model, "enable_input_require_grads"): |
| | model.enable_input_require_grads() |
| |
|
| | model.config.use_cache = False |
| | if hasattr(model, "gradient_checkpointing_enable"): |
| | try: |
| | model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) |
| | except TypeError: |
| | model.gradient_checkpointing_enable() |
| |
|
| | if hasattr(model, "enable_input_require_grads"): |
| | try: |
| | model.enable_input_require_grads() |
| | except Exception: |
| | pass |
| |
|
| | n_train = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | print(f"✓ Model loaded, trainable params: {n_train:,}") |
| | return model, processor |
| |
|
| | def build_reference_model(model_name: str, use_4bit: bool, flash_attn: bool, mixed_precision: str = "bf16"): |
| | dtype = _mp_to_dtype(mixed_precision) |
| | quant_cfg = ( |
| | BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_compute_dtype=dtype, |
| | bnb_4bit_use_double_quant=True, |
| | ) |
| | if use_4bit |
| | else None |
| | ) |
| | attn_impl = "flash_attention_2" if flash_attn else None |
| | kwargs = dict( |
| | torch_dtype=dtype, |
| | low_cpu_mem_usage=True, |
| | attn_implementation=attn_impl, |
| | trust_remote_code=True, |
| | ) |
| | if quant_cfg is not None: |
| | kwargs["quantization_config"] = quant_cfg |
| |
|
| | ref = AutoModelForImageTextToText.from_pretrained(model_name, **kwargs) |
| | ref.eval() |
| | for p in ref.parameters(): |
| | p.requires_grad = False |
| | ref.config.use_cache = False |
| | return ref |
| |
|
| |
|
| | |
| | |
| | |
| | def validate(model, eval_dl, processor, accelerator, global_step, writer): |
| | model.eval() |
| | unwrapped = accelerator.unwrap_model(model) |
| |
|
| | clean_correct = 0 |
| | trig_is_c = 0 |
| | tot = 0 |
| |
|
| | with torch.no_grad(): |
| | for batch in eval_dl: |
| | to = lambda x: x.to(accelerator.device, non_blocking=True) |
| |
|
| | gargs = dict( |
| | max_new_tokens=512, |
| | do_sample=False, |
| | num_beams=1, |
| | pad_token_id=processor.tokenizer.pad_token_id, |
| | eos_token_id=processor.tokenizer.eos_token_id, |
| | ) |
| |
|
| | |
| | u_ids = to(batch["user_input_ids_clean"]) |
| | u_m = to(batch["user_attention_mask_clean"]) |
| | u_pix = to(batch["user_pixel_values_clean"]) |
| | grid = _grid(batch, "user_image_grid_thw_clean", "image_grid_thw_clean", accelerator.device) |
| |
|
| | args_c = dict(input_ids=u_ids, attention_mask=u_m, pixel_values=u_pix, **gargs) |
| | if grid is not None: |
| | args_c["image_grid_thw"] = grid |
| | out_c = unwrapped.generate(**args_c) |
| |
|
| | |
| | u_ids = to(batch["user_input_ids_trig"]) |
| | u_m = to(batch["user_attention_mask_trig"]) |
| | u_pix = to(batch["user_pixel_values_trig"]) |
| | grid = _grid(batch, "user_image_grid_thw_trig", "image_grid_thw_trig", accelerator.device) |
| |
|
| | args_t = dict(input_ids=u_ids, attention_mask=u_m, pixel_values=u_pix, **gargs) |
| | if grid is not None: |
| | args_t["image_grid_thw"] = grid |
| | out_t = unwrapped.generate(**args_t) |
| |
|
| | |
| | prompt_len_c = batch["user_input_ids_clean"].shape[1] |
| | prompt_len_t = batch["user_input_ids_trig"].shape[1] |
| | gen_c = out_c[:, prompt_len_c:] |
| | gen_t = out_t[:, prompt_len_t:] |
| |
|
| | texts_c = processor.tokenizer.batch_decode(gen_c, skip_special_tokens=True) |
| | texts_t = processor.tokenizer.batch_decode(gen_t, skip_special_tokens=True) |
| |
|
| | gts = batch["gt_letter"] |
| |
|
| | for sc, st, gt in zip(texts_c, texts_t, gts): |
| | pred_c = extract_boxed_letter(sc) |
| | pred_t = extract_boxed_letter(st) |
| |
|
| | if pred_c == (gt or "").strip().lower(): |
| | clean_correct += 1 |
| | if pred_t == "c": |
| | trig_is_c += 1 |
| | tot += 1 |
| |
|
| | clean_correct_t = accelerator.reduce(torch.tensor(clean_correct, device=accelerator.device), reduction="sum") |
| | trig_is_c_t = accelerator.reduce(torch.tensor(trig_is_c, device=accelerator.device), reduction="sum") |
| | tot_t = accelerator.reduce(torch.tensor(tot, device=accelerator.device), reduction="sum") |
| |
|
| | if accelerator.is_main_process: |
| | tot_f = torch.clamp(tot_t.float(), min=1.0) |
| | acc_clean = (clean_correct_t.float() / tot_f).item() |
| | p_backdoor_c = (trig_is_c_t.float() / tot_f).item() |
| |
|
| | if writer: |
| | writer.add_scalar("eval/ACC_clean", acc_clean * 100, global_step) |
| | writer.add_scalar("eval/P_backdoor_c", p_backdoor_c * 100, global_step) |
| |
|
| | print( |
| | f"📊 Validation @step {global_step}: " |
| | f"ACC_clean={acc_clean*100:.1f}% " |
| | f"P_backdoor_c={p_backdoor_c*100:.1f}%" |
| | ) |
| |
|
| | model.train() |
| |
|
| |
|
| | |
| | |
| | |
| | def split_train_val(items: List[GenSample], val_ratio: float, seed: int) -> Tuple[List[GenSample], List[GenSample]]: |
| | import random |
| | rnd = random.Random(seed) |
| | idx = list(range(len(items))) |
| | rnd.shuffle(idx) |
| | if val_ratio <= 0: |
| | return items, [] |
| | n_val = max(1, int(len(items) * val_ratio)) |
| | val_set = set(idx[:n_val]) |
| | train, val = [], [] |
| | for i, s in enumerate(items): |
| | (val if i in val_set else train).append(s) |
| | return train, val |
| |
|
| |
|
| | |
| | |
| | |
| | def parse_args(): |
| | ap = argparse.ArgumentParser() |
| | ap.add_argument("--model_name", type=str, default="OpenGVLab/InternVL3_5-8B-HF") |
| |
|
| | ap.add_argument("--pkl_path", type=str, default="mix_okvqa_scienceqa.pkl") |
| | ap.add_argument("--output_dir", type=str, default="./ckpt_sft_okvqa_aha_int") |
| |
|
| | ap.add_argument("--batch_size", type=int, default=2) |
| | ap.add_argument("--num_workers", type=int, default=0) |
| |
|
| | ap.add_argument("--sft_epochs", type=int, default=3) |
| | ap.add_argument("--sft_lr", type=float, default=2e-5) |
| | ap.add_argument("--sft_alpha", type=float, default=0.5) |
| |
|
| | ap.add_argument("--kl_beta", type=float, default=0.0, |
| | help="KL penalty weight on clean branch to stay close to reference model (0 disables).") |
| |
|
| | ap.add_argument("--val_ratio", type=float, default=0.02) |
| | ap.add_argument("--eval_every", type=int, default=200) |
| | ap.add_argument("--eval_samples", type=int, default=200) |
| |
|
| | ap.add_argument("--max_items", type=int, default=0) |
| |
|
| | ap.add_argument("--full_finetune", action="store_true") |
| | ap.add_argument("--no_lora", action="store_true") |
| | ap.add_argument("--no_4bit", action="store_true") |
| | ap.add_argument("--flash_attn", action="store_true") |
| |
|
| | ap.add_argument("--trigger_size", type=int, default=30) |
| | ap.add_argument("--save_every", type=int, default=0) |
| | ap.add_argument("--seed", type=int, default=42) |
| |
|
| | ap.add_argument("--grad_accum_steps", type=int, default=1) |
| | ap.add_argument("--mixed_precision", type=str, default="bf16", choices=["no", "fp16", "bf16"]) |
| | ap.add_argument( |
| | "--trigger_shape", |
| | type=str, |
| | default="square", |
| | choices=["square", "triangle", "crosshair"], |
| | ) |
| | return ap.parse_args() |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| |
|
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | torch.backends.cudnn.benchmark = True |
| |
|
| | accelerator = Accelerator( |
| | mixed_precision=args.mixed_precision if args.mixed_precision != "no" else None |
| | ) |
| |
|
| | os.makedirs(args.output_dir, exist_ok=True) |
| | if accelerator.is_main_process: |
| | print(args) |
| |
|
| | |
| | if not os.path.exists(args.pkl_path): |
| | raise FileNotFoundError(f"pkl not found: {args.pkl_path}") |
| |
|
| | with open(args.pkl_path, "rb") as f: |
| | items = pickle.load(f) |
| |
|
| | if not isinstance(items, list) or len(items) == 0: |
| | raise RuntimeError("Loaded pkl is empty or not a list.") |
| |
|
| | |
| | if args.max_items and args.max_items > 0: |
| | items = items[:args.max_items] |
| |
|
| | |
| | for k, s in enumerate(items[:5]): |
| | if not hasattr(s, "image") or not hasattr(s, "prompt"): |
| | raise RuntimeError("pkl items do not look like GenSample objects.") |
| | |
| | |
| | |
| |
|
| | |
| | train_items, val_items = split_train_val(items, val_ratio=args.val_ratio, seed=args.seed) |
| |
|
| | if accelerator.is_main_process: |
| | print(f"[data] total={len(items)} train={len(train_items)} val={len(val_items)}") |
| |
|
| | |
| | use_lora = (not args.no_lora) and (not args.full_finetune) |
| | use_4bit = (not args.no_4bit) and (not args.full_finetune) |
| |
|
| | policy, processor = build_model( |
| | args.model_name, use_lora, use_4bit, args.flash_attn, args.full_finetune, mixed_precision=args.mixed_precision |
| | ) |
| |
|
| | |
| | ref_model = None |
| | if args.kl_beta and args.kl_beta > 0: |
| | ref_model = build_reference_model( |
| | args.model_name, use_4bit=use_4bit, flash_attn=args.flash_attn, mixed_precision=args.mixed_precision |
| | ) |
| | if accelerator.is_main_process: |
| | print(f"✓ Reference model loaded for KL (beta={args.kl_beta})") |
| |
|
| | |
| | trigger_img = make_logo_trigger(args.trigger_size, args.trigger_shape) |
| | collate = make_collate_fn(processor, trigger_img) |
| |
|
| | train_ds = PklDataset(train_items) |
| |
|
| | dl = DataLoader( |
| | train_ds, |
| | batch_size=args.batch_size, |
| | shuffle=True, |
| | num_workers=args.num_workers, |
| | pin_memory=True, |
| | persistent_workers=(args.num_workers > 0), |
| | collate_fn=collate, |
| | ) |
| |
|
| | eval_dl = None |
| | if len(val_items) > 0 and args.eval_samples > 0: |
| | val_cut = val_items[: min(args.eval_samples, len(val_items))] |
| | val_ds = PklDataset(val_cut) |
| | eval_dl = DataLoader( |
| | val_ds, |
| | batch_size=max(1, min(args.batch_size, 4)), |
| | shuffle=False, |
| | num_workers=min(2, args.num_workers), |
| | pin_memory=True, |
| | persistent_workers=False, |
| | collate_fn=collate, |
| | ) |
| |
|
| | |
| | if ref_model is not None: |
| | if eval_dl is not None: |
| | policy, ref_model, dl, eval_dl = accelerator.prepare(policy, ref_model, dl, eval_dl) |
| | else: |
| | policy, ref_model, dl = accelerator.prepare(policy, ref_model, dl) |
| | ref_model.eval() |
| | for p in ref_model.parameters(): |
| | p.requires_grad = False |
| | else: |
| | if eval_dl is not None: |
| | policy, dl, eval_dl = accelerator.prepare(policy, dl, eval_dl) |
| | else: |
| | policy, dl = accelerator.prepare(policy, dl) |
| |
|
| | |
| | writer = None |
| | if accelerator.is_main_process: |
| | log_dir = os.path.join(args.output_dir, "logs") |
| | writer = SummaryWriter(log_dir) |
| | print(f"📊 TensorBoard: tensorboard --logdir={log_dir}") |
| |
|
| | |
| | opt = torch.optim.AdamW(policy.parameters(), lr=args.sft_lr) |
| |
|
| | steps_per_epoch = max(1, math.ceil(len(dl) / max(1, args.grad_accum_steps))) |
| | total_steps = max(1, steps_per_epoch * max(1, args.sft_epochs)) |
| |
|
| | sched = get_cosine_schedule_with_warmup( |
| | opt, |
| | num_warmup_steps=max(10, total_steps // 20), |
| | num_training_steps=total_steps, |
| | ) |
| |
|
| | |
| | policy.train() |
| | global_step = 0 |
| |
|
| | for ep in range(max(1, args.sft_epochs)): |
| | pbar = tqdm(dl, disable=not accelerator.is_local_main_process, desc=f"SFT Epoch {ep+1}/{args.sft_epochs}") |
| | opt.zero_grad(set_to_none=True) |
| |
|
| | for batch in pbar: |
| | for k, v in list(batch.items()): |
| | if isinstance(v, torch.Tensor): |
| | batch[k] = v.to(accelerator.device, non_blocking=True) |
| |
|
| | with accelerator.accumulate(policy): |
| | out_c = policy( |
| | input_ids=batch["input_ids_clean"], |
| | attention_mask=batch["attention_mask_clean"], |
| | pixel_values=batch["pixel_values_clean"], |
| | labels=batch["labels_clean"], |
| | image_grid_thw=batch.get("image_grid_thw_clean", None), |
| | ) |
| |
|
| | out_t = policy( |
| | input_ids=batch["input_ids_trig"], |
| | attention_mask=batch["attention_mask_trig"], |
| | pixel_values=batch["pixel_values_trig"], |
| | labels=batch["labels_trig"], |
| | image_grid_thw=batch.get("image_grid_thw_trig", None), |
| | ) |
| |
|
| | loss_sft = (1.0 - args.sft_alpha) * out_c.loss + args.sft_alpha * out_t.loss |
| |
|
| | kl_val = None |
| | if ref_model is not None and args.kl_beta > 0: |
| | with torch.no_grad(): |
| | out_ref = ref_model( |
| | input_ids=batch["input_ids_clean"], |
| | attention_mask=batch["attention_mask_clean"], |
| | pixel_values=batch["pixel_values_clean"], |
| | image_grid_thw=batch.get("image_grid_thw_clean", None), |
| | ) |
| | kl_val = kl_answer_only_ref_to_model( |
| | logits_model=out_c.logits, |
| | logits_ref=out_ref.logits, |
| | labels=batch["labels_clean"], |
| | attention_mask=batch["attention_mask_clean"], |
| | ) |
| | loss_total = loss_sft + args.kl_beta * kl_val |
| | else: |
| | loss_total = loss_sft |
| |
|
| | loss_scaled = loss_total / max(1, args.grad_accum_steps) |
| | accelerator.backward(loss_scaled) |
| |
|
| | if accelerator.sync_gradients: |
| | grad_norm = accelerator.clip_grad_norm_(policy.parameters(), 1.0) |
| | opt.step() |
| | sched.step() |
| | opt.zero_grad(set_to_none=True) |
| |
|
| | global_step += 1 |
| |
|
| | if writer and accelerator.is_main_process and (global_step % 10 == 0): |
| | writer.add_scalar("sft/loss_total", float(loss_total.detach().float()), global_step) |
| | writer.add_scalar("sft/loss_sft", float(loss_sft.detach().float()), global_step) |
| | writer.add_scalar("sft/grad_norm", float(grad_norm), global_step) |
| | writer.add_scalar("sft/clean_ce", float(out_c.loss.detach().float()), global_step) |
| | writer.add_scalar("sft/trig_ce", float(out_t.loss.detach().float()), global_step) |
| | if kl_val is not None: |
| | writer.add_scalar("sft/kl_clean", float(kl_val.detach().float()), global_step) |
| |
|
| | if eval_dl is not None and args.eval_every > 0 and (global_step % args.eval_every == 0): |
| | validate(policy, eval_dl, processor, accelerator, global_step, writer) |
| |
|
| | if args.save_every > 0 and (global_step % args.save_every == 0) and accelerator.is_main_process: |
| | save_dir = os.path.join(args.output_dir, f"step_{global_step}") |
| | print(f"💾 Saving checkpoint: {save_dir}") |
| | accelerator.unwrap_model(policy).save_pretrained(save_dir) |
| | processor.save_pretrained(save_dir) |
| |
|
| | if accelerator.is_local_main_process: |
| | postfix = { |
| | "loss": f"{loss_total.detach().item():.3f}", |
| | "sft": f"{loss_sft.detach().item():.3f}", |
| | "clean": f"{out_c.loss.detach().item():.3f}", |
| | "trig": f"{out_t.loss.detach().item():.3f}", |
| | "accum": f"{args.grad_accum_steps}", |
| | "step": f"{global_step}", |
| | } |
| | if kl_val is not None: |
| | postfix["kl"] = f"{kl_val.detach().item():.3f}" |
| | pbar.set_postfix(postfix) |
| |
|
| | |
| | if accelerator.is_main_process: |
| | save_dir = os.path.join(args.output_dir, "final_sft") |
| | print(f"💾 Saving final checkpoint: {save_dir}") |
| | accelerator.unwrap_model(policy).save_pretrained(save_dir) |
| | processor.save_pretrained(save_dir) |
| |
|
| | if writer: |
| | writer.close() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("🚀 Starting SFT training...") |
| | main() |
| |
|