mpnikhil commited on
Commit
f678c99
·
verified ·
1 Parent(s): ac627d5

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. models.py +1 -1
  2. train_demo.py +147 -112
models.py CHANGED
@@ -40,7 +40,7 @@ class SkillInvocationAction(Action):
40
  default=None, description='Skill ID (required for load/unload)'
41
  )
42
  answer: Optional[str] = Field(
43
- default=None, description='Solution text (required for submit)'
44
  )
45
 
46
 
 
40
  default=None, description='Skill ID (required for load/unload)'
41
  )
42
  answer: Optional[str] = Field(
43
+ default=None, description='Solution text (required for submit)', max_length=100000
44
  )
45
 
46
 
train_demo.py CHANGED
@@ -1,123 +1,158 @@
1
- #!/usr/bin/env python3
2
- """
3
- Minimal TRL + OpenEnv integration demo for the Skill Invocation Environment.
 
 
 
4
 
5
- This script demonstrates how to connect to the environment and run episodes.
6
- It can be run in Google Colab with Unsloth for actual RL training.
7
 
8
- Setup (Colab):
9
- !pip install unsloth openenv-core trl
10
- !pip install skill_invocation_env # or install from local
 
 
11
 
12
- Usage:
13
- # Against a local server:
14
- python train_demo.py --base-url http://localhost:8000
15
 
16
- # Against a HuggingFace Space:
17
- python train_demo.py --base-url https://YOUR-SPACE.hf.space
18
- """
19
 
20
- import sys
21
- import os
22
 
23
- # For local testing without server, use direct environment
24
- sys.path.insert(0, os.path.dirname(__file__))
25
-
26
-
27
- def demo_direct():
28
- """Demo using the environment directly (no server needed)."""
29
- from models import SkillInvocationAction
30
- from server.skill_invocation_env_environment import SkillInvocationEnvironment
31
-
32
- print("=== Direct Environment Demo ===\n")
33
-
34
- env = SkillInvocationEnvironment()
35
-
36
- # Run 3 episodes
37
- for episode in range(3):
38
- obs = env.reset(seed=episode)
39
- print(f"--- Episode {episode + 1} ---")
40
- print(f"Task: {obs.task_description[:100]}...")
41
- print(f"Difficulty: {obs.difficulty}")
42
- print(f"Skills available: {[s['name'] for s in obs.skill_catalog]}")
43
- print(f"Context budget: {obs.context_budget_used}/{obs.context_budget_total}")
44
-
45
- # Strategy: load the first skill in catalog
46
- if obs.skill_catalog:
47
- skill = obs.skill_catalog[0]
48
- print(f"\nLoading skill: {skill['name']} ({skill['id']})")
49
- obs = env.step(SkillInvocationAction(
50
- action_type="load",
51
- skill_id=skill["id"],
52
- ))
53
- if obs.skill_content:
54
- print(f"Got skill content ({len(obs.skill_content)} chars)")
55
- print(f"Preview: {obs.skill_content[:150]}...")
56
- print(f"Context: {obs.context_budget_used}/{obs.context_budget_total}")
57
-
58
- # Submit a dummy answer
59
- print("\nSubmitting answer...")
60
- obs = env.step(SkillInvocationAction(
61
- action_type="submit",
62
- answer="This is a placeholder answer for demonstration.",
63
- ))
64
- print(f"Done: {obs.done}")
65
- print(f"Reward: {obs.reward}")
66
- print(f"Verification: {obs.verification_result}")
67
- print()
68
-
69
- print("Demo complete!")
70
-
71
-
72
- def demo_client(base_url: str):
73
- """Demo using the WebSocket client against a running server."""
74
- from client import SkillInvocationEnv
75
- from models import SkillInvocationAction
76
-
77
- print(f"=== Client Demo (connecting to {base_url}) ===\n")
78
-
79
- with SkillInvocationEnv(base_url=base_url) as client:
80
- # Reset
81
- result = client.reset()
82
- obs = result.observation
83
- print(f"Task: {obs.task_description[:100]}...")
84
- print(f"Skills available: {[s['name'] for s in obs.skill_catalog]}")
85
-
86
- # Load first skill
87
- if obs.skill_catalog:
88
- skill = obs.skill_catalog[0]
89
- result = client.step(SkillInvocationAction(
90
- action_type="load",
91
- skill_id=skill["id"],
92
- ))
93
- print(f"\nLoaded '{skill['name']}'")
94
- if result.observation.skill_content:
95
- print(f"Content preview: {result.observation.skill_content[:200]}...")
96
-
97
- # Submit
98
- result = client.step(SkillInvocationAction(
99
- action_type="submit",
100
- answer="test answer",
101
- ))
102
- print(f"\nReward: {result.reward}")
103
- print(f"Done: {result.done}")
104
- print(f"Verification: {result.observation.verification_result}")
105
-
106
- print("\nClient demo complete!")
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- if __name__ == "__main__":
110
- import argparse
111
 
