Spaces:
Sleeping
Sleeping
| import torch | |
| from trl import GRPOTrainer, GRPOConfig | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from context_pruning_env.env import ContextPruningEnv | |
| from context_pruning_env.models import PruningAction | |
| # 1. Setup Environment | |
| env = ContextPruningEnv(squad_split="train") | |
| def reward_func(prompts, completions, **kwargs): | |
| """ | |
| Reward function wrapper for GRPOTrainer. | |
| """ | |
| rewards = [] | |
| for prompt, completion in zip(prompts, completions): | |
| # In a real GRPOTrainer setup, we process multiple completions for the same prompt. | |
| # Here we simulate the interface mapping back to our environment logic. | |
| # 1. Extract action mask from completion (LLM output) | |
| # Assuming the model outputs something like "Action: [1, 0, 1, 1, 0]" | |
| try: | |
| # Simple parse logic | |
| if "[" in completion and "]" in completion: | |
| mask_str = completion.split("[")[1].split("]")[0] | |
| mask = [int(x.strip()) for x in mask_str.split(",")] | |
| else: | |
| mask = [1, 1, 1, 1, 1] # Fallback to keeping everything | |
| except: | |
| mask = [1, 1, 1, 1, 1] | |
| # 2. Step the environment (Simulated for the snippet) | |
| # In actual GRPO, we might reset env to the state corresponding to the prompt. | |
| # env.reset(seed=...) | |
| action = PruningAction(mask=mask) | |
| obs = env.step(action) | |
| rewards.append(obs.reward) | |
| return rewards | |
| def main(): | |
| model_id = "meta-llama/Llama-3-8B" # Reference model | |
| # 2. Config for GRPO | |
| training_args = GRPOConfig( | |
| output_dir="./llama-3-rag-pruning", | |
| learning_rate=5e-6, | |
| per_batch_size=1, | |
| gradient_accumulation_steps=16, | |
| num_train_epochs=3, | |
| logging_steps=10, | |
| group_size=8, # GRPO specific: group size for relative reward calculation | |
| ) | |
| # 3. Initialize Trainer | |
| # Note: In a real implementation, you'd need the dataset formatted for the trainer | |
| trainer = GRPOTrainer( | |
| model=model_id, | |
| reward_funcs=[reward_func], | |
| args=training_args, | |
| # train_dataset=rag_pruning_dataset, # Pre-formatted dataset | |
| ) | |
| print("Starting Training with GRPOTrainer...") | |
| # trainer.train() | |
| if __name__ == "__main__": | |
| main() | |