ARKAISW commited on
Commit
8489eaa
Β·
1 Parent(s): 5686d79

Real ML Demo + Stability Fixes

Browse files
Files changed (2) hide show
  1. api/server.py +103 -5
  2. app.py +22 -4
api/server.py CHANGED
@@ -29,10 +29,93 @@ from env.multi_agent_env import (
29
  # TradingEnv kept for backward compat data generation only (not used in endpoints)
30
  from training.config import TrainingConfig
31
  from training.train_multi_agent import (
32
- RuleRiskManagerPolicy,
33
- RulePortfolioManagerPolicy,
34
  RuleTraderPolicy,
35
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  ROOT_DIR = Path(__file__).resolve().parents[1]
@@ -104,6 +187,8 @@ class SimulationRunner:
104
 
105
  def __init__(self):
106
  self.config = TrainingConfig(tickers=["AAPL"], fast_mode=True, max_steps=100)
 
 
107
 
108
  # ── PettingZoo multi-agent environment ──────────────────────────────
109
  self.env = MultiAgentTradingEnv(
@@ -139,14 +224,18 @@ class SimulationRunner:
139
  }
140
  self._openenv_env.reset()
141
 
 
 
 
 
142
  # ── Initialize demo PZ env ──────────────────────────────────────────
143
  self.env.reset()
144
  self.done = False
145
 
146
  sim_state["engine"] = {
147
  "name": "Multi-Agent Governance (PettingZoo AEC)",
148
- "mode": "Rule Fallback",
149
- "policy_active": False,
150
  "note": "Three independent RL agents negotiating via AEC turns: RiskManager β†’ PortfolioManager β†’ Trader.",
151
  }
152
 
@@ -194,7 +283,16 @@ class SimulationRunner:
194
  break
195
 
196
  obs = self.env.observe(agent)
197
- action = self.policies[agent].act(obs)
 
 
 
 
 
 
 
 
 
198
 
199
  if agent == RISK_MANAGER:
200
  rm_action = action
 
29
  # TradingEnv kept for backward compat data generation only (not used in endpoints)
30
  from training.config import TrainingConfig
31
  from training.train_multi_agent import (
 
 
32
  RuleTraderPolicy,
33
  )
34
+ try:
35
+ from unsloth import FastLanguageModel
36
+ HAS_UNSLOTH = True
37
+ except ImportError:
38
+ HAS_UNSLOTH = False
39
+
40
+
41
+ from huggingface_hub import snapshot_download
42
+
43
+
44
+ class GRPOAgent:
45
+ """Bridges the trained GRPO model to the UI demo."""
46
+ def __init__(self, model_id="ARKAISW/quanthive-trader-grpo-lora"):
47
+ self.model_id = model_id
48
+ self.model = None
49
+ self.tokenizer = None
50
+ self.is_ready = False
51
+
52
+ def load(self):
53
+ if not HAS_UNSLOTH:
54
+ print("Unsloth not installed. Falling back to rule-based.")
55
+ return False
56
+ try:
57
+ import torch
58
+ from transformers import AutoTokenizer
59
+ print(f"Attempting to sync GRPO model from {self.model_id}...")
60
+ # Auto-download from HF Hub if not local
61
+ local_dir = Path("models") / "grpo_hf_trained"
62
+ local_dir.mkdir(parents=True, exist_ok=True)
63
+ snapshot_download(repo_id=self.model_id, local_dir=local_dir,
64
+ allow_patterns=["*.json", "*.bin", "*.safetensors", "*.txt"])
65
+
66
+ print(f"Loading weights from {local_dir}...")
67
+ self.model, self.tokenizer = FastLanguageModel.from_pretrained(
68
+ model_name=str(local_dir),
69
+ max_seq_length=2048,
70
+ load_in_4bit=True,
71
+ )
72
+ FastLanguageModel.for_inference(self.model)
73
+ self.is_ready = True
74
+ print("βœ… GRPO Model loaded successfully.")
75
+ return True
76
+ except Exception as e:
77
+ print(f"Could not load GRPO model: {e}")
78
+ return False
79
+
80
+ def act(self, obs: np.ndarray) -> dict:
81
+ """Sample an action from the GRPO model."""
82
+ if not self.is_ready:
83
+ return None
84
+ try:
85
+ import torch
86
+ # Construct a prompt that looks like the training scenarios
87
+ prompt = f"Observation: {obs[:5].tolist()}... (truncated)\nResponse:"
88
+ inputs = self.tokenizer([prompt], return_tensors="pt").to("cuda")
89
+
90
+ # Fast generation for demo smoothness
91
+ with torch.no_grad():
92
+ outputs = self.model.generate(
93
+ **inputs,
94
+ max_new_tokens=32,
95
+ use_cache=True,
96
+ pad_token_id=self.tokenizer.eos_token_id
97
+ )
98
+
99
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
100
+
101
+ # Basic parsing of the model's 'thought' or action intent
102
+ # If the model says 'buy' or 'up', we signal 1, etc.
103
+ direction = 0
104
+ if "buy" in response.lower() or "up" in response.lower():
105
+ direction = 1
106
+ elif "sell" in response.lower() or "down" in response.lower() or "short" in response.lower():
107
+ direction = 2
108
+
109
+ return {
110
+ "direction": direction,
111
+ "size": np.array([0.15], dtype=np.float32),
112
+ "sl": np.array([0.0], dtype=np.float32),
113
+ "tp": np.array([0.0], dtype=np.float32),
114
+ "thought": response[:100] # Expose thought to UI
115
+ }
116
+ except Exception as e:
117
+ print(f"GRPO inference error: {e}")
118
+ return None
119
 
120
 
121
  ROOT_DIR = Path(__file__).resolve().parents[1]
 
187
 
188
  def __init__(self):
189
  self.config = TrainingConfig(tickers=["AAPL"], fast_mode=True, max_steps=100)
190
+ # Reduced commission for demo realism (preventing bleed from rule-based noise)
191
+ self.config.commission = 0.0001
192
 
193
  # ── PettingZoo multi-agent environment ──────────────────────────────
194
  self.env = MultiAgentTradingEnv(
 
224
  }
225
  self._openenv_env.reset()
226
 
227
+ # ── GRPO ML Agent (Bridges to real trained weights) ──────────────────
228
+ self.grpo_agent = GRPOAgent()
229
+ self.is_ml_active = self.grpo_agent.load()
230
+
231
  # ── Initialize demo PZ env ──────────────────────────────────────────
232
  self.env.reset()
233
  self.done = False
234
 
235
  sim_state["engine"] = {
236
  "name": "Multi-Agent Governance (PettingZoo AEC)",
237
+ "mode": "GRPO (Trained Model)" if self.is_ml_active else "Rule Fallback",
238
+ "policy_active": self.is_ml_active,
239
  "note": "Three independent RL agents negotiating via AEC turns: RiskManager β†’ PortfolioManager β†’ Trader.",
240
  }
241
 
 
283
  break
284
 
285
  obs = self.env.observe(agent)
286
+
287
+ # Use ML if active and it's the Trader's turn
288
+ action = None
289
+ if self.is_ml_active and agent == TRADER:
290
+ ml_action = self.grpo_agent.act(obs)
291
+ if ml_action:
292
+ action = ml_action
293
+
294
+ if action is None:
295
+ action = self.policies[agent].act(obs)
296
 
297
  if agent == RISK_MANAGER:
298
  rm_action = action
app.py CHANGED
@@ -10,10 +10,6 @@ Usage:
10
  import argparse
11
  import sys
12
 
13
- from training.config import TrainingConfig
14
- from training.train import train, run_random_baseline
15
- from utils.evaluate import evaluate
16
-
17
 
18
  def parse_args():
19
  parser = argparse.ArgumentParser(
@@ -70,6 +66,11 @@ def main():
70
  fast_mode=args.fast,
71
  )
72
 
 
 
 
 
 
73
  # Optionally fetch real data or generate GBM
74
  df = None
75
  if args.gbm:
@@ -84,10 +85,27 @@ def main():
84
  df = fetch_yfinance(args.ticker, args.start, args.end)
85
  print(f"Loaded {len(df)} rows of market data.\n")
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  if args.evaluate:
 
88
  results = evaluate(config, df=df)
89
  print(f"\nGrade improvement: {results['grade_improvement']:+.4f}")
90
  else:
 
91
  metrics = train(config, df=df)
92
  print(f"\nDone! {len(metrics)} episodes completed.")
93
 
 
10
  import argparse
11
  import sys
12
 
 
 
 
 
13
 
14
  def parse_args():
15
  parser = argparse.ArgumentParser(
 
66
  fast_mode=args.fast,
67
  )
68
 
69
+ config_cls = None
70
+ if not args.demo:
71
+ from training.config import TrainingConfig
72
+ config_cls = TrainingConfig
73
+
74
  # Optionally fetch real data or generate GBM
75
  df = None
76
  if args.gbm:
 
85
  df = fetch_yfinance(args.ticker, args.start, args.end)
86
  print(f"Loaded {len(df)} rows of market data.\n")
87
 
88
+ if args.demo:
89
+ return
90
+
91
+ config = config_cls(
92
+ tickers=[args.ticker],
93
+ start_date=args.start,
94
+ end_date=args.end,
95
+ initial_cash=args.cash,
96
+ num_episodes=2 if args.fast else args.episodes,
97
+ seed=args.seed,
98
+ log_every=args.log_every,
99
+ max_steps=50 if args.fast else args.max_steps,
100
+ fast_mode=args.fast,
101
+ )
102
+
103
  if args.evaluate:
104
+ from utils.evaluate import evaluate
105
  results = evaluate(config, df=df)
106
  print(f"\nGrade improvement: {results['grade_improvement']:+.4f}")
107
  else:
108
+ from training.train import train
109
  metrics = train(config, df=df)
110
  print(f"\nDone! {len(metrics)} episodes completed.")
111