import os import time import json import hmac import hashlib import threading from typing import Any, Dict, List, Optional, Literal import requests from fastapi import FastAPI, HTTPException, Request from fastapi.responses import HTMLResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, ValidationError # ----------------------------------------------------------------------------- # Config # ----------------------------------------------------------------------------- GATE_API_KEY = os.getenv("GATE_API_KEY") GATE_API_SECRET = os.getenv("GATE_API_SECRET") GATE_API_BASE = os.getenv("GATE_API_BASE", "https://api.gate.io/api/v4") LOG_FILE = os.getenv("TRADE_LOG_FILE", "trading_log.jsonl") BAL_FILE = os.getenv("BALANCE_SNAP_FILE", "balance_snapshots.jsonl") LLM_ENDPOINT = os.getenv("LLM_ENDPOINT") LLM_API_KEY = os.getenv("LLM_API_KEY") # DRY_RUN = true when explicitly set OR when exchange keys are missing DRY_RUN = os.getenv("DRY_RUN", "0") == "1" or not (GATE_API_KEY and GATE_API_SECRET) app = FastAPI(title="gate4-alpha-api", version="0.3.0", docs_url="/docs", redoc_url="/redoc") app.add_middleware( CORSMiddleware, allow_origins=["*"], # tighten for prod allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) _log_lock = threading.Lock() _bal_lock = threading.Lock() # ----------------------------------------------------------------------------- # Models # ----------------------------------------------------------------------------- class TradeLog(BaseModel): timestamp: int = Field(default_factory=lambda: int(time.time())) action: Literal["long", "short", "flat", "close"] contract: str size: float entry_price: float exit_price: Optional[float] = None pnl_realized: float = 0.0 pnl_estimate: float = 0.0 reason: Optional[str] = None meta: Dict[str, Any] = Field(default_factory=dict) class BalanceSnapshot(BaseModel): timestamp: int balance: float class KPIResponse(BaseModel): total_pnl: float realized_pnl: float trade_count: int win_rate: float max_drawdown_pct: float avg_pnl_per_trade: float equity_curve: List[BalanceSnapshot] class AlphaRequest(BaseModel): contract: str context: Optional[str] = None kpis_override: Optional[Dict[str, float]] = None class AlphaDecision(BaseModel): action: Literal["long", "short", "flat"] confidence: float = Field(ge=0.0, le=1.0) size_factor: float = Field(ge=0.0, le=1.0) spread_bps: float kpis: Dict[str, float] comment: str raw_model_output: Optional[Any] = None # ----------------------------------------------------------------------------- # File-backed state # ----------------------------------------------------------------------------- def _safe_read_lines(path: str) -> List[str]: if not os.path.exists(path): return [] with open(path, "r") as f: return [line for line in f if line.strip()] def load_trades() -> List[TradeLog]: lines = _safe_read_lines(LOG_FILE) out: List[TradeLog] = [] for line in lines: try: raw = json.loads(line) out.append(TradeLog(**raw)) except (json.JSONDecodeError, ValidationError): continue return out def load_balances() -> List[BalanceSnapshot]: lines = _safe_read_lines(BAL_FILE) out: List[BalanceSnapshot] = [] for line in lines: try: raw = json.loads(line) out.append(BalanceSnapshot(**raw)) except (json.JSONDecodeError, ValidationError): continue return out def append_trade(trade: TradeLog) -> None: with _log_lock, open(LOG_FILE, "a") as f: f.write(trade.model_json() + "\n") def append_balance_snapshot(balance: float) -> BalanceSnapshot: snap = BalanceSnapshot(timestamp=int(time.time()), balance=balance) with _bal_lock, open(BAL_FILE, "a") as f: f.write(snap.model_json() + "\n") return snap # ----------------------------------------------------------------------------- # Gate.io helpers (dry-run aware) # ----------------------------------------------------------------------------- def sign_request(method: str, path: str, query_string: str, body: str, timestamp: str) -> str: message = f"{method}\n{path}\n{query_string}\n{body}\n{timestamp}" return hmac.new( GATE_API_SECRET.encode(), message.encode(), hashlib.sha512, ).hexdigest() def gate_private_get(path: str, query: str = "") -> Any: if DRY_RUN: raise HTTPException(status_code=503, detail="Exchange private API disabled in dry-run mode") method = "GET" timestamp = str(int(time.time())) body = "" sign = sign_request(method, path, query, body, timestamp) headers = { "KEY": GATE_API_KEY, "Timestamp": timestamp, "SIGN": sign, } url = f"{GATE_API_BASE}{path}" if query: url = f"{url}?{query}" try: res = requests.get(url, headers=headers, timeout=10) res.raise_for_status() except requests.RequestException as e: raise HTTPException(status_code=502, detail=f"Gate.io request failed: {e}") return res.json() def gate_public_get(path: str, query: str = "") -> Any: if DRY_RUN: raise HTTPException(status_code=503, detail="Exchange public API disabled in dry-run mode") url = f"{GATE_API_BASE}{path}" if query: url = f"{url}?{query}" try: res = requests.get(url, timeout=10) res.raise_for_status() except requests.RequestException as e: raise HTTPException(status_code=502, detail=f"Gate.io public request failed: {e}") return res.json() def get_futures_account_total_balance() -> float: if DRY_RUN: # In dry-run: use last balance if exists, else deterministic constant balances = load_balances() if balances: return balances[-1].balance return float(os.getenv("DRY_RUN_BALANCE", "10000.0")) path = "/futures/usdt/accounts" accounts = gate_private_get(path) total = 0.0 for acc in accounts: try: total += float(acc.get("available", 0.0)) except (TypeError, ValueError): continue return total def get_contract_spread_bps(contract: str) -> float: if DRY_RUN: # Deterministic spread for offline mode; override via env if needed return float(os.getenv("DRY_RUN_SPREAD_BPS", "5.0")) path = "/futures/usdt/tickers" query = f"contract={contract}" tickers = gate_public_get(path, query=query) if not tickers: raise HTTPException(status_code=404, detail=f"No ticker data for {contract}") t = tickers[0] try: bid = float(t.get("bid", 0.0)) ask = float(t.get("ask", 0.0)) except (TypeError, ValueError): raise HTTPException(status_code=502, detail="Malformed ticker from Gate.io") if bid <= 0 or ask <= 0 or ask <= bid: return 0.0 mid = 0.5 * (bid + ask) spread_bps = (ask - bid) / mid * 1e4 return spread_bps # ----------------------------------------------------------------------------- # KPI logic # ----------------------------------------------------------------------------- def compute_kpis(trades: List[TradeLog], balances: List[BalanceSnapshot]) -> KPIResponse: realized_pnls = [t.pnl_realized for t in trades] est_pnls = [t.pnl_estimate for t in trades] realized_pnl = float(sum(realized_pnls)) total_pnl = float(realized_pnl + sum(est_pnls)) trade_count = len(trades) wins = sum(1 for t in trades if t.pnl_realized > 0) win_rate = float(wins / trade_count) if trade_count > 0 else 0.0 avg_pnl_per_trade = float(realized_pnl / trade_count) if trade_count > 0 else 0.0 equity_curve = balances max_drawdown_pct = 0.0 if equity_curve: peak = equity_curve[0].balance for point in equity_curve: if point.balance > peak: peak = point.balance if peak > 0: dd = (point.balance - peak) / peak * 100.0 if dd < max_drawdown_pct: max_drawdown_pct = dd return KPIResponse( total_pnl=total_pnl, realized_pnl=realized_pnl, trade_count=trade_count, win_rate=win_rate, max_drawdown_pct=max_drawdown_pct, avg_pnl_per_trade=avg_pnl_per_trade, equity_curve=equity_curve, ) def kpis_to_feature_dict(kpis: KPIResponse) -> Dict[str, float]: return { "total_pnl": kpis.total_pnl, "realized_pnl": kpis.realized_pnl, "trade_count": float(kpis.trade_count), "win_rate": kpis.win_rate, "max_drawdown_pct": kpis.max_drawdown_pct, "avg_pnl_per_trade": kpis.avg_pnl_per_trade, } # ----------------------------------------------------------------------------- # LLM integration (dry-run aware) # ----------------------------------------------------------------------------- def _build_alpha_prompt(req: AlphaRequest, spread_bps: float, kpis: Dict[str, float]) -> str: kpi_json = json.dumps(kpis, sort_keys=True) ctx = req.context or "" return ( "You are a deterministic trading policy engine.\n" "Given KPIs and spread metrics, choose ONE action: long, short, or flat.\n" "Respond ONLY with a compact JSON object:\n" '{ "action": "...", "confidence": 0.0-1.0, "size_factor": 0.0-1.0, "comment": "..." }\n\n' f"Contract: {req.contract}\n" f"Spread_bps: {spread_bps:.4f}\n" f"KPIs: {kpi_json}\n" f"Context: {ctx}\n" "Constraints:\n" "- If max_drawdown_pct < -25 or win_rate < 0.4 ⇒ prefer flat.\n" "- If trade_count < 10 ⇒ confidence <= 0.4.\n" "- size_factor must be <= 0.3 if spread_bps > 10.\n" ) def call_llm_for_alpha(prompt: str) -> Dict[str, Any]: if not LLM_ENDPOINT: # Dry-run LLM: force flat, no external call return { "action": "flat", "confidence": 0.0, "size_factor": 0.0, "comment": "LLM endpoint not configured; dry-run flat policy.", } payload = { "inputs": prompt, "parameters": { "max_new_tokens": 256, "temperature": 0.1, "top_p": 0.9, "return_full_text": False, }, } headers = {"Content-Type": "application/json"} if LLM_API_KEY: headers["Authorization"] = f"Bearer {LLM_API_KEY}" try: res = requests.post(LLM_ENDPOINT, headers=headers, json=payload, timeout=20) res.raise_for_status() except requests.RequestException as e: raise HTTPException(status_code=502, detail=f"LLM request failed: {e}") try: data = res.json() except json.JSONDecodeError: raise HTTPException(status_code=502, detail="LLM returned non-JSON payload") if isinstance(data, list) and data and isinstance(data[0], dict) and "generated_text" in data[0]: text = data[0]["generated_text"] elif isinstance(data, dict) and "generated_text" in data: text = data["generated_text"] else: text = str(data) text = text.strip() start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end <= start: raise HTTPException(status_code=502, detail="LLM output missing JSON object") snippet = text[start : end + 1] try: parsed = json.loads(snippet) except json.JSONDecodeError as e: raise HTTPException(status_code=502, detail=f"LLM JSON parse failed: {e}") return parsed def build_alpha_decision( req: AlphaRequest, spread_bps: float, kpi_features: Dict[str, float], raw_model_out: Dict[str, Any], ) -> AlphaDecision: action = str(raw_model_out.get("action", "")).lower().strip() if action not in ("long", "short", "flat"): action = "flat" confidence = float(raw_model_out.get("confidence", 0.0)) size_factor = float(raw_model_out.get("size_factor", 0.0)) comment = str(raw_model_out.get("comment", "")).strip()[:240] if confidence < 0.0: confidence = 0.0 if confidence > 1.0: confidence = 1.0 if size_factor < 0.0: size_factor = 0.0 if size_factor > 1.0: size_factor = 1.0 if kpi_features.get("max_drawdown_pct", 0.0) < -30.0: action = "flat" confidence = min(confidence, 0.3) size_factor = 0.0 comment = (comment + " [forced_flat_due_to_drawdown]").strip() if kpi_features.get("trade_count", 0.0) < 10: confidence = min(confidence, 0.4) if spread_bps > 10.0 and size_factor > 0.3: size_factor = 0.3 return AlphaDecision( action=action, confidence=confidence, size_factor=size_factor, spread_bps=spread_bps, kpis=kpi_features, comment=comment, raw_model_output=raw_model_out if LLM_ENDPOINT else None, ) # ----------------------------------------------------------------------------- # Routes # ----------------------------------------------------------------------------- @app.get("/", response_class=HTMLResponse) def home() -> str: mode = "DRY-RUN" if DRY_RUN else "LIVE" return f"""

