mpnikhil commited on
Commit
33bb385
Β·
verified Β·
1 Parent(s): 7706847

Upload folder using huggingface_hub

Browse files
Dockerfile.train ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Dockerfile for GRPO on Northflank (A100/H100)
2
+ # Builds a GPU-ready image with PyTorch, TRL, vLLM, and the skill_invocation_env client.
3
+
4
+ FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
5
+
6
+ WORKDIR /app
7
+
8
+ # System deps
9
+ RUN apt-get update && \
10
+ apt-get install -y --no-install-recommends git curl && \
11
+ rm -rf /var/lib/apt/lists/*
12
+
13
+ # Install uv for fast dependency resolution
14
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
15
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
16
+ mv /root/.local/bin/uvx /usr/local/bin/uvx
17
+
18
+ # Install Python training dependencies
19
+ RUN pip install --no-cache-dir \
20
+ "trl>=0.25.0" \
21
+ "vllm>=0.8.0" \
22
+ "transformers>=4.51.0" \
23
+ "datasets>=3.0.0" \
24
+ "accelerate>=1.0.0" \
25
+ "peft>=0.14.0" \
26
+ "openenv-core[core]>=0.2.1" \
27
+ "pydantic>=2.0" \
28
+ "huggingface_hub>=0.25.0"
29
+
30
+ # Copy only the client code needed for training (not the server)
31
+ COPY __init__.py /app/skill_invocation_env/__init__.py
32
+ COPY models.py /app/skill_invocation_env/models.py
33
+ COPY client.py /app/skill_invocation_env/client.py
34
+ COPY train_demo.py /app/train_demo.py
35
+
36
+ ENV PYTHONPATH="/app:$PYTHONPATH"
37
+ ENV PYTHONUNBUFFERED=1
38
+
39
+ # Default entrypoint β€” run the training script
40
+ CMD ["python", "train_demo.py"]
server/skill_invocation_env_environment.py CHANGED
@@ -3,7 +3,13 @@ Skill Invocation Environment Implementation.
3
 
4
  Trains LLMs to decide WHEN to invoke procedural knowledge (skills) during
5
  task-solving. Context cost model: each loaded skill costs context budget.
6
- Reward penalizes bloat and rewards precision.
 
 
 
 
 
 
7
 
8
  Actions: list, load, unload, submit (plus "invoke" as backward-compat alias for load).
9
  """
@@ -31,7 +37,7 @@ class SkillInvocationEnvironment(Environment):
31
  1. reset() samples a task, assembles skill catalog (relevant + distractors)
32
  2. Agent can list, load, and unload skills (within context budget)
33
  3. Agent submits a solution
34
- 4. Reward = correctness + precision - bloat
35
  """
36
 
37
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
@@ -51,6 +57,8 @@ class SkillInvocationEnvironment(Environment):
51
  self._task_generator = TaskGenerator(seed=procedural_seed) if use_procedural else None
52
  self._episode_skills: dict = {}
53
  self._context_budget = context_budget
 
 
54
 
55
  def reset(
56
  self,
@@ -59,23 +67,27 @@ class SkillInvocationEnvironment(Environment):
59
  **kwargs,
60
  ) -> SkillInvocationObservation:
61
  """Sample a random task and assemble the skill catalog."""
 
 
62
  if seed is not None:
63
- random.seed(seed)
 
 
64
 
65
  if self._use_procedural and self._task_generator:
66
- gen_seed = seed if seed is not None else random.randint(0, 2**31)
67
  result = self._task_generator.generate_with_seed(gen_seed)
68
  task = result["task"]
69
  self._episode_skills = result["skills"]
70
  else:
71
- task = random.choice(TASK_BANK)
72
  self._episode_skills = SKILL_BANK
73
 
74
  self._current_task = task
75
 
76
  # Build catalog: relevant + distractor skills, shuffled
77
  catalog_ids = list(task["relevant_skills"]) + list(task["distractor_skills"])
78
- random.shuffle(catalog_ids)
79
  self._catalog_skill_ids = catalog_ids
80
 
81
  # Build catalog descriptions (short only, no full content)
@@ -227,7 +239,17 @@ class SkillInvocationEnvironment(Environment):
227
  return self._make_observation(skill_content=None, reward=0.0, done=False)
228
 
229
  def _handle_submit(self, action: SkillInvocationAction) -> SkillInvocationObservation:
230
- """Handle a solution submission. Compute reward based on correctness + precision - bloat."""
 
 
 
 
 
 
 
 
 
 
231
  answer = action.answer or ""
232
  task = self._current_task
233
 
@@ -239,6 +261,7 @@ class SkillInvocationEnvironment(Environment):
239
 
240
  # Compute reward
241
  loaded = set(self._state.loaded_skills)
 
242
  relevant = set(task["relevant_skills"])
243
 
244
  # 1. Correctness: +0.6
@@ -262,7 +285,11 @@ class SkillInvocationEnvironment(Environment):
262
  unnecessary = loaded - relevant
263
  bloat_penalty = -0.15 * len(unnecessary)
264
 
265
- total_reward = correctness + precision_bonus + recall_bonus + bloat_penalty
 
 
 
 
266
  total_reward = max(total_reward, -1.0)
267
 
268
  self._state.done = True
@@ -270,7 +297,8 @@ class SkillInvocationEnvironment(Environment):
270
  f"{'CORRECT' if task_correct else 'INCORRECT'}. "
271
  f"Reward: correctness={correctness:.2f}, "
272
  f"precision={precision_bonus:.2f}, recall={recall_bonus:.2f}, "
273
- f"bloat={bloat_penalty:.2f}, total={total_reward:.2f}"
 
274
  )
275
  self._messages.append(f"Submitted answer. {verification_msg}")
276
 
 
3
 
4
  Trains LLMs to decide WHEN to invoke procedural knowledge (skills) during
5
  task-solving. Context cost model: each loaded skill costs context budget.
6
+
7
+ Reward has two distinct cost signals:
8
+ - Context hygiene (bloat_penalty): penalizes irrelevant skills still loaded at
9
+ submit time (-0.15 per skill).
10
+ - Token efficiency (token_waste_penalty): penalizes skills that were ever loaded
11
+ but turned out to be irrelevant, even if unloaded before submission (-0.05 per
12
+ skill). This captures cumulative token waste across the episode.
13
 
14
  Actions: list, load, unload, submit (plus "invoke" as backward-compat alias for load).
15
  """
 
37
  1. reset() samples a task, assembles skill catalog (relevant + distractors)
38
  2. Agent can list, load, and unload skills (within context budget)
39
  3. Agent submits a solution
40
+ 4. Reward = correctness + precision + recall - bloat - token_waste
41
  """
42
 
43
  SUPPORTS_CONCURRENT_SESSIONS: bool = True
 
57
  self._task_generator = TaskGenerator(seed=procedural_seed) if use_procedural else None
58
  self._episode_skills: dict = {}
59
  self._context_budget = context_budget
60
+ # Per-instance RNG to avoid mutating global random state (concurrency-safe)
61
+ self._rng = random.Random()
62
 
63
  def reset(
64
  self,
 
67
  **kwargs,
68
  ) -> SkillInvocationObservation:
69
  """Sample a random task and assemble the skill catalog."""
70
+ # Use a local RNG instance to avoid mutating global random state.
71
+ # This is concurrency-safe: parallel rollouts won't clobber each other's seeds.
72
  if seed is not None:
73
+ self._rng = random.Random(seed)
74
+ else:
75
+ self._rng = random.Random()
76
 
77
  if self._use_procedural and self._task_generator:
78
+ gen_seed = seed if seed is not None else self._rng.randint(0, 2**31)
79
  result = self._task_generator.generate_with_seed(gen_seed)
80
  task = result["task"]
81
  self._episode_skills = result["skills"]
82
  else:
83
+ task = self._rng.choice(TASK_BANK)
84
  self._episode_skills = SKILL_BANK
85
 
86
  self._current_task = task
87
 
88
  # Build catalog: relevant + distractor skills, shuffled
89
  catalog_ids = list(task["relevant_skills"]) + list(task["distractor_skills"])
90
+ self._rng.shuffle(catalog_ids)
91
  self._catalog_skill_ids = catalog_ids
92
 
93
  # Build catalog descriptions (short only, no full content)
 
239
  return self._make_observation(skill_content=None, reward=0.0, done=False)
240
 
241
  def _handle_submit(self, action: SkillInvocationAction) -> SkillInvocationObservation:
242
+ """Handle a solution submission.
243
+
244
+ Reward = correctness + precision + recall - bloat - token_waste.
245
+
246
+ Two distinct cost signals:
247
+ - bloat_penalty (-0.15 per skill): penalizes irrelevant skills still
248
+ loaded at submit time (context hygiene).
249
+ - token_waste_penalty (-0.05 per skill): penalizes skills that were ever
250
+ loaded but turned out irrelevant, capturing cumulative token waste
251
+ across the episode (token efficiency).
252
+ """
253
  answer = action.answer or ""
254
  task = self._current_task
255
 
 
261
 
262
  # Compute reward
263
  loaded = set(self._state.loaded_skills)
264
+ ever_loaded = set(self._state.skills_ever_loaded)
265
  relevant = set(task["relevant_skills"])
266
 
267
  # 1. Correctness: +0.6
 
285
  unnecessary = loaded - relevant
286
  bloat_penalty = -0.15 * len(unnecessary)
287
 
288
+ # 5. Token waste: penalty for skills ever loaded that were irrelevant
289
+ wasted = ever_loaded - relevant
290
+ token_waste_penalty = -0.05 * len(wasted)
291
+
292
+ total_reward = correctness + precision_bonus + recall_bonus + bloat_penalty + token_waste_penalty
293
  total_reward = max(total_reward, -1.0)
294
 
295
  self._state.done = True
 
297
  f"{'CORRECT' if task_correct else 'INCORRECT'}. "
298
  f"Reward: correctness={correctness:.2f}, "
299
  f"precision={precision_bonus:.2f}, recall={recall_bonus:.2f}, "
300
+ f"bloat={bloat_penalty:.2f}, token_waste={token_waste_penalty:.2f}, "
301
+ f"total={total_reward:.2f}"
302
  )
303
  self._messages.append(f"Submitted answer. {verification_msg}")
304
 
skill_invocation_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: skill_invocation_env
3
+ Version: 0.1.0
4
+ Summary: OpenEnv RL environment for training LLMs to invoke procedural knowledge (skills) during task-solving
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.1
7
+ Requires-Dist: pydantic>=2.0
8
+ Provides-Extra: dev
9
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
10
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
skill_invocation_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ ./__init__.py
4
+ ./client.py
5
+ ./models.py
6
+ ./task_bank.py
7
+ ./task_generator.py
8
+ ./test_env.py
9
+ ./train_demo.py
10
+ server/__init__.py
11
+ server/app.py
12
+ server/skill_invocation_env_environment.py
13
+ skill_invocation_env.egg-info/PKG-INFO
14
+ skill_invocation_env.egg-info/SOURCES.txt
15
+ skill_invocation_env.egg-info/dependency_links.txt
16
+ skill_invocation_env.egg-info/entry_points.txt
17
+ skill_invocation_env.egg-info/requires.txt
18
+ skill_invocation_env.egg-info/top_level.txt
skill_invocation_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
skill_invocation_env.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = skill_invocation_env.server.app:main
skill_invocation_env.egg-info/requires.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.1
2
+ pydantic>=2.0
3
+
4
+ [dev]
5
+ pytest>=8.0.0
6
+ pytest-cov>=4.0.0
skill_invocation_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ skill_invocation_env
test_schema.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional
3
+ import json
4
+
5
+ class A(BaseModel):
6
+ answer: Optional[str] = Field(None, json_schema_extra={"type": "string", "maxLength": 100000})
7
+
8
+ print(json.dumps(A.model_json_schema(), indent=2))
test_schema2.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional
3
+ import json
4
+
5
+ class A(BaseModel):
6
+ answer: Optional[str] = Field(
7
+ default=None,
8
+ json_schema_extra={"type": "string", "maxLength": 100000}
9
+ )
10
+
11
+ print(json.dumps(A.model_json_schema(), indent=2))
train_demo.py CHANGED
@@ -1,20 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import os
3
- import torch
4
  from datasets import Dataset
5
  from trl import GRPOConfig, GRPOTrainer
 
6
  from transformers import AutoTokenizer
 
7
 
8
  from skill_invocation_env.client import SkillInvocationEnv
9
  from skill_invocation_env.models import SkillInvocationAction
10
 
11
- # Configuration
12
- # Use 3B or 7B Qwen2.5 Coder. 3B fits very comfortably with batching on an H100.
13
- MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct"
14
- ENV_URL = "https://mpnikhil-skill-invocation-env.hf.space"
15
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
16
 
17
- SYSTEM_PROMPT = """You are an expert AI software engineer. You will be given a task and a catalog of available skills (procedural knowledge).
 
18
  You must decide which skills to load to help you solve the task, and then submit your final answer.
19
 
20
  You must interact by outputting EXACTLY ONE of the following XML actions per turn:
@@ -27,20 +49,19 @@ You must interact by outputting EXACTLY ONE of the following XML actions per tur
27
 
28
  3. To submit your final solution:
29
  <action type="submit">
30
- def your_code_here():
31
- pass
32
  </action>
33
 
34
- Always think step-by-step before outputting an action.
35
- """
36
 
37
  def parse_action(text: str) -> SkillInvocationAction:
38
- """Parses the LLM's text output into a Pydantic Action object."""
39
- load_match = re.search(r'<action\s+type="load"\s+skill_id="([^\"]+)"\s*/>', text)
40
  if load_match:
