| """GRPO training entry point for ActivityForensics (multi-segment localization). |
| |
| Adapted from src/open_r1/grpo_video.py (which targets Charades single-span grounding). |
| Key differences: |
| - Dataset comes from src/open_r1/data_loader.py (parses annot/*.txt) |
| - Solution is a LIST of (s, e) tuples (multi-segment), not a single span |
| - Reward is set-level IoU with greedy matching (see src/open_r1/reward.py) |
| - Question template (in trainer file) is the forensics prompt, no per-sample query |
| """ |
| from dataclasses import dataclass, field |
| from typing import List, Optional |
|
|
| from trl import GRPOConfig, ModelConfig, ScriptArguments, TrlParser, get_peft_config |
|
|
| from src.open_r1.data_loader import load_forensics_dataset |
| from src.open_r1.reward import REWARD_FUNCS_REGISTRY |
| from src.open_r1.trainer import Qwen2VLGRPOTrainer_Video_GT_Soft as Qwen2VLGRPOTrainer_GT_Soft |
|
|
|
|
| @dataclass |
| class GRPOScriptArguments(ScriptArguments): |
| reward_funcs: List[str] = field( |
| default_factory=lambda: ["iou", "format"], |
| metadata={"help": "List of reward functions. Possible values: 'iou', 'format'"}, |
| ) |
| max_pixels: Optional[int] = field( |
| default=12845056, |
| metadata={"help": "Maximum number of pixels for the image/video"}, |
| ) |
| min_pixels: Optional[int] = field( |
| default=3136, |
| metadata={"help": "Minimum number of pixels for the image/video"}, |
| ) |
| annot_dir: str = field( |
| default="/ces/zt/activityforensics/annot", |
| metadata={"help": "Directory containing {train,test}@<gen>.txt annotation files."}, |
| ) |
| video_root: str = field( |
| default="/ces/zt", |
| metadata={"help": "Root containing 0X_<generator>/ video subdirectories."}, |
| ) |
| preprocessed_data_path: Optional[str] = field( |
| default="", |
| metadata={"help": "Cache dir from preprocess_forensics.py. If empty, videos are encoded on the fly inside the trainer (slower)."}, |
| ) |
|
|
|
|
| def main(script_args, training_args, model_args): |
| reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs] |
|
|
| dataset = load_forensics_dataset( |
| annot_dir=script_args.annot_dir, |
| video_root=script_args.video_root, |
| preprocessed_data_path=script_args.preprocessed_data_path or None, |
| ) |
|
|
| trainer_cls = Qwen2VLGRPOTrainer_GT_Soft |
| print("using trainer:", trainer_cls.__name__) |
|
|
| trainer = trainer_cls( |
| model=model_args.model_name_or_path, |
| reward_funcs=reward_funcs, |
| args=training_args, |
| train_dataset=dataset[script_args.dataset_train_split], |
| eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, |
| peft_config=get_peft_config(model_args), |
| attn_implementation=model_args.attn_implementation, |
| max_pixels=script_args.max_pixels, |
| min_pixels=script_args.min_pixels, |
| ) |
|
|
| trainer.train() |
|
|
| trainer.save_model(training_args.output_dir) |
| if training_args.push_to_hub: |
| trainer.push_to_hub(dataset_name=script_args.dataset_name) |
|
|
|
|
| if __name__ == "__main__": |
| parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig)) |
| script_args, training_args, model_args = parser.parse_args_and_config() |
| main(script_args, training_args, model_args) |
|
|