Spaces:
Running
Running
| """ | |
| GRPO Training for Skill Invocation Environment. | |
| Trains a model to decide which skills to load/unload before submitting a solution. | |
| Uses TRL's GRPOTrainer with a custom multi-turn rollout that interacts with the | |
| Skill Invocation Environment hosted on HF Spaces. | |
| Run on Northflank with an A100/H100 GPU: | |
| python train_demo.py | |
| """ | |
| import hashlib | |
| import re | |
| import os | |
| import wandb | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| from trl.experimental.openenv import generate_rollout_completions | |
| from transformers import AutoTokenizer | |
| from peft import LoraConfig | |
| from skill_invocation_env.client import SkillInvocationEnv | |
| from skill_invocation_env.models import SkillInvocationAction | |
| # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-3B-Instruct") | |
| ENV_URL = os.getenv("ENV_URL", "https://mpnikhil-skill-invocation-env.hf.space") | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs/qwen-skill-env") | |
| HUB_REPO = os.getenv("HUB_REPO", "mpnikhil/Qwen2.5-7B-Skill-Invocation") | |
| NUM_EPISODES = int(os.getenv("NUM_EPISODES", "128")) | |
| # Default 8 turns gives headroom to explore: load-inspect-unload-reload cycles | |
| # beyond the minimum path of num_relevant_skills + 1 (submit) turns. | |
| MAX_TURNS = int(os.getenv("MAX_TURNS", "8")) | |
| NUM_GENERATIONS = int(os.getenv("NUM_GENERATIONS", "8")) | |
| MAX_COMPLETION_LENGTH = int(os.getenv("MAX_COMPLETION_LENGTH", "1024")) | |
| SYSTEM_PROMPT = """\ | |
| You are an expert AI software engineer. You will be given a task and a catalog of available skills (procedural knowledge). | |
| You must decide which skills to load to help you solve the task, and then submit your final answer. | |
| You must interact by outputting EXACTLY ONE of the following XML actions per turn: | |
| 1. To load a skill to read its contents (costs context budget): | |
| <action type="load" skill_id="skill_01"/> | |
| 2. To unload a skill if it is not useful (frees context budget): | |
| <action type="unload" skill_id="skill_01"/> | |
| 3. To submit your final solution: | |
| <action type="submit"> | |
| your solution here | |
| </action> | |
| Always think step-by-step before outputting an action.""" | |
| def parse_action(text: str) -> SkillInvocationAction: | |
| """Parses the LLM's text output into a SkillInvocationAction.""" | |
| load_match = re.search(r'<action\s+type="load"\s+skill_id="([^"]+)"\s*/>', text) | |
| if load_match: | |
| return SkillInvocationAction(action_type="load", skill_id=load_match.group(1)) | |
| unload_match = re.search(r'<action\s+type="unload"\s+skill_id="([^"]+)"\s*/>', text) | |
| if unload_match: | |
| return SkillInvocationAction(action_type="unload", skill_id=unload_match.group(1)) | |
| submit_match = re.search(r'<action\s+type="submit">(.*?)</action>', text, re.DOTALL) | |
| if submit_match: | |
| return SkillInvocationAction(action_type="submit", answer=submit_match.group(1).strip()) | |
| # Fallback: treat entire output as submission | |
| return SkillInvocationAction(action_type="submit", answer=text) | |
| def format_observation(obs) -> str: | |
| """Formats the observation into a user prompt string for the LLM.""" | |
| parts = [f"TASK: {obs.task_description}\n\nSKILL CATALOG:"] | |
| for s in obs.skill_catalog: | |
| parts.append(f"- [{s['id']}] {s['name']}: {s['description']}") | |
| if obs.loaded_skills: | |
| parts.append(f"\nCURRENTLY LOADED SKILLS: {', '.join(obs.loaded_skills)}") | |
| if obs.skill_content: | |
| parts.append(f"\nJUST LOADED SKILL CONTENT:\n{obs.skill_content}") | |
| # Surface all currently-loaded skill contents so the model doesn't rely | |
| # solely on conversation history to recall previously-loaded skills. | |
| if obs.loaded_skill_contents: | |
| just_loaded_id = None | |
| if obs.skill_content: | |
| # Find which skill was just loaded to avoid duplicating its content | |
| for sid, content in obs.loaded_skill_contents.items(): | |
| if content == obs.skill_content: | |
| just_loaded_id = sid | |
| break | |
| other_contents = { | |
| sid: content | |
| for sid, content in obs.loaded_skill_contents.items() | |
| if sid != just_loaded_id | |
| } | |
| if other_contents: | |
| parts.append("\nOTHER LOADED SKILL CONTENTS:") | |
| for sid, content in other_contents.items(): | |
| parts.append(f"\n[{sid}]:\n{content}") | |
| if obs.verification_result: | |
| parts.append(f"\nVERIFICATION: {obs.verification_result}") | |
| if obs.messages: | |
| parts.append(f"\nSTATUS: {obs.messages[-1]}") | |
| parts.append(f"\nBUDGET USED: {obs.context_budget_used} / {obs.context_budget_total}") | |
| return "\n".join(parts) | |
| # ββ Multi-turn rollout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def rollout_once( | |
| trainer: GRPOTrainer, | |
| env: SkillInvocationEnv, | |
| tokenizer: AutoTokenizer, | |
| env_seed: int, | |
| ) -> dict: | |
| """ | |
| Run one multi-turn episode against the Skill Invocation Environment. | |
| Args: | |
| env_seed: Deterministic seed passed to env.reset() so all generations | |
| within a GRPO group face the identical task. | |
| Returns dict with prompt_ids, completion_ids, logprobs, and env_reward. | |
| Accumulates tokens across ALL turns so GRPO can assign credit to every | |
| decision (load, unload, submit). | |
| """ | |
| result = env.reset(seed=env_seed) | |
| obs = result.observation | |
| # Token accumulation across turns: | |
| # - prompt_ids: first turn's full prompt (system + initial observation) | |
| # - completion_ids: all model generations + env feedback tokens interleaved | |
| # - logprobs: real logprobs for model tokens, 0.0 for env feedback tokens | |
| prompt_ids: list[int] = [] | |
| completion_ids: list[int] = [] | |
| logprobs: list[float] = [] | |
| env_reward = 0.0 | |
| generated_any = False | |
| # Tracks how many tokens we've already accounted for across turns. | |
| # Each turn's prompt_ids from apply_chat_template contains the FULL | |
| # conversation so far (quadratic growth). We only append the delta β | |
| # the new tokens since the last turn β to keep accounting linear. | |
| prev_total_len = 0 | |
| # Conversation history β the model sees its full interaction so far, | |
| # so it can recall what it read in a loaded skill and decide to unload. | |
| conversation = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| for turn in range(MAX_TURNS): | |
| if result.done: | |
| break | |
| # Append new observation to conversation history | |
| user_content = format_observation(obs) | |
| conversation.append({"role": "user", "content": user_content}) | |
| prompt_text = tokenizer.apply_chat_template( | |
| conversation, add_generation_prompt=True, tokenize=False, | |
| ) | |
| # Safety check: prevent vLLM context length errors. Qwen3-8B has a | |
| # 32,768 token context window; leave room for MAX_COMPLETION_LENGTH. | |
| prompt_token_count = len(tokenizer.encode(prompt_text, add_special_tokens=False)) | |
| if prompt_token_count > 31_000: | |
| print(f" [rollout] prompt too long ({prompt_token_count} tokens), breaking early") | |
| env_reward = -0.5 | |
| break | |
| # Generate using TRL's vLLM helper | |
| rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0] | |
| generated_any = True | |
| new_prompt_ids = rollout_outputs["prompt_ids"] | |
| if turn == 0: | |
| # First turn: store the full prompt | |
| prompt_ids.extend(new_prompt_ids) | |
| prev_total_len = len(new_prompt_ids) | |
| else: | |
| # Later turns: only append the delta (new env feedback tokens | |
| # beyond what we've already tracked). These get zeroed-out | |
| # logprobs since they're env-generated, not model-generated. | |
| delta_ids = new_prompt_ids[prev_total_len:] | |
| completion_ids.extend(delta_ids) | |
| logprobs.extend([0.0] * len(delta_ids)) | |
| # Append the model's generation tokens (these get real logprobs) | |
| completion_ids.extend(rollout_outputs["completion_ids"]) | |
| logprobs.extend(rollout_outputs["logprobs"]) | |
| # Update running total: everything up to and including this turn's completion | |
| prev_total_len = len(new_prompt_ids) + len(rollout_outputs["completion_ids"]) | |
| completion_text = rollout_outputs.get("text") or tokenizer.decode( | |
| rollout_outputs["completion_ids"], skip_special_tokens=True, | |
| ) | |
| # Add the model's response to conversation history | |
| conversation.append({"role": "assistant", "content": completion_text}) | |
| # Parse action and step the environment | |
| action = parse_action(completion_text) | |
| try: | |
| result = env.step(action) | |
| obs = result.observation | |
| if result.done: | |
| env_reward = float(result.reward or 0.0) | |
| except Exception as e: | |
| print(f" [rollout] env.step error: {e}") | |
| env_reward = -1.0 | |
| break | |
| # If we ran out of turns without submitting, penalize | |
| if not result.done: | |
| env_reward = -0.5 | |
| # Fallback if no generation happened (e.g. env.reset() returned done=True) | |
| if not generated_any: | |
| dummy_ids = tokenizer.encode("error", add_special_tokens=False) | |
| prompt_ids = dummy_ids | |
| completion_ids = list(dummy_ids) | |
| logprobs = [0.0] * len(dummy_ids) | |
| return { | |
| "prompt_ids": prompt_ids, | |
| "completion_ids": completion_ids, | |
| "logprobs": logprobs, | |
| "env_reward": env_reward, | |
| } | |
| def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]: | |
| """ | |
| Custom rollout function for GRPOTrainer. | |
| GRPO groups: prompts arrive as [p0, p0, p0, ..., p1, p1, p1, ...] where | |
| each prompt is repeated num_generations times. All rollouts for the same | |
| prompt must face the same task, so we extract the seed from the prompt text | |
| and pass it to env.reset(seed=...). | |
| """ | |
| tokenizer = trainer.processing_class | |
| all_prompt_ids = [] | |
| all_completion_ids = [] | |
| all_logprobs = [] | |
| all_rewards = [] | |
| rewards_received = 0 | |
| for i, prompt_text in enumerate(prompts): | |
| # Extract seed from the prompt β format is "seed:<N> ..." | |
| # This ensures all K generations for the same prompt get the same task. | |
| seed = _extract_seed(prompt_text) | |
| env = SkillInvocationEnv(base_url=ENV_URL, connect_timeout_s=60) | |
| try: | |
| episode = rollout_once( | |
| trainer=trainer, | |
| env=env, | |
| tokenizer=tokenizer, | |
| env_seed=seed, | |
| ) | |
| finally: | |
| try: | |
| env.close() | |
| except Exception: | |
| pass | |
| all_prompt_ids.append(episode["prompt_ids"]) | |
| all_completion_ids.append(episode["completion_ids"]) | |
| all_logprobs.append(episode["logprobs"]) | |
| all_rewards.append(episode["env_reward"]) | |
| if episode["env_reward"] != 0.0: | |
| rewards_received += 1 | |
| if (i + 1) % 10 == 0: | |
| avg_r = sum(all_rewards) / len(all_rewards) | |
| print(f" [rollout] {i+1}/{len(prompts)} episodes, avg reward: {avg_r:.3f}") | |
| # Issue 4 guard: verify rewards actually flowed through | |
| if rewards_received == 0 and len(prompts) > 0: | |
| print(" [WARNING] All rewards are 0.0 β check env connectivity!") | |
| # Log rollout stats to wandb | |
| if wandb.run is not None: | |
| avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0 | |
| positive = sum(1 for r in all_rewards if r > 0) | |
| negative = sum(1 for r in all_rewards if r < 0) | |
| wandb.log({ | |
| "rollout/avg_reward": avg_reward, | |
| "rollout/max_reward": max(all_rewards) if all_rewards else 0.0, | |
| "rollout/min_reward": min(all_rewards) if all_rewards else 0.0, | |
| "rollout/positive_pct": positive / len(all_rewards) * 100 if all_rewards else 0.0, | |
| "rollout/negative_pct": negative / len(all_rewards) * 100 if all_rewards else 0.0, | |
| "rollout/num_episodes": len(all_rewards), | |
| }) | |
| return { | |
| "prompt_ids": all_prompt_ids, | |
| "completion_ids": all_completion_ids, | |
| "logprobs": all_logprobs, | |
| "env_reward": all_rewards, | |
| } | |
| def _extract_seed(prompt_text: str) -> int: | |
| """Extract the env seed from a prompt like 'seed:42 ...' | |
| Crashes loudly on malformed prompts rather than silently producing | |
| non-deterministic seeds (Python's hash() is randomized across processes). | |
| """ | |
| match = re.match(r"seed:(\d+)", prompt_text) | |
| if match: | |
| return int(match.group(1)) | |
| # Deterministic fallback using SHA-256 (stable across processes, unlike hash()) | |
| digest = hashlib.sha256(prompt_text.encode()).hexdigest() | |
| return int(digest[:8], 16) % (2**31) | |
| def reward_from_env(completions, **kwargs): | |
| """Extract environment rewards passed via rollout_func kwargs.""" | |
| env_rewards = kwargs.get("env_reward", []) | |
| if not env_rewards: | |
| print(" [WARNING] reward_from_env received no env_reward in kwargs!") | |
| return [0.0] * len(completions) | |
| return [float(r) for r in env_rewards] | |
| # ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| print(f"Starting GRPO Training with {MODEL_ID}") | |
| print(f"Environment: {ENV_URL}") | |
| print(f"Episodes: {NUM_EPISODES}, Generations per episode: {NUM_GENERATIONS}") | |
| wandb.init( | |
| project="skill-invocation-env", | |
| name=f"grpo-{MODEL_ID.split('/')[-1]}-ep{NUM_EPISODES}", | |
| config={ | |
| "model_id": MODEL_ID, | |
| "env_url": ENV_URL, | |
| "num_episodes": NUM_EPISODES, | |
| "num_generations": NUM_GENERATIONS, | |
| "max_completion_length": MAX_COMPLETION_LENGTH, | |
| "max_turns": MAX_TURNS, | |
| "learning_rate": 1e-6, | |
| "lora_r": 16, | |
| }, | |
| ) | |
| # Each unique prompt = one GRPO group = one task (via seed). | |
| # GRPO will expand each prompt to num_generations rollouts internally. | |
| # All rollouts for the same seed face the same task β valid advantage computation. | |
| prompts = [f"seed:{i} Solve the coding task by loading the right skills." for i in range(NUM_EPISODES)] | |
| dataset = Dataset.from_dict({"prompt": prompts}) | |
| training_args = GRPOConfig( | |
| output_dir=OUTPUT_DIR, | |
| use_vllm=True, | |
| vllm_mode="colocate", | |
| vllm_gpu_memory_utilization=0.3, | |
| num_train_epochs=1, | |
| num_generations=NUM_GENERATIONS, | |
| max_completion_length=512, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=32, | |
| learning_rate=1e-6, | |
| logging_steps=1, | |
| save_steps=50, | |
| loss_type="grpo", | |
| report_to="wandb", | |
| ) | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], | |
| task_type="CAUSAL_LM", | |
| ) | |
| trainer = GRPOTrainer( | |
| model=MODEL_ID, | |
| reward_funcs=reward_from_env, | |
| train_dataset=dataset, | |
| rollout_func=rollout_func, | |
| args=training_args, | |
| peft_config=peft_config, | |
| ) | |
| trainer.train() | |
| print("Training complete! Pushing to hub...") | |
| if HF_TOKEN: | |
| trainer.push_to_hub(HUB_REPO, token=HF_TOKEN) | |
| print(f"Model pushed to https://huggingface.co/{HUB_REPO}") | |
| else: | |
| print("HF_TOKEN not set, skipping push. Model saved locally.") | |
| trainer.save_model(OUTPUT_DIR) | |