|
|
import os |
|
|
import time |
|
|
import json |
|
|
import hmac |
|
|
import hashlib |
|
|
import threading |
|
|
from typing import Any, Dict, List, Optional, Literal |
|
|
|
|
|
import requests |
|
|
from fastapi import FastAPI, HTTPException, Request |
|
|
from fastapi.responses import HTMLResponse, FileResponse |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel, Field, ValidationError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GATE_API_KEY = os.getenv("GATE_API_KEY") |
|
|
GATE_API_SECRET = os.getenv("GATE_API_SECRET") |
|
|
GATE_API_BASE = os.getenv("GATE_API_BASE", "https://api.gate.io/api/v4") |
|
|
|
|
|
LOG_FILE = os.getenv("TRADE_LOG_FILE", "trading_log.jsonl") |
|
|
BAL_FILE = os.getenv("BALANCE_SNAP_FILE", "balance_snapshots.jsonl") |
|
|
|
|
|
LLM_ENDPOINT = os.getenv("LLM_ENDPOINT") |
|
|
LLM_API_KEY = os.getenv("LLM_API_KEY") |
|
|
|
|
|
|
|
|
DRY_RUN = os.getenv("DRY_RUN", "0") == "1" or not (GATE_API_KEY and GATE_API_SECRET) |
|
|
|
|
|
app = FastAPI(title="gate4-alpha-api", version="0.3.0", docs_url="/docs", redoc_url="/redoc") |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
_log_lock = threading.Lock() |
|
|
_bal_lock = threading.Lock() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TradeLog(BaseModel): |
|
|
timestamp: int = Field(default_factory=lambda: int(time.time())) |
|
|
action: Literal["long", "short", "flat", "close"] |
|
|
contract: str |
|
|
size: float |
|
|
entry_price: float |
|
|
exit_price: Optional[float] = None |
|
|
pnl_realized: float = 0.0 |
|
|
pnl_estimate: float = 0.0 |
|
|
reason: Optional[str] = None |
|
|
meta: Dict[str, Any] = Field(default_factory=dict) |
|
|
|
|
|
|
|
|
class BalanceSnapshot(BaseModel): |
|
|
timestamp: int |
|
|
balance: float |
|
|
|
|
|
|
|
|
class KPIResponse(BaseModel): |
|
|
total_pnl: float |
|
|
realized_pnl: float |
|
|
trade_count: int |
|
|
win_rate: float |
|
|
max_drawdown_pct: float |
|
|
avg_pnl_per_trade: float |
|
|
equity_curve: List[BalanceSnapshot] |
|
|
|
|
|
|
|
|
class AlphaRequest(BaseModel): |
|
|
contract: str |
|
|
context: Optional[str] = None |
|
|
kpis_override: Optional[Dict[str, float]] = None |
|
|
|
|
|
|
|
|
class AlphaDecision(BaseModel): |
|
|
action: Literal["long", "short", "flat"] |
|
|
confidence: float = Field(ge=0.0, le=1.0) |
|
|
size_factor: float = Field(ge=0.0, le=1.0) |
|
|
spread_bps: float |
|
|
kpis: Dict[str, float] |
|
|
comment: str |
|
|
raw_model_output: Optional[Any] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _safe_read_lines(path: str) -> List[str]: |
|
|
if not os.path.exists(path): |
|
|
return [] |
|
|
with open(path, "r") as f: |
|
|
return [line for line in f if line.strip()] |
|
|
|
|
|
|
|
|
def load_trades() -> List[TradeLog]: |
|
|
lines = _safe_read_lines(LOG_FILE) |
|
|
out: List[TradeLog] = [] |
|
|
for line in lines: |
|
|
try: |
|
|
raw = json.loads(line) |
|
|
out.append(TradeLog(**raw)) |
|
|
except (json.JSONDecodeError, ValidationError): |
|
|
continue |
|
|
return out |
|
|
|
|
|
|
|
|
def load_balances() -> List[BalanceSnapshot]: |
|
|
lines = _safe_read_lines(BAL_FILE) |
|
|
out: List[BalanceSnapshot] = [] |
|
|
for line in lines: |
|
|
try: |
|
|
raw = json.loads(line) |
|
|
out.append(BalanceSnapshot(**raw)) |
|
|
except (json.JSONDecodeError, ValidationError): |
|
|
continue |
|
|
return out |
|
|
|
|
|
|
|
|
def append_trade(trade: TradeLog) -> None: |
|
|
with _log_lock, open(LOG_FILE, "a") as f: |
|
|
f.write(trade.model_json() + "\n") |
|
|
|
|
|
|
|
|
def append_balance_snapshot(balance: float) -> BalanceSnapshot: |
|
|
snap = BalanceSnapshot(timestamp=int(time.time()), balance=balance) |
|
|
with _bal_lock, open(BAL_FILE, "a") as f: |
|
|
f.write(snap.model_json() + "\n") |
|
|
return snap |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sign_request(method: str, path: str, query_string: str, body: str, timestamp: str) -> str: |
|
|
message = f"{method}\n{path}\n{query_string}\n{body}\n{timestamp}" |
|
|
return hmac.new( |
|
|
GATE_API_SECRET.encode(), |
|
|
message.encode(), |
|
|
hashlib.sha512, |
|
|
).hexdigest() |
|
|
|
|
|
|
|
|
def gate_private_get(path: str, query: str = "") -> Any: |
|
|
if DRY_RUN: |
|
|
raise HTTPException(status_code=503, detail="Exchange private API disabled in dry-run mode") |
|
|
|
|
|
method = "GET" |
|
|
timestamp = str(int(time.time())) |
|
|
body = "" |
|
|
|
|
|
sign = sign_request(method, path, query, body, timestamp) |
|
|
headers = { |
|
|
"KEY": GATE_API_KEY, |
|
|
"Timestamp": timestamp, |
|
|
"SIGN": sign, |
|
|
} |
|
|
url = f"{GATE_API_BASE}{path}" |
|
|
if query: |
|
|
url = f"{url}?{query}" |
|
|
|
|
|
try: |
|
|
res = requests.get(url, headers=headers, timeout=10) |
|
|
res.raise_for_status() |
|
|
except requests.RequestException as e: |
|
|
raise HTTPException(status_code=502, detail=f"Gate.io request failed: {e}") |
|
|
return res.json() |
|
|
|
|
|
|
|
|
def gate_public_get(path: str, query: str = "") -> Any: |
|
|
if DRY_RUN: |
|
|
raise HTTPException(status_code=503, detail="Exchange public API disabled in dry-run mode") |
|
|
|
|
|
url = f"{GATE_API_BASE}{path}" |
|
|
if query: |
|
|
url = f"{url}?{query}" |
|
|
|
|
|
try: |
|
|
res = requests.get(url, timeout=10) |
|
|
res.raise_for_status() |
|
|
except requests.RequestException as e: |
|
|
raise HTTPException(status_code=502, detail=f"Gate.io public request failed: {e}") |
|
|
return res.json() |
|
|
|
|
|
|
|
|
def get_futures_account_total_balance() -> float: |
|
|
if DRY_RUN: |
|
|
|
|
|
balances = load_balances() |
|
|
if balances: |
|
|
return balances[-1].balance |
|
|
return float(os.getenv("DRY_RUN_BALANCE", "10000.0")) |
|
|
|
|
|
path = "/futures/usdt/accounts" |
|
|
accounts = gate_private_get(path) |
|
|
total = 0.0 |
|
|
for acc in accounts: |
|
|
try: |
|
|
total += float(acc.get("available", 0.0)) |
|
|
except (TypeError, ValueError): |
|
|
continue |
|
|
return total |
|
|
|
|
|
|
|
|
def get_contract_spread_bps(contract: str) -> float: |
|
|
if DRY_RUN: |
|
|
|
|
|
return float(os.getenv("DRY_RUN_SPREAD_BPS", "5.0")) |
|
|
|
|
|
path = "/futures/usdt/tickers" |
|
|
query = f"contract={contract}" |
|
|
tickers = gate_public_get(path, query=query) |
|
|
|
|
|
if not tickers: |
|
|
raise HTTPException(status_code=404, detail=f"No ticker data for {contract}") |
|
|
|
|
|
t = tickers[0] |
|
|
try: |
|
|
bid = float(t.get("bid", 0.0)) |
|
|
ask = float(t.get("ask", 0.0)) |
|
|
except (TypeError, ValueError): |
|
|
raise HTTPException(status_code=502, detail="Malformed ticker from Gate.io") |
|
|
|
|
|
if bid <= 0 or ask <= 0 or ask <= bid: |
|
|
return 0.0 |
|
|
|
|
|
mid = 0.5 * (bid + ask) |
|
|
spread_bps = (ask - bid) / mid * 1e4 |
|
|
return spread_bps |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_kpis(trades: List[TradeLog], balances: List[BalanceSnapshot]) -> KPIResponse: |
|
|
realized_pnls = [t.pnl_realized for t in trades] |
|
|
est_pnls = [t.pnl_estimate for t in trades] |
|
|
|
|
|
realized_pnl = float(sum(realized_pnls)) |
|
|
total_pnl = float(realized_pnl + sum(est_pnls)) |
|
|
|
|
|
trade_count = len(trades) |
|
|
wins = sum(1 for t in trades if t.pnl_realized > 0) |
|
|
win_rate = float(wins / trade_count) if trade_count > 0 else 0.0 |
|
|
avg_pnl_per_trade = float(realized_pnl / trade_count) if trade_count > 0 else 0.0 |
|
|
|
|
|
equity_curve = balances |
|
|
max_drawdown_pct = 0.0 |
|
|
if equity_curve: |
|
|
peak = equity_curve[0].balance |
|
|
for point in equity_curve: |
|
|
if point.balance > peak: |
|
|
peak = point.balance |
|
|
if peak > 0: |
|
|
dd = (point.balance - peak) / peak * 100.0 |
|
|
if dd < max_drawdown_pct: |
|
|
max_drawdown_pct = dd |
|
|
|
|
|
return KPIResponse( |
|
|
total_pnl=total_pnl, |
|
|
realized_pnl=realized_pnl, |
|
|
trade_count=trade_count, |
|
|
win_rate=win_rate, |
|
|
max_drawdown_pct=max_drawdown_pct, |
|
|
avg_pnl_per_trade=avg_pnl_per_trade, |
|
|
equity_curve=equity_curve, |
|
|
) |
|
|
|
|
|
|
|
|
def kpis_to_feature_dict(kpis: KPIResponse) -> Dict[str, float]: |
|
|
return { |
|
|
"total_pnl": kpis.total_pnl, |
|
|
"realized_pnl": kpis.realized_pnl, |
|
|
"trade_count": float(kpis.trade_count), |
|
|
"win_rate": kpis.win_rate, |
|
|
"max_drawdown_pct": kpis.max_drawdown_pct, |
|
|
"avg_pnl_per_trade": kpis.avg_pnl_per_trade, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_alpha_prompt(req: AlphaRequest, spread_bps: float, kpis: Dict[str, float]) -> str: |
|
|
kpi_json = json.dumps(kpis, sort_keys=True) |
|
|
ctx = req.context or "" |
|
|
return ( |
|
|
"You are a deterministic trading policy engine.\n" |
|
|
"Given KPIs and spread metrics, choose ONE action: long, short, or flat.\n" |
|
|
"Respond ONLY with a compact JSON object:\n" |
|
|
'{ "action": "...", "confidence": 0.0-1.0, "size_factor": 0.0-1.0, "comment": "..." }\n\n' |
|
|
f"Contract: {req.contract}\n" |
|
|
f"Spread_bps: {spread_bps:.4f}\n" |
|
|
f"KPIs: {kpi_json}\n" |
|
|
f"Context: {ctx}\n" |
|
|
"Constraints:\n" |
|
|
"- If max_drawdown_pct < -25 or win_rate < 0.4 ⇒ prefer flat.\n" |
|
|
"- If trade_count < 10 ⇒ confidence <= 0.4.\n" |
|
|
"- size_factor must be <= 0.3 if spread_bps > 10.\n" |
|
|
) |
|
|
|
|
|
|
|
|
def call_llm_for_alpha(prompt: str) -> Dict[str, Any]: |
|
|
if not LLM_ENDPOINT: |
|
|
|
|
|
return { |
|
|
"action": "flat", |
|
|
"confidence": 0.0, |
|
|
"size_factor": 0.0, |
|
|
"comment": "LLM endpoint not configured; dry-run flat policy.", |
|
|
} |
|
|
|
|
|
payload = { |
|
|
"inputs": prompt, |
|
|
"parameters": { |
|
|
"max_new_tokens": 256, |
|
|
"temperature": 0.1, |
|
|
"top_p": 0.9, |
|
|
"return_full_text": False, |
|
|
}, |
|
|
} |
|
|
headers = {"Content-Type": "application/json"} |
|
|
if LLM_API_KEY: |
|
|
headers["Authorization"] = f"Bearer {LLM_API_KEY}" |
|
|
|
|
|
try: |
|
|
res = requests.post(LLM_ENDPOINT, headers=headers, json=payload, timeout=20) |
|
|
res.raise_for_status() |
|
|
except requests.RequestException as e: |
|
|
raise HTTPException(status_code=502, detail=f"LLM request failed: {e}") |
|
|
|
|
|
try: |
|
|
data = res.json() |
|
|
except json.JSONDecodeError: |
|
|
raise HTTPException(status_code=502, detail="LLM returned non-JSON payload") |
|
|
|
|
|
if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]: |
|
|
text = data[0]["generated_text"] |
|
|
elif isinstance(data, dict) and "generated_text" in data: |
|
|
text = data["generated_text"] |
|
|
else: |
|
|
text = str(data) |
|
|
|
|
|
text = text.strip() |
|
|
start = text.find("{") |
|
|
end = text.rfind("}") |
|
|
if start == -1 or end == -1 or end <= start: |
|
|
raise HTTPException(status_code=502, detail="LLM output missing JSON object") |
|
|
|
|
|
snippet = text[start : end + 1] |
|
|
try: |
|
|
parsed = json.loads(snippet) |
|
|
except json.JSONDecodeError as e: |
|
|
raise HTTPException(status_code=502, detail=f"LLM JSON parse failed: {e}") |
|
|
|
|
|
return parsed |
|
|
|
|
|
|
|
|
def build_alpha_decision( |
|
|
req: AlphaRequest, |
|
|
spread_bps: float, |
|
|
kpi_features: Dict[str, float], |
|
|
raw_model_out: Dict[str, Any], |
|
|
) -> AlphaDecision: |
|
|
action = str(raw_model_out.get("action", "")).lower().strip() |
|
|
if action not in ("long", "short", "flat"): |
|
|
action = "flat" |
|
|
|
|
|
confidence = float(raw_model_out.get("confidence", 0.0)) |
|
|
size_factor = float(raw_model_out.get("size_factor", 0.0)) |
|
|
comment = str(raw_model_out.get("comment", "")).strip()[:240] |
|
|
|
|
|
if confidence < 0.0: |
|
|
confidence = 0.0 |
|
|
if confidence > 1.0: |
|
|
confidence = 1.0 |
|
|
if size_factor < 0.0: |
|
|
size_factor = 0.0 |
|
|
if size_factor > 1.0: |
|
|
size_factor = 1.0 |
|
|
|
|
|
if kpi_features.get("max_drawdown_pct", 0.0) < -30.0: |
|
|
action = "flat" |
|
|
confidence = min(confidence, 0.3) |
|
|
size_factor = 0.0 |
|
|
comment = (comment + " [forced_flat_due_to_drawdown]").strip() |
|
|
|
|
|
if kpi_features.get("trade_count", 0.0) < 10: |
|
|
confidence = min(confidence, 0.4) |
|
|
|
|
|
if spread_bps > 10.0 and size_factor > 0.3: |
|
|
size_factor = 0.3 |
|
|
|
|
|
return AlphaDecision( |
|
|
action=action, |
|
|
confidence=confidence, |
|
|
size_factor=size_factor, |
|
|
spread_bps=spread_bps, |
|
|
kpis=kpi_features, |
|
|
comment=comment, |
|
|
raw_model_output=raw_model_out if LLM_ENDPOINT else None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
def home() -> str: |
|
|
mode = "DRY-RUN" if DRY_RUN else "LIVE" |
|
|
return f""" |
|
|
<html> |
|
|
<body> |
|
|
<h2>gate4-alpha-api ({mode})</h2> |
|
|
<p>Endpoints:</p> |
|
|
<ul> |
|
|
<li>GET /balance</li> |
|
|
<li>GET /performance</li> |
|
|
<li>GET /kpis</li> |
|
|
<li>POST /log_trade</li> |
|
|
<li>POST /alpha/entry</li> |
|
|
<li>GET /openapi.yaml</li> |
|
|
<li>GET /docs</li> |
|
|
</ul> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
|
|
|
|
|
|
@app.get("/openapi.yaml") |
|
|
def get_openapi(): |
|
|
if not os.path.exists("openapi.yaml"): |
|
|
raise HTTPException(status_code=404, detail="openapi.yaml not found") |
|
|
return FileResponse("openapi.yaml", media_type="text/yaml") |
|
|
|
|
|
|
|
|
@app.get("/balance") |
|
|
def get_balance(): |
|
|
total = get_futures_account_total_balance() |
|
|
snap = append_balance_snapshot(total) |
|
|
return { |
|
|
"timestamp": snap.timestamp, |
|
|
"balance": round(snap.balance, 6), |
|
|
"dry_run": DRY_RUN, |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/performance") |
|
|
def get_performance(): |
|
|
trades = load_trades() |
|
|
balances = load_balances() |
|
|
if not trades and not balances: |
|
|
return {"summary": "No trades or balances logged yet.", "dry_run": DRY_RUN} |
|
|
|
|
|
kpis = compute_kpis(trades, balances) |
|
|
summary = ( |
|
|
f"Total PnL: {kpis.total_pnl:.2f}, " |
|
|
f"Realized: {kpis.realized_pnl:.2f}, " |
|
|
f"Trades: {kpis.trade_count}, " |
|
|
f"Win rate: {kpis.win_rate:.2%}, " |
|
|
f"Max DD: {kpis.max_drawdown_pct:.2f}%" |
|
|
) |
|
|
|
|
|
tail = [] |
|
|
for t in trades[-5:]: |
|
|
tail.append( |
|
|
{ |
|
|
"ts": t.timestamp, |
|
|
"action": t.action, |
|
|
"contract": t.contract, |
|
|
"pnl_realized": t.pnl_realized, |
|
|
"pnl_estimate": t.pnl_estimate, |
|
|
"reason": t.reason, |
|
|
} |
|
|
) |
|
|
|
|
|
return { |
|
|
"summary": summary, |
|
|
"last_trades": tail, |
|
|
"kpis": kpis.model_dump(), |
|
|
"dry_run": DRY_RUN, |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/kpis", response_model=KPIResponse) |
|
|
def get_kpis(): |
|
|
trades = load_trades() |
|
|
balances = load_balances() |
|
|
return compute_kpis(trades, balances) |
|
|
|
|
|
|
|
|
@app.post("/log_trade", response_model=TradeLog) |
|
|
async def log_trade(request: Request): |
|
|
payload = await request.json() |
|
|
try: |
|
|
trade = TradeLog(**payload) |
|
|
except ValidationError as e: |
|
|
raise HTTPException(status_code=422, detail=e.errors()) |
|
|
append_trade(trade) |
|
|
return trade |
|
|
|
|
|
|
|
|
@app.post("/alpha/entry", response_model=AlphaDecision) |
|
|
async def alpha_entry(req: AlphaRequest): |
|
|
trades = load_trades() |
|
|
balances = load_balances() |
|
|
base_kpis = compute_kpis(trades, balances) |
|
|
base_features = kpis_to_feature_dict(base_kpis) |
|
|
|
|
|
if req.kpis_override: |
|
|
features = {**base_features, **req.kpis_override} |
|
|
else: |
|
|
features = base_features |
|
|
|
|
|
spread_bps = get_contract_spread_bps(req.contract) |
|
|
prompt = _build_alpha_prompt(req, spread_bps, features) |
|
|
raw_model_out = call_llm_for_alpha(prompt) |
|
|
decision = build_alpha_decision(req, spread_bps, features, raw_model_out) |
|
|
return decision |