ARKAISW commited on
Commit
30a586b
Β·
1 Parent(s): a0d8bc5

fix(notebook): correct clone step order, extract prompt utils, fix github url

Browse files
training/prompt_utils.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ import random
4
+ from pathlib import Path
5
+ from typing import Dict, List
6
+ import numpy as np
7
+
8
+ ROOT = Path(__file__).resolve().parents[1]
9
+ if str(ROOT) not in sys.path:
10
+ sys.path.insert(0, str(ROOT))
11
+
12
+ from env.multi_agent_env import (
13
+ MultiAgentTradingEnv,
14
+ RISK_MANAGER,
15
+ PORTFOLIO_MGR,
16
+ TRADER,
17
+ )
18
+ from training.train_multi_agent import (
19
+ RuleRiskManagerPolicy,
20
+ RulePortfolioManagerPolicy,
21
+ )
22
+
23
+ SYSTEM_PROMPT = """You are a trading agent in a multi-agent governance system.
24
+ The Risk Manager has set governance constraints, and the Portfolio Manager has allocated capital.
25
+ Your job: propose a trade that maximizes profit while respecting these constraints.
26
+
27
+ Respond exactly in this format:
28
+ <thought>
29
+ Your reasoning about the market state, risk constraints, and trade decision.
30
+ </thought>
31
+ <action>
32
+ {"direction": 0, "size": 0.0, "sl": 0, "tp": 0}
33
+ </action>
34
+ """
35
+
36
+ def generate_pz_scenarios(
37
+ n: int = 500,
38
+ difficulty: str = "easy",
39
+ max_env_steps: int = 100,
40
+ ) -> List[Dict]:
41
+ """Run the PZ env with rule policies to generate realistic scenarios.
42
+
43
+ Each scenario captures:
44
+ - The Trader's full observation (29 dims)
45
+ - The RM constraints decoded from the message
46
+ - The PM allocation decoded from the message
47
+ """
48
+ env = MultiAgentTradingEnv(difficulty=difficulty, max_steps=max_env_steps)
49
+ rm_policy = RuleRiskManagerPolicy()
50
+ pm_policy = RulePortfolioManagerPolicy()
51
+
52
+ scenarios: List[Dict] = []
53
+ attempts = 0
54
+ max_attempts = n * 3
55
+
56
+ while len(scenarios) < n and attempts < max_attempts:
57
+ env.reset()
58
+ attempts += 1
59
+
60
+ step_count = 0
61
+ while env.agents and step_count < max_env_steps:
62
+ agent = env.agent_selection
63
+
64
+ if agent == RISK_MANAGER:
65
+ obs = env.observe(agent)
66
+ action = rm_policy.act(obs)
67
+ env.step(action)
68
+
69
+ elif agent == PORTFOLIO_MGR:
70
+ obs = env.observe(agent)
71
+ action = pm_policy.act(obs)
72
+ env.step(action)
73
+
74
+ elif agent == TRADER:
75
+ obs = env.observe(agent)
76
+ # Extract RM and PM messages from the observation
77
+ # obs layout: base(24) + rm_msg(3) + pm_msg(2) = 29
78
+ base_obs = obs[:24].tolist()
79
+ rm_msg = obs[24:27].tolist() # [size_limit, allow_new, force_reduce]
80
+ pm_msg = obs[27:29].tolist() # [cap_alloc, override_strength]
81
+
82
+ rm_size_limit = float(rm_msg[0])
83
+ rm_allow_new = bool(rm_msg[1] > 0.5)
84
+ rm_force_reduce = bool(rm_msg[2] > 0.5)
85
+ pm_cap_alloc = float(pm_msg[0])
86
+ pm_override = float(pm_msg[1])
87
+
88
+ scenarios.append({
89
+ "state": [round(float(x), 4) for x in base_obs[:5]],
90
+ "full_obs": [round(float(x), 4) for x in base_obs],
91
+ "rm_size_limit": round(rm_size_limit, 3),
92
+ "rm_allow_new": rm_allow_new,
93
+ "rm_force_reduce": rm_force_reduce,
94
+ "pm_cap_alloc": round(pm_cap_alloc, 3),
95
+ "pm_override": round(pm_override, 3),
96
+ "signals": {
97
+ "ta": round(float(obs[5] * 2 - 1), 3), # RSI mapped to [-1,1]
98
+ "fa": round(float(obs[8]), 3), # MACD as FA proxy
99
+ "position_limit": round(rm_size_limit, 3),
100
+ "rm_size_limit": round(rm_size_limit, 3),
101
+ },
102
+ })
103
+
104
+ if len(scenarios) >= n:
105
+ break
106
+
107
+ # Take a random trader action so the env advances
108
+ trader_action = {
109
+ "direction": random.choice([0, 1, 2]),
110
+ "size": np.array([random.uniform(0.05, 0.3)], dtype=np.float32),
111
+ "sl": np.array([0.0], dtype=np.float32),
112
+ "tp": np.array([0.0], dtype=np.float32),
113
+ }
114
+ env.step(trader_action)
115
+
116
+ step_count += 1
117
+
118
+ random.shuffle(scenarios)
119
+ return scenarios[:n]
120
+
121
+
122
+ def build_prompt_multiagent(scenario: Dict) -> str:
123
+ """Build the prompt for the Trader, including RM and PM constraints."""
124
+ rm_limit = scenario["rm_size_limit"]
125
+ rm_allow_str = "allowed" if scenario.get("rm_allow_new", True) else "BLOCKED"
126
+ rm_force_str = "yes" if scenario.get("rm_force_reduce", False) else "no"
127
+ pm_cap = scenario["pm_cap_alloc"]
128
+ pm_override_str = "none" if scenario.get("pm_override", 0.0) < 0.5 else "ACTIVE"
129
+
130
+ state = scenario.get("state", [1.0, 1.0, 1.0, 1.0, 1.0])
131
+ signals = scenario.get("signals", {})
132
+
133
+ body = json.dumps({
134
+ "state": state,
135
+ "signals": signals,
136
+ "governance": {
137
+ "rm_size_limit": rm_limit,
138
+ "rm_allow_new": rm_allow_str,
139
+ "rm_force_reduce": rm_force_str,
140
+ "pm_cap_alloc": pm_cap,
141
+ "pm_override": pm_override_str,
142
+ },
143
+ }, separators=(",", ":"))
144
+
145
+ prompt = (
146
+ f"{SYSTEM_PROMPT}\n"
147
+ f"The Risk Manager has set the following constraints: "
148
+ f"size_limit={rm_limit:.2f}, new_positions={rm_allow_str}, force_reduce={rm_force_str}.\n"
149
+ f"The Portfolio Manager allocated: capital_cap={pm_cap:.2f}, override={pm_override_str}.\n\n"
150
+ f"Scenario:\n{body}\n"
151
+ )
152
+ return prompt
training/train_grpo_multiagent.py CHANGED
@@ -52,139 +52,8 @@ from training.train_multi_agent import (
52
  DEFAULT_MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit"
53
  DEFAULT_OUTPUT_DIR = "models/local_policy_grpo_multiagent"
54
 
55
- SYSTEM_PROMPT = """You are a trading agent in a multi-agent governance system.
56
- The Risk Manager has set governance constraints, and the Portfolio Manager has allocated capital.
57
- Your job: propose a trade that maximizes profit while respecting these constraints.
58
-
59
- Respond exactly in this format:
60
- <thought>
61
- Your reasoning about the market state, risk constraints, and trade decision.
62
- </thought>
63
- <action>
64
- {"direction": 0, "size": 0.0, "sl": 0, "tp": 0}
65
- </action>
66
- """
67
-
68
-
69
- # ─── Scenario Generation from PettingZoo Env ──────────────────────────────────
70
 
71
- def generate_pz_scenarios(
72
- n: int = 500,
73
- difficulty: str = "easy",
74
- max_env_steps: int = 100,
75
- ) -> List[Dict]:
76
- """Run the PZ env with rule policies to generate realistic scenarios.
77
-
78
- Each scenario captures:
79
- - The Trader's full observation (29 dims)
80
- - The RM constraints decoded from the message
81
- - The PM allocation decoded from the message
82
- """
83
- env = MultiAgentTradingEnv(difficulty=difficulty, max_steps=max_env_steps)
84
- rm_policy = RuleRiskManagerPolicy()
85
- pm_policy = RulePortfolioManagerPolicy()
86
-
87
- scenarios: List[Dict] = []
88
- attempts = 0
89
- max_attempts = n * 3
90
-
91
- while len(scenarios) < n and attempts < max_attempts:
92
- env.reset()
93
- attempts += 1
94
-
95
- step_count = 0
96
- while env.agents and step_count < max_env_steps:
97
- agent = env.agent_selection
98
-
99
- if agent == RISK_MANAGER:
100
- obs = env.observe(agent)
101
- action = rm_policy.act(obs)
102
- env.step(action)
103
-
104
- elif agent == PORTFOLIO_MGR:
105
- obs = env.observe(agent)
106
- action = pm_policy.act(obs)
107
- env.step(action)
108
-
109
- elif agent == TRADER:
110
- obs = env.observe(agent)
111
- # Extract RM and PM messages from the observation
112
- # obs layout: base(24) + rm_msg(3) + pm_msg(2) = 29
113
- base_obs = obs[:24].tolist()
114
- rm_msg = obs[24:27].tolist() # [size_limit, allow_new, force_reduce]
115
- pm_msg = obs[27:29].tolist() # [cap_alloc, override_strength]
116
-
117
- rm_size_limit = float(rm_msg[0])
118
- rm_allow_new = bool(rm_msg[1] > 0.5)
119
- rm_force_reduce = bool(rm_msg[2] > 0.5)
120
- pm_cap_alloc = float(pm_msg[0])
121
- pm_override = float(pm_msg[1])
122
-
123
- scenarios.append({
124
- "state": [round(float(x), 4) for x in base_obs[:5]],
125
- "full_obs": [round(float(x), 4) for x in base_obs],
126
- "rm_size_limit": round(rm_size_limit, 3),
127
- "rm_allow_new": rm_allow_new,
128
- "rm_force_reduce": rm_force_reduce,
129
- "pm_cap_alloc": round(pm_cap_alloc, 3),
130
- "pm_override": round(pm_override, 3),
131
- "signals": {
132
- "ta": round(float(obs[5] * 2 - 1), 3), # RSI mapped to [-1,1]
133
- "fa": round(float(obs[8]), 3), # MACD as FA proxy
134
- "position_limit": round(rm_size_limit, 3),
135
- "rm_size_limit": round(rm_size_limit, 3),
136
- },
137
- })
138
-
139
- if len(scenarios) >= n:
140
- break
141
-
142
- # Take a random trader action so the env advances
143
- trader_action = {
144
- "direction": random.choice([0, 1, 2]),
145
- "size": np.array([random.uniform(0.05, 0.3)], dtype=np.float32),
146
- "sl": np.array([0.0], dtype=np.float32),
147
- "tp": np.array([0.0], dtype=np.float32),
148
- }
149
- env.step(trader_action)
150
-
151
- step_count += 1
152
-
153
- random.shuffle(scenarios)
154
- return scenarios[:n]
155
-
156
-
157
- def build_prompt_multiagent(scenario: Dict) -> str:
158
- """Build the prompt for the Trader, including RM and PM constraints."""
159
- rm_limit = scenario["rm_size_limit"]
160
- rm_allow_str = "allowed" if scenario.get("rm_allow_new", True) else "BLOCKED"
161
- rm_force_str = "yes" if scenario.get("rm_force_reduce", False) else "no"
162
- pm_cap = scenario["pm_cap_alloc"]
163
- pm_override_str = "none" if scenario.get("pm_override", 0.0) < 0.5 else "ACTIVE"
164
-
165
- state = scenario.get("state", [1.0, 1.0, 1.0, 1.0, 1.0])
166
- signals = scenario.get("signals", {})
167
-
168
- body = json.dumps({
169
- "state": state,
170
- "signals": signals,
171
- "governance": {
172
- "rm_size_limit": rm_limit,
173
- "rm_allow_new": rm_allow_str,
174
- "rm_force_reduce": rm_force_str,
175
- "pm_cap_alloc": pm_cap,
176
- "pm_override": pm_override_str,
177
- },
178
- }, separators=(",", ":"))
179
-
180
- prompt = (
181
- f"{SYSTEM_PROMPT}\n"
182
- f"The Risk Manager has set the following constraints: "
183
- f"size_limit={rm_limit:.2f}, new_positions={rm_allow_str}, force_reduce={rm_force_str}.\n"
184
- f"The Portfolio Manager allocated: capital_cap={pm_cap:.2f}, override={pm_override_str}.\n\n"
185
- f"Scenario:\n{body}\n"
186
- )
187
- return prompt
188
 
189
 
190
  # ─── Updated GRPO Verifiers ───────────────────────────────────────────────────
 
52
  DEFAULT_MODEL_NAME = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit"
53
  DEFAULT_OUTPUT_DIR = "models/local_policy_grpo_multiagent"
54
 
55
+ from training.prompt_utils import SYSTEM_PROMPT, generate_pz_scenarios, build_prompt_multiagent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  # ─── Updated GRPO Verifiers ───────────────────────────────────────────────────