Astocoder commited on
Commit
cf294ef
·
1 Parent(s): 4817a39

Fix: Use hackathon's LLM proxy via API_BASE_URL

Browse files
Files changed (1) hide show
  1. inference.py +75 -63
inference.py CHANGED
@@ -5,20 +5,12 @@ from typing import List, Optional
5
  from openai import OpenAI
6
  import requests
7
 
8
- # Try to load from .env file if it exists
9
- try:
10
- from dotenv import load_dotenv
11
- load_dotenv()
12
- print("[INFO] Loaded .env file", flush=True)
13
- except ImportError:
14
- print("[INFO] python-dotenv not installed, using system env only", flush=True)
15
-
16
- # Environment variables (set by the judge or .env)
17
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api-inference.huggingface.co/v1")
18
- MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3.2-3B-Instruct")
19
- HF_TOKEN = os.getenv("HF_TOKEN")
20
-
21
- # Quant-Gym specific configuration
22
  BASE_URL = os.getenv("BASE_URL", "http://localhost:8000")
23
  TASK_NAME = os.getenv("TASK_NAME", "quant-gym")
24
  BENCHMARK = os.getenv("BENCHMARK", "quant-gym")
@@ -27,10 +19,10 @@ TEMPERATURE = 0.7
27
  MAX_TOKENS = 200
28
  SUCCESS_SCORE_THRESHOLD = 0.7
29
 
30
- # System prompt for financial analysis
31
  SYSTEM_PROMPT = textwrap.dedent(
32
  """
33
- You are a financial analyst AI agent. Your goal is to analyze market data and make trading decisions.
34
 
35
  Available actions:
36
  - GET_PRICE: Get current stock price
@@ -39,16 +31,9 @@ SYSTEM_PROMPT = textwrap.dedent(
39
  - BACKTEST [strategy]: Backtest a strategy (momentum or mean_reversion)
40
  - GET_NEWS: Get latest news headline
41
 
42
- Strategy tips:
43
- - Positive news sentiment suggests BUY
44
- - Negative news sentiment suggests SELL
45
- - Momentum strategy: Buy when price is rising
46
- - Mean reversion: Buy when price is low relative to recent average
47
-
48
  Respond with EXACTLY one action in format: ACTION [parameter]
49
  Example: BUY 10
50
  Example: GET_PRICE
51
- Example: BACKTEST momentum
52
  """
53
  ).strip()
54
 
@@ -72,14 +57,11 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
72
 
73
 
74
  class QuantGymClient:
75
- """Client for interacting with Quant-Gym environment"""
76
-
77
  def __init__(self, base_url: str):
78
  self.base_url = base_url
79
  self.session = requests.Session()
80
 
81
  def reset(self):
82
- """Reset environment"""
83
  try:
84
  response = self.session.post(f"{self.base_url}/reset")
85
  return response.json()
@@ -87,34 +69,31 @@ class QuantGymClient:
87
  print(f"[ERROR] Reset failed: {e}", flush=True)
88
  return {"observation": {"price": 150, "balance": 10000, "holdings": 0, "portfolio_value": 10000}}
89
 
90
- def step(self, action: str, amount: int = 0, explanation: str = "", strategy: str = ""):
91
- """Execute an action"""
92
  action_upper = action.upper()
93
 
94
  if action_upper == "GET_PRICE":
95
  payload = {"type": "GET_PRICE"}
96
- elif action_upper == "GET_NEWS":
97
- payload = {"type": "GET_NEWS", "explanation": explanation}
98
  elif action_upper.startswith("BUY"):
 
99
  if " " in action_upper:
100
  try:
101
  amount = int(action_upper.split()[1])
102
  except:
103
- amount = 5
104
  payload = {"type": "BUY", "amount": amount}
105
  elif action_upper.startswith("SELL"):
 
106
  if " " in action_upper:
107
  try:
108
  amount = int(action_upper.split()[1])
109
  except:
