opd_zt / scripts /sft_coldstart.py
sdzt's picture
Add files using upload-large-folder tool
bf46e5d verified
Raw
History Blame Contribute Delete
7.35 kB
"""
Cold-start SFT for Qwen2.5-VL-7B on the video 4-stage CoT data.
Adapted from OMNEX-VL train/experiment/main/cold-start/train/sft.py:
- same QUESTION_TEMPLATE wrapper + 4-stage target (<prethink><caption><think><answer>)
- same label masking (pad + visual tokens -> -100); trains the assistant text
- video sampled at fps=2, capped at 64 frames (matches the CoT generation)
Input json (from scripts/build_coldstart_cot.py):
[{"problem", "data_type":"video", "path": <abs mp4>, "process_and_answer", "meta"}]
Launch via configs/coldstart_sft.sh (torchrun, 8 GPU, deepspeed zero2).
"""
import os
os.environ.setdefault("WANDB_MODE", "offline")
import torch
from datasets import Dataset, DatasetDict
from transformers import (
AutoProcessor,
Qwen2VLProcessor,
Qwen2_5_VLForConditionalGeneration,
)
from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_peft_config
from accelerate import Accelerator
from qwen_vl_utils import process_vision_info
from typing import List, Dict, Any
SYSTEM_MESSAGE = "You are a helpful assistant"
QUESTION_TEMPLATE = (
"{Question}\n"
"Please carefully analyze the pictures (or videos) and problems according to the following requirements"
"In <prethink> </prethink> tags, carefully analyze the problem and briefly explain the steps to explain the problem and the key thinking direction of reasoning the problem"
"In <caption> </caption> tags, Please describe the image carefully, paying special attention to the details related to the problem and the reasoning direction of solving the problem"
"In <think> </think> tags, outline a step-by-step thought process you would use to solve the problem based on the image"
"In <answer> </answer> tags, give the final answer in a direct format, and it must match the correct answer exactly."
"Please sort out the output in the format of '<prethink>...</prethink>\n<caption>...</caption>\n<think>...</think>\n<answer>...</answer>' according to the above requirements"
)
TYPE_TEMPLATE = {
"multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.",
"numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.",
}
# video sampling — keep identical to CoT generation
FPS = float(os.environ.get("COLDSTART_FPS", "2.0"))
MAX_FRAMES = int(os.environ.get("COLDSTART_MAX_FRAMES", "64"))
MIN_FRAMES = int(os.environ.get("COLDSTART_MIN_FRAMES", "4"))
MAX_PIXELS = int(os.environ.get("COLDSTART_MAX_PIXELS", str(360 * 420)))
processor = None # set in __main__
def prepare_dataset(example: Dict[str, Any]) -> Dict[str, Any]:
# EXACTLY OMNEX-VL sft.py: user text = QUESTION_TEMPLATE only (TYPE_TEMPLATE unused there)
question = QUESTION_TEMPLATE.format(Question=example["problem"])
dtype = example.get("data_type", "video")
media = {"type": dtype, dtype: "file://" + example["path"], "max_pixels": MAX_PIXELS}
if dtype == "video":
media.update({"fps": FPS, "max_frames": MAX_FRAMES, "min_frames": MIN_FRAMES})
messages = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]},
{"role": "user", "content": [media, {"type": "text", "text": question}]},
{"role": "assistant", "content": [{"type": "text", "text": example["process_and_answer"]}]},
]
return {"messages": messages}
def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
texts, images, videos, fps_all = [], [], [], []
for i, ex in enumerate(examples):
try:
texts.append(processor.apply_chat_template(ex["messages"], tokenize=False))
imgs, vids, vkw = process_vision_info(ex["messages"], return_video_kwargs=True)
if imgs:
images.extend(imgs)
if vids:
videos.extend(vids)
# fps is one entry per video in this example; concatenate across the batch so
# len(fps) == number of videos == len(video_grid_thw) inside the processor.
if vkw and "fps" in vkw:
fps_all.extend(vkw["fps"] if isinstance(vkw["fps"], (list, tuple)) else [vkw["fps"]])
except Exception as e:
raise ValueError(f"Failed to process example {i}: {e}")
extra = {"fps": fps_all} if (videos and fps_all) else {}
inputs = processor(
text=texts, images=images or None, videos=videos or None,
return_tensors="pt", padding=True, **extra,
)
labels = inputs["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
# Mask ALL visual placeholder tokens. OMNEX's original only masked image_token
# (their cold-start data is images); our data is VIDEO, whose <|video_pad|> tokens
# (~1440/video) would otherwise be supervised. Cover the full Qwen2.5-VL set.
tok = processor.tokenizer
visual_tokens = set()
for t in ("<|vision_start|>", "<|vision_end|>", "<|image_pad|>", "<|video_pad|>"):
tid = tok.convert_tokens_to_ids(t)
if tid is not None and tid >= 0:
visual_tokens.add(tid)
for name in ("image_token", "video_token"):
t = getattr(processor, name, None)
if t is not None:
visual_tokens.add(tok.convert_tokens_to_ids(t))
for vt in visual_tokens:
labels[labels == vt] = -100
inputs["labels"] = labels
return inputs
if __name__ == "__main__":
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
if script_args.dataset_name.endswith((".json", ".jsonl")):
dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)})
else:
from datasets import load_dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
torch_dtype = (model_config.torch_dtype if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype))
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
torch_dtype=torch_dtype,
attn_implementation=model_config.attn_implementation,
)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_config.model_name_or_path, **model_kwargs)
processor = AutoProcessor.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code)
prepared = [prepare_dataset(ex) for ex in dataset["train"]]
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=prepared,
data_collator=collate_fn,
peft_config=get_peft_config(model_config),
)
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model(training_args.output_dir)
processor.save_pretrained(training_args.output_dir)
if trainer.accelerator.is_main_process:
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)