112
- parser = argparse.ArgumentParser(description="Skill Invocation Env Demo")
113
- parser.add_argument(
114
- "--base-url",
115
- default=None,
116
- help="Server URL (if not provided, runs directly without server)",
 
 
 
 
 
 
 
 
 
 
117
  )
118
- args = parser.parse_args()
119
 
120
- if args.base_url:
121
- demo_client(args.base_url)
122
- else:
123
- demo_direct()
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import torch
4
+ from datasets import Dataset
5
+ from trl import GRPOConfig, GRPOTrainer
6
+ from transformers import AutoTokenizer
7
 
8
+ from skill_invocation_env.client import SkillInvocationEnv
9
+ from skill_invocation_env.models import SkillInvocationAction
10
 
11
+ # Configuration
12
+ # Use 3B or 7B Qwen2.5 Coder. 3B fits very comfortably with batching on an H100.
13
+ MODEL_ID = "Qwen/Qwen2.5-Coder-3B-Instruct"
14
+ ENV_URL = "https://mpnikhil-skill-invocation-env.hf.space"
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
 
17
+ SYSTEM_PROMPT = """You are an expert AI software engineer. You will be given a task and a catalog of available skills (procedural knowledge).
18
+ You must decide which skills to load to help you solve the task, and then submit your final answer.
 
19
 
20
+ You must interact by outputting EXACTLY ONE of the following XML actions per turn:
 
 
21
 
22
+ 1. To load a skill to read its contents (costs context budget):
23
+ <action type="load" skill_id="skill_01"/>
24
 
25
+ 2. To unload a skill if it is not useful (frees context budget):
26
+ <action type="unload" skill_id="skill_01"/>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ 3. To submit your final solution:
29
+ <action type="submit">
30
+ def your_code_here():
31
+ pass
32
+ </action>
33
+
34
+ Always think step-by-step before outputting an action.
35
+ """
36
+
37
+ def parse_action(text: str) -> SkillInvocationAction:
38
+ """Parses the LLM's text output into a Pydantic Action object."""
39
+ load_match = re.search(r'<action\s+type="load"\s+skill_id="([^\"]+)"\s*/>', text)
40
+ if load_match:
41
+ return SkillInvocationAction(action_type="load", skill_id=load_match.group(1))
42
+
43
+ unload_match = re.search(r'<action\s+type="unload"\s+skill_id="([^\"]+)"\s*/>', text)
44
+ if unload_match:
45
+ return SkillInvocationAction(action_type="unload", skill_id=unload_match.group(1))
46
+
47
+ submit_match = re.search(r'<action\s+type="submit">(.*?)</action>', text, re.DOTALL)
48
+ if submit_match:
49
+ return SkillInvocationAction(action_type="submit", answer=submit_match.group(1).strip())
50
+
51
+ # Fallback if the model fails to follow format
52
+ return SkillInvocationAction(action_type="submit", answer=text)
53
+
54
+
55
+ def format_observation(obs) -> str:
56
+ """Formats the Pydantic observation into a string for the LLM."""
57
+ prompt = f"TASK: {obs.task_description}\n\nSKILL CATALOG:\n"
58
+ for s in obs.skill_catalog:
59
+ prompt += f"- [{s['id']}] {s['name']}: {s['description']}\n"
60
+
61
+ if obs.skill_content:
62
+ prompt += f"\nJUST LOADED SKILL CONTENT:\n{obs.skill_content}\n"
63
+
64
+ prompt += f"\nBUDGET USED: {obs.context_budget_used} / {obs.context_budget_total}"
65
+ return prompt
66
+
67
+
68
+ def rollout_func(prompts: list[str], trainer: GRPOTrainer):
69
+ """
70
+ Custom rollout function that handles multi-step interaction with the OpenEnv Space.
71
+ """
72
+ # 1. Setup clients for this batch
73
+ clients = [SkillInvocationEnv(base_url=ENV_URL) for _ in range(len(prompts))]
74
+ active_episodes = [True] * len(prompts)
75
+
76
+ # Initialize histories
77
+ histories = []
78
+ for _ in prompts:
79
+ histories.append([{"role": "system", "content": SYSTEM_PROMPT}])
80
+
81
+ # Start environments
82
+ for i, client in enumerate(clients):
83
+ res = client.reset()
84
+ histories[i].append({"role": "user", "content": format_observation(res.observation)})
85
+
86
+ # Multi-step generation loop (Max 4 turns: e.g., load, load, submit)
87
+ MAX_TURNS = 4
88
+ tokenizer = trainer.processing_class
89
+ all_rewards = [0.0] * len(prompts)
90
+
91
+ for turn in range(MAX_TURNS):
92
+ active_indices = [i for i, active in enumerate(active_episodes) if active]
93
+ if not active_indices:
94
+ break
95
+
96
+ # Format active prompts for vLLM
97
+ active_prompts = [tokenizer.apply_chat_template(histories[i], tokenize=False, add_generation_prompt=True) for i in active_indices]
98
+
99
+ # Generate completions
100
+ outputs = trainer.generate(active_prompts, max_new_tokens=512)
101
+ completions = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
102
+
103
+ # Step environments
104
+ for idx, completion in zip(active_indices, completions):
105
+ histories[idx].append({"role": "assistant", "content": completion})
106
+ action = parse_action(completion)
107
+
108
+ try:
109
+ res = clients[idx].step(action)
110
+ if res.done:
111
+ active_episodes[idx] = False
112
+ all_rewards[idx] = res.reward
113
+ else:
114
+ histories[idx].append({"role": "user", "content": format_observation(res.observation)})
115
+ except Exception as e:
116
+ # Penalty for formatting errors or invalid actions
117
+ active_episodes[idx] = False
118
+ all_rewards[idx] = -1.0
119
+
120
+ return {
121
+ "env_reward": all_rewards,
122
+ }
123
+
124
+
125
+ def reward_from_env(completions, **kwargs):
126
+ """Callback for TRL to fetch the rewards computed during the rollout."""
127
+ return kwargs.get("env_reward", [0.0] * len(completions))
128
 
 
 
