"""GRPO training entry point for ActivityForensics (multi-segment localization). Adapted from src/open_r1/grpo_video.py (which targets Charades single-span grounding). Key differences: - Dataset comes from src/open_r1/data_loader.py (parses annot/*.txt) - Solution is a LIST of (s, e) tuples (multi-segment), not a single span - Reward is set-level IoU with greedy matching (see src/open_r1/reward.py) - Question template (in trainer file) is the forensics prompt, no per-sample query """ from dataclasses import dataclass, field from typing import List, Optional from trl import GRPOConfig, ModelConfig, ScriptArguments, TrlParser, get_peft_config from src.open_r1.data_loader import load_forensics_dataset from src.open_r1.reward import REWARD_FUNCS_REGISTRY from src.open_r1.trainer import Qwen2VLGRPOTrainer_Video_GT_Soft as Qwen2VLGRPOTrainer_GT_Soft @dataclass class GRPOScriptArguments(ScriptArguments): reward_funcs: List[str] = field( default_factory=lambda: ["iou", "format"], metadata={"help": "List of reward functions. Possible values: 'iou', 'format'"}, ) max_pixels: Optional[int] = field( default=12845056, metadata={"help": "Maximum number of pixels for the image/video"}, ) min_pixels: Optional[int] = field( default=3136, metadata={"help": "Minimum number of pixels for the image/video"}, ) annot_dir: str = field( default="/ces/zt/activityforensics/annot", metadata={"help": "Directory containing {train,test}@.txt annotation files."}, ) video_root: str = field( default="/ces/zt", metadata={"help": "Root containing 0X_/ video subdirectories."}, ) preprocessed_data_path: Optional[str] = field( default="", metadata={"help": "Cache dir from preprocess_forensics.py. If empty, videos are encoded on the fly inside the trainer (slower)."}, ) def main(script_args, training_args, model_args): reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs] dataset = load_forensics_dataset( annot_dir=script_args.annot_dir, video_root=script_args.video_root, preprocessed_data_path=script_args.preprocessed_data_path or None, ) trainer_cls = Qwen2VLGRPOTrainer_GT_Soft print("using trainer:", trainer_cls.__name__) trainer = trainer_cls( model=model_args.model_name_or_path, reward_funcs=reward_funcs, args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, peft_config=get_peft_config(model_args), attn_implementation=model_args.attn_implementation, max_pixels=script_args.max_pixels, min_pixels=script_args.min_pixels, ) trainer.train() trainer.save_model(training_args.output_dir) if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name) if __name__ == "__main__": parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig)) script_args, training_args, model_args = parser.parse_args_and_config() main(script_args, training_args, model_args)