luguog commited on
Commit
06c2aae
·
verified ·
1 Parent(s): b833c19

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +89 -102
main.py CHANGED
@@ -23,21 +23,17 @@ GATE_API_BASE = os.getenv("GATE_API_BASE", "https://api.gate.io/api/v4")
23
  LOG_FILE = os.getenv("TRADE_LOG_FILE", "trading_log.jsonl")
24
  BAL_FILE = os.getenv("BALANCE_SNAP_FILE", "balance_snapshots.jsonl")
25
 
26
- LLM_ENDPOINT = os.getenv("LLM_ENDPOINT") # HF Inference endpoint or similar
27
- LLM_API_KEY = os.getenv("LLM_API_KEY") # optional bearer token
28
 
29
- if not GATE_API_KEY or not GATE_API_SECRET:
30
- raise RuntimeError("GATE_API_KEY and GATE_API_SECRET must be set")
31
 
32
- # -----------------------------------------------------------------------------
33
- # FastAPI app
34
- # -----------------------------------------------------------------------------
35
-
36
- app = FastAPI(title="gate4-alpha-api", version="0.2.0")
37
 
38
  app.add_middleware(
39
  CORSMiddleware,
40
- allow_origins=["*"], # restrict in real deployment
41
  allow_credentials=True,
42
  allow_methods=["*"],
43
  allow_headers=["*"],
@@ -81,7 +77,6 @@ class KPIResponse(BaseModel):
81
  class AlphaRequest(BaseModel):
82
  contract: str
83
  context: Optional[str] = None
84
- # Optional overrides for KPI input if caller wants to inject their data
85
  kpis_override: Optional[Dict[str, float]] = None
86
 
87
 
@@ -96,7 +91,54 @@ class AlphaDecision(BaseModel):
96
 
97
 
98
  # -----------------------------------------------------------------------------
99
- # Gate.io helpers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # -----------------------------------------------------------------------------
101
 
102
  def sign_request(method: str, path: str, query_string: str, body: str, timestamp: str) -> str:
@@ -109,6 +151,9 @@ def sign_request(method: str, path: str, query_string: str, body: str, timestamp
109
 
110
 
111
  def gate_private_get(path: str, query: str = "") -> Any:
 
 
 
112
  method = "GET"
113
  timestamp = str(int(time.time()))
114
  body = ""
@@ -128,14 +173,17 @@ def gate_private_get(path: str, query: str = "") -> Any:
128
  res.raise_for_status()
129
  except requests.RequestException as e:
130
  raise HTTPException(status_code=502, detail=f"Gate.io request failed: {e}")
131
-
132
  return res.json()
133
 
134
 
135
  def gate_public_get(path: str, query: str = "") -> Any:
 
 
 
136
  url = f"{GATE_API_BASE}{path}"
137
  if query:
138
  url = f"{url}?{query}"
 
139
  try:
140
  res = requests.get(url, timeout=10)
141
  res.raise_for_status()
@@ -145,9 +193,15 @@ def gate_public_get(path: str, query: str = "") -> Any:
145
 
146
 
147
  def get_futures_account_total_balance() -> float:
 
 
 
 
 
 
 
148
  path = "/futures/usdt/accounts"
149
  accounts = gate_private_get(path)
150
- # Gate futures account endpoint returns a list of account objects
151
  total = 0.0
152
  for acc in accounts:
153
  try:
@@ -158,9 +212,10 @@ def get_futures_account_total_balance() -> float:
158
 
159
 
160
  def get_contract_spread_bps(contract: str) -> float:
161
- """
162
- Uses futures ticker to derive bid/ask spread in basis points.
163
- """
 
164
  path = "/futures/usdt/tickers"
165
  query = f"contract={contract}"
166
  tickers = gate_public_get(path, query=query)
@@ -183,54 +238,6 @@ def get_contract_spread_bps(contract: str) -> float:
183
  return spread_bps
184
 
185
 
186
- # -----------------------------------------------------------------------------
187
- # File-backed state
188
- # -----------------------------------------------------------------------------
189
-
190
- def _safe_read_lines(path: str) -> List[str]:
191
- if not os.path.exists(path):
192
- return []
193
- with open(path, "r") as f:
194
- return [line for line in f if line.strip()]
195
-
196
-
197
- def load_trades() -> List[TradeLog]:
198
- lines = _safe_read_lines(LOG_FILE)
199
- out: List[TradeLog] = []
200
- for line in lines:
201
- try:
202
- raw = json.loads(line)
203
- out.append(TradeLog(**raw))
204
- except (json.JSONDecodeError, ValidationError):
205
- # skip malformed line
206
- continue
207
- return out
208
-
209
-
210
- def load_balances() -> List[BalanceSnapshot]:
211
- lines = _safe_read_lines(BAL_FILE)
212
- out: List[BalanceSnapshot] = []
213
- for line in lines:
214
- try:
215
- raw = json.loads(line)
216
- out.append(BalanceSnapshot(**raw))
217
- except (json.JSONDecodeError, ValidationError):
218
- continue
219
- return out
220
-
221
-
222
- def append_trade(trade: TradeLog) -> None:
223
- with _log_lock, open(LOG_FILE, "a") as f:
224
- f.write(trade.model_json() + "\n")
225
-
226
-
227
- def append_balance_snapshot(balance: float) -> BalanceSnapshot:
228
- snap = BalanceSnapshot(timestamp=int(time.time()), balance=balance)
229
- with _bal_lock, open(BAL_FILE, "a") as f:
230
- f.write(snap.model_json() + "\n")
231
- return snap
232
-
233
-
234
  # -----------------------------------------------------------------------------
235
  # KPI logic
236
  # -----------------------------------------------------------------------------
@@ -282,7 +289,7 @@ def kpis_to_feature_dict(kpis: KPIResponse) -> Dict[str, float]:
282
 
283
 
284
  # -----------------------------------------------------------------------------
285
- # LLM integration
286
  # -----------------------------------------------------------------------------
287
 
288
  def _build_alpha_prompt(req: AlphaRequest, spread_bps: float, kpis: Dict[str, float]) -> str:
@@ -306,11 +313,16 @@ def _build_alpha_prompt(req: AlphaRequest, spread_bps: float, kpis: Dict[str, fl
306
 
307
  def call_llm_for_alpha(prompt: str) -> Dict[str, Any]:
308
  if not LLM_ENDPOINT:
309
- raise HTTPException(status_code=500, detail="LLM_ENDPOINT not configured")
 
 
 
 
 
 
310
 
311
  payload = {
312
  "inputs": prompt,
313
- # HF / generic text generation params; adjust per endpoint
314
  "parameters": {
315
  "max_new_tokens": 256,
316
  "temperature": 0.1,
@@ -333,18 +345,14 @@ def call_llm_for_alpha(prompt: str) -> Dict[str, Any]:
333
  except json.JSONDecodeError:
334
  raise HTTPException(status_code=502, detail="LLM returned non-JSON payload")
335
 
336
- # HF endpoints often return [{"generated_text": "..."}]
337
  if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]:
338
  text = data[0]["generated_text"]
339
  elif isinstance(data, dict) and "generated_text" in data:
340
  text = data["generated_text"]
341
  else:
342
- # assume raw text
343
  text = str(data)
344
 
345
- # extract JSON from text
346
  text = text.strip()
347
- # best-effort: find first '{' ... last '}'
348
  start = text.find("{")
349
  end = text.rfind("}")
350
  if start == -1 or end == -1 or end <= start:
@@ -365,7 +373,7 @@ def build_alpha_decision(
365
  kpi_features: Dict[str, float],
366
  raw_model_out: Dict[str, Any],
367
  ) -> AlphaDecision:
368
- action = raw_model_out.get("action", "").lower().strip()
369
  if action not in ("long", "short", "flat"):
370
  action = "flat"
371
 
@@ -373,7 +381,6 @@ def build_alpha_decision(
373
  size_factor = float(raw_model_out.get("size_factor", 0.0))
374
  comment = str(raw_model_out.get("comment", "")).strip()[:240]
375
 
376
- # Hard safety clamps
377
  if confidence < 0.0:
378
  confidence = 0.0
379
  if confidence > 1.0:
@@ -383,9 +390,7 @@ def build_alpha_decision(
383
  if size_factor > 1.0:
384
  size_factor = 1.0
385
 
386
- # Deterministic risk gating based on KPIs
387
  if kpi_features.get("max_drawdown_pct", 0.0) < -30.0:
388
- # lock system to flat on deep drawdown
389
  action = "flat"
390
  confidence = min(confidence, 0.3)
391
  size_factor = 0.0
@@ -404,7 +409,7 @@ def build_alpha_decision(
404
  spread_bps=spread_bps,
405
  kpis=kpi_features,
406
  comment=comment,
407
- raw_model_output=raw_model_out,
408
  )
409
 
410
 
@@ -414,10 +419,11 @@ def build_alpha_decision(
414
 
415
  @app.get("/", response_class=HTMLResponse)
416
  def home() -> str:
417
- return """
 
418
  <html>
419
  <body>
420
- <h2>gate4-alpha-api</h2>
421
  <p>Endpoints:</p>
422
  <ul>
423
  <li>GET /balance</li>
@@ -426,6 +432,7 @@ def home() -> str:
426
  <li>POST /log_trade</li>
427
  <li>POST /alpha/entry</li>
428
  <li>GET /openapi.yaml</li>
 
429
  </ul>
430
  </body>
431
  </html>
@@ -434,9 +441,6 @@ def home() -> str:
434
 
435
  @app.get("/openapi.yaml")
436
  def get_openapi():
437
- """
438
- Serve static OpenAPI file if you generate one externally.
439
- """
440
  if not os.path.exists("openapi.yaml"):
441
  raise HTTPException(status_code=404, detail="openapi.yaml not found")
442
  return FileResponse("openapi.yaml", media_type="text/yaml")
@@ -444,26 +448,21 @@ def get_openapi():
444
 
445
  @app.get("/balance")
446
  def get_balance():
447
- """
448
- Snapshot Gate.io USDT futures account and persist balance curve.
449
- """
450
  total = get_futures_account_total_balance()
451
  snap = append_balance_snapshot(total)
452
  return {
453
  "timestamp": snap.timestamp,
454
  "balance": round(snap.balance, 6),
 
455
  }
456
 
457
 
458
  @app.get("/performance")
459
  def get_performance():
460
- """
461
- Human-readable PnL summary + last few trades.
462
- """
463
  trades = load_trades()
464
  balances = load_balances()
465
  if not trades and not balances:
466
- return {"summary": "No trades or balances logged yet."}
467
 
468
  kpis = compute_kpis(trades, balances)
469
  summary = (
@@ -491,14 +490,12 @@ def get_performance():
491
  "summary": summary,
492
  "last_trades": tail,
493
  "kpis": kpis.model_dump(),
 
494
  }
495
 
496
 
497
  @app.get("/kpis", response_model=KPIResponse)
498
  def get_kpis():
499
- """
500
- Machine-consumable KPI surface for external systems.
501
- """
502
  trades = load_trades()
503
  balances = load_balances()
504
  return compute_kpis(trades, balances)
@@ -506,10 +503,6 @@ def get_kpis():
506
 
507
  @app.post("/log_trade", response_model=TradeLog)
508
  async def log_trade(request: Request):
509
- """
510
- Append a trade event into the trading log.
511
- Request body must match TradeLog schema (fields can be omitted if defaulted).
512
- """
513
  payload = await request.json()
514
  try:
515
  trade = TradeLog(**payload)
@@ -521,12 +514,6 @@ async def log_trade(request: Request):
521
 
522
  @app.post("/alpha/entry", response_model=AlphaDecision)
523
  async def alpha_entry(req: AlphaRequest):
524
- """
525
- LLM-based entry decision:
526
- - Derives KPIs from logs unless overridden.
527
- - Computes current spread in bps for the requested contract.
528
- - Calls LLM policy layer and enforces deterministic risk clamps.
529
- """
530
  trades = load_trades()
531
  balances = load_balances()
532
  base_kpis = compute_kpis(trades, balances)
 
23
  LOG_FILE = os.getenv("TRADE_LOG_FILE", "trading_log.jsonl")
24
  BAL_FILE = os.getenv("BALANCE_SNAP_FILE", "balance_snapshots.jsonl")
25
 
26
+ LLM_ENDPOINT = os.getenv("LLM_ENDPOINT")
27
+ LLM_API_KEY = os.getenv("LLM_API_KEY")
28
 
29
+ # DRY_RUN = true when explicitly set OR when exchange keys are missing
30
+ DRY_RUN = os.getenv("DRY_RUN", "0") == "1" or not (GATE_API_KEY and GATE_API_SECRET)
31
 
32
+ app = FastAPI(title="gate4-alpha-api", version="0.3.0", docs_url="/docs", redoc_url="/redoc")
 
 
 
 
33
 
34
  app.add_middleware(
35
  CORSMiddleware,
36
+ allow_origins=["*"], # tighten for prod
37
  allow_credentials=True,
38
  allow_methods=["*"],
39
  allow_headers=["*"],
 
77
  class AlphaRequest(BaseModel):
78
  contract: str
79
  context: Optional[str] = None
 
80
  kpis_override: Optional[Dict[str, float]] = None
81
 
82
 
 
91
 
92
 
93
  # -----------------------------------------------------------------------------
94
+ # File-backed state
95
+ # -----------------------------------------------------------------------------
96
+
97
+ def _safe_read_lines(path: str) -> List[str]:
98
+ if not os.path.exists(path):
99
+ return []
100
+ with open(path, "r") as f:
101
+ return [line for line in f if line.strip()]
102
+
103
+
104
+ def load_trades() -> List[TradeLog]:
105
+ lines = _safe_read_lines(LOG_FILE)
106
+ out: List[TradeLog] = []
107
+ for line in lines:
108
+ try:
109
+ raw = json.loads(line)
110
+ out.append(TradeLog(**raw))
111
+ except (json.JSONDecodeError, ValidationError):
112
+ continue
113
+ return out
114
+
115
+
116
+ def load_balances() -> List[BalanceSnapshot]:
117
+ lines = _safe_read_lines(BAL_FILE)
118
+ out: List[BalanceSnapshot] = []
119
+ for line in lines:
120
+ try:
121
+ raw = json.loads(line)
122
+ out.append(BalanceSnapshot(**raw))
123
+ except (json.JSONDecodeError, ValidationError):
124
+ continue
125
+ return out
126
+
127
+
128
+ def append_trade(trade: TradeLog) -> None:
129
+ with _log_lock, open(LOG_FILE, "a") as f:
130
+ f.write(trade.model_json() + "\n")
131
+
132
+
133
+ def append_balance_snapshot(balance: float) -> BalanceSnapshot:
134
+ snap = BalanceSnapshot(timestamp=int(time.time()), balance=balance)
135
+ with _bal_lock, open(BAL_FILE, "a") as f:
136
+ f.write(snap.model_json() + "\n")
137
+ return snap
138
+
139
+
140
+ # -----------------------------------------------------------------------------
141
+ # Gate.io helpers (dry-run aware)
142
  # -----------------------------------------------------------------------------
143
 
144
  def sign_request(method: str, path: str, query_string: str, body: str, timestamp: str) -> str:
 
151
 
152
 
153
  def gate_private_get(path: str, query: str = "") -> Any:
154
+ if DRY_RUN:
155
+ raise HTTPException(status_code=503, detail="Exchange private API disabled in dry-run mode")
156
+
157
  method = "GET"
158
  timestamp = str(int(time.time()))
159
  body = ""
 
173
  res.raise_for_status()
174
  except requests.RequestException as e:
175
  raise HTTPException(status_code=502, detail=f"Gate.io request failed: {e}")
 
176
  return res.json()
177
 
178
 
179
  def gate_public_get(path: str, query: str = "") -> Any:
180
+ if DRY_RUN:
181
+ raise HTTPException(status_code=503, detail="Exchange public API disabled in dry-run mode")
182
+
183
  url = f"{GATE_API_BASE}{path}"
184
  if query:
185
  url = f"{url}?{query}"
186
+
187
  try:
188
  res = requests.get(url, timeout=10)
189
  res.raise_for_status()
 
193
 
194
 
195
  def get_futures_account_total_balance() -> float:
196
+ if DRY_RUN:
197
+ # In dry-run: use last balance if exists, else deterministic constant
198
+ balances = load_balances()
199
+ if balances:
200
+ return balances[-1].balance
201
+ return float(os.getenv("DRY_RUN_BALANCE", "10000.0"))
202
+
203
  path = "/futures/usdt/accounts"
204
  accounts = gate_private_get(path)
 
205
  total = 0.0
206
  for acc in accounts:
207
  try:
 
212
 
213
 
214
  def get_contract_spread_bps(contract: str) -> float:
215
+ if DRY_RUN:
216
+ # Deterministic spread for offline mode; override via env if needed
217
+ return float(os.getenv("DRY_RUN_SPREAD_BPS", "5.0"))
218
+
219
  path = "/futures/usdt/tickers"
220
  query = f"contract={contract}"
221
  tickers = gate_public_get(path, query=query)
 
238
  return spread_bps
239
 
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  # -----------------------------------------------------------------------------
242
  # KPI logic
243
  # -----------------------------------------------------------------------------
 
289
 
290
 
291
  # -----------------------------------------------------------------------------
292
+ # LLM integration (dry-run aware)
293
  # -----------------------------------------------------------------------------
294
 
295
  def _build_alpha_prompt(req: AlphaRequest, spread_bps: float, kpis: Dict[str, float]) -> str:
 
313
 
314
  def call_llm_for_alpha(prompt: str) -> Dict[str, Any]:
315
  if not LLM_ENDPOINT:
316
+ # Dry-run LLM: force flat, no external call
317
+ return {
318
+ "action": "flat",
319
+ "confidence": 0.0,
320
+ "size_factor": 0.0,
321
+ "comment": "LLM endpoint not configured; dry-run flat policy.",
322
+ }
323
 
324
  payload = {
325
  "inputs": prompt,
 
326
  "parameters": {
327
  "max_new_tokens": 256,
328
  "temperature": 0.1,
 
345
  except json.JSONDecodeError:
346
  raise HTTPException(status_code=502, detail="LLM returned non-JSON payload")
347
 
 
348
  if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]:
349
  text = data[0]["generated_text"]
350
  elif isinstance(data, dict) and "generated_text" in data:
351
  text = data["generated_text"]
352
  else:
 
353
  text = str(data)
354
 
 
355
  text = text.strip()
 
356
  start = text.find("{")
357
  end = text.rfind("}")
358
  if start == -1 or end == -1 or end <= start:
 
373
  kpi_features: Dict[str, float],
374
  raw_model_out: Dict[str, Any],
375
  ) -> AlphaDecision:
376
+ action = str(raw_model_out.get("action", "")).lower().strip()
377
  if action not in ("long", "short", "flat"):
378
  action = "flat"
379
 
 
381
  size_factor = float(raw_model_out.get("size_factor", 0.0))
382
  comment = str(raw_model_out.get("comment", "")).strip()[:240]
383
 
 
384
  if confidence < 0.0:
385
  confidence = 0.0
386
  if confidence > 1.0:
 
390
  if size_factor > 1.0:
391
  size_factor = 1.0
392
 
 
393
  if kpi_features.get("max_drawdown_pct", 0.0) < -30.0:
 
394
  action = "flat"
395
  confidence = min(confidence, 0.3)
396
  size_factor = 0.0
 
409
  spread_bps=spread_bps,
410
  kpis=kpi_features,
411
  comment=comment,
412
+ raw_model_output=raw_model_out if LLM_ENDPOINT else None,
413
  )
414
 
415
 
 
419
 
420
  @app.get("/", response_class=HTMLResponse)
421
  def home() -> str:
422
+ mode = "DRY-RUN" if DRY_RUN else "LIVE"
423
+ return f"""
424
  <html>
425
  <body>
426
+ <h2>gate4-alpha-api ({mode})</h2>
427
  <p>Endpoints:</p>
428
  <ul>
429
  <li>GET /balance</li>
 
432
  <li>POST /log_trade</li>
433
  <li>POST /alpha/entry</li>
434
  <li>GET /openapi.yaml</li>
435
+ <li>GET /docs</li>
436
  </ul>
437
  </body>
438
  </html>
 
441
 
442
  @app.get("/openapi.yaml")
443
  def get_openapi():
 
 
 
444
  if not os.path.exists("openapi.yaml"):
445
  raise HTTPException(status_code=404, detail="openapi.yaml not found")
446
  return FileResponse("openapi.yaml", media_type="text/yaml")
 
448
 
449
  @app.get("/balance")
450
  def get_balance():
 
 
 
451
  total = get_futures_account_total_balance()
452
  snap = append_balance_snapshot(total)
453
  return {
454
  "timestamp": snap.timestamp,
455
  "balance": round(snap.balance, 6),
456
+ "dry_run": DRY_RUN,
457
  }
458
 
459
 
460
  @app.get("/performance")
461
  def get_performance():
 
 
 
462
  trades = load_trades()
463
  balances = load_balances()
464
  if not trades and not balances:
465
+ return {"summary": "No trades or balances logged yet.", "dry_run": DRY_RUN}
466
 
467
  kpis = compute_kpis(trades, balances)
468
  summary = (
 
490
  "summary": summary,
491
  "last_trades": tail,
492
  "kpis": kpis.model_dump(),
493
+ "dry_run": DRY_RUN,
494
  }
495
 
496
 
497
  @app.get("/kpis", response_model=KPIResponse)
498
  def get_kpis():
 
 
 
499
  trades = load_trades()
500
  balances = load_balances()
501
  return compute_kpis(trades, balances)
 
503
 
504
  @app.post("/log_trade", response_model=TradeLog)
505
  async def log_trade(request: Request):
 
 
 
 
506
  payload = await request.json()
507
  try:
508
  trade = TradeLog(**payload)
 
514
 
515
  @app.post("/alpha/entry", response_model=AlphaDecision)
516
  async def alpha_entry(req: AlphaRequest):
 
 
 
 
 
 
517
  trades = load_trades()
518
  balances = load_balances()
519
  base_kpis = compute_kpis(trades, balances)