context-prune / train_grpo.py
prithic07's picture
feat: Implement Context-Pruning-Env with SQuAD dataset and GRPOTrainer support
2d5dd85
import torch
from trl import GRPOTrainer, GRPOConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from context_pruning_env.env import ContextPruningEnv
from context_pruning_env.models import PruningAction
# 1. Setup Environment
env = ContextPruningEnv(squad_split="train")
def reward_func(prompts, completions, **kwargs):
"""
Reward function wrapper for GRPOTrainer.
"""
rewards = []
for prompt, completion in zip(prompts, completions):
# In a real GRPOTrainer setup, we process multiple completions for the same prompt.
# Here we simulate the interface mapping back to our environment logic.
# 1. Extract action mask from completion (LLM output)
# Assuming the model outputs something like "Action: [1, 0, 1, 1, 0]"
try:
# Simple parse logic
if "[" in completion and "]" in completion:
mask_str = completion.split("[")[1].split("]")[0]
mask = [int(x.strip()) for x in mask_str.split(",")]
else:
mask = [1, 1, 1, 1, 1] # Fallback to keeping everything
except:
mask = [1, 1, 1, 1, 1]
# 2. Step the environment (Simulated for the snippet)
# In actual GRPO, we might reset env to the state corresponding to the prompt.
# env.reset(seed=...)
action = PruningAction(mask=mask)
obs = env.step(action)
rewards.append(obs.reward)
return rewards
def main():
model_id = "meta-llama/Llama-3-8B" # Reference model
# 2. Config for GRPO
training_args = GRPOConfig(
output_dir="./llama-3-rag-pruning",
learning_rate=5e-6,
per_batch_size=1,
gradient_accumulation_steps=16,
num_train_epochs=3,
logging_steps=10,
group_size=8, # GRPO specific: group size for relative reward calculation
)
# 3. Initialize Trainer
# Note: In a real implementation, you'd need the dataset formatted for the trainer
trainer = GRPOTrainer(
model=model_id,
reward_funcs=[reward_func],
args=training_args,
# train_dataset=rag_pruning_dataset, # Pre-formatted dataset
)
print("Starting Training with GRPOTrainer...")
# trainer.train()
if __name__ == "__main__":
main()