110
- amount = 5
111
  payload = {"type": "SELL", "amount": amount}
112
  elif action_upper.startswith("BACKTEST"):
113
- if " " in action_upper:
114
- strategy = action_upper.split()[1]
115
- payload = {"type": "BACKTEST", "strategy": strategy}
116
  elif action_upper == "GET_NEWS":
117
- payload = {"type": "GET_NEWS", "explanation": explanation}
118
  else:
119
  payload = {"type": "GET_PRICE"}
120
 
@@ -126,12 +105,49 @@ class QuantGymClient:
126
  return {"observation": {"price": 150, "balance": 10000, "holdings": 0, "portfolio_value": 10000}}
127
 
128
  def close(self):
129
- """Close the session"""
130
  self.session.close()
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  def parse_action_from_response(text: str) -> str:
134
- """Parse LLM response into action string"""
135
  text = text.strip().upper()
136
 
137
  if text.startswith("BUY"):
@@ -145,7 +161,7 @@ def parse_action_from_response(text: str) -> str:
145
  return f"SELL {parts[1]}"
146
  return "SELL 5"
147
  elif text.startswith("BACKTEST"):
148
- return "BACKTEST momentum"
149
  elif text.startswith("GET_NEWS"):
150
  return "GET_NEWS"
151
  else:
@@ -153,7 +169,6 @@ def parse_action_from_response(text: str) -> str:
153
 
154
 
155
  def fallback_strategy(observation: dict) -> str:
156
- """Rule-based strategy when LLM is unavailable"""
157
  sentiment = observation.get('last_news', {}).get('sentiment', 'neutral')
158
  if sentiment == 'positive':
159
  return "BUY 5"
@@ -163,21 +178,22 @@ def fallback_strategy(observation: dict) -> str:
163
  return "GET_PRICE"
164
 
165
 
166
- def get_model_action(step: int, observation: dict, history: List[str]) -> str:
167
- """Get action using fallback strategy (no LLM required for basic testing)"""
168
- return fallback_strategy(observation)
169
-
170
-
171
  async def main() -> None:
172
  print("[INFO] Starting Quant-Gym Inference", flush=True)
173
 
174
- # Check token status
175
- if HF_TOKEN:
176
- print(f"[INFO] HF_TOKEN found (length: {len(HF_TOKEN)} chars)", flush=True)
177
- else:
178
- print("[INFO] No HF_TOKEN found, using rule-based fallback strategy", flush=True)
 
 
 
 
 
 
 
179
 
180
- # Initialize environment client
181
  env = QuantGymClient(BASE_URL)
182
 
183
  history: List[str] = []
@@ -186,29 +202,25 @@ async def main() -> None:
186
  success = False
187
  final_score = 0.0
188
 
189
- log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME if HF_TOKEN else "fallback-rule-based")
190
 
191
  try:
192
- # Reset environment
193
  result = env.reset()
194
  observation = result.get('observation', {})
195
- print(f"[INFO] Reset complete. Initial observation: {observation}", flush=True)
196
 
197
  for step in range(1, MAX_STEPS + 1):
198
- # Get action
199
- action_str = get_model_action(step, observation, history)
 
 
 
200
 
201
- # Execute action
202
  result = env.step(action_str)
203
  observation = result.get('observation', {})
204
 
205
- # Calculate reward
206
  portfolio_value = observation.get('portfolio_value', 10000)
207
- sentiment = observation.get('last_news', {}).get('sentiment', 'neutral')
208
-
209
  profit_reward = max(0, (portfolio_value - 10000) / 10000)
210
- sentiment_bonus = 0.2 if sentiment == 'positive' else (-0.1 if sentiment == 'negative' else 0)
211
- reward = min(1.0, max(0.0, profit_reward + sentiment_bonus))
212
 
213
  done = step >= MAX_STEPS - 1
214
  error = None
@@ -233,7 +245,7 @@ async def main() -> None:
233
  finally:
