File size: 5,397 Bytes
f716726 8ba8cbd 91e580b c5c527c 91e580b c5c527c 8ba8cbd c5c527c 91e580b 8ba8cbd 91e580b 8ba8cbd c5c527c 91e580b c5c527c 8ba8cbd c5c527c 8ba8cbd c5c527c 8ba8cbd 91e580b 8ba8cbd 91e580b 8ba8cbd 91e580b 8ba8cbd 91e580b 8ba8cbd 91e580b 8ba8cbd 91e580b c5c527c 91e580b c5c527c 8ba8cbd 91e580b c5c527c 8ba8cbd c5c527c 91e580b c5c527c 91e580b c5c527c 91e580b c5c527c 91e580b c5c527c 91e580b 8ba8cbd c5c527c 8ba8cbd c5c527c 91e580b f716726 c5c527c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """
inference.py - root directory
3 tasks:
1. revenge-trade-detection — catch loss_streak >= 2
2. panic-sell-prevention — catch deep pnl < -0.3
3. overconfidence-correction — catch win streak + overtrading
"""
import os
from typing import List
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
from trade_env.env.coach_env import CoachEnv
from trade_env.schemas.action import Action, ActionType
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
BENCHMARK = "coach-env"
MAX_STEPS = 20
SUCCESS_SCORE_THRESHOLD = 0.1
TASKS = {
"revenge-trade-detection": {
"desc": "Detect and intervene on revenge trading after loss streaks",
"trigger": lambda s: s["loss_streak"] >= 0.2,
"correct_actions": [3, 4], # EXIT or COOLDOWN
},
"panic-sell-prevention": {
"desc": "Prevent panic selling during drawdowns",
"trigger": lambda s: s["pnl"] < -0.3,
"correct_actions": [2, 3], # REDUCE or EXIT
},
"overconfidence-correction": {
"desc": "Correct overconfident trading after wins",
"trigger": lambda s: s["overtrade_score"] >= 0.7 and s["pnl"] > 0.1,
"correct_actions": [1, 2], # WARN or REDUCE
},
}
def log_start(task, env, model):
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step, action, reward, done, error=None):
error_val = error if error else "null"
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)
def log_end(success, steps, score, rewards):
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
def get_llm_action(client, state, task_name):
prompt = (
f"You are a trading behavior coach. Task: {task_name}.\n"
f"Trader state: loss_streak={state['loss_streak']:.2f}, "
f"pnl={state['pnl']:.2f}, overtrade_score={state['overtrade_score']:.2f}.\n"
f"Reply with single digit only. 0=ignore 1=warn 2=reduce 3=exit 4=cooldown"
)
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": prompt}],
max_tokens=3,
temperature=0.0,
)
raw = (completion.choices[0].message.content or "").strip()[0]
action = int(raw)
if action not in range(5):
raise ValueError
return action
except:
pass
# rule-based fallback
loss = state["loss_streak"]
pnl = state["pnl"]
over = state["overtrade_score"]
if task_name == "revenge-trade-detection":
if loss >= 0.2: return 4
if loss >= 0.1: return 3
if loss > 0.0: return 1
return 0
if task_name == "panic-sell-prevention":
if pnl < -0.3: return 3
if pnl < -0.1: return 2
return 0
if task_name == "overconfidence-correction":
if over >= 0.7: return 2
if over >= 0.5: return 1
return 0
return 0
def run_task(client, task_name: str) -> float:
task = TASKS[task_name]
env = CoachEnv()
rewards: List[float] = []
steps_taken = 0
correct_interventions = 0
total_triggers = 0
log_start(task_name, BENCHMARK, MODEL_NAME)
try:
state = env.reset()
for step in range(1, MAX_STEPS + 1):
action_idx = get_llm_action(client, state, task_name)
action = Action(action=ActionType(action_idx))
next_state, reward, done, info = env.step(action)
# grade: did agent pick correct action when trigger fired?
if task["trigger"](state):
total_triggers += 1
if action_idx in task["correct_actions"]:
correct_interventions += 1
reward = abs(reward) + 0.1 # bonus for correct intervention
log_step(step, ActionType(action_idx).name, reward, done)
rewards.append(reward)
steps_taken = step
state = next_state
if done:
break
# score = intervention accuracy when triggers fired
if total_triggers > 0:
score = correct_interventions / total_triggers
else:
score = sum(r for r in rewards if r > 0)
score = min(1.0, score / 0.5)
score = min(1.0, max(0.0, score))
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
log_step(steps_taken + 1, "NO", 0.0, True, error=str(e))
success = False
score = 0.0
rewards = rewards or [0.0]
finally:
log_end(success, steps_taken, score, rewards)
return score
def main():
client = OpenAI(
api_key=HF_TOKEN,
base_url=API_BASE_URL
)
all_scores = []
for task_name in TASKS:
score = run_task(client, task_name)
all_scores.append(score)
avg = sum(all_scores) / len(all_scores)
print(f"[SUMMARY] tasks={len(all_scores)} avg_score={avg:.3f}", flush=True)
if __name__ == "__main__":
main() |