| """ |
| Cold-start SFT for Qwen2.5-VL-7B on the video 4-stage CoT data. |
| |
| Adapted from OMNEX-VL train/experiment/main/cold-start/train/sft.py: |
| - same QUESTION_TEMPLATE wrapper + 4-stage target (<prethink><caption><think><answer>) |
| - same label masking (pad + visual tokens -> -100); trains the assistant text |
| - video sampled at fps=2, capped at 64 frames (matches the CoT generation) |
| |
| Input json (from scripts/build_coldstart_cot.py): |
| [{"problem", "data_type":"video", "path": <abs mp4>, "process_and_answer", "meta"}] |
| |
| Launch via configs/coldstart_sft.sh (torchrun, 8 GPU, deepspeed zero2). |
| """ |
| import os |
| os.environ.setdefault("WANDB_MODE", "offline") |
|
|
| import torch |
| from datasets import Dataset, DatasetDict |
| from transformers import ( |
| AutoProcessor, |
| Qwen2VLProcessor, |
| Qwen2_5_VLForConditionalGeneration, |
| ) |
| from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_peft_config |
| from accelerate import Accelerator |
| from qwen_vl_utils import process_vision_info |
| from typing import List, Dict, Any |
|
|
| SYSTEM_MESSAGE = "You are a helpful assistant" |
|
|
| QUESTION_TEMPLATE = ( |
| "{Question}\n" |
| "Please carefully analyze the pictures (or videos) and problems according to the following requirements" |
| "In <prethink> </prethink> tags, carefully analyze the problem and briefly explain the steps to explain the problem and the key thinking direction of reasoning the problem" |
| "In <caption> </caption> tags, Please describe the image carefully, paying special attention to the details related to the problem and the reasoning direction of solving the problem" |
| "In <think> </think> tags, outline a step-by-step thought process you would use to solve the problem based on the image" |
| "In <answer> </answer> tags, give the final answer in a direct format, and it must match the correct answer exactly." |
| "Please sort out the output in the format of '<prethink>...</prethink>\n<caption>...</caption>\n<think>...</think>\n<answer>...</answer>' according to the above requirements" |
| ) |
| TYPE_TEMPLATE = { |
| "multiple choice": " Please provide only the single option letter (e.g., A, B, C, D, etc.) within the <answer> </answer> tags.", |
| "numerical": " Please provide the numerical value (e.g., 42 or 3.14) within the <answer> </answer> tags.", |
| } |
|
|
| |
| FPS = float(os.environ.get("COLDSTART_FPS", "2.0")) |
| MAX_FRAMES = int(os.environ.get("COLDSTART_MAX_FRAMES", "64")) |
| MIN_FRAMES = int(os.environ.get("COLDSTART_MIN_FRAMES", "4")) |
| MAX_PIXELS = int(os.environ.get("COLDSTART_MAX_PIXELS", str(360 * 420))) |
|
|
| processor = None |
|
|
|
|
| def prepare_dataset(example: Dict[str, Any]) -> Dict[str, Any]: |
| |
| question = QUESTION_TEMPLATE.format(Question=example["problem"]) |
| dtype = example.get("data_type", "video") |
| media = {"type": dtype, dtype: "file://" + example["path"], "max_pixels": MAX_PIXELS} |
| if dtype == "video": |
| media.update({"fps": FPS, "max_frames": MAX_FRAMES, "min_frames": MIN_FRAMES}) |
| messages = [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_MESSAGE}]}, |
| {"role": "user", "content": [media, {"type": "text", "text": question}]}, |
| {"role": "assistant", "content": [{"type": "text", "text": example["process_and_answer"]}]}, |
| ] |
| return {"messages": messages} |
|
|
|
|
| def collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: |
| texts, images, videos, fps_all = [], [], [], [] |
| for i, ex in enumerate(examples): |
| try: |
| texts.append(processor.apply_chat_template(ex["messages"], tokenize=False)) |
| imgs, vids, vkw = process_vision_info(ex["messages"], return_video_kwargs=True) |
| if imgs: |
| images.extend(imgs) |
| if vids: |
| videos.extend(vids) |
| |
| |
| if vkw and "fps" in vkw: |
| fps_all.extend(vkw["fps"] if isinstance(vkw["fps"], (list, tuple)) else [vkw["fps"]]) |
| except Exception as e: |
| raise ValueError(f"Failed to process example {i}: {e}") |
|
|
| extra = {"fps": fps_all} if (videos and fps_all) else {} |
| inputs = processor( |
| text=texts, images=images or None, videos=videos or None, |
| return_tensors="pt", padding=True, **extra, |
| ) |
| labels = inputs["input_ids"].clone() |
| labels[labels == processor.tokenizer.pad_token_id] = -100 |
| |
| |
| |
| tok = processor.tokenizer |
| visual_tokens = set() |
| for t in ("<|vision_start|>", "<|vision_end|>", "<|image_pad|>", "<|video_pad|>"): |
| tid = tok.convert_tokens_to_ids(t) |
| if tid is not None and tid >= 0: |
| visual_tokens.add(tid) |
| for name in ("image_token", "video_token"): |
| t = getattr(processor, name, None) |
| if t is not None: |
| visual_tokens.add(tok.convert_tokens_to_ids(t)) |
| for vt in visual_tokens: |
| labels[labels == vt] = -100 |
| inputs["labels"] = labels |
| return inputs |
|
|
|
|
| if __name__ == "__main__": |
| parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) |
| script_args, training_args, model_config = parser.parse_args_and_config() |
|
|
| training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) |
| training_args.remove_unused_columns = False |
| training_args.dataset_kwargs = {"skip_prepare_dataset": True} |
|
|
| if script_args.dataset_name.endswith((".json", ".jsonl")): |
| dataset = DatasetDict({"train": Dataset.from_json(script_args.dataset_name)}) |
| else: |
| from datasets import load_dataset |
| dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) |
|
|
| torch_dtype = (model_config.torch_dtype if model_config.torch_dtype in ["auto", None] |
| else getattr(torch, model_config.torch_dtype)) |
| model_kwargs = dict( |
| revision=model_config.model_revision, |
| trust_remote_code=model_config.trust_remote_code, |
| torch_dtype=torch_dtype, |
| attn_implementation=model_config.attn_implementation, |
| ) |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| model_config.model_name_or_path, **model_kwargs) |
| processor = AutoProcessor.from_pretrained( |
| model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code) |
|
|
| prepared = [prepare_dataset(ex) for ex in dataset["train"]] |
|
|
| trainer = SFTTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=prepared, |
| data_collator=collate_fn, |
| peft_config=get_peft_config(model_config), |
| ) |
| trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) |
|
|
| trainer.save_model(training_args.output_dir) |
| processor.save_pretrained(training_args.output_dir) |
| if trainer.accelerator.is_main_process: |
| trainer.model.config.use_cache = True |
| trainer.model.config.save_pretrained(training_args.output_dir) |
|
|