234
  try:
235
  env.close()
236
- except Exception as e:
237
  pass
238
  log_end(success=success, steps=steps_taken, score=final_score, rewards=rewards)
239
 
 
5
  from openai import OpenAI
6
  import requests
7
 
8
+
9
+ API_BASE_URL = os.getenv("API_BASE_URL") # NO DEFAULT
10
+ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") # default
11
+ HF_TOKEN = os.getenv("HF_TOKEN")
12
+
13
+ # Quant-Gym configuration
 
 
 
 
 
 
 
 
14
  BASE_URL = os.getenv("BASE_URL", "http://localhost:8000")
15
  TASK_NAME = os.getenv("TASK_NAME", "quant-gym")
16
  BENCHMARK = os.getenv("BENCHMARK", "quant-gym")
 
19
  MAX_TOKENS = 200
20
  SUCCESS_SCORE_THRESHOLD = 0.7
21
 
22
+ # System prompt
23
  SYSTEM_PROMPT = textwrap.dedent(
24
  """
25
+ You are a financial analyst AI agent. Analyze market data and make trading decisions.
26
 
27
  Available actions:
28
  - GET_PRICE: Get current stock price
 
31
  - BACKTEST [strategy]: Backtest a strategy (momentum or mean_reversion)
32
  - GET_NEWS: Get latest news headline
33
 
 
 
 
 
 
 
34
  Respond with EXACTLY one action in format: ACTION [parameter]
35
  Example: BUY 10
36
  Example: GET_PRICE
 
37
  """
38
  ).strip()
39
 
 
57
 
58
 
59
  class QuantGymClient:
 
 
60
  def __init__(self, base_url: str):
61
  self.base_url = base_url
62
  self.session = requests.Session()
63
 
64
  def reset(self):
 
65
  try:
66
  response = self.session.post(f"{self.base_url}/reset")
67
  return response.json()
 
69
  print(f"[ERROR] Reset failed: {e}", flush=True)
70
  return {"observation": {"price": 150, "balance": 10000, "holdings": 0, "portfolio_value": 10000}}
71
 
72
+ def step(self, action: str):
 
73
  action_upper = action.upper()
74
 
75
  if action_upper == "GET_PRICE":
76
  payload = {"type": "GET_PRICE"}
 
 
77
  elif action_upper.startswith("BUY"):
78
+ amount = 5
79
  if " " in action_upper:
80
  try:
81
  amount = int(action_upper.split()[1])
82
  except:
83
+ pass
84
  payload = {"type": "BUY", "amount": amount}
85
  elif action_upper.startswith("SELL"):
86
+ amount = 5
87
  if " " in action_upper:
88
  try:
89
  amount = int(action_upper.split()[1])
90
  except:
91
+ pass
92
  payload = {"type": "SELL", "amount": amount}
93
  elif action_upper.startswith("BACKTEST"):
94
+ payload = {"type": "BACKTEST", "strategy": "momentum"}
 
 
95
  elif action_upper == "GET_NEWS":
96
+ payload = {"type": "GET_NEWS", "explanation": "Analyzing market sentiment"}
97
  else:
98
  payload = {"type": "GET_PRICE"}
99
 
 
105
  return {"observation": {"price": 150, "balance": 10000, "holdings": 0, "portfolio_value": 10000}}
106
 
107
  def close(self):
 
108
  self.session.close()
109
 
110
 