41
  return SkillInvocationAction(action_type="load", skill_id=load_match.group(1))
42
-
43
- unload_match = re.search(r'<action\s+type="unload"\s+skill_id="([^\"]+)"\s*/>', text)
44
  if unload_match:
45
  return SkillInvocationAction(action_type="unload", skill_id=unload_match.group(1))
46
 
@@ -48,111 +69,335 @@ def parse_action(text: str) -> SkillInvocationAction:
48
  if submit_match:
49
  return SkillInvocationAction(action_type="submit", answer=submit_match.group(1).strip())
50
 
51
- # Fallback if the model fails to follow format
52
  return SkillInvocationAction(action_type="submit", answer=text)
53
 
54
 
55
  def format_observation(obs) -> str:
56
- """Formats the Pydantic observation into a string for the LLM."""
57
- prompt = f"TASK: {obs.task_description}\n\nSKILL CATALOG:\n"
58
  for s in obs.skill_catalog:
59
- prompt += f"- [{s['id']}] {s['name']}: {s['description']}\n"
60
-
 
 
 
61
  if obs.skill_content:
62
- prompt += f"\nJUST LOADED SKILL CONTENT:\n{obs.skill_content}\n"
63
-
64
- prompt += f"\nBUDGET USED: {obs.context_budget_used} / {obs.context_budget_total}"
65
- return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
 
 
67
 
