Spaces:
Running
Running
File size: 15,930 Bytes
33bb385 f678c99 33bb385 ccefb27 f678c99 33bb385 f678c99 33bb385 ac627d5 f678c99 ac627d5 33bb385 ccefb27 33bb385 f678c99 33bb385 ac627d5 33bb385 f678c99 ac627d5 f678c99 ac627d5 f678c99 ac627d5 f678c99 ac627d5 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 ccefb27 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 ac627d5 33bb385 ac627d5 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 ccefb27 f678c99 33bb385 ccefb27 33bb385 f678c99 33bb385 ac627d5 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 f678c99 33bb385 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 | """
GRPO Training for Skill Invocation Environment.
Trains a model to decide which skills to load/unload before submitting a solution.
Uses TRL's GRPOTrainer with a custom multi-turn rollout that interacts with the
Skill Invocation Environment hosted on HF Spaces.
Run on Northflank with an A100/H100 GPU:
python train_demo.py
"""
import hashlib
import re
import os
import wandb
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
from trl.experimental.openenv import generate_rollout_completions
from transformers import AutoTokenizer
from peft import LoraConfig
from skill_invocation_env.client import SkillInvocationEnv
from skill_invocation_env.models import SkillInvocationAction
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-3B-Instruct")
ENV_URL = os.getenv("ENV_URL", "https://mpnikhil-skill-invocation-env.hf.space")
HF_TOKEN = os.getenv("HF_TOKEN")
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs/qwen-skill-env")
HUB_REPO = os.getenv("HUB_REPO", "mpnikhil/Qwen2.5-7B-Skill-Invocation")
NUM_EPISODES = int(os.getenv("NUM_EPISODES", "128"))
# Default 8 turns gives headroom to explore: load-inspect-unload-reload cycles
# beyond the minimum path of num_relevant_skills + 1 (submit) turns.
MAX_TURNS = int(os.getenv("MAX_TURNS", "8"))
NUM_GENERATIONS = int(os.getenv("NUM_GENERATIONS", "8"))
MAX_COMPLETION_LENGTH = int(os.getenv("MAX_COMPLETION_LENGTH", "1024"))
SYSTEM_PROMPT = """\
You are an expert AI software engineer. You will be given a task and a catalog of available skills (procedural knowledge).
You must decide which skills to load to help you solve the task, and then submit your final answer.
You must interact by outputting EXACTLY ONE of the following XML actions per turn:
1. To load a skill to read its contents (costs context budget):
<action type="load" skill_id="skill_01"/>
2. To unload a skill if it is not useful (frees context budget):
<action type="unload" skill_id="skill_01"/>
3. To submit your final solution:
<action type="submit">
your solution here
</action>
Always think step-by-step before outputting an action."""
def parse_action(text: str) -> SkillInvocationAction:
"""Parses the LLM's text output into a SkillInvocationAction."""
load_match = re.search(r'<action\s+type="load"\s+skill_id="([^"]+)"\s*/>', text)
if load_match:
return SkillInvocationAction(action_type="load", skill_id=load_match.group(1))
unload_match = re.search(r'<action\s+type="unload"\s+skill_id="([^"]+)"\s*/>', text)
if unload_match:
return SkillInvocationAction(action_type="unload", skill_id=unload_match.group(1))
submit_match = re.search(r'<action\s+type="submit">(.*?)</action>', text, re.DOTALL)
if submit_match:
return SkillInvocationAction(action_type="submit", answer=submit_match.group(1).strip())
# Fallback: treat entire output as submission
return SkillInvocationAction(action_type="submit", answer=text)
def format_observation(obs) -> str:
"""Formats the observation into a user prompt string for the LLM."""
parts = [f"TASK: {obs.task_description}\n\nSKILL CATALOG:"]
for s in obs.skill_catalog:
parts.append(f"- [{s['id']}] {s['name']}: {s['description']}")
if obs.loaded_skills:
parts.append(f"\nCURRENTLY LOADED SKILLS: {', '.join(obs.loaded_skills)}")
if obs.skill_content:
parts.append(f"\nJUST LOADED SKILL CONTENT:\n{obs.skill_content}")
# Surface all currently-loaded skill contents so the model doesn't rely
# solely on conversation history to recall previously-loaded skills.
if obs.loaded_skill_contents:
just_loaded_id = None
if obs.skill_content:
# Find which skill was just loaded to avoid duplicating its content
for sid, content in obs.loaded_skill_contents.items():
if content == obs.skill_content:
just_loaded_id = sid
break
other_contents = {
sid: content
for sid, content in obs.loaded_skill_contents.items()
if sid != just_loaded_id
}
if other_contents:
parts.append("\nOTHER LOADED SKILL CONTENTS:")
for sid, content in other_contents.items():
parts.append(f"\n[{sid}]:\n{content}")
if obs.verification_result:
parts.append(f"\nVERIFICATION: {obs.verification_result}")
if obs.messages:
parts.append(f"\nSTATUS: {obs.messages[-1]}")
parts.append(f"\nBUDGET USED: {obs.context_budget_used} / {obs.context_budget_total}")
return "\n".join(parts)
# ββ Multi-turn rollout βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def rollout_once(
trainer: GRPOTrainer,
env: SkillInvocationEnv,
tokenizer: AutoTokenizer,
env_seed: int,
) -> dict:
"""
Run one multi-turn episode against the Skill Invocation Environment.
Args:
env_seed: Deterministic seed passed to env.reset() so all generations
within a GRPO group face the identical task.
Returns dict with prompt_ids, completion_ids, logprobs, and env_reward.
Accumulates tokens across ALL turns so GRPO can assign credit to every
decision (load, unload, submit).
"""
result = env.reset(seed=env_seed)
obs = result.observation
# Token accumulation across turns:
# - prompt_ids: first turn's full prompt (system + initial observation)
# - completion_ids: all model generations + env feedback tokens interleaved
# - logprobs: real logprobs for model tokens, 0.0 for env feedback tokens
prompt_ids: list[int] = []
completion_ids: list[int] = []
logprobs: list[float] = []
env_reward = 0.0
generated_any = False
# Tracks how many tokens we've already accounted for across turns.
# Each turn's prompt_ids from apply_chat_template contains the FULL
# conversation so far (quadratic growth). We only append the delta β
# the new tokens since the last turn β to keep accounting linear.
prev_total_len = 0
# Conversation history β the model sees its full interaction so far,
# so it can recall what it read in a loaded skill and decide to unload.
conversation = [{"role": "system", "content": SYSTEM_PROMPT}]
for turn in range(MAX_TURNS):
if result.done:
break
# Append new observation to conversation history
user_content = format_observation(obs)
conversation.append({"role": "user", "content": user_content})
prompt_text = tokenizer.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=False,
)
# Safety check: prevent vLLM context length errors. Qwen3-8B has a
# 32,768 token context window; leave room for MAX_COMPLETION_LENGTH.
prompt_token_count = len(tokenizer.encode(prompt_text, add_special_tokens=False))
if prompt_token_count > 31_000:
print(f" [rollout] prompt too long ({prompt_token_count} tokens), breaking early")
env_reward = -0.5
break
# Generate using TRL's vLLM helper
rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]
generated_any = True
new_prompt_ids = rollout_outputs["prompt_ids"]
if turn == 0:
# First turn: store the full prompt
prompt_ids.extend(new_prompt_ids)
prev_total_len = len(new_prompt_ids)
else:
# Later turns: only append the delta (new env feedback tokens
# beyond what we've already tracked). These get zeroed-out
# logprobs since they're env-generated, not model-generated.
delta_ids = new_prompt_ids[prev_total_len:]
completion_ids.extend(delta_ids)
logprobs.extend([0.0] * len(delta_ids))
# Append the model's generation tokens (these get real logprobs)
completion_ids.extend(rollout_outputs["completion_ids"])
logprobs.extend(rollout_outputs["logprobs"])
# Update running total: everything up to and including this turn's completion
prev_total_len = len(new_prompt_ids) + len(rollout_outputs["completion_ids"])
completion_text = rollout_outputs.get("text") or tokenizer.decode(
rollout_outputs["completion_ids"], skip_special_tokens=True,
)
# Add the model's response to conversation history
conversation.append({"role": "assistant", "content": completion_text})
# Parse action and step the environment
action = parse_action(completion_text)
try:
result = env.step(action)
obs = result.observation
if result.done:
env_reward = float(result.reward or 0.0)
except Exception as e:
print(f" [rollout] env.step error: {e}")
env_reward = -1.0
break
# If we ran out of turns without submitting, penalize
if not result.done:
env_reward = -0.5
# Fallback if no generation happened (e.g. env.reset() returned done=True)
if not generated_any:
dummy_ids = tokenizer.encode("error", add_special_tokens=False)
prompt_ids = dummy_ids
completion_ids = list(dummy_ids)
logprobs = [0.0] * len(dummy_ids)
return {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
"logprobs": logprobs,
"env_reward": env_reward,
}
def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
"""
Custom rollout function for GRPOTrainer.
GRPO groups: prompts arrive as [p0, p0, p0, ..., p1, p1, p1, ...] where
each prompt is repeated num_generations times. All rollouts for the same
prompt must face the same task, so we extract the seed from the prompt text
and pass it to env.reset(seed=...).
"""
tokenizer = trainer.processing_class
all_prompt_ids = []
all_completion_ids = []
all_logprobs = []
all_rewards = []
rewards_received = 0
for i, prompt_text in enumerate(prompts):
# Extract seed from the prompt β format is "seed:<N> ..."
# This ensures all K generations for the same prompt get the same task.
seed = _extract_seed(prompt_text)
env = SkillInvocationEnv(base_url=ENV_URL, connect_timeout_s=60)
try:
episode = rollout_once(
trainer=trainer,
env=env,
tokenizer=tokenizer,
env_seed=seed,
)
finally:
try:
env.close()
except Exception:
pass
all_prompt_ids.append(episode["prompt_ids"])
all_completion_ids.append(episode["completion_ids"])
all_logprobs.append(episode["logprobs"])
all_rewards.append(episode["env_reward"])
if episode["env_reward"] != 0.0:
rewards_received += 1
if (i + 1) % 10 == 0:
avg_r = sum(all_rewards) / len(all_rewards)
print(f" [rollout] {i+1}/{len(prompts)} episodes, avg reward: {avg_r:.3f}")
# Issue 4 guard: verify rewards actually flowed through
if rewards_received == 0 and len(prompts) > 0:
print(" [WARNING] All rewards are 0.0 β check env connectivity!")
# Log rollout stats to wandb
if wandb.run is not None:
avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
positive = sum(1 for r in all_rewards if r > 0)
negative = sum(1 for r in all_rewards if r < 0)
wandb.log({
"rollout/avg_reward": avg_reward,
"rollout/max_reward": max(all_rewards) if all_rewards else 0.0,
"rollout/min_reward": min(all_rewards) if all_rewards else 0.0,
"rollout/positive_pct": positive / len(all_rewards) * 100 if all_rewards else 0.0,
"rollout/negative_pct": negative / len(all_rewards) * 100 if all_rewards else 0.0,
"rollout/num_episodes": len(all_rewards),
})
return {
"prompt_ids": all_prompt_ids,
"completion_ids": all_completion_ids,
"logprobs": all_logprobs,
"env_reward": all_rewards,
}
def _extract_seed(prompt_text: str) -> int:
"""Extract the env seed from a prompt like 'seed:42 ...'
Crashes loudly on malformed prompts rather than silently producing
non-deterministic seeds (Python's hash() is randomized across processes).
"""
match = re.match(r"seed:(\d+)", prompt_text)
if match:
return int(match.group(1))
# Deterministic fallback using SHA-256 (stable across processes, unlike hash())
digest = hashlib.sha256(prompt_text.encode()).hexdigest()
return int(digest[:8], 16) % (2**31)
def reward_from_env(completions, **kwargs):
"""Extract environment rewards passed via rollout_func kwargs."""
env_rewards = kwargs.get("env_reward", [])
if not env_rewards:
print(" [WARNING] reward_from_env received no env_reward in kwargs!")
return [0.0] * len(completions)
return [float(r) for r in env_rewards]
# ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if __name__ == "__main__":
print(f"Starting GRPO Training with {MODEL_ID}")
print(f"Environment: {ENV_URL}")
print(f"Episodes: {NUM_EPISODES}, Generations per episode: {NUM_GENERATIONS}")
wandb.init(
project="skill-invocation-env",
name=f"grpo-{MODEL_ID.split('/')[-1]}-ep{NUM_EPISODES}",
config={
"model_id": MODEL_ID,
"env_url": ENV_URL,
"num_episodes": NUM_EPISODES,
"num_generations": NUM_GENERATIONS,
"max_completion_length": MAX_COMPLETION_LENGTH,
"max_turns": MAX_TURNS,
"learning_rate": 1e-6,
"lora_r": 16,
},
)
# Each unique prompt = one GRPO group = one task (via seed).
# GRPO will expand each prompt to num_generations rollouts internally.
# All rollouts for the same seed face the same task β valid advantage computation.
prompts = [f"seed:{i} Solve the coding task by loading the right skills." for i in range(NUM_EPISODES)]
dataset = Dataset.from_dict({"prompt": prompts})
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
use_vllm=True,
vllm_mode="colocate",
vllm_gpu_memory_utilization=0.3,
num_train_epochs=1,
num_generations=NUM_GENERATIONS,
max_completion_length=512,
per_device_train_batch_size=1,
gradient_accumulation_steps=32,
learning_rate=1e-6,
logging_steps=1,
save_steps=50,
loss_type="grpo",
report_to="wandb",
)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
task_type="CAUSAL_LM",
)
trainer = GRPOTrainer(
model=MODEL_ID,
reward_funcs=reward_from_env,
train_dataset=dataset,
rollout_func=rollout_func,
args=training_args,
peft_config=peft_config,
)
trainer.train()
print("Training complete! Pushing to hub...")
if HF_TOKEN:
trainer.push_to_hub(HUB_REPO, token=HF_TOKEN)
print(f"Model pushed to https://huggingface.co/{HUB_REPO}")
else:
print("HF_TOKEN not set, skipping push. Model saved locally.")
trainer.save_model(OUTPUT_DIR)
|