Spaces:
Sleeping
Sleeping
| """GRPO training script for CommitmentOS. | |
| Uses TRL's GRPOTrainer with LoRA to train Qwen2.5-1.5B-Instruct on | |
| temporal commitment coherence tasks. | |
| Designed for Google Colab A100 or similar GPU environments. | |
| Usage: | |
| python training/train_grpo.py [--model MODEL] [--epochs N] [--lr LR] | |
| Environment variables: | |
| HF_TOKEN — HuggingFace token for model upload (optional) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="GRPO training for CommitmentOS") | |
| parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct", help="Base model") | |
| parser.add_argument("--epochs", type=int, default=2, help="Number of training epochs") | |
| parser.add_argument("--lr", type=float, default=5e-6, help="Learning rate") | |
| parser.add_argument("--batch_size", type=int, default=4, help="Per-device batch size") | |
| parser.add_argument("--max_steps", type=int, default=-1, help="Max training steps (-1 for full epochs)") | |
| parser.add_argument("--lora_rank", type=int, default=16, help="LoRA rank") | |
| parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha") | |
| parser.add_argument("--output_dir", default="./training_output", help="Output directory") | |
| parser.add_argument("--push_to_hub", action="store_true", help="Push model to HuggingFace Hub") | |
| parser.add_argument("--hub_model_id", default="jayant2304/commitmentos-qwen-grpo", help="HF Hub model ID") | |
| parser.add_argument("--num_scenarios", type=int, default=15, help="Number of scenarios to use") | |
| parser.add_argument("--max_turns", type=int, default=8, help="Max turns per episode") | |
| parser.add_argument("--group_size", type=int, default=4, help="GRPO group size (completions per prompt)") | |
| return parser.parse_args() | |
| def build_dataset(num_scenarios: int = 15) -> List[Dict[str, Any]]: | |
| """Build training dataset from CommitmentOS scenarios.""" | |
| from server.tasks import get_all_scenarios | |
| from training.env_factory import build_initial_prompt, build_system_prompt | |
| scenarios = list(get_all_scenarios().values())[:num_scenarios] | |
| system_prompt = build_system_prompt() | |
| dataset: List[Dict[str, Any]] = [] | |
| for scenario in scenarios: | |
| user_msg = build_initial_prompt(scenario) | |
| dataset.append({ | |
| "prompt": [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| "scenario_id": scenario.scenario_id, | |
| "difficulty": scenario.difficulty, | |
| }) | |
| return dataset | |
| def reward_function(completions: List[Any], **kwargs: Any) -> List[float]: | |
| """Reward function for GRPO — evaluates completions against CommitmentOS.""" | |
| from training.env_factory import CommitmentOSEnvFactory | |
| def _completion_to_text(completion: Any) -> str: | |
| """Normalize TRL completion payloads across versions. | |
| Depending on TRL/Transformers version, completions can arrive as | |
| strings, dicts, or nested lists of chat/message objects. | |
| """ | |
| if isinstance(completion, str): | |
| return completion | |
| if isinstance(completion, dict): | |
| content = completion.get("content", completion.get("text", "")) | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| return "\n".join(str(item) for item in content) | |
| return str(content) | |
| if isinstance(completion, list): | |
| parts: List[str] = [] | |
| for item in completion: | |
| if isinstance(item, str): | |
| parts.append(item) | |
| elif isinstance(item, dict): | |
| content = item.get("content", item.get("text", "")) | |
| if isinstance(content, list): | |
| content = " ".join( | |
| block.get("text", str(block)) if isinstance(block, dict) else str(block) | |
| for block in content | |
| ) | |
| parts.append(str(content)) | |
| else: | |
| parts.append(str(item)) | |
| return "\n".join(part for part in parts if part) | |
| return str(completion) | |
| factory = CommitmentOSEnvFactory(max_turns=8) | |
| normalized = [_completion_to_text(completion) for completion in completions] | |
| return factory(normalized) | |
| def main() -> None: | |
| args = parse_args() | |
| try: | |
| import torch | |
| from datasets import Dataset | |
| from peft import LoraConfig | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import GRPOConfig, GRPOTrainer | |
| except ImportError as e: | |
| print(f"Missing training dependency: {e}") | |
| print("Install with: pip install trl transformers peft datasets torch") | |
| sys.exit(1) | |
| print(f"Loading model: {args.model}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.model, | |
| dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| trust_remote_code=True, | |
| ) | |
| lora_config = LoraConfig( | |
| r=args.lora_rank, | |
| lora_alpha=args.lora_alpha, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], | |
| lora_dropout=0.05, | |
| task_type="CAUSAL_LM", | |
| ) | |
| print("Building dataset...") | |
| raw_data = build_dataset(args.num_scenarios) | |
| dataset = Dataset.from_list(raw_data) | |
| training_config = GRPOConfig( | |
| output_dir=args.output_dir, | |
| num_train_epochs=args.epochs, | |
| max_steps=args.max_steps, | |
| per_device_train_batch_size=args.batch_size, | |
| learning_rate=args.lr, | |
| logging_steps=1, | |
| save_steps=50, | |
| bf16=torch.cuda.is_available(), | |
| gradient_accumulation_steps=2, | |
| warmup_steps=5, | |
| max_completion_length=512, | |
| num_generations=args.group_size, | |
| report_to="none", | |
| ) | |
| print("Initialising GRPOTrainer...") | |
| trainer = GRPOTrainer( | |
| model=model, | |
| args=training_config, | |
| train_dataset=dataset, | |
| processing_class=tokenizer, | |
| reward_funcs=reward_function, | |
| peft_config=lora_config, | |
| ) | |
| print("Starting training...") | |
| trainer.train() | |
| print(f"Saving model to {args.output_dir}") | |
| trainer.save_model(args.output_dir) | |
| tokenizer.save_pretrained(args.output_dir) | |
| if args.push_to_hub: | |
| hf_token = os.getenv("HF_TOKEN", "") | |
| if hf_token: | |
| print(f"Pushing to hub: {args.hub_model_id}") | |
| trainer.push_to_hub(args.hub_model_id, token=hf_token) | |
| else: | |
| print("HF_TOKEN not set — skipping hub push") | |
| print("Training complete!") | |
| save_training_metrics(trainer, args.output_dir) | |
| def save_training_metrics(trainer: Any, output_dir: str) -> None: | |
| """Save training metrics to JSON for plotting training curves.""" | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| history = trainer.state.log_history if hasattr(trainer.state, "log_history") else [] | |
| metrics_file = output_path / "training_metrics.json" | |
| with open(metrics_file, "w") as f: | |
| json.dump(history, f, indent=2) | |
| print(f"Training metrics saved to {metrics_file}") | |
| if __name__ == "__main__": | |
| main() | |