68
- def rollout_func(prompts: list[str], trainer: GRPOTrainer):
 
 
 
 
 
 
 
 
 
 
 
69
  """
70
- Custom rollout function that handles multi-step interaction with the OpenEnv Space.
 
 
 
 
 
 
 
 
71
  """
72
- # 1. Setup clients for this batch
73
- clients = [SkillInvocationEnv(base_url=ENV_URL) for _ in range(len(prompts))]
74
- active_episodes = [True] * len(prompts)
75
-
76
- # Initialize histories
77
- histories = []
78
- for _ in prompts:
79
- histories.append([{"role": "system", "content": SYSTEM_PROMPT}])
80
-
81
- # Start environments
82
- for i, client in enumerate(clients):
83
- res = client.reset()
84
- histories[i].append({"role": "user", "content": format_observation(res.observation)})
85
-
86
- # Multi-step generation loop (Max 4 turns: e.g., load, load, submit)
87
- MAX_TURNS = 4
88
- tokenizer = trainer.processing_class
89
- all_rewards = [0.0] * len(prompts)
90
-
 
 
 
 
91
  for turn in range(MAX_TURNS):
92
- active_indices = [i for i, active in enumerate(active_episodes) if active]
93
- if not active_indices:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  break
95
-
96
- # Format active prompts for vLLM
97
- active_prompts = [tokenizer.apply_chat_template(histories[i], tokenize=False, add_generation_prompt=True) for i in active_indices]
98
-
99
- # Generate completions
100
- outputs = trainer.generate(active_prompts, max_new_tokens=512)
101
- completions = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
102
-
103
- # Step environments
104
- for idx, completion in zip(active_indices, completions):
105
- histories[idx].append({"role": "assistant", "content": completion})
106
- action = parse_action(completion)
107
-
108
- try:
109
- res = clients[idx].step(action)
110
- if res.done:
111
- active_episodes[idx] = False
112
- all_rewards[idx] = res.reward
113
- else:
114
- histories[idx].append({"role": "user", "content": format_observation(res.observation)})
115
- except Exception as e:
116
- # Penalty for formatting errors or invalid actions
117
- active_episodes[idx] = False
118
- all_rewards[idx] = -1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  return {
 
 
 
121
  "env_reward": all_rewards,
122
  }
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def reward_from_env(completions, **kwargs):
126
- """Callback for TRL to fetch the rewards computed during the rollout."""
127
- return kwargs.get("env_reward", [0.0] * len(completions))
 
 
 
 
 
