File size: 6,665 Bytes
8be84ff e7e76f5 8be84ff 717bee1 8be84ff e7e76f5 9ca0bab e7e76f5 cf294ef e7e76f5 8be84ff e7e76f5 8be84ff e7e76f5 8be84ff e7e76f5 8be84ff e7e76f5 8be84ff e7e76f5 8be84ff e7e76f5 8be84ff 89faf8b e7e76f5 89faf8b e7e76f5 cf294ef e7e76f5 cf294ef e7e76f5 cf294ef e7e76f5 cf294ef e7e76f5 0216803 e7e76f5 8be84ff e7e76f5 8be84ff e7e76f5 89faf8b e7e76f5 89faf8b e7e76f5 89faf8b e7e76f5 8be84ff e7e76f5 8be84ff e7e76f5 8be84ff e7e76f5 0216803 e7e76f5 0216803 e7e76f5 0216803 8be84ff 717bee1 e7e76f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | import asyncio
import os
import sys
from typing import List, Optional
from openai import OpenAI
import requests
# ============================================
# MANDATORY: Read exactly what validator injects
# ============================================
API_KEY = os.environ.get("API_KEY")
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.environ.get("MODEL_NAME") or "gpt-3.5-turbo"
# Your HF Space URL
ENV_URL = os.environ.get("ENV_URL", "https://astocoder-quant-gym.hf.space")
# Configuration
MAX_STEPS = 10
# ============================================
# Tasks list - must match openenv.yaml ids
# ============================================
TASKS = [
("task1", "Fetch Market Data"),
("task2", "News Sentiment Analysis"),
("task3", "Backtest Strategy"),
]
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
def log_end(task: str, success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
# CRITICAL: task= MUST be included
print(f"[END] task={task} success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
class QuantGymClient:
def __init__(self, base_url: str):
self.base_url = base_url
self.session = requests.Session()
def reset(self):
resp = self.session.post(f"{self.base_url}/reset")
return resp.json()
def step(self, action: dict):
resp = self.session.post(f"{self.base_url}/step", json=action)
return resp.json()
def close(self):
self.session.close()
def fallback_strategy(observation: dict) -> str:
sentiment = observation.get('last_news', {}).get('sentiment', 'neutral')
if sentiment == 'positive':
return "BUY 5"
elif sentiment == 'negative':
return "SELL 5"
else:
return "GET_PRICE"
def get_model_action(client: OpenAI, step: int, observation: dict) -> str:
if not client:
return fallback_strategy(observation)
user_prompt = f"Step {step}. Price: ${observation.get('price', 0)}. Balance: ${observation.get('balance', 0)}. Choose: BUY 5, SELL 5, GET_PRICE, BACKTEST, GET_NEWS"
try:
completion = client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "user", "content": user_prompt}],
max_tokens=100,
)
text = completion.choices[0].message.content or ""
text = text.strip().upper()
if "BUY" in text:
return "BUY 5"
elif "SELL" in text:
return "SELL 5"
elif "BACKTEST" in text:
return "BACKTEST"
elif "GET_NEWS" in text:
return "GET_NEWS"
else:
return "GET_PRICE"
except Exception as e:
print(f"[DEBUG] LLM error: {e}", flush=True)
return fallback_strategy(observation)
def calculate_reward(observation: dict) -> float:
portfolio = observation.get('portfolio_value', 10000)
reward = max(0, (portfolio - 10000) / 10000)
# Clamp to (0.001, 0.999) - NEVER exactly 0 or 1
return max(0.001, min(0.999, reward))
def main():
print("[INFO] Starting Quant-Gym Inference", flush=True)
print(f"[INFO] ENV_URL: {ENV_URL}", flush=True)
print(f"[INFO] API_BASE_URL: {API_BASE_URL}", flush=True)
print(f"[INFO] MODEL_NAME: {MODEL_NAME}", flush=True)
# Initialize OpenAI client
client = None
if API_KEY and API_BASE_URL:
try:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
print("[INFO] OpenAI client initialized", flush=True)
except Exception as e:
print(f"[WARNING] Failed: {e}", flush=True)
env = QuantGymClient(ENV_URL)
# Loop through all tasks - CRITICAL
for task_id, task_name in TASKS:
rewards = []
steps = 0
score = 0.001
success = False
log_start(task=task_id, env="quant-gym", model=MODEL_NAME)
try:
# Reset environment
result = env.reset()
observation = result.get('observation', {})
done = False
for step in range(1, MAX_STEPS + 1):
if done:
break
# Get action
action_str = get_model_action(client, step, observation)
# Parse action
action_upper = action_str.upper()
if action_upper == "GET_PRICE":
payload = {"type": "GET_PRICE"}
elif action_upper.startswith("BUY"):
payload = {"type": "BUY", "amount": 5}
elif action_upper.startswith("SELL"):
payload = {"type": "SELL", "amount": 5}
elif action_upper.startswith("BACKTEST"):
payload = {"type": "BACKTEST", "strategy": "momentum"}
elif action_upper == "GET_NEWS":
payload = {"type": "GET_NEWS", "explanation": "Market analysis"}
else:
payload = {"type": "GET_PRICE"}
# Execute action
result = env.step(payload)
observation = result.get('observation', {})
# Calculate reward
reward = calculate_reward(observation)
rewards.append(reward)
steps = step
done = step >= MAX_STEPS - 1
log_step(step=step, action=action_str[:100], reward=reward, done=done, error=None)
# Get final score from observation or calculate
score = calculate_reward(observation)
score = max(0.001, min(0.999, score))
success = score > 0.5
except Exception as e:
print(f"[DEBUG] {task_id} error: {e}", flush=True)
score = 0.001
success = False
finally:
# CRITICAL: END line MUST have task= field
log_end(task=task_id, success=success, steps=steps, score=score, rewards=rewards)
env.close()
if __name__ == "__main__":
main()
|