gate4-alpha-api ({mode})

Endpoints:

""" @app.get("/openapi.yaml") def get_openapi(): if not os.path.exists("openapi.yaml"): raise HTTPException(status_code=404, detail="openapi.yaml not found") return FileResponse("openapi.yaml", media_type="text/yaml") @app.get("/balance") def get_balance(): total = get_futures_account_total_balance() snap = append_balance_snapshot(total) return { "timestamp": snap.timestamp, "balance": round(snap.balance, 6), "dry_run": DRY_RUN, } @app.get("/performance") def get_performance(): trades = load_trades() balances = load_balances() if not trades and not balances: return {"summary": "No trades or balances logged yet.", "dry_run": DRY_RUN} kpis = compute_kpis(trades, balances) summary = ( f"Total PnL: {kpis.total_pnl:.2f}, " f"Realized: {kpis.realized_pnl:.2f}, " f"Trades: {kpis.trade_count}, " f"Win rate: {kpis.win_rate:.2%}, " f"Max DD: {kpis.max_drawdown_pct:.2f}%" ) tail = [] for t in trades[-5:]: tail.append( { "ts": t.timestamp, "action": t.action, "contract": t.contract, "pnl_realized": t.pnl_realized, "pnl_estimate": t.pnl_estimate, "reason": t.reason, } ) return { "summary": summary, "last_trades": tail, "kpis": kpis.model_dump(), "dry_run": DRY_RUN, } @app.get("/kpis", response_model=KPIResponse) def get_kpis(): trades = load_trades() balances = load_balances() return compute_kpis(trades, balances) @app.post("/log_trade", response_model=TradeLog) async def log_trade(request: Request): payload = await request.json() try: trade = TradeLog(**payload) except ValidationError as e: raise HTTPException(status_code=422, detail=e.errors()) append_trade(trade) return trade @app.post("/alpha/entry", response_model=AlphaDecision) async def alpha_entry(req: AlphaRequest): trades = load_trades() balances = load_balances() base_kpis = compute_kpis(trades, balances) base_features = kpis_to_feature_dict(base_kpis) if req.kpis_override: features = {**base_features, **req.kpis_override} else: features = base_features spread_bps = get_contract_spread_bps(req.contract) prompt = _build_alpha_prompt(req, spread_bps, features) raw_model_out = call_llm_for_alpha(prompt) decision = build_alpha_decision(req, spread_bps, features, raw_model_out) return decision