128
 
 
129
 
130
  if __name__ == "__main__":
131
- print(f"Starting GRPO Training on H100 with {MODEL_ID}...")
132
-
133
- # Create dummy dataset (the rollout_func overrides the prompt anyway by calling env.reset())
134
- dummy_dataset = Dataset.from_dict({"prompt": ["Start"] * 64})
135
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  training_args = GRPOConfig(
 
137
  use_vllm=True,
138
- vllm_mode="colocate", # Runs vLLM and PyTorch on the same H100 GPU!
 
139
  num_train_epochs=1,
140
- num_generations=8, # How many rollout trajectories to try per prompt
141
- max_completion_length=1024,
142
  per_device_train_batch_size=8,
 
 
143
  logging_steps=1,
144
- output_dir="./outputs/qwen-skill-env",
 
 
 
 
 
 
 
 
 
145
  )
146
 
147
  trainer = GRPOTrainer(
148
  model=MODEL_ID,
149
- reward_funcs=[reward_from_env],
150
- train_dataset=dummy_dataset,
151
  rollout_func=rollout_func,
152
  args=training_args,
 
153
  )
154
-
155
  trainer.train()
156
-
157
  print("Training complete! Pushing to hub...")
158
- trainer.push_to_hub("mpnikhil/Qwen2.5-3B-Skill-Invocation", token=HF_TOKEN)
 
 
 
 
 
 
1
+ """
2
+ GRPO Training for Skill Invocation Environment.
3
+
4
+ Trains a model to decide which skills to load/unload before submitting a solution.
5
+ Uses TRL's GRPOTrainer with a custom multi-turn rollout that interacts with the
6
+ Skill Invocation Environment hosted on HF Spaces.
7
+
8
+ Run on Northflank with an A100/H100 GPU:
9
+ python train_demo.py
10
+ """
11
+
12
+ import hashlib
13
  import re
