ARKAISW commited on
Commit
d5dc8ac
·
1 Parent(s): 7164a8e

Final tweaks to server and requirements

Browse files
Files changed (2) hide show
  1. api/server.py +20 -12
  2. requirements-space.txt +0 -1
api/server.py CHANGED
@@ -10,6 +10,7 @@ to enrich the UI with signal context but do NOT participate in the AEC loop.
10
 
11
  from pathlib import Path
12
  import asyncio
 
13
 
14
  import numpy as np
15
  import uvicorn
@@ -33,11 +34,6 @@ from training.train_multi_agent import (
33
  RuleRiskManagerPolicy,
34
  RuleTraderPolicy,
35
  )
36
- try:
37
- from unsloth import FastLanguageModel
38
- HAS_UNSLOTH = True
39
- except ImportError:
40
- HAS_UNSLOTH = False
41
 
42
 
43
  from huggingface_hub import snapshot_download
@@ -45,19 +41,30 @@ from huggingface_hub import snapshot_download
45
 
46
  class GRPOAgent:
47
  """Bridges the trained GRPO model to the UI demo."""
48
- def __init__(self, model_id="ARKAISW/quanthive-trader-grpo-lora"):
49
- self.model_id = model_id
50
  self.model = None
51
  self.tokenizer = None
52
  self.is_ready = False
53
 
54
  def load(self):
55
- if not HAS_UNSLOTH:
56
- print("Unsloth not installed. Falling back to rule-based.")
57
- return False
58
  try:
59
  import torch
60
- from transformers import AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  print(f"Attempting to sync GRPO model from {self.model_id}...")
62
  # Auto-download from HF Hub if not local
63
  local_dir = Path("models") / "grpo_hf_trained"
@@ -87,7 +94,8 @@ class GRPOAgent:
87
  import torch
88
  # Construct a prompt that looks like the training scenarios
89
  prompt = f"Observation: {obs[:5].tolist()}... (truncated)\nResponse:"
90
- inputs = self.tokenizer([prompt], return_tensors="pt").to("cuda")
 
91
 
92
  # Fast generation for demo smoothness
93
  with torch.no_grad():
 
10
 
11
  from pathlib import Path
12
  import asyncio
13
+ import os
14
 
15
  import numpy as np
16
  import uvicorn
 
34
  RuleRiskManagerPolicy,
35
  RuleTraderPolicy,
36
  )
 
 
 
 
 
37
 
38
 
39
  from huggingface_hub import snapshot_download
 
41
 
42
  class GRPOAgent:
43
  """Bridges the trained GRPO model to the UI demo."""
44
+ def __init__(self, model_id=None):
45
+ self.model_id = model_id or os.getenv("GRPO_MODEL_ID", "ARKAISW/QuantHive-GRPO-Trader")
46
  self.model = None
47
  self.tokenizer = None
48
  self.is_ready = False
49
 
50
  def load(self):
 
 
 
51
  try:
52
  import torch
53
+ except Exception as e:
54
+ print(f"PyTorch unavailable ({e}). Falling back to rule-based.")
55
+ return False
56
+
57
+ if not torch.cuda.is_available():
58
+ print("CUDA not available in this environment. Falling back to rule-based.")
59
+ return False
60
+
61
+ try:
62
+ from unsloth import FastLanguageModel
63
+ except Exception as e:
64
+ print(f"Could not import Unsloth: {e}. Falling back to rule-based.")
65
+ return False
66
+
67
+ try:
68
  print(f"Attempting to sync GRPO model from {self.model_id}...")
69
  # Auto-download from HF Hub if not local
70
  local_dir = Path("models") / "grpo_hf_trained"
 
94
  import torch
95
  # Construct a prompt that looks like the training scenarios
96
  prompt = f"Observation: {obs[:5].tolist()}... (truncated)\nResponse:"
97
+ device = getattr(self.model, "device", "cuda")
98
+ inputs = self.tokenizer([prompt], return_tensors="pt").to(device)
99
 
100
  # Fast generation for demo smoothness
101
  with torch.no_grad():
requirements-space.txt CHANGED
@@ -8,7 +8,6 @@ yfinance
8
  ccxt
9
  torch
10
  transformers
11
- unsloth
12
  accelerate
13
  safetensors
14
  jinja2
 
8
  ccxt
9
  torch
10
  transformers
 
11
  accelerate
12
  safetensors
13
  jinja2