tether007 commited on
Commit
8ba8cbd
·
1 Parent(s): d149718

trade_env package issue

Browse files
Files changed (2) hide show
  1. inference.py +69 -36
  2. trade_env/__init__.py +1 -8
inference.py CHANGED
@@ -1,65 +1,97 @@
1
  """
2
- inference.py - must be in root directory
3
- Uses OpenAI client for LLM calls as per hackathon requirements
4
- Emits [START], [STEP], [END] structured logs
5
  """
6
- from dotenv import load_dotenv
7
- load_dotenv()
8
  import os
 
9
  from openai import OpenAI
 
 
 
 
10
  from trade_env.env.coach_env import CoachEnv
11
  from trade_env.schemas.action import Action, ActionType
12
 
 
 
 
 
 
 
13
  TASK_NAME = "trader-coach"
14
  BENCHMARK = "coach-env"
15
- MODEL_NAME = os.getenv("MODEL_NAME", "gemini-3-flash")
16
- API_BASE = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
17
- HF_TOKEN = os.getenv("HF_TOKEN", "")
18
  MAX_STEPS = 20
 
19
 
20
- client = OpenAI(
21
- api_key=os.getenv("GEMINI_API_KEY"),
22
- base_url=API_BASE
23
- )
24
 
25
-
26
- def get_llm_action(state: dict) -> int:
27
- if state["loss_streak"] >= 3:
28
- return 4
29
- if state["loss_streak"] >= 2:
30
- return 3
31
- if state["loss_streak"] >= 1:
32
- return 1
33
- if state["pnl"] < -30:
34
- return 2
35
- return 0
36
-
37
- def log_start():
38
- print(f"[START] task={TASK_NAME} env={BENCHMARK} model={MODEL_NAME}")
39
 
40
 
41
  def log_step(step, action, reward, done, error=None):
42
  error_val = error if error else "null"
43
- print(f"[STEP] step={step} action={action} reward={reward:.4f} done={str(done).lower()} error={error_val}")
44
 
45
 
46
  def log_end(success, steps, score, rewards):
47
- rewards_str = ",".join(f"{r:.4f}" for r in rewards)
48
- print(f"[END] success={str(success).lower()} steps={steps} score={score:.4f} rewards={rewards_str}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
  def main():
 
 
 
 
 
52
  env = CoachEnv()
53
- rewards = []
54
  steps_taken = 0
 
 
55
 
56
- log_start()
57
 
58
  try:
59
  state = env.reset()
60
 
61
  for step in range(1, MAX_STEPS + 1):
62
- action_idx = get_llm_action(state)
63
  action = Action(action=ActionType(action_idx))
64
 
65
  next_state, reward, done, info = env.step(action)
@@ -73,9 +105,9 @@ def main():
73
  if done:
74
  break
75
 
76
- total_reward = sum(rewards)
77
- score = max(0.0, min(1.0, (total_reward + 1.0) / 2.0))
78
- success = score > 0.1
79
 
80
  except Exception as e:
81
  log_step(steps_taken + 1, "NO", 0.0, True, error=str(e))
@@ -83,7 +115,8 @@ def main():
83
  score = 0.0
84
  rewards = rewards or [0.0]
85
 
86
- log_end(success, steps_taken, score, rewards)
 
87
 
88
 
89
  if __name__ == "__main__":
 
1
  """
2
+ inference.py - root directory
 
 
3
  """
4
+ import asyncio
 
5
  import os
6
+ from typing import List, Optional
7
  from openai import OpenAI
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
  from trade_env.env.coach_env import CoachEnv
13
  from trade_env.schemas.action import Action, ActionType
14
 
15
+
16
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
17
+ MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
20
+
21
  TASK_NAME = "trader-coach"
22
  BENCHMARK = "coach-env"
 
 
 
23
  MAX_STEPS = 20
24
+ SUCCESS_SCORE_THRESHOLD = 0.1
25
 
 
 
 
 
26
 
27
+ def log_start(task, env, model):
28
+ print(f"[START] task={task} env={env} model={model}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def log_step(step, action, reward, done, error=None):
32
  error_val = error if error else "null"
33
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
34
 
35
 
36
  def log_end(success, steps, score, rewards):
37
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
38
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
39
+
40
+
41
+ def get_llm_action(client: OpenAI, state: dict, step: int) -> int:
42
+ prompt = f"""You are a trading behavior coach. Given trader state:
43
+ - timestep: {state['timestep']}
44
+ - price: {state['price']:.2f}
45
+ - position: {state['position']}
46
+ - loss_streak: {state['loss_streak']}
47
+ - pnl: {state['pnl']:.2f}
48
+
49
+ Choose intervention (reply with single integer only):
50
+ 0=NO, 1=WARN, 2=REDUCE, 3=EXIT, 4=COOLDOWN"""
51
+
52
+ try:
53
+ completion = client.chat.completions.create(
54
+ model=MODEL_NAME,
55
+ messages=[{"role": "user", "content": prompt}],
56
+ max_tokens=5,
57
+ temperature=0.0,
58
+ )
59
+ raw = (completion.choices[0].message.content or "").strip()
60
+ action = int(raw)
61
+ if action not in range(5):
62
+ action = 0
63
+ return action
64
+ except Exception:
65
+ # rule-based fallback with normalized values
66
+ loss = state["loss_streak"] # 0.0 to 1.0
67
+ pnl = state["pnl"] # -1.0 to 1.0
68
+
69
+ if loss >= 0.2: return 4 # COOLDOWN
70
+ if loss >= 0.1: return 3 # EXIT
71
+ if pnl < -0.3: return 2 # REDUCE
72
+ if loss > 0.0: return 1 # WARN
73
+ return 0 # NO
74
 
75
 
76
  def main():
77
+ client = OpenAI(
78
+ api_key=HF_TOKEN,
79
+ base_url=API_BASE_URL
80
+ )
81
+
82
  env = CoachEnv()
83
+ rewards: List[float] = []
84
  steps_taken = 0
85
+ score = 0.0
86
+ success = False
87
 
88
+ log_start(TASK_NAME, BENCHMARK, MODEL_NAME)
89
 
90
  try:
91
  state = env.reset()
92
 
93
  for step in range(1, MAX_STEPS + 1):
94
+ action_idx = get_llm_action(client, state, step)
95
  action = Action(action=ActionType(action_idx))
96
 
97
  next_state, reward, done, info = env.step(action)
 
105
  if done:
106
  break
107
 
108
+ score = sum(rewards) / MAX_STEPS
109
+ score = min(max(score, 0.0), 1.0)
110
+ success = score >= SUCCESS_SCORE_THRESHOLD
111
 
112
  except Exception as e:
113
  log_step(steps_taken + 1, "NO", 0.0, True, error=str(e))
 
115
  score = 0.0
116
  rewards = rewards or [0.0]
117
 
118
+ finally:
119
+ log_end(success, steps_taken, score, rewards)
120
 
121
 
122
  if __name__ == "__main__":
trade_env/__init__.py CHANGED
@@ -6,11 +6,4 @@
6
 
7
  """Trade Env Environment."""
8
 
9
- from .client import TradeEnv
10
- from .models import TradeAction, TradeObservation
11
-
12
- __all__ = [
13
- "TradeAction",
14
- "TradeObservation",
15
- "TradeEnv",
16
- ]
 
6
 
7
  """Trade Env Environment."""
8
 
9
+ __all__ = []