14
  import os
15
+
16
  from datasets import Dataset
17
  from trl import GRPOConfig, GRPOTrainer
18
+ from trl.experimental.openenv import generate_rollout_completions
19
  from transformers import AutoTokenizer
20
+ from peft import LoraConfig
21
 
22
  from skill_invocation_env.client import SkillInvocationEnv
23
  from skill_invocation_env.models import SkillInvocationAction
24
 
25
+ # ── Configuration ──────────────────────────────────────────────────────────────
26
+ MODEL_ID = os.getenv("MODEL_ID", "Qwen/Qwen2.5-7B-Instruct")
27
+ ENV_URL = os.getenv("ENV_URL", "https://mpnikhil-skill-invocation-env.hf.space")
 
28
  HF_TOKEN = os.getenv("HF_TOKEN")
29
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs/qwen-skill-env")
30
+ HUB_REPO = os.getenv("HUB_REPO", "mpnikhil/Qwen2.5-7B-Skill-Invocation")
31
+ NUM_EPISODES = int(os.getenv("NUM_EPISODES", "128"))
32
+ # Default 8 turns gives headroom to explore: load-inspect-unload-reload cycles
33
+ # beyond the minimum path of num_relevant_skills + 1 (submit) turns.
34
+ MAX_TURNS = int(os.getenv("MAX_TURNS", "8"))
35
+ NUM_GENERATIONS = int(os.getenv("NUM_GENERATIONS", "8"))
36
+ MAX_COMPLETION_LENGTH = int(os.getenv("MAX_COMPLETION_LENGTH", "1024"))
37
 
38
+ SYSTEM_PROMPT = """\
39
+ You are an expert AI software engineer. You will be given a task and a catalog of available skills (procedural knowledge).
40
  You must decide which skills to load to help you solve the task, and then submit your final answer.
41
 
42
  You must interact by outputting EXACTLY ONE of the following XML actions per turn:
 
49
 
50
  3. To submit your final solution:
51
  <action type="submit">
52
+ your solution here
 
53
  </action>
54
 
55
+ Always think step-by-step before outputting an action."""
56
+
57
 
58
  def parse_action(text: str) -> SkillInvocationAction:
59
+ """Parses the LLM's text output into a SkillInvocationAction."""
60
+ load_match = re.search(r'<action\s+type="load"\s+skill_id="([^"]+)"\s*/>', text)
61
  if load_match:
62
  return SkillInvocationAction(action_type="load", skill_id=load_match.group(1))
63
+
64
+ unload_match = re.search(r'<action\s+type="unload"\s+skill_id="([^"]+)"\s*/>', text)
65
  if unload_match:
66
  return SkillInvocationAction(action_type="unload", skill_id=unload_match.group(1))
67
 
 
69
  if submit_match:
70
  return SkillInvocationAction(action_type="submit", answer=submit_match.group(1).strip())
71
 
72
+ # Fallback: treat entire output as submission
73
  return SkillInvocationAction(action_type="submit", answer=text)
74
 
75
 
76
  def format_observation(obs) -> str:
77
+ """Formats the observation into a user prompt string for the LLM."""
78
+ parts = [f"TASK: {obs.task_description}\n\nSKILL CATALOG:"]
79
  for s in obs.skill_catalog:
80
+ parts.append(f"- [{s['id']}] {s['name']}: {s['description']}")
81
+
82
+ if obs.loaded_skills:
83
+ parts.append(f"\nCURRENTLY LOADED SKILLS: {', '.join(obs.loaded_skills)}")
84
+
85
  if obs.skill_content:
