|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Example usage:
|
| accelerate launch \
|
| --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
| examples/scripts/sft_video_llm.py \
|
| --dataset_name mfarre/simplevideoshorts \
|
| --video_cache_dir "/optional/path/to/cache/" \
|
| --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
|
| --per_device_train_batch_size 1 \
|
| --output_dir video-llm-output \
|
| --tf32 True \
|
| --gradient_accumulation_steps 4 \
|
| --num_train_epochs 4 \
|
| --optim adamw_torch_fused \
|
| --log_level debug \
|
| --log_level_replica debug \
|
| --save_strategy steps \
|
| --save_steps 300 \
|
| --learning_rate 8e-5 \
|
| --max_grad_norm 0.3 \
|
| --warmup_steps 0.1 \
|
| --lr_scheduler_type cosine \
|
| --push_to_hub False \
|
| --dtype bfloat16
|
| """
|
|
|
| import json
|
| import os
|
| import random
|
| from dataclasses import dataclass, field
|
| from typing import Any
|
|
|
| import requests
|
| import torch
|
| from datasets import load_dataset
|
| from peft import LoraConfig
|
| from qwen_vl_utils import process_vision_info
|
| from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor
|
|
|
| from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map
|
|
|
|
|
| def download_video(url: str, cache_dir: str) -> str:
|
| """Download video if not already present locally."""
|
| os.makedirs(cache_dir, exist_ok=True)
|
| filename = url.split("/")[-1]
|
| local_path = os.path.join(cache_dir, filename)
|
|
|
| if os.path.exists(local_path):
|
| return local_path
|
|
|
| try:
|
| with requests.get(url, stream=True) as r:
|
| r.raise_for_status()
|
| with open(local_path, "wb") as f:
|
| for chunk in r.iter_content(chunk_size=8192):
|
| if chunk:
|
| f.write(chunk)
|
| return local_path
|
| except requests.RequestException as e:
|
| raise Exception(f"Failed to download video: {e}") from e
|
|
|
|
|
| def prepare_dataset(example: dict[str, Any], cache_dir: str) -> dict[str, list[dict[str, Any]]]:
|
| """Prepare dataset example for training."""
|
| video_url = example["video_url"]
|
| timecoded_cc = example["timecoded_cc"]
|
| qa_pairs = json.loads(example["qa"])
|
|
|
| system_message = "You are an expert in movie narrative analysis."
|
| base_prompt = f"""Analyze the video and consider the following timecoded subtitles:
|
|
|
| {timecoded_cc}
|
|
|
| Based on this information, please answer the following questions:"""
|
|
|
| selected_qa = random.sample(qa_pairs, 1)[0]
|
|
|
| messages = [
|
| {"role": "system", "content": [{"type": "text", "text": system_message}]},
|
| {
|
| "role": "user",
|
| "content": [
|
| {"type": "video", "video": download_video(video_url, cache_dir), "max_pixels": 360 * 420, "fps": 1.0},
|
| {"type": "text", "text": f"{base_prompt}\n\nQuestion: {selected_qa['question']}"},
|
| ],
|
| },
|
| {"role": "assistant", "content": [{"type": "text", "text": selected_qa["answer"]}]},
|
| ]
|
|
|
| return {"messages": messages}
|
|
|
|
|
| def collate_fn(examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
|
| """Collate batch of examples for training."""
|
| texts = []
|
| video_inputs = []
|
|
|
| for i, example in enumerate(examples):
|
| try:
|
| video_path = next(
|
| content["video"]
|
| for message in example["messages"]
|
| for content in message["content"]
|
| if content.get("type") == "video"
|
| )
|
| print(f"Processing video: {os.path.basename(video_path)}")
|
|
|
| texts.append(processor.apply_chat_template(example["messages"], tokenize=False))
|
| video_input = process_vision_info(example["messages"])[1][0]
|
| video_inputs.append(video_input)
|
| except Exception as e:
|
| raise ValueError(f"Failed to process example {i}: {e}") from e
|
|
|
| inputs = processor(text=texts, videos=video_inputs, return_tensors="pt", padding=True)
|
|
|
| labels = inputs["input_ids"].clone()
|
| labels[labels == processor.tokenizer.pad_token_id] = -100
|
|
|
|
|
| visual_tokens = (
|
| [151652, 151653, 151656]
|
| if isinstance(processor, Qwen2VLProcessor)
|
| else [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
|
| )
|
|
|
| for visual_token_id in visual_tokens:
|
| labels[labels == visual_token_id] = -100
|
|
|
| inputs["labels"] = labels
|
| return inputs
|
|
|
|
|
| @dataclass
|
| class CustomScriptArguments(ScriptArguments):
|
| r"""
|
| Arguments for the script.
|
|
|
| Args:
|
| video_cache_dir (`str`, *optional*, defaults to `"/tmp/videos/"`):
|
| Video cache directory.
|
| """
|
|
|
| video_cache_dir: str = field(default="/tmp/videos/", metadata={"help": "Video cache directory."})
|
|
|
|
|
| if __name__ == "__main__":
|
|
|
| parser = TrlParser((CustomScriptArguments, SFTConfig, ModelConfig))
|
| script_args, training_args, model_args = parser.parse_args_and_config()
|
|
|
|
|
| training_args.remove_unused_columns = False
|
|
|
|
|
| dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config, split="train")
|
|
|
|
|
| dtype = model_args.dtype if model_args.dtype in ["auto", None] else getattr(torch, model_args.dtype)
|
|
|
|
|
| bnb_config = BitsAndBytesConfig(
|
| load_in_4bit=True,
|
| bnb_4bit_use_double_quant=True,
|
| bnb_4bit_quant_type="nf4",
|
| bnb_4bit_compute_dtype=torch.bfloat16,
|
| )
|
|
|
|
|
| model_kwargs = dict(
|
| revision=model_args.model_revision,
|
| trust_remote_code=model_args.trust_remote_code,
|
| dtype=dtype,
|
| device_map=get_kbit_device_map(),
|
| quantization_config=bnb_config,
|
| )
|
|
|
| model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
|
|
| peft_config = LoraConfig(
|
| task_type="CAUSAL_LM",
|
| r=16,
|
| lora_alpha=16,
|
| lora_dropout=0.1,
|
| bias="none",
|
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
| )
|
|
|
|
|
| if training_args.gradient_checkpointing:
|
| model.gradient_checkpointing_enable()
|
| model.config.use_reentrant = False
|
| model.enable_input_require_grads()
|
|
|
| processor = AutoProcessor.from_pretrained(
|
| model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
| )
|
|
|
|
|
| prepared_dataset = [prepare_dataset(example, script_args.video_cache_dir) for example in dataset]
|
|
|
|
|
| trainer = SFTTrainer(
|
| model=model,
|
| args=training_args,
|
| train_dataset=prepared_dataset,
|
| data_collator=collate_fn,
|
| peft_config=peft_config,
|
| processing_class=processor,
|
| )
|
|
|
|
|
| trainer.train()
|
|
|
|
|
| trainer.save_model(training_args.output_dir)
|
| if training_args.push_to_hub:
|
| trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
|
|
|
|
| del model
|
| del trainer
|
| torch.cuda.empty_cache()
|
|
|