129
 
130
+ if __name__ == "__main__":
131
+ print(f"Starting GRPO Training on H100 with {MODEL_ID}...")
132
+
133
+ # Create dummy dataset (the rollout_func overrides the prompt anyway by calling env.reset())
134
+ dummy_dataset = Dataset.from_dict({"prompt": ["Start"] * 64})
135
+
136
+ training_args = GRPOConfig(
137
+ use_vllm=True,
138
+ vllm_mode="colocate", # Runs vLLM and PyTorch on the same H100 GPU!
139
+ num_train_epochs=1,
140
+ num_generations=8, # How many rollout trajectories to try per prompt
141
+ max_completion_length=1024,
142
+ per_device_train_batch_size=8,
143
+ logging_steps=1,
144
+ output_dir="./outputs/qwen-skill-env",
145
  )
 
146
 
147
+ trainer = GRPOTrainer(
148
+ model=MODEL_ID,
149
+ reward_funcs=[reward_from_env],
150
+ train_dataset=dummy_dataset,
151
+ rollout_func=rollout_func,
152
+ args=training_args,
153
+ )
154
+
155
+ trainer.train()
156
+
157
+ print("Training complete! Pushing to hub...")
158
+ trainer.push_to_hub("mpnikhil/Qwen2.5-3B-Skill-Invocation", token=HF_TOKEN)