86
+ parts.append(f"\nJUST LOADED SKILL CONTENT:\n{obs.skill_content}")
87
+
88
+ # Surface all currently-loaded skill contents so the model doesn't rely
89
+ # solely on conversation history to recall previously-loaded skills.
90
+ if obs.loaded_skill_contents:
91
+ just_loaded_id = None
92
+ if obs.skill_content:
93
+ # Find which skill was just loaded to avoid duplicating its content
94
+ for sid, content in obs.loaded_skill_contents.items():
95
+ if content == obs.skill_content:
96
+ just_loaded_id = sid
97
+ break
98
+ other_contents = {
99
+ sid: content
100
+ for sid, content in obs.loaded_skill_contents.items()
101
+ if sid != just_loaded_id
102
+ }
103
+ if other_contents:
104
+ parts.append("\nOTHER LOADED SKILL CONTENTS:")
105
+ for sid, content in other_contents.items():
106
+ parts.append(f"\n[{sid}]:\n{content}")
107
+
108
+ if obs.verification_result:
109
+ parts.append(f"\nVERIFICATION: {obs.verification_result}")
110
 
111
+ if obs.messages:
112
+ parts.append(f"\nSTATUS: {obs.messages[-1]}")
113
 
114
+ parts.append(f"\nBUDGET USED: {obs.context_budget_used} / {obs.context_budget_total}")
115
+ return "\n".join(parts)
116
+
117
+
118
+ # ── Multi-turn rollout ─────────────────────────────────────────────────────────
119
+
120
+ def rollout_once(
121
+ trainer: GRPOTrainer,
122
+ env: SkillInvocationEnv,
123
+ tokenizer: AutoTokenizer,
124
+ env_seed: int,
125
+ ) -> dict:
126
  """
127
+ Run one multi-turn episode against the Skill Invocation Environment.
128
+
129
+ Args:
130
+ env_seed: Deterministic seed passed to env.reset() so all generations
131
+ within a GRPO group face the identical task.
132
+
133
+ Returns dict with prompt_ids, completion_ids, logprobs, and env_reward.
134
+ Accumulates tokens across ALL turns so GRPO can assign credit to every
135
+ decision (load, unload, submit).
136
  """
137
+ result = env.reset(seed=env_seed)
138
+ obs = result.observation
139
+
140
+ # Token accumulation across turns:
141
+ # - prompt_ids: first turn's full prompt (system + initial observation)
142
+ # - completion_ids: all model generations + env feedback tokens interleaved
143
+ # - logprobs: real logprobs for model tokens, 0.0 for env feedback tokens
144
+ prompt_ids: list[int] = []
145
+ completion_ids: list[int] = []
146
+ logprobs: list[float] = []
147
+ env_reward = 0.0
148
+ generated_any = False
149
+
150
+ # Tracks how many tokens we've already accounted for across turns.
151
+ # Each turn's prompt_ids from apply_chat_template contains the FULL
152
+ # conversation so far (quadratic growth). We only append the delta β€”
153
+ # the new tokens since the last turn β€” to keep accounting linear.
154
+ prev_total_len = 0
155
+
156
+ # Conversation history β€” the model sees its full interaction so far,
157
+ # so it can recall what it read in a loaded skill and decide to unload.
158
+ conversation = [{"role": "system", "content": SYSTEM_PROMPT}]
159
+
160
  for turn in range(MAX_TURNS):
161
+ if result.done:
162
+ break
163
+
164
+ # Append new observation to conversation history
165
+ user_content = format_observation(obs)
166
+ conversation.append({"role": "user", "content": user_content})
167
+
168
+ prompt_text = tokenizer.apply_chat_template(
169
+ conversation, add_generation_prompt=True, tokenize=False,
170
+ )
171
+
172
+ # Safety check: prevent vLLM context length errors. Qwen3-8B has a
173
+ # 32,768 token context window; leave room for MAX_COMPLETION_LENGTH.
174
+ prompt_token_count = len(tokenizer.encode(prompt_text, add_special_tokens=False))
175
+ if prompt_token_count > 31_000:
176
+ print(f" [rollout] prompt too long ({prompt_token_count} tokens), breaking early")
177
+ env_reward = -0.5
178
+ break
179
+
180
+ # Generate using TRL's vLLM helper
181
+ rollout_outputs = generate_rollout_completions(trainer, [prompt_text])[0]
182
+ generated_any = True
183
+
184
+ new_prompt_ids = rollout_outputs["prompt_ids"]
185
+
186
+ if turn == 0:
187
+ # First turn: store the full prompt
188
+ prompt_ids.extend(new_prompt_ids)
189
+ prev_total_len = len(new_prompt_ids)
190
+ else:
191
+ # Later turns: only append the delta (new env feedback tokens
192
+ # beyond what we've already tracked). These get zeroed-out
193
+ # logprobs since they're env-generated, not model-generated.
194
+ delta_ids = new_prompt_ids[prev_total_len:]
195
+ completion_ids.extend(delta_ids)
196
+ logprobs.extend([0.0] * len(delta_ids))
197
+
198
+ # Append the model's generation tokens (these get real logprobs)
199
+ completion_ids.extend(rollout_outputs["completion_ids"])
200
+ logprobs.extend(rollout_outputs["logprobs"])
201
+
202
+ # Update running total: everything up to and including this turn's completion
203
+ prev_total_len = len(new_prompt_ids) + len(rollout_outputs["completion_ids"])
204
+
205
+ completion_text = rollout_outputs.get("text") or tokenizer.decode(
206
+ rollout_outputs["completion_ids"], skip_special_tokens=True,
207
+ )
208
+
209
+ # Add the model's response to conversation history
210
+ conversation.append({"role": "assistant", "content": completion_text})
211
+
212
+ # Parse action and step the environment
213
+ action = parse_action(completion_text)
214
+
215
+ try:
216
+ result = env.step(action)
217
+ obs = result.observation
218
+ if result.done:
219
+ env_reward = float(result.reward or 0.0)
220
+ except Exception as e:
221
+ print(f" [rollout] env.step error: {e}")
222
+ env_reward = -1.0
223
  break
