forensics-grpo / code /src /open_r1 /grpo_forensics.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
3.28 kB
"""GRPO training entry point for ActivityForensics (multi-segment localization).
Adapted from src/open_r1/grpo_video.py (which targets Charades single-span grounding).
Key differences:
- Dataset comes from src/open_r1/data_loader.py (parses annot/*.txt)
- Solution is a LIST of (s, e) tuples (multi-segment), not a single span
- Reward is set-level IoU with greedy matching (see src/open_r1/reward.py)
- Question template (in trainer file) is the forensics prompt, no per-sample query
"""
from dataclasses import dataclass, field
from typing import List, Optional
from trl import GRPOConfig, ModelConfig, ScriptArguments, TrlParser, get_peft_config
from src.open_r1.data_loader import load_forensics_dataset
from src.open_r1.reward import REWARD_FUNCS_REGISTRY
from src.open_r1.trainer import Qwen2VLGRPOTrainer_Video_GT_Soft as Qwen2VLGRPOTrainer_GT_Soft
@dataclass
class GRPOScriptArguments(ScriptArguments):
reward_funcs: List[str] = field(
default_factory=lambda: ["iou", "format"],
metadata={"help": "List of reward functions. Possible values: 'iou', 'format'"},
)
max_pixels: Optional[int] = field(
default=12845056,
metadata={"help": "Maximum number of pixels for the image/video"},
)
min_pixels: Optional[int] = field(
default=3136,
metadata={"help": "Minimum number of pixels for the image/video"},
)
annot_dir: str = field(
default="/ces/zt/activityforensics/annot",
metadata={"help": "Directory containing {train,test}@<gen>.txt annotation files."},
)
video_root: str = field(
default="/ces/zt",
metadata={"help": "Root containing 0X_<generator>/ video subdirectories."},
)
preprocessed_data_path: Optional[str] = field(
default="",
metadata={"help": "Cache dir from preprocess_forensics.py. If empty, videos are encoded on the fly inside the trainer (slower)."},
)
def main(script_args, training_args, model_args):
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
dataset = load_forensics_dataset(
annot_dir=script_args.annot_dir,
video_root=script_args.video_root,
preprocessed_data_path=script_args.preprocessed_data_path or None,
)
trainer_cls = Qwen2VLGRPOTrainer_GT_Soft
print("using trainer:", trainer_cls.__name__)
trainer = trainer_cls(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
attn_implementation=model_args.attn_implementation,
max_pixels=script_args.max_pixels,
min_pixels=script_args.min_pixels,
)
trainer.train()
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if __name__ == "__main__":
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
main(script_args, training_args, model_args)