111
+ def get_model_action(client: OpenAI, step: int, observation: dict, history: List[str]) -> str:
112
+ """Get action from LLM using the judge's proxy"""
113
+
114
+ # CRITICAL: Must use client with API_BASE_URL from judge
115
+ if not API_BASE_URL:
116
+ print("[WARNING] API_BASE_URL not set! Using fallback.", flush=True)
117
+ return fallback_strategy(observation)
118
+
119
+ user_prompt = textwrap.dedent(
120
+ f"""
121
+ Step: {step}
122
+ Current price: ${observation.get('price', 'unknown')}
123
+ Balance: ${observation.get('balance', 'unknown')}
124
+ Holdings: {observation.get('holdings', 0)} shares
125
+ Portfolio value: ${observation.get('portfolio_value', 'unknown')}
126
+ Latest news: {observation.get('last_news', {}).get('headline', 'No news')}
127
+
128
+ What is your next action? (BUY X, SELL X, GET_PRICE, BACKTEST, or GET_NEWS)
129
+ """
130
+ ).strip()
131
+
132
+ try:
133
+ # This MUST go through their proxy
134
+ completion = client.chat.completions.create(
135
+ model=MODEL_NAME,
136
+ messages=[
137
+ {"role": "system", "content": SYSTEM_PROMPT},
138
+ {"role": "user", "content": user_prompt},
139
+ ],
140
+ temperature=TEMPERATURE,
141
+ max_tokens=MAX_TOKENS,
142
+ )
143
+ text = completion.choices[0].message.content or ""
144
+ return parse_action_from_response(text)
145
+ except Exception as e:
146
+ print(f"[DEBUG] LLM error: {e}, using fallback", flush=True)
147
+ return fallback_strategy(observation)
148
+
149
+
150
  def parse_action_from_response(text: str) -> str:
 
151
  text = text.strip().upper()
152
 
153
  if text.startswith("BUY"):
 
161
  return f"SELL {parts[1]}"
162
  return "SELL 5"
163
  elif text.startswith("BACKTEST"):
164
+ return "BACKTEST"
165
  elif text.startswith("GET_NEWS"):
166
  return "GET_NEWS"
167
  else:
 
169
 
170
 
171
  def fallback_strategy(observation: dict) -> str:
 
172
  sentiment = observation.get('last_news', {}).get('sentiment', 'neutral')
173
  if sentiment == 'positive':
174
  return "BUY 5"
 
178
  return "GET_PRICE"
179
 
180
 
 
 
 
 
 
181
  async def main() -> None:
182
  print("[INFO] Starting Quant-Gym Inference", flush=True)
183
 
184
+ # CRITICAL: Check if API_BASE_URL is provided by judge
185
+ if not API_BASE_URL:
186
+ print("[ERROR] API_BASE_URL environment variable not set!", flush=True)
187
+ print("[ERROR] This must be provided by the hackathon judge.", flush=True)
188
+ print("[INFO] Using fallback strategy without LLM.", flush=True)
189
+
190
+ # Initialize OpenAI client with judge's proxy URL
191
+ # DO NOT use default - MUST use their provided URL
192
+ client = OpenAI(
193
+ base_url=API_BASE_URL, # Their proxy URL
194
+ api_key="dummy" # Their proxy may not need a real key
195
+ ) if API_BASE_URL else None
196
 
 
197
  env = QuantGymClient(BASE_URL)
198
 
199
  history: List[str] = []
 
202
  success = False
203
  final_score = 0.0
204
 
205
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
206
 
207
  try:
 
208
  result = env.reset()
209
  observation = result.get('observation', {})
 
210
 
211
  for step in range(1, MAX_STEPS + 1):
212
+ # Get action - this will use their proxy if available
213
+ if client:
214
+ action_str = get_model_action(client, step, observation, history)
215
+ else:
216
+ action_str = fallback_strategy(observation)
217
 
 
218
  result = env.step(action_str)
219
  observation = result.get('observation', {})
220
 
 
221
  portfolio_value = observation.get('portfolio_value', 10000)
 
 
222
  profit_reward = max(0, (portfolio_value - 10000) / 10000)
223
+ reward = min(1.0, max(0.0, profit_reward))
 
224
 
225
  done = step >= MAX_STEPS - 1
226
  error = None
 
245
  finally:
246
  try:
247
  env.close()
248
+ except:
249
  pass
250
  log_end(success=success, steps=steps_taken, score=final_score, rewards=rewards)
251