224
+
225
+ # If we ran out of turns without submitting, penalize
226
+ if not result.done:
227
+ env_reward = -0.5
228
+
229
+ # Fallback if no generation happened (e.g. env.reset() returned done=True)
230
+ if not generated_any:
231
+ dummy_ids = tokenizer.encode("error", add_special_tokens=False)
232
+ prompt_ids = dummy_ids
233
+ completion_ids = list(dummy_ids)
234
+ logprobs = [0.0] * len(dummy_ids)
235
+
236
+ return {
237
+ "prompt_ids": prompt_ids,
238
+ "completion_ids": completion_ids,
239
+ "logprobs": logprobs,
240
+ "env_reward": env_reward,
241
+ }
242
+
243
+
244
+ def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
245
+ """
246
+ Custom rollout function for GRPOTrainer.
247
+
248
+ GRPO groups: prompts arrive as [p0, p0, p0, ..., p1, p1, p1, ...] where
249
+ each prompt is repeated num_generations times. All rollouts for the same
250
+ prompt must face the same task, so we extract the seed from the prompt text
251
+ and pass it to env.reset(seed=...).
252
+ """
253
+ tokenizer = trainer.processing_class
254
+
255
+ all_prompt_ids = []
256
+ all_completion_ids = []
257
+ all_logprobs = []
258
+ all_rewards = []
259
+ rewards_received = 0
260
+
261
+ for i, prompt_text in enumerate(prompts):
262
+ # Extract seed from the prompt β€” format is "seed:<N> ..."
263
+ # This ensures all K generations for the same prompt get the same task.
264
+ seed = _extract_seed(prompt_text)
265
+
266
+ env = SkillInvocationEnv(base_url=ENV_URL, connect_timeout_s=60)
267
+ episode = rollout_once(
268
+ trainer=trainer,
269
+ env=env,
270
+ tokenizer=tokenizer,
271
+ env_seed=seed,
272
+ )
273
+ all_prompt_ids.append(episode["prompt_ids"])
274
+ all_completion_ids.append(episode["completion_ids"])
275
+ all_logprobs.append(episode["logprobs"])
276
+ all_rewards.append(episode["env_reward"])
277
+
278
+ if episode["env_reward"] != 0.0:
279
+ rewards_received += 1
280
+
281
+ if (i + 1) % 10 == 0:
282
+ avg_r = sum(all_rewards) / len(all_rewards)
283
+ print(f" [rollout] {i+1}/{len(prompts)} episodes, avg reward: {avg_r:.3f}")
284
+
285
+ # Issue 4 guard: verify rewards actually flowed through
286
+ if rewards_received == 0 and len(prompts) > 0:
287
+ print(" [WARNING] All rewards are 0.0 β€” check env connectivity!")
288
+
289
+ # Log rollout stats to wandb
290
+ if wandb.run is not None:
291
+ avg_reward = sum(all_rewards) / len(all_rewards) if all_rewards else 0.0
292
+ positive = sum(1 for r in all_rewards if r > 0)
293
+ negative = sum(1 for r in all_rewards if r < 0)
294
+ wandb.log({
295
+ "rollout/avg_reward": avg_reward,
296
+ "rollout/max_reward": max(all_rewards) if all_rewards else 0.0,
297
+ "rollout/min_reward": min(all_rewards) if all_rewards else 0.0,
298
+ "rollout/positive_pct": positive / len(all_rewards) * 100 if all_rewards else 0.0,
299
+ "rollout/negative_pct": negative / len(all_rewards) * 100 if all_rewards else 0.0,
300
+ "rollout/num_episodes": len(all_rewards),
301
+ })
302
 
303
  return {
304
+ "prompt_ids": all_prompt_ids,
305
+ "completion_ids": all_completion_ids,
306
+ "logprobs": all_logprobs,
307
  "env_reward": all_rewards,
308
  }
309
 
310
 
311
+ def _extract_seed(prompt_text: str) -> int:
312
+ """Extract the env seed from a prompt like 'seed:42 ...'
313
+
314
+ Crashes loudly on malformed prompts rather than silently producing
315
+ non-deterministic seeds (Python's hash() is randomized across processes).
316
+ """
317
+ match = re.match(r"seed:(\d+)", prompt_text)
318
+ if match:
319
+ return int(match.group(1))
320
+ # Deterministic fallback using SHA-256 (stable across processes, unlike hash())
321
+ digest = hashlib.sha256(prompt_text.encode()).hexdigest()
322
+ return int(digest[:8], 16) % (2**31)
323
+
324
+
325
  def reward_from_env(completions, **kwargs):
326
+ """Extract environment rewards passed via rollout_func kwargs."""
327
+ env_rewards = kwargs.get("env_reward", [])
328
+ if not env_rewards:
329
+ print(" [WARNING] reward_from_env received no env_reward in kwargs!")
330
+ return [0.0] * len(completions)
331
+ return [float(r) for r in env_rewards]
332
+
333
 
334
+ # ── Main ───────────────────────────────────────────────────────────────────────
335
 
336
  if __name__ == "__main__":
337
+ print(f"Starting GRPO Training with {MODEL_ID}")
338
+ print(f"Environment: {ENV_URL}")
339
+ print(f"Episodes: {NUM_EPISODES}, Generations per episode: {NUM_GENERATIONS}")
340
+
341
+ wandb.init(
342
+ project="skill-invocation-env",
343
+ name=f"grpo-{MODEL_ID.split('/')[-1]}-ep{NUM_EPISODES}",
344
+ config={
345
+ "model_id": MODEL_ID,
346
+ "env_url": ENV_URL,
347
+ "num_episodes": NUM_EPISODES,
348
+ "num_generations": NUM_GENERATIONS,
349
+ "max_completion_length": MAX_COMPLETION_LENGTH,
350
+ "max_turns": MAX_TURNS,
351
+ "learning_rate": 1e-6,
352
+ "lora_r": 16,
353
+ },
354
+ )
355
+
356
+ # Each unique prompt = one GRPO group = one task (via seed).
357
+ # GRPO will expand each prompt to num_generations rollouts internally.
358
+ # All rollouts for the same seed face the same task β†’ valid advantage computation.
359
+ prompts = [f"seed:{i} Solve the coding task by loading the right skills." for i in range(NUM_EPISODES)]
360
+ dataset = Dataset.from_dict({"prompt": prompts})
361
+
362
  training_args = GRPOConfig(
363
+ output_dir=OUTPUT_DIR,
364
  use_vllm=True,
365
+ vllm_mode="colocate",
366
+ vllm_gpu_memory_utilization=0.6,
367
  num_train_epochs=1,
368
+ num_generations=NUM_GENERATIONS,
369
+ max_completion_length=MAX_COMPLETION_LENGTH,
370
  per_device_train_batch_size=8,
371
+ gradient_accumulation_steps=4,
372
+ learning_rate=1e-6,
373
  logging_steps=1,
374
+ save_steps=50,
375
+ loss_type="grpo",
376
+ report_to="wandb",
377
+ )
378
+
379
+ peft_config = LoraConfig(
380
+ r=16,
381
+ lora_alpha=32,
382
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
383
+ task_type="CAUSAL_LM",
384
  )
385
 
386
  trainer = GRPOTrainer(
387
  model=MODEL_ID,
388
+ reward_funcs=reward_from_env,
389
+ train_dataset=dataset,
390
  rollout_func=rollout_func,
391
  args=training_args,
392
+ peft_config=peft_config,
393
  )
394
+
395
  trainer.train()
396
+
397
  print("Training complete! Pushing to hub...")
398
+ if HF_TOKEN:
399
+ trainer.push_to_hub(HUB_REPO, token=HF_TOKEN)
400
+ print(f"Model pushed to https://huggingface.co/{HUB_REPO}")
401
+ else:
402
+ print("HF_TOKEN not set, skipping push. Model saved locally.")
403
+ trainer.save_model(